diff options
Diffstat (limited to 'net/l2tp/l2tp_ppp.c')
-rw-r--r-- | net/l2tp/l2tp_ppp.c | 154 |
1 files changed, 70 insertions, 84 deletions
diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index 3596290047b2..53baf2dd5d5d 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -119,7 +119,6 @@ struct pppol2tp_session { struct mutex sk_lock; /* Protects .sk */ struct sock __rcu *sk; /* Pointer to the session PPPoX socket */ struct sock *__sk; /* Copy of .sk, for cleanup */ - struct rcu_head rcu; /* For asynchronous release */ }; static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb); @@ -150,27 +149,23 @@ static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session) /* Helpers to obtain tunnel/session contexts from sockets. */ -static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) +static struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) { struct l2tp_session *session; if (!sk) return NULL; - sock_hold(sk); - session = (struct l2tp_session *)(sk->sk_user_data); - if (!session) { - sock_put(sk); - goto out; - } - if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) { - session = NULL; - sock_put(sk); - goto out; + rcu_read_lock(); + session = rcu_dereference_sk_user_data(sk); + if (session && refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock(); + WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC); + return session; } + rcu_read_unlock(); -out: - return session; + return NULL; } /***************************************************************************** @@ -318,12 +313,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m, l2tp_xmit_skb(session, skb); local_bh_enable(); - sock_put(sk); + l2tp_session_put(session); return total_len; error_put_sess: - sock_put(sk); + l2tp_session_put(session); error: return error; } @@ -377,12 +372,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) l2tp_xmit_skb(session, skb); local_bh_enable(); - sock_put(sk); + l2tp_session_put(session); return 1; abort_put_sess: - sock_put(sk); + l2tp_session_put(session); abort: /* Free the original skb */ kfree_skb(skb); @@ -393,29 +388,32 @@ abort: * Session (and tunnel control) socket create/destroy. *****************************************************************************/ -static void pppol2tp_put_sk(struct rcu_head *head) -{ - struct pppol2tp_session *ps; - - ps = container_of(head, typeof(*ps), rcu); - sock_put(ps->__sk); -} - /* Really kill the session socket. (Called from sock_put() if * refcnt == 0.) */ static void pppol2tp_session_destruct(struct sock *sk) { - struct l2tp_session *session = sk->sk_user_data; - skb_queue_purge(&sk->sk_receive_queue); skb_queue_purge(&sk->sk_write_queue); +} - if (session) { - sk->sk_user_data = NULL; - if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) - return; - l2tp_session_dec_refcount(session); +static void pppol2tp_session_close(struct l2tp_session *session) +{ + struct pppol2tp_session *ps; + + ps = l2tp_session_priv(session); + mutex_lock(&ps->sk_lock); + ps->__sk = rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)); + RCU_INIT_POINTER(ps->sk, NULL); + mutex_unlock(&ps->sk_lock); + if (ps->__sk) { + /* detach socket */ + rcu_assign_sk_user_data(ps->__sk, NULL); + sock_put(ps->__sk); + + /* drop ref taken when we referenced socket via sk_user_data */ + l2tp_session_put(session); } } @@ -444,30 +442,13 @@ static int pppol2tp_release(struct socket *sock) session = pppol2tp_sock_to_session(sk); if (session) { - struct pppol2tp_session *ps; - l2tp_session_delete(session); - - ps = l2tp_session_priv(session); - mutex_lock(&ps->sk_lock); - ps->__sk = rcu_dereference_protected(ps->sk, - lockdep_is_held(&ps->sk_lock)); - RCU_INIT_POINTER(ps->sk, NULL); - mutex_unlock(&ps->sk_lock); - call_rcu(&ps->rcu, pppol2tp_put_sk); - - /* Rely on the sock_put() call at the end of the function for - * dropping the reference held by pppol2tp_sock_to_session(). - * The last reference will be dropped by pppol2tp_put_sk(). - */ + /* drop ref taken by pppol2tp_sock_to_session */ + l2tp_session_put(session); } release_sock(sk); - /* This will delete the session context via - * pppol2tp_session_destruct() if the socket's refcnt drops to - * zero. - */ sock_put(sk); return 0; @@ -506,6 +487,7 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern) goto out; sock_init_data(sock, sk); + sock_set_flag(sk, SOCK_RCU_FREE); sock->state = SS_UNCONNECTED; sock->ops = &pppol2tp_ops; @@ -542,6 +524,7 @@ static void pppol2tp_session_init(struct l2tp_session *session) struct pppol2tp_session *ps; session->recv_skb = pppol2tp_recv; + session->session_close = pppol2tp_session_close; if (IS_ENABLED(CONFIG_L2TP_DEBUGFS)) session->show = pppol2tp_show; @@ -685,7 +668,7 @@ static struct l2tp_tunnel *pppol2tp_tunnel_get(struct net *net, if (error < 0) return ERR_PTR(error); - l2tp_tunnel_inc_refcount(tunnel); + refcount_inc(&tunnel->ref_count); error = l2tp_tunnel_register(tunnel, net, &tcfg); if (error < 0) { kfree(tunnel); @@ -701,7 +684,7 @@ static struct l2tp_tunnel *pppol2tp_tunnel_get(struct net *net, /* Error if socket is not prepped */ if (!tunnel->sock) { - l2tp_tunnel_dec_refcount(tunnel); + l2tp_tunnel_put(tunnel); return ERR_PTR(-ENOENT); } } @@ -787,18 +770,20 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, goto end; } + drop_refcnt = true; + pppol2tp_session_init(session); ps = l2tp_session_priv(session); - l2tp_session_inc_refcount(session); + refcount_inc(&session->ref_count); mutex_lock(&ps->sk_lock); error = l2tp_session_register(session, tunnel); if (error < 0) { mutex_unlock(&ps->sk_lock); - kfree(session); + l2tp_session_put(session); goto end; } - drop_refcnt = true; + new_session = true; } @@ -830,12 +815,13 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, out_no_ppp: /* This is how we get the session context from the socket. */ - sk->sk_user_data = session; + sock_hold(sk); + rcu_assign_sk_user_data(sk, session); rcu_assign_pointer(ps->sk, sk); mutex_unlock(&ps->sk_lock); /* Keep the reference we've grabbed on the session: sk doesn't expect - * the session to disappear. pppol2tp_session_destruct() is responsible + * the session to disappear. pppol2tp_session_close() is responsible * for dropping it. */ drop_refcnt = false; @@ -850,8 +836,8 @@ end: l2tp_tunnel_delete(tunnel); } if (drop_refcnt) - l2tp_session_dec_refcount(session); - l2tp_tunnel_dec_refcount(tunnel); + l2tp_session_put(session); + l2tp_tunnel_put(tunnel); release_sock(sk); return error; @@ -891,7 +877,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel, return 0; err_sess: - kfree(session); + l2tp_session_put(session); err: return error; } @@ -1002,7 +988,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr, error = len; - sock_put(sk); + l2tp_session_put(session); end: return error; } @@ -1052,12 +1038,12 @@ static int pppol2tp_tunnel_copy_stats(struct pppol2tp_ioc_stats *stats, return -EBADR; if (session->pwtype != L2TP_PWTYPE_PPP) { - l2tp_session_dec_refcount(session); + l2tp_session_put(session); return -EBADR; } pppol2tp_copy_stats(stats, &session->stats); - l2tp_session_dec_refcount(session); + l2tp_session_put(session); return 0; } @@ -1205,7 +1191,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk, po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ : PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; } - l2tp_session_set_header_len(session, session->tunnel->version); + l2tp_session_set_header_len(session, session->tunnel->version, + session->tunnel->encap); break; case PPPOL2TP_SO_LNSMODE: @@ -1274,7 +1261,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, err = pppol2tp_session_setsockopt(sk, session, optname, val); } - sock_put(sk); + l2tp_session_put(session); end: return err; } @@ -1395,7 +1382,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname, err = 0; end_put_sess: - sock_put(sk); + l2tp_session_put(session); end: return err; } @@ -1406,14 +1393,12 @@ end: * L2TPv2, we dump only L2TPv2 tunnels and sessions here. *****************************************************************************/ -static unsigned int pppol2tp_net_id; - #ifdef CONFIG_PROC_FS struct pppol2tp_seq_data { struct seq_net_private p; - int tunnel_idx; /* current tunnel */ - int session_idx; /* index of session within current tunnel */ + unsigned long tkey; /* lookup key of current tunnel */ + unsigned long skey; /* lookup key of current session */ struct l2tp_tunnel *tunnel; struct l2tp_session *session; /* NULL means get next tunnel */ }; @@ -1422,17 +1407,17 @@ static void pppol2tp_next_tunnel(struct net *net, struct pppol2tp_seq_data *pd) { /* Drop reference taken during previous invocation */ if (pd->tunnel) - l2tp_tunnel_dec_refcount(pd->tunnel); + l2tp_tunnel_put(pd->tunnel); for (;;) { - pd->tunnel = l2tp_tunnel_get_nth(net, pd->tunnel_idx); - pd->tunnel_idx++; + pd->tunnel = l2tp_tunnel_get_next(net, &pd->tkey); + pd->tkey++; /* Only accept L2TPv2 tunnels */ if (!pd->tunnel || pd->tunnel->version == 2) return; - l2tp_tunnel_dec_refcount(pd->tunnel); + l2tp_tunnel_put(pd->tunnel); } } @@ -1440,13 +1425,15 @@ static void pppol2tp_next_session(struct net *net, struct pppol2tp_seq_data *pd) { /* Drop reference taken during previous invocation */ if (pd->session) - l2tp_session_dec_refcount(pd->session); + l2tp_session_put(pd->session); - pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx); - pd->session_idx++; + pd->session = l2tp_session_get_next(net, pd->tunnel->sock, + pd->tunnel->version, + pd->tunnel->tunnel_id, &pd->skey); + pd->skey++; if (!pd->session) { - pd->session_idx = 0; + pd->skey = 0; pppol2tp_next_tunnel(net, pd); } } @@ -1498,11 +1485,11 @@ static void pppol2tp_seq_stop(struct seq_file *p, void *v) * or pppol2tp_next_tunnel(). */ if (pd->session) { - l2tp_session_dec_refcount(pd->session); + l2tp_session_put(pd->session); pd->session = NULL; } if (pd->tunnel) { - l2tp_tunnel_dec_refcount(pd->tunnel); + l2tp_tunnel_put(pd->tunnel); pd->tunnel = NULL; } } @@ -1513,7 +1500,7 @@ static void pppol2tp_seq_tunnel_show(struct seq_file *m, void *v) seq_printf(m, "\nTUNNEL '%s', %c %d\n", tunnel->name, - (tunnel == tunnel->sock->sk_user_data) ? 'Y' : 'N', + tunnel->sock ? 'Y' : 'N', refcount_read(&tunnel->ref_count) - 1); seq_printf(m, " %08x %ld/%ld/%ld %ld/%ld/%ld\n", 0, @@ -1641,7 +1628,6 @@ static __net_exit void pppol2tp_exit_net(struct net *net) static struct pernet_operations pppol2tp_net_ops = { .init = pppol2tp_init_net, .exit = pppol2tp_exit_net, - .id = &pppol2tp_net_id, }; /***************************************************************************** |