diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
| -rw-r--r-- | net/netlink/af_netlink.c | 187 |
1 files changed, 86 insertions, 101 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index ef5f77b44ec7..01b702d63457 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -97,12 +97,12 @@ static int netlink_dump(struct sock *sk); static void netlink_skb_destructor(struct sk_buff *skb); /* nl_table locking explained: - * Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock - * combined with an RCU read-side lock. Insertion and removal are protected - * with nl_sk_hash_lock while using RCU list modification primitives and may - * run in parallel to nl_table_lock protected lookups. Destruction of the - * Netlink socket may only occur *after* nl_table_lock has been acquired - * either during or after the socket has been removed from the list. + * Lookup and traversal are protected with an RCU read-side lock. Insertion + * and removal are protected with per bucket lock while using RCU list + * modification primitives and may run in parallel to RCU protected lookups. + * Destruction of the Netlink socket may only occur *after* nl_table_lock has + * been acquired * either during or after the socket has been removed from + * the list and after an RCU grace period. */ DEFINE_RWLOCK(nl_table_lock); EXPORT_SYMBOL_GPL(nl_table_lock); @@ -110,19 +110,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); -/* Protects netlink socket hash table mutations */ -DEFINE_MUTEX(nl_sk_hash_lock); -EXPORT_SYMBOL_GPL(nl_sk_hash_lock); - -#ifdef CONFIG_PROVE_LOCKING -static int lockdep_nl_sk_hash_is_held(void *parent) -{ - if (debug_locks) - return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock); - return 1; -} -#endif - static ATOMIC_NOTIFIER_HEAD(netlink_chain); static DEFINE_SPINLOCK(netlink_tap_lock); @@ -525,14 +512,14 @@ out: return err; } -static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr) +static void netlink_frame_flush_dcache(const struct nl_mmap_hdr *hdr, unsigned int nm_len) { #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1 struct page *p_start, *p_end; /* First page is flushed through netlink_{get,set}_status */ p_start = pgvec_to_page(hdr + PAGE_SIZE); - p_end = pgvec_to_page((void *)hdr + NL_MMAP_HDRLEN + hdr->nm_len - 1); + p_end = pgvec_to_page((void *)hdr + NL_MMAP_HDRLEN + nm_len - 1); while (p_start <= p_end) { flush_dcache_page(p_start); p_start++; @@ -550,9 +537,9 @@ static enum nl_mmap_status netlink_get_status(const struct nl_mmap_hdr *hdr) static void netlink_set_status(struct nl_mmap_hdr *hdr, enum nl_mmap_status status) { + smp_mb(); hdr->nm_status = status; flush_dcache_page(pgvec_to_page(hdr)); - smp_wmb(); } static struct nl_mmap_hdr * @@ -714,24 +701,16 @@ static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg, struct nl_mmap_hdr *hdr; struct sk_buff *skb; unsigned int maxlen; - bool excl = true; int err = 0, len = 0; - /* Netlink messages are validated by the receiver before processing. - * In order to avoid userspace changing the contents of the message - * after validation, the socket and the ring may only be used by a - * single process, otherwise we fall back to copying. - */ - if (atomic_long_read(&sk->sk_socket->file->f_count) > 1 || - atomic_read(&nlk->mapped) > 1) - excl = false; - mutex_lock(&nlk->pg_vec_lock); ring = &nlk->tx_ring; maxlen = ring->frame_size - NL_MMAP_HDRLEN; do { + unsigned int nm_len; + hdr = netlink_current_frame(ring, NL_MMAP_STATUS_VALID); if (hdr == NULL) { if (!(msg->msg_flags & MSG_DONTWAIT) && @@ -739,35 +718,23 @@ static int netlink_mmap_sendmsg(struct sock *sk, struct msghdr *msg, schedule(); continue; } - if (hdr->nm_len > maxlen) { + + nm_len = ACCESS_ONCE(hdr->nm_len); + if (nm_len > maxlen) { err = -EINVAL; goto out; } - netlink_frame_flush_dcache(hdr); + netlink_frame_flush_dcache(hdr, nm_len); - if (likely(dst_portid == 0 && dst_group == 0 && excl)) { - skb = alloc_skb_head(GFP_KERNEL); - if (skb == NULL) { - err = -ENOBUFS; - goto out; - } - sock_hold(sk); - netlink_ring_setup_skb(skb, sk, ring, hdr); - NETLINK_CB(skb).flags |= NETLINK_SKB_TX; - __skb_put(skb, hdr->nm_len); - netlink_set_status(hdr, NL_MMAP_STATUS_RESERVED); - atomic_inc(&ring->pending); - } else { - skb = alloc_skb(hdr->nm_len, GFP_KERNEL); - if (skb == NULL) { - err = -ENOBUFS; - goto out; - } - __skb_put(skb, hdr->nm_len); - memcpy(skb->data, (void *)hdr + NL_MMAP_HDRLEN, hdr->nm_len); - netlink_set_status(hdr, NL_MMAP_STATUS_UNUSED); + skb = alloc_skb(nm_len, GFP_KERNEL); + if (skb == NULL) { + err = -ENOBUFS; + goto out; } + __skb_put(skb, nm_len); + memcpy(skb->data, (void *)hdr + NL_MMAP_HDRLEN, nm_len); + netlink_set_status(hdr, NL_MMAP_STATUS_UNUSED); netlink_increment_head(ring); @@ -813,7 +780,7 @@ static void netlink_queue_mmaped_skb(struct sock *sk, struct sk_buff *skb) hdr->nm_pid = NETLINK_CB(skb).creds.pid; hdr->nm_uid = from_kuid(sk_user_ns(sk), NETLINK_CB(skb).creds.uid); hdr->nm_gid = from_kgid(sk_user_ns(sk), NETLINK_CB(skb).creds.gid); - netlink_frame_flush_dcache(hdr); + netlink_frame_flush_dcache(hdr, hdr->nm_len); netlink_set_status(hdr, NL_MMAP_STATUS_VALID); NETLINK_CB(skb).flags |= NETLINK_SKB_DELIVERED; @@ -1022,26 +989,34 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, .net = net, .portid = portid, }; - u32 hash; - - hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid)); - return rhashtable_lookup_compare(&table->hash, hash, + return rhashtable_lookup_compare(&table->hash, &portid, &netlink_compare, &arg); } +static bool __netlink_insert(struct netlink_table *table, struct sock *sk, + struct net *net) +{ + struct netlink_compare_arg arg = { + .net = net, + .portid = nlk_sk(sk)->portid, + }; + + return rhashtable_lookup_compare_insert(&table->hash, + &nlk_sk(sk)->node, + &netlink_compare, &arg); +} + static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) { struct netlink_table *table = &nl_table[protocol]; struct sock *sk; - read_lock(&nl_table_lock); rcu_read_lock(); sk = __netlink_lookup(table, portid, net); if (sk) sock_hold(sk); rcu_read_unlock(); - read_unlock(&nl_table_lock); return sk; } @@ -1077,24 +1052,25 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) struct netlink_table *table = &nl_table[sk->sk_protocol]; int err = -EADDRINUSE; - mutex_lock(&nl_sk_hash_lock); - if (__netlink_lookup(table, portid, net)) - goto err; + lock_sock(sk); err = -EBUSY; if (nlk_sk(sk)->portid) goto err; err = -ENOMEM; - if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX)) + if (BITS_PER_LONG > 32 && + unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX)) goto err; nlk_sk(sk)->portid = portid; sock_hold(sk); - rhashtable_insert(&table->hash, &nlk_sk(sk)->node); - err = 0; + if (__netlink_insert(table, sk, net)) + err = 0; + else + sock_put(sk); err: - mutex_unlock(&nl_sk_hash_lock); + release_sock(sk); return err; } @@ -1102,17 +1078,17 @@ static void netlink_remove(struct sock *sk) { struct netlink_table *table; - mutex_lock(&nl_sk_hash_lock); table = &nl_table[sk->sk_protocol]; if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) { WARN_ON(atomic_read(&sk->sk_refcnt) == 1); __sock_put(sk); } - mutex_unlock(&nl_sk_hash_lock); netlink_table_grab(); - if (nlk_sk(sk)->subscriptions) + if (nlk_sk(sk)->subscriptions) { __sk_del_bind_node(sk); + netlink_update_listeners(sk); + } netlink_table_ungrab(); } @@ -1159,8 +1135,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, struct module *module = NULL; struct mutex *cb_mutex; struct netlink_sock *nlk; - int (*bind)(int group); - void (*unbind)(int group); + int (*bind)(struct net *net, int group); + void (*unbind)(struct net *net, int group); int err = 0; sock->state = SS_UNCONNECTED; @@ -1212,6 +1188,13 @@ out_module: goto out; } +static void deferred_put_nlk_sk(struct rcu_head *head) +{ + struct netlink_sock *nlk = container_of(head, struct netlink_sock, rcu); + + sock_put(&nlk->sk); +} + static int netlink_release(struct socket *sock) { struct sock *sk = sock->sk; @@ -1246,8 +1229,8 @@ static int netlink_release(struct socket *sock) module_put(nlk->module); - netlink_table_grab(); if (netlink_is_kernel(sk)) { + netlink_table_grab(); BUG_ON(nl_table[sk->sk_protocol].registered == 0); if (--nl_table[sk->sk_protocol].registered == 0) { struct listeners *old; @@ -1261,18 +1244,23 @@ static int netlink_release(struct socket *sock) nl_table[sk->sk_protocol].flags = 0; nl_table[sk->sk_protocol].registered = 0; } - } else if (nlk->subscriptions) { - netlink_update_listeners(sk); + netlink_table_ungrab(); } - netlink_table_ungrab(); + if (nlk->netlink_unbind) { + int i; + + for (i = 0; i < nlk->ngroups; i++) + if (test_bit(i, nlk->groups)) + nlk->netlink_unbind(sock_net(sk), i + 1); + } kfree(nlk->groups); nlk->groups = NULL; local_bh_disable(); sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); local_bh_enable(); - sock_put(sk); + call_rcu(&nlk->rcu, deferred_put_nlk_sk); return 0; } @@ -1287,7 +1275,6 @@ static int netlink_autobind(struct socket *sock) retry: cond_resched(); - netlink_table_grab(); rcu_read_lock(); if (__netlink_lookup(table, portid, net)) { /* Bind collision, search negative portid values. */ @@ -1295,11 +1282,9 @@ retry: if (rover > -4097) rover = -4097; rcu_read_unlock(); - netlink_table_ungrab(); goto retry; } rcu_read_unlock(); - netlink_table_ungrab(); err = netlink_insert(sk, net, portid); if (err == -EADDRINUSE) @@ -1430,9 +1415,10 @@ static int netlink_realloc_groups(struct sock *sk) return err; } -static void netlink_unbind(int group, long unsigned int groups, - struct netlink_sock *nlk) +static void netlink_undo_bind(int group, long unsigned int groups, + struct sock *sk) { + struct netlink_sock *nlk = nlk_sk(sk); int undo; if (!nlk->netlink_unbind) @@ -1440,7 +1426,7 @@ static void netlink_unbind(int group, long unsigned int groups, for (undo = 0; undo < group; undo++) if (test_bit(undo, &groups)) - nlk->netlink_unbind(undo); + nlk->netlink_unbind(sock_net(sk), undo); } static int netlink_bind(struct socket *sock, struct sockaddr *addr, @@ -1478,10 +1464,10 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, for (group = 0; group < nlk->ngroups; group++) { if (!test_bit(group, &groups)) continue; - err = nlk->netlink_bind(group); + err = nlk->netlink_bind(net, group); if (!err) continue; - netlink_unbind(group, groups, nlk); + netlink_undo_bind(group, groups, sk); return err; } } @@ -1491,7 +1477,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, netlink_insert(sk, net, nladdr->nl_pid) : netlink_autobind(sock); if (err) { - netlink_unbind(nlk->ngroups, groups, nlk); + netlink_undo_bind(nlk->ngroups, groups, sk); return err; } } @@ -2142,7 +2128,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, if (!val || val - 1 >= nlk->ngroups) return -EINVAL; if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) { - err = nlk->netlink_bind(val); + err = nlk->netlink_bind(sock_net(sk), val); if (err) return err; } @@ -2151,7 +2137,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname, optname == NETLINK_ADD_MEMBERSHIP); netlink_table_ungrab(); if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind) - nlk->netlink_unbind(val); + nlk->netlink_unbind(sock_net(sk), val); err = 0; break; @@ -2913,7 +2899,9 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); for (j = 0; j < tbl->size; j++) { - rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { + struct rhash_head *node; + + rht_for_each_entry_rcu(nlk, node, tbl, j, node) { s = (struct sock *)nlk; if (sock_net(s) != seq_file_net(seq)) @@ -2931,9 +2919,8 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) } static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(nl_table_lock) __acquires(RCU) + __acquires(RCU) { - read_lock(&nl_table_lock); rcu_read_lock(); return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; } @@ -2941,6 +2928,8 @@ static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) { struct rhashtable *ht; + const struct bucket_table *tbl; + struct rhash_head *node; struct netlink_sock *nlk; struct nl_seq_iter *iter; struct net *net; @@ -2957,17 +2946,17 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) i = iter->link; ht = &nl_table[i].hash; - rht_for_each_entry(nlk, nlk->node.next, ht, node) + tbl = rht_dereference_rcu(ht->tbl, ht); + rht_for_each_entry_rcu_continue(nlk, node, nlk->node.next, tbl, iter->hash_idx, node) if (net_eq(sock_net((struct sock *)nlk), net)) return nlk; j = iter->hash_idx + 1; do { - const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); for (; j < tbl->size; j++) { - rht_for_each_entry(nlk, tbl->buckets[j], ht, node) { + rht_for_each_entry_rcu(nlk, node, tbl, j, node) { if (net_eq(sock_net((struct sock *)nlk), net)) { iter->link = i; iter->hash_idx = j; @@ -2983,10 +2972,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static void netlink_seq_stop(struct seq_file *seq, void *v) - __releases(RCU) __releases(nl_table_lock) + __releases(RCU) { rcu_read_unlock(); - read_unlock(&nl_table_lock); } @@ -3133,9 +3121,6 @@ static int __init netlink_proto_init(void) .max_shift = 16, /* 64K */ .grow_decision = rht_grow_above_75, .shrink_decision = rht_shrink_below_30, -#ifdef CONFIG_PROVE_LOCKING - .mutex_is_held = lockdep_nl_sk_hash_is_held, -#endif }; if (err != 0) |