diff options
-rw-r--r-- | net/ipv4/ip_output.c | 1 | ||||
-rw-r--r-- | net/l2tp/l2tp_core.c | 199 | ||||
-rw-r--r-- | net/l2tp/l2tp_core.h | 14 | ||||
-rw-r--r-- | net/l2tp/l2tp_eth.c | 2 | ||||
-rw-r--r-- | net/l2tp/l2tp_ip.c | 13 | ||||
-rw-r--r-- | net/l2tp/l2tp_ip6.c | 7 | ||||
-rw-r--r-- | net/l2tp/l2tp_netlink.c | 4 | ||||
-rw-r--r-- | net/l2tp/l2tp_ppp.c | 107 |
8 files changed, 179 insertions, 168 deletions
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c index b90d0f78ac80..8a10a7c67834 100644 --- a/net/ipv4/ip_output.c +++ b/net/ipv4/ip_output.c @@ -1534,6 +1534,7 @@ void ip_flush_pending_frames(struct sock *sk) { __ip_flush_pending_frames(sk, &sk->sk_write_queue, &inet_sk(sk)->cork.base); } +EXPORT_SYMBOL_GPL(ip_flush_pending_frames); struct sk_buff *ip_make_skb(struct sock *sk, struct flowi4 *fl4, diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index c80ab3f26084..5d2068b6c778 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -137,9 +137,28 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel) { + struct sock *sk = tunnel->sock; + trace_free_tunnel(tunnel); - sock_put(tunnel->sock); - /* the tunnel is freed in the socket destructor */ + + if (sk) { + /* Disable udp encapsulation */ + switch (tunnel->encap) { + case L2TP_ENCAPTYPE_UDP: + /* No longer an encapsulation socket. See net/ipv4/udp.c */ + WRITE_ONCE(udp_sk(sk)->encap_type, 0); + udp_sk(sk)->encap_rcv = NULL; + udp_sk(sk)->encap_destroy = NULL; + break; + case L2TP_ENCAPTYPE_IP: + break; + } + + tunnel->sock = NULL; + sock_put(sk); + } + + kfree_rcu(tunnel, rcu); } static void l2tp_session_free(struct l2tp_session *session) @@ -147,18 +166,29 @@ static void l2tp_session_free(struct l2tp_session *session) trace_free_session(session); if (session->tunnel) l2tp_tunnel_dec_refcount(session->tunnel); - kfree(session); + kfree_rcu(session, rcu); } -struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk) +struct l2tp_tunnel *l2tp_sk_to_tunnel(const struct sock *sk) { - struct l2tp_tunnel *tunnel = sk->sk_user_data; + const struct net *net = sock_net(sk); + unsigned long tunnel_id, tmp; + struct l2tp_tunnel *tunnel; + struct l2tp_net *pn; - if (tunnel) - if (WARN_ON(tunnel->magic != L2TP_TUNNEL_MAGIC)) - return NULL; + rcu_read_lock_bh(); + pn = l2tp_pernet(net); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && + tunnel->sock == sk && + refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; + } + } + rcu_read_unlock_bh(); - return tunnel; + return NULL; } EXPORT_SYMBOL_GPL(l2tp_sk_to_tunnel); @@ -249,7 +279,14 @@ struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session, hlist, key) { - if (session->tunnel->sock == sk && + /* session->tunnel may be NULL if another thread is in + * l2tp_session_register and has added an item to + * l2tp_v3_session_htable but hasn't yet added the + * session to its tunnel's session_list. + */ + struct l2tp_tunnel *tunnel = READ_ONCE(session->tunnel); + + if (tunnel && tunnel->sock == sk && refcount_inc_not_zero(&session->ref_count)) { rcu_read_unlock_bh(); return session; @@ -382,12 +419,12 @@ static int l2tp_session_collision_add(struct l2tp_net *pn, /* If existing session isn't already in the session hlist, add it. */ if (!hash_hashed(&session2->hlist)) - hash_add(pn->l2tp_v3_session_htable, &session2->hlist, - session2->hlist_key); + hash_add_rcu(pn->l2tp_v3_session_htable, &session2->hlist, + session2->hlist_key); /* Add new session to the hlist and collision list */ - hash_add(pn->l2tp_v3_session_htable, &session1->hlist, - session1->hlist_key); + hash_add_rcu(pn->l2tp_v3_session_htable, &session1->hlist, + session1->hlist_key); refcount_inc(&clist->ref_count); l2tp_session_coll_list_add(clist, session1); @@ -403,7 +440,7 @@ static void l2tp_session_collision_del(struct l2tp_net *pn, lockdep_assert_held(&pn->l2tp_session_idr_lock); - hash_del(&session->hlist); + hash_del_rcu(&session->hlist); if (clist) { /* Remove session from its collision list. If there @@ -437,6 +474,7 @@ int l2tp_session_register(struct l2tp_session *session, { struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net); struct l2tp_session *other_session = NULL; + void *old = NULL; u32 session_key; int err; @@ -477,15 +515,22 @@ int l2tp_session_register(struct l2tp_session *session, } l2tp_tunnel_inc_refcount(tunnel); - list_add(&session->list, &tunnel->session_list); + WRITE_ONCE(session->tunnel, tunnel); + list_add_rcu(&session->list, &tunnel->session_list); + /* this makes session available to lockless getters */ if (tunnel->version == L2TP_HDR_VER_3) { if (!other_session) - idr_replace(&pn->l2tp_v3_session_idr, session, session_key); + old = idr_replace(&pn->l2tp_v3_session_idr, session, session_key); } else { - idr_replace(&pn->l2tp_v2_session_idr, session, session_key); + old = idr_replace(&pn->l2tp_v2_session_idr, session, session_key); } + /* old should be NULL, unless something removed or modified + * the IDR entry after our idr_alloc_32 above (which shouldn't + * happen). + */ + WARN_ON_ONCE(old); out: spin_unlock_bh(&pn->l2tp_session_idr_lock); spin_unlock_bh(&tunnel->list_lock); @@ -792,7 +837,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, if (!session->lns_mode && !session->send_seq) { trace_session_seqnum_lns_enable(session); session->send_seq = 1; - l2tp_session_set_header_len(session, tunnel->version); + l2tp_session_set_header_len(session, tunnel->version, + tunnel->encap); } } else { /* No sequence numbers. @@ -813,7 +859,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, if (!session->lns_mode && session->send_seq) { trace_session_seqnum_lns_disable(session); session->send_seq = 0; - l2tp_session_set_header_len(session, tunnel->version); + l2tp_session_set_header_len(session, tunnel->version, + tunnel->encap); } else if (session->send_seq) { pr_debug_ratelimited("%s: recv data has no seq numbers when required. Discarding.\n", session->name); @@ -1207,44 +1254,6 @@ EXPORT_SYMBOL_GPL(l2tp_xmit_skb); * Tinnel and session create/destroy. *****************************************************************************/ -/* Tunnel socket destruct hook. - * The tunnel context is deleted only when all session sockets have been - * closed. - */ -static void l2tp_tunnel_destruct(struct sock *sk) -{ - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); - - if (!tunnel) - goto end; - - /* Disable udp encapsulation */ - switch (tunnel->encap) { - case L2TP_ENCAPTYPE_UDP: - /* No longer an encapsulation socket. See net/ipv4/udp.c */ - WRITE_ONCE(udp_sk(sk)->encap_type, 0); - udp_sk(sk)->encap_rcv = NULL; - udp_sk(sk)->encap_destroy = NULL; - break; - case L2TP_ENCAPTYPE_IP: - break; - } - - /* Remove hooks into tunnel socket */ - write_lock_bh(&sk->sk_callback_lock); - sk->sk_destruct = tunnel->old_sk_destruct; - sk->sk_user_data = NULL; - write_unlock_bh(&sk->sk_callback_lock); - - /* Call the original destructor */ - if (sk->sk_destruct) - (*sk->sk_destruct)(sk); - - kfree_rcu(tunnel, rcu); -end: - return; -} - /* Remove an l2tp session from l2tp_core's lists. */ static void l2tp_session_unhash(struct l2tp_session *session) { @@ -1277,8 +1286,6 @@ static void l2tp_session_unhash(struct l2tp_session *session) spin_unlock_bh(&pn->l2tp_session_idr_lock); spin_unlock_bh(&tunnel->list_lock); - - synchronize_rcu(); } } @@ -1290,28 +1297,21 @@ static void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) spin_lock_bh(&tunnel->list_lock); tunnel->acpt_newsess = false; - for (;;) { - session = list_first_entry_or_null(&tunnel->session_list, - struct l2tp_session, list); - if (!session) - break; - l2tp_session_inc_refcount(session); - list_del_init(&session->list); - spin_unlock_bh(&tunnel->list_lock); + list_for_each_entry(session, &tunnel->session_list, list) l2tp_session_delete(session); - spin_lock_bh(&tunnel->list_lock); - l2tp_session_dec_refcount(session); - } spin_unlock_bh(&tunnel->list_lock); } /* Tunnel socket destroy hook for UDP encapsulation */ static void l2tp_udp_encap_destroy(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel) @@ -1494,7 +1494,6 @@ int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id, tunnel->tunnel_id = tunnel_id; tunnel->peer_tunnel_id = peer_tunnel_id; - tunnel->magic = L2TP_TUNNEL_MAGIC; sprintf(&tunnel->name[0], "tunl %u", tunnel_id); spin_lock_init(&tunnel->list_lock); tunnel->acpt_newsess = true; @@ -1520,6 +1519,8 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_create); static int l2tp_validate_socket(const struct sock *sk, const struct net *net, enum l2tp_encap_type encap) { + struct l2tp_tunnel *tunnel; + if (!net_eq(sock_net(sk), net)) return -EINVAL; @@ -1533,8 +1534,11 @@ static int l2tp_validate_socket(const struct sock *sk, const struct net *net, (encap == L2TP_ENCAPTYPE_IP && sk->sk_protocol != IPPROTO_L2TP)) return -EPROTONOSUPPORT; - if (sk->sk_user_data) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { + l2tp_tunnel_dec_refcount(tunnel); return -EBUSY; + } return 0; } @@ -1573,12 +1577,10 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, ret = l2tp_validate_socket(sk, net, tunnel->encap); if (ret < 0) goto err_inval_sock; - rcu_assign_sk_user_data(sk, tunnel); write_unlock_bh(&sk->sk_callback_lock); if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { struct udp_tunnel_sock_cfg udp_cfg = { - .sk_user_data = tunnel, .encap_type = UDP_ENCAP_L2TPINUDP, .encap_rcv = l2tp_udp_encap_recv, .encap_err_rcv = l2tp_udp_encap_err_recv, @@ -1588,8 +1590,6 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, setup_udp_tunnel_sock(net, sock, &udp_cfg); } - tunnel->old_sk_destruct = sk->sk_destruct; - sk->sk_destruct = &l2tp_tunnel_destruct; sk->sk_allocation = GFP_ATOMIC; release_sock(sk); @@ -1636,23 +1636,37 @@ EXPORT_SYMBOL_GPL(l2tp_tunnel_delete); void l2tp_session_delete(struct l2tp_session *session) { - if (test_and_set_bit(0, &session->dead)) - return; + if (!test_and_set_bit(0, &session->dead)) { + trace_delete_session(session); + l2tp_session_inc_refcount(session); + queue_work(l2tp_wq, &session->del_work); + } +} +EXPORT_SYMBOL_GPL(l2tp_session_delete); + +/* Workqueue session deletion function */ +static void l2tp_session_del_work(struct work_struct *work) +{ + struct l2tp_session *session = container_of(work, struct l2tp_session, + del_work); - trace_delete_session(session); l2tp_session_unhash(session); l2tp_session_queue_purge(session); if (session->session_close) (*session->session_close)(session); + /* drop initial ref */ + l2tp_session_dec_refcount(session); + + /* drop workqueue ref */ l2tp_session_dec_refcount(session); } -EXPORT_SYMBOL_GPL(l2tp_session_delete); /* We come here whenever a session's send_seq, cookie_len or * l2specific_type parameters are set. */ -void l2tp_session_set_header_len(struct l2tp_session *session, int version) +void l2tp_session_set_header_len(struct l2tp_session *session, int version, + enum l2tp_encap_type encap) { if (version == L2TP_HDR_VER_2) { session->hdr_len = 6; @@ -1661,7 +1675,7 @@ void l2tp_session_set_header_len(struct l2tp_session *session, int version) } else { session->hdr_len = 4 + session->cookie_len; session->hdr_len += l2tp_get_l2specific_len(session); - if (session->tunnel->encap == L2TP_ENCAPTYPE_UDP) + if (encap == L2TP_ENCAPTYPE_UDP) session->hdr_len += 4; } } @@ -1675,7 +1689,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn session = kzalloc(sizeof(*session) + priv_size, GFP_KERNEL); if (session) { session->magic = L2TP_SESSION_MAGIC; - session->tunnel = tunnel; session->session_id = session_id; session->peer_session_id = peer_session_id; @@ -1699,6 +1712,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn INIT_HLIST_NODE(&session->hlist); INIT_LIST_HEAD(&session->clist); INIT_LIST_HEAD(&session->list); + INIT_WORK(&session->del_work, l2tp_session_del_work); if (cfg) { session->pwtype = cfg->pw_type; @@ -1713,7 +1727,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn memcpy(&session->peer_cookie[0], &cfg->peer_cookie[0], cfg->peer_cookie_len); } - l2tp_session_set_header_len(session, tunnel->version); + l2tp_session_set_header_len(session, tunnel->version, tunnel->encap); refcount_set(&session->ref_count, 1); @@ -1742,7 +1756,7 @@ static __net_init int l2tp_init_net(struct net *net) return 0; } -static __net_exit void l2tp_exit_net(struct net *net) +static __net_exit void l2tp_pre_exit_net(struct net *net) { struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_tunnel *tunnel = NULL; @@ -1756,8 +1770,12 @@ static __net_exit void l2tp_exit_net(struct net *net) rcu_read_unlock_bh(); if (l2tp_wq) - flush_workqueue(l2tp_wq); - rcu_barrier(); + drain_workqueue(l2tp_wq); +} + +static __net_exit void l2tp_exit_net(struct net *net) +{ + struct l2tp_net *pn = l2tp_pernet(net); idr_destroy(&pn->l2tp_v2_session_idr); idr_destroy(&pn->l2tp_v3_session_idr); @@ -1767,6 +1785,7 @@ static __net_exit void l2tp_exit_net(struct net *net) static struct pernet_operations l2tp_net_ops = { .init = l2tp_init_net, .exit = l2tp_exit_net, + .pre_exit = l2tp_pre_exit_net, .id = &l2tp_net_id, .size = sizeof(struct l2tp_net), }; diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index 8ac81bc1bc6f..c907687705b9 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -16,7 +16,6 @@ #endif /* Random numbers used for internal consistency checks of tunnel and session structures */ -#define L2TP_TUNNEL_MAGIC 0x42114DDA #define L2TP_SESSION_MAGIC 0x0C04EB7D struct sk_buff; @@ -67,6 +66,7 @@ struct l2tp_session_coll_list { struct l2tp_session { int magic; /* should be L2TP_SESSION_MAGIC */ long dead; + struct rcu_head rcu; struct l2tp_tunnel *tunnel; /* back pointer to tunnel context */ u32 session_id; @@ -103,6 +103,7 @@ struct l2tp_session { int reorder_skip; /* set if skip to next nr */ enum l2tp_pwtype pwtype; struct l2tp_stats stats; + struct work_struct del_work; /* Session receive handler for data packets. * Each pseudowire implementation should implement this callback in order to @@ -155,8 +156,6 @@ struct l2tp_tunnel_cfg { */ #define L2TP_TUNNEL_NAME_MAX 20 struct l2tp_tunnel { - int magic; /* Should be L2TP_TUNNEL_MAGIC */ - unsigned long dead; struct rcu_head rcu; @@ -176,7 +175,6 @@ struct l2tp_tunnel { struct net *l2tp_net; /* the net we belong to */ refcount_t ref_count; - void (*old_sk_destruct)(struct sock *sk); struct sock *sock; /* parent socket */ int fd; /* parent fd, if tunnel socket was created * by userspace @@ -260,7 +258,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb); /* Transmit path helpers for sending packets over the tunnel socket. */ -void l2tp_session_set_header_len(struct l2tp_session *session, int version); +void l2tp_session_set_header_len(struct l2tp_session *session, int version, + enum l2tp_encap_type encap); int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb); /* Pseudowire management. @@ -273,10 +272,7 @@ void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type); /* IOCTL helper for IP encap modules. */ int l2tp_ioctl(struct sock *sk, int cmd, int *karg); -/* Extract the tunnel structure from a socket's sk_user_data pointer, - * validating the tunnel magic feather. - */ -struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk); +struct l2tp_tunnel *l2tp_sk_to_tunnel(const struct sock *sk); static inline int l2tp_get_l2specific_len(struct l2tp_session *session) { diff --git a/net/l2tp/l2tp_eth.c b/net/l2tp/l2tp_eth.c index 8ba00ad433c2..cc8a3ce716e9 100644 --- a/net/l2tp/l2tp_eth.c +++ b/net/l2tp/l2tp_eth.c @@ -322,7 +322,7 @@ err_sess_dev: l2tp_session_dec_refcount(session); free_netdev(dev); err_sess: - kfree(session); + l2tp_session_dec_refcount(session); err: return rc; } diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c index e48aa177d74c..f21dcbf3efd5 100644 --- a/net/l2tp/l2tp_ip.c +++ b/net/l2tp/l2tp_ip.c @@ -235,14 +235,17 @@ static void l2tp_ip_close(struct sock *sk, long timeout) static void l2tp_ip_destroy_sock(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); - struct sk_buff *skb; + struct l2tp_tunnel *tunnel; - while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL) - kfree_skb(skb); + lock_sock(sk); + ip_flush_pending_frames(sk); + release_sock(sk); - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index d217ff1f229e..3b0465f2d60d 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -246,14 +246,17 @@ static void l2tp_ip6_close(struct sock *sk, long timeout) static void l2tp_ip6_destroy_sock(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct l2tp_tunnel *tunnel; lock_sock(sk); ip6_flush_pending_frames(sk); release_sock(sk); - if (tunnel) + tunnel = l2tp_sk_to_tunnel(sk); + if (tunnel) { l2tp_tunnel_delete(tunnel); + l2tp_tunnel_dec_refcount(tunnel); + } } static int l2tp_ip6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) diff --git a/net/l2tp/l2tp_netlink.c b/net/l2tp/l2tp_netlink.c index d105030520f9..fc43ecbd128c 100644 --- a/net/l2tp/l2tp_netlink.c +++ b/net/l2tp/l2tp_netlink.c @@ -692,8 +692,10 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf session->recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]); if (info->attrs[L2TP_ATTR_SEND_SEQ]) { + struct l2tp_tunnel *tunnel = session->tunnel; + session->send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]); - l2tp_session_set_header_len(session, session->tunnel->version); + l2tp_session_set_header_len(session, tunnel->version, tunnel->encap); } if (info->attrs[L2TP_ATTR_LNS_MODE]) diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index 3596290047b2..90bf3a8ccab6 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -119,7 +119,6 @@ struct pppol2tp_session { struct mutex sk_lock; /* Protects .sk */ struct sock __rcu *sk; /* Pointer to the session PPPoX socket */ struct sock *__sk; /* Copy of .sk, for cleanup */ - struct rcu_head rcu; /* For asynchronous release */ }; static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb); @@ -157,20 +156,16 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) if (!sk) return NULL; - sock_hold(sk); - session = (struct l2tp_session *)(sk->sk_user_data); - if (!session) { - sock_put(sk); - goto out; - } - if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) { - session = NULL; - sock_put(sk); - goto out; + rcu_read_lock(); + session = rcu_dereference_sk_user_data(sk); + if (session && refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock(); + WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC); + return session; } + rcu_read_unlock(); -out: - return session; + return NULL; } /***************************************************************************** @@ -318,12 +313,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m, l2tp_xmit_skb(session, skb); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return total_len; error_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); error: return error; } @@ -377,12 +372,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) l2tp_xmit_skb(session, skb); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return 1; abort_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); abort: /* Free the original skb */ kfree_skb(skb); @@ -393,28 +388,31 @@ abort: * Session (and tunnel control) socket create/destroy. *****************************************************************************/ -static void pppol2tp_put_sk(struct rcu_head *head) -{ - struct pppol2tp_session *ps; - - ps = container_of(head, typeof(*ps), rcu); - sock_put(ps->__sk); -} - /* Really kill the session socket. (Called from sock_put() if * refcnt == 0.) */ static void pppol2tp_session_destruct(struct sock *sk) { - struct l2tp_session *session = sk->sk_user_data; - skb_queue_purge(&sk->sk_receive_queue); skb_queue_purge(&sk->sk_write_queue); +} - if (session) { - sk->sk_user_data = NULL; - if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) - return; +static void pppol2tp_session_close(struct l2tp_session *session) +{ + struct pppol2tp_session *ps; + + ps = l2tp_session_priv(session); + mutex_lock(&ps->sk_lock); + ps->__sk = rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)); + RCU_INIT_POINTER(ps->sk, NULL); + mutex_unlock(&ps->sk_lock); + if (ps->__sk) { + /* detach socket */ + rcu_assign_sk_user_data(ps->__sk, NULL); + sock_put(ps->__sk); + + /* drop ref taken when we referenced socket via sk_user_data */ l2tp_session_dec_refcount(session); } } @@ -444,30 +442,13 @@ static int pppol2tp_release(struct socket *sock) session = pppol2tp_sock_to_session(sk); if (session) { - struct pppol2tp_session *ps; - l2tp_session_delete(session); - - ps = l2tp_session_priv(session); - mutex_lock(&ps->sk_lock); - ps->__sk = rcu_dereference_protected(ps->sk, - lockdep_is_held(&ps->sk_lock)); - RCU_INIT_POINTER(ps->sk, NULL); - mutex_unlock(&ps->sk_lock); - call_rcu(&ps->rcu, pppol2tp_put_sk); - - /* Rely on the sock_put() call at the end of the function for - * dropping the reference held by pppol2tp_sock_to_session(). - * The last reference will be dropped by pppol2tp_put_sk(). - */ + /* drop ref taken by pppol2tp_sock_to_session */ + l2tp_session_dec_refcount(session); } release_sock(sk); - /* This will delete the session context via - * pppol2tp_session_destruct() if the socket's refcnt drops to - * zero. - */ sock_put(sk); return 0; @@ -506,6 +487,7 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern) goto out; sock_init_data(sock, sk); + sock_set_flag(sk, SOCK_RCU_FREE); sock->state = SS_UNCONNECTED; sock->ops = &pppol2tp_ops; @@ -542,6 +524,7 @@ static void pppol2tp_session_init(struct l2tp_session *session) struct pppol2tp_session *ps; session->recv_skb = pppol2tp_recv; + session->session_close = pppol2tp_session_close; if (IS_ENABLED(CONFIG_L2TP_DEBUGFS)) session->show = pppol2tp_show; @@ -787,6 +770,8 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, goto end; } + drop_refcnt = true; + pppol2tp_session_init(session); ps = l2tp_session_priv(session); l2tp_session_inc_refcount(session); @@ -795,10 +780,10 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, error = l2tp_session_register(session, tunnel); if (error < 0) { mutex_unlock(&ps->sk_lock); - kfree(session); + l2tp_session_dec_refcount(session); goto end; } - drop_refcnt = true; + new_session = true; } @@ -830,12 +815,13 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, out_no_ppp: /* This is how we get the session context from the socket. */ - sk->sk_user_data = session; + sock_hold(sk); + rcu_assign_sk_user_data(sk, session); rcu_assign_pointer(ps->sk, sk); mutex_unlock(&ps->sk_lock); /* Keep the reference we've grabbed on the session: sk doesn't expect - * the session to disappear. pppol2tp_session_destruct() is responsible + * the session to disappear. pppol2tp_session_close() is responsible * for dropping it. */ drop_refcnt = false; @@ -891,7 +877,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel, return 0; err_sess: - kfree(session); + l2tp_session_dec_refcount(session); err: return error; } @@ -1002,7 +988,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr, error = len; - sock_put(sk); + l2tp_session_dec_refcount(session); end: return error; } @@ -1205,7 +1191,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk, po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ : PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; } - l2tp_session_set_header_len(session, session->tunnel->version); + l2tp_session_set_header_len(session, session->tunnel->version, + session->tunnel->encap); break; case PPPOL2TP_SO_LNSMODE: @@ -1274,7 +1261,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, err = pppol2tp_session_setsockopt(sk, session, optname, val); } - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1395,7 +1382,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname, err = 0; end_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1513,7 +1500,7 @@ static void pppol2tp_seq_tunnel_show(struct seq_file *m, void *v) seq_printf(m, "\nTUNNEL '%s', %c %d\n", tunnel->name, - (tunnel == tunnel->sock->sk_user_data) ? 'Y' : 'N', + tunnel->sock ? 'Y' : 'N', refcount_read(&tunnel->ref_count) - 1); seq_printf(m, " %08x %ld/%ld/%ld %ld/%ld/%ld\n", 0, |