diff options
Diffstat (limited to 'net/netlink/af_netlink.c')
| -rw-r--r-- | net/netlink/af_netlink.c | 313 | 
1 files changed, 120 insertions, 193 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index e6fac7e3db52..c416725d28c4 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -58,7 +58,9 @@  #include <linux/mutex.h>  #include <linux/vmalloc.h>  #include <linux/if_arp.h> +#include <linux/rhashtable.h>  #include <asm/cacheflush.h> +#include <linux/hash.h>  #include <net/net_namespace.h>  #include <net/sock.h> @@ -100,6 +102,19 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);  #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); +/* Protects netlink socket hash table mutations */ +DEFINE_MUTEX(nl_sk_hash_lock); +EXPORT_SYMBOL_GPL(nl_sk_hash_lock); + +static int lockdep_nl_sk_hash_is_held(void) +{ +#ifdef CONFIG_LOCKDEP +	return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1; +#else +	return 1; +#endif +} +  static ATOMIC_NOTIFIER_HEAD(netlink_chain);  static DEFINE_SPINLOCK(netlink_tap_lock); @@ -110,11 +125,6 @@ static inline u32 netlink_group_mask(u32 group)  	return group ? 1 << (group - 1) : 0;  } -static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid) -{ -	return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask]; -} -  int netlink_add_tap(struct netlink_tap *nt)  {  	if (unlikely(nt->dev->type != ARPHRD_NETLINK)) @@ -170,7 +180,6 @@ EXPORT_SYMBOL_GPL(netlink_remove_tap);  static bool netlink_filter_tap(const struct sk_buff *skb)  {  	struct sock *sk = skb->sk; -	bool pass = false;  	/* We take the more conservative approach and  	 * whitelist socket protocols that may pass. @@ -184,11 +193,10 @@ static bool netlink_filter_tap(const struct sk_buff *skb)  	case NETLINK_FIB_LOOKUP:  	case NETLINK_NETFILTER:  	case NETLINK_GENERIC: -		pass = true; -		break; +		return true;  	} -	return pass; +	return false;  }  static int __netlink_deliver_tap_skb(struct sk_buff *skb, @@ -205,7 +213,7 @@ static int __netlink_deliver_tap_skb(struct sk_buff *skb,  		nskb->protocol = htons((u16) sk->sk_protocol);  		nskb->pkt_type = netlink_is_kernel(sk) ?  				 PACKET_KERNEL : PACKET_USER; - +		skb_reset_network_header(nskb);  		ret = dev_queue_xmit(nskb);  		if (unlikely(ret > 0))  			ret = net_xmit_errno(ret); @@ -376,7 +384,7 @@ static int netlink_set_ring(struct sock *sk, struct nl_mmap_req *req,  		if ((int)req->nm_block_size <= 0)  			return -EINVAL; -		if (!IS_ALIGNED(req->nm_block_size, PAGE_SIZE)) +		if (!PAGE_ALIGNED(req->nm_block_size))  			return -EINVAL;  		if (req->nm_frame_size < NL_MMAP_HDRLEN)  			return -EINVAL; @@ -985,105 +993,48 @@ netlink_unlock_table(void)  		wake_up(&nl_table_wait);  } -static bool netlink_compare(struct net *net, struct sock *sk) +struct netlink_compare_arg  { -	return net_eq(sock_net(sk), net); -} - -static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) -{ -	struct netlink_table *table = &nl_table[protocol]; -	struct nl_portid_hash *hash = &table->hash; -	struct hlist_head *head; -	struct sock *sk; - -	read_lock(&nl_table_lock); -	head = nl_portid_hashfn(hash, portid); -	sk_for_each(sk, head) { -		if (table->compare(net, sk) && -		    (nlk_sk(sk)->portid == portid)) { -			sock_hold(sk); -			goto found; -		} -	} -	sk = NULL; -found: -	read_unlock(&nl_table_lock); -	return sk; -} +	struct net *net; +	u32 portid; +}; -static struct hlist_head *nl_portid_hash_zalloc(size_t size) +static bool netlink_compare(void *ptr, void *arg)  { -	if (size <= PAGE_SIZE) -		return kzalloc(size, GFP_ATOMIC); -	else -		return (struct hlist_head *) -			__get_free_pages(GFP_ATOMIC | __GFP_ZERO, -					 get_order(size)); -} +	struct netlink_compare_arg *x = arg; +	struct sock *sk = ptr; -static void nl_portid_hash_free(struct hlist_head *table, size_t size) -{ -	if (size <= PAGE_SIZE) -		kfree(table); -	else -		free_pages((unsigned long)table, get_order(size)); +	return nlk_sk(sk)->portid == x->portid && +	       net_eq(sock_net(sk), x->net);  } -static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow) +static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, +				     struct net *net)  { -	unsigned int omask, mask, shift; -	size_t osize, size; -	struct hlist_head *otable, *table; -	int i; - -	omask = mask = hash->mask; -	osize = size = (mask + 1) * sizeof(*table); -	shift = hash->shift; - -	if (grow) { -		if (++shift > hash->max_shift) -			return 0; -		mask = mask * 2 + 1; -		size *= 2; -	} - -	table = nl_portid_hash_zalloc(size); -	if (!table) -		return 0; - -	otable = hash->table; -	hash->table = table; -	hash->mask = mask; -	hash->shift = shift; -	get_random_bytes(&hash->rnd, sizeof(hash->rnd)); +	struct netlink_compare_arg arg = { +		.net = net, +		.portid = portid, +	}; +	u32 hash; -	for (i = 0; i <= omask; i++) { -		struct sock *sk; -		struct hlist_node *tmp; - -		sk_for_each_safe(sk, tmp, &otable[i]) -			__sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid)); -	} +	hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid)); -	nl_portid_hash_free(otable, osize); -	hash->rehash_time = jiffies + 10 * 60 * HZ; -	return 1; +	return rhashtable_lookup_compare(&table->hash, hash, +					 &netlink_compare, &arg);  } -static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len) +static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)  { -	int avg = hash->entries >> hash->shift; - -	if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1)) -		return 1; +	struct netlink_table *table = &nl_table[protocol]; +	struct sock *sk; -	if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) { -		nl_portid_hash_rehash(hash, 0); -		return 1; -	} +	rcu_read_lock(); +	sk = __netlink_lookup(table, portid, net); +	if (sk) +		sock_hold(sk); +	rcu_read_unlock(); -	return 0; +	return sk;  }  static const struct proto_ops netlink_ops; @@ -1115,22 +1066,10 @@ netlink_update_listeners(struct sock *sk)  static int netlink_insert(struct sock *sk, struct net *net, u32 portid)  {  	struct netlink_table *table = &nl_table[sk->sk_protocol]; -	struct nl_portid_hash *hash = &table->hash; -	struct hlist_head *head;  	int err = -EADDRINUSE; -	struct sock *osk; -	int len; -	netlink_table_grab(); -	head = nl_portid_hashfn(hash, portid); -	len = 0; -	sk_for_each(osk, head) { -		if (table->compare(net, osk) && -		    (nlk_sk(osk)->portid == portid)) -			break; -		len++; -	} -	if (osk) +	mutex_lock(&nl_sk_hash_lock); +	if (__netlink_lookup(table, portid, net))  		goto err;  	err = -EBUSY; @@ -1138,26 +1077,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)  		goto err;  	err = -ENOMEM; -	if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX)) +	if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))  		goto err; -	if (len && nl_portid_hash_dilute(hash, len)) -		head = nl_portid_hashfn(hash, portid); -	hash->entries++;  	nlk_sk(sk)->portid = portid; -	sk_add_node(sk, head); +	sock_hold(sk); +	rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);  	err = 0; -  err: -	netlink_table_ungrab(); +	mutex_unlock(&nl_sk_hash_lock);  	return err;  }  static void netlink_remove(struct sock *sk)  { +	struct netlink_table *table; + +	mutex_lock(&nl_sk_hash_lock); +	table = &nl_table[sk->sk_protocol]; +	if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) { +		WARN_ON(atomic_read(&sk->sk_refcnt) == 1); +		__sock_put(sk); +	} +	mutex_unlock(&nl_sk_hash_lock); +  	netlink_table_grab(); -	if (sk_del_node_init(sk)) -		nl_table[sk->sk_protocol].hash.entries--;  	if (nlk_sk(sk)->subscriptions)  		__sk_del_bind_node(sk);  	netlink_table_ungrab(); @@ -1313,6 +1257,9 @@ static int netlink_release(struct socket *sock)  	}  	netlink_table_ungrab(); +	/* Wait for readers to complete */ +	synchronize_net(); +  	kfree(nlk->groups);  	nlk->groups = NULL; @@ -1328,30 +1275,22 @@ static int netlink_autobind(struct socket *sock)  	struct sock *sk = sock->sk;  	struct net *net = sock_net(sk);  	struct netlink_table *table = &nl_table[sk->sk_protocol]; -	struct nl_portid_hash *hash = &table->hash; -	struct hlist_head *head; -	struct sock *osk;  	s32 portid = task_tgid_vnr(current);  	int err;  	static s32 rover = -4097;  retry:  	cond_resched(); -	netlink_table_grab(); -	head = nl_portid_hashfn(hash, portid); -	sk_for_each(osk, head) { -		if (!table->compare(net, osk)) -			continue; -		if (nlk_sk(osk)->portid == portid) { -			/* Bind collision, search negative portid values. */ -			portid = rover--; -			if (rover > -4097) -				rover = -4097; -			netlink_table_ungrab(); -			goto retry; -		} +	rcu_read_lock(); +	if (__netlink_lookup(table, portid, net)) { +		/* Bind collision, search negative portid values. */ +		portid = rover--; +		if (rover > -4097) +			rover = -4097; +		rcu_read_unlock(); +		goto retry;  	} -	netlink_table_ungrab(); +	rcu_read_unlock();  	err = netlink_insert(sk, net, portid);  	if (err == -EADDRINUSE) @@ -1961,25 +1900,25 @@ struct netlink_broadcast_data {  	void *tx_data;  }; -static int do_one_broadcast(struct sock *sk, -				   struct netlink_broadcast_data *p) +static void do_one_broadcast(struct sock *sk, +				    struct netlink_broadcast_data *p)  {  	struct netlink_sock *nlk = nlk_sk(sk);  	int val;  	if (p->exclude_sk == sk) -		goto out; +		return;  	if (nlk->portid == p->portid || p->group - 1 >= nlk->ngroups ||  	    !test_bit(p->group - 1, nlk->groups)) -		goto out; +		return;  	if (!net_eq(sock_net(sk), p->net)) -		goto out; +		return;  	if (p->failure) {  		netlink_overrun(sk); -		goto out; +		return;  	}  	sock_hold(sk); @@ -2017,9 +1956,6 @@ static int do_one_broadcast(struct sock *sk,  		p->skb2 = NULL;  	}  	sock_put(sk); - -out: -	return 0;  }  int netlink_broadcast_filtered(struct sock *ssk, struct sk_buff *skb, u32 portid, @@ -2958,14 +2894,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)  {  	struct nl_seq_iter *iter = seq->private;  	int i, j; +	struct netlink_sock *nlk;  	struct sock *s;  	loff_t off = 0;  	for (i = 0; i < MAX_LINKS; i++) { -		struct nl_portid_hash *hash = &nl_table[i].hash; +		struct rhashtable *ht = &nl_table[i].hash; +		const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); + +		for (j = 0; j < tbl->size; j++) { +			rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { +				s = (struct sock *)nlk; -		for (j = 0; j <= hash->mask; j++) { -			sk_for_each(s, &hash->table[j]) {  				if (sock_net(s) != seq_file_net(seq))  					continue;  				if (off == pos) { @@ -2981,15 +2921,15 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)  }  static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) -	__acquires(nl_table_lock) +	__acquires(RCU)  { -	read_lock(&nl_table_lock); +	rcu_read_lock();  	return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;  }  static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  { -	struct sock *s; +	struct netlink_sock *nlk;  	struct nl_seq_iter *iter;  	struct net *net;  	int i, j; @@ -3001,28 +2941,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  	net = seq_file_net(seq);  	iter = seq->private; -	s = v; -	do { -		s = sk_next(s); -	} while (s && !nl_table[s->sk_protocol].compare(net, s)); -	if (s) -		return s; +	nlk = v; + +	rht_for_each_entry_rcu(nlk, nlk->node.next, node) +		if (net_eq(sock_net((struct sock *)nlk), net)) +			return nlk;  	i = iter->link;  	j = iter->hash_idx + 1;  	do { -		struct nl_portid_hash *hash = &nl_table[i].hash; - -		for (; j <= hash->mask; j++) { -			s = sk_head(&hash->table[j]); +		struct rhashtable *ht = &nl_table[i].hash; +		const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); -			while (s && !nl_table[s->sk_protocol].compare(net, s)) -				s = sk_next(s); -			if (s) { -				iter->link = i; -				iter->hash_idx = j; -				return s; +		for (; j < tbl->size; j++) { +			rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { +				if (net_eq(sock_net((struct sock *)nlk), net)) { +					iter->link = i; +					iter->hash_idx = j; +					return nlk; +				}  			}  		} @@ -3033,9 +2971,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)  }  static void netlink_seq_stop(struct seq_file *seq, void *v) -	__releases(nl_table_lock) +	__releases(RCU)  { -	read_unlock(&nl_table_lock); +	rcu_read_unlock();  } @@ -3173,9 +3111,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = {  static int __init netlink_proto_init(void)  {  	int i; -	unsigned long limit; -	unsigned int order;  	int err = proto_register(&netlink_proto, 0); +	struct rhashtable_params ht_params = { +		.head_offset = offsetof(struct netlink_sock, node), +		.key_offset = offsetof(struct netlink_sock, portid), +		.key_len = sizeof(u32), /* portid */ +		.hashfn = arch_fast_hash, +		.max_shift = 16, /* 64K */ +		.grow_decision = rht_grow_above_75, +		.shrink_decision = rht_shrink_below_30, +		.mutex_is_held = lockdep_nl_sk_hash_is_held, +	};  	if (err != 0)  		goto out; @@ -3186,32 +3132,13 @@ static int __init netlink_proto_init(void)  	if (!nl_table)  		goto panic; -	if (totalram_pages >= (128 * 1024)) -		limit = totalram_pages >> (21 - PAGE_SHIFT); -	else -		limit = totalram_pages >> (23 - PAGE_SHIFT); - -	order = get_bitmask_order(limit) - 1 + PAGE_SHIFT; -	limit = (1UL << order) / sizeof(struct hlist_head); -	order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1; -  	for (i = 0; i < MAX_LINKS; i++) { -		struct nl_portid_hash *hash = &nl_table[i].hash; - -		hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table)); -		if (!hash->table) { -			while (i-- > 0) -				nl_portid_hash_free(nl_table[i].hash.table, -						 1 * sizeof(*hash->table)); +		if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) { +			while (--i > 0) +				rhashtable_destroy(&nl_table[i].hash);  			kfree(nl_table);  			goto panic;  		} -		hash->max_shift = order; -		hash->shift = 0; -		hash->mask = 0; -		hash->rehash_time = jiffies; - -		nl_table[i].compare = netlink_compare;  	}  	INIT_LIST_HEAD(&netlink_tap_all);  |