diff options
Diffstat (limited to 'net/packet/af_packet.c')
| -rw-r--r-- | net/packet/af_packet.c | 438 | 
1 files changed, 374 insertions, 64 deletions
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index c0c3cda19712..c698cec0a445 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -187,9 +187,11 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg);  static void packet_flush_mclist(struct sock *sk); +struct packet_fanout;  struct packet_sock {  	/* struct sock has to be the first member of packet_sock */  	struct sock		sk; +	struct packet_fanout	*fanout;  	struct tpacket_stats	stats;  	struct packet_ring_buffer	rx_ring;  	struct packet_ring_buffer	tx_ring; @@ -212,6 +214,24 @@ struct packet_sock {  	struct packet_type	prot_hook ____cacheline_aligned_in_smp;  }; +#define PACKET_FANOUT_MAX	256 + +struct packet_fanout { +#ifdef CONFIG_NET_NS +	struct net		*net; +#endif +	unsigned int		num_members; +	u16			id; +	u8			type; +	u8			defrag; +	atomic_t		rr_cur; +	struct list_head	list; +	struct sock		*arr[PACKET_FANOUT_MAX]; +	spinlock_t		lock; +	atomic_t		sk_ref; +	struct packet_type	prot_hook ____cacheline_aligned_in_smp; +}; +  struct packet_skb_cb {  	unsigned int origlen;  	union { @@ -222,6 +242,64 @@ struct packet_skb_cb {  #define PACKET_SKB_CB(__skb)	((struct packet_skb_cb *)((__skb)->cb)) +static inline struct packet_sock *pkt_sk(struct sock *sk) +{ +	return (struct packet_sock *)sk; +} + +static void __fanout_unlink(struct sock *sk, struct packet_sock *po); +static void __fanout_link(struct sock *sk, struct packet_sock *po); + +/* register_prot_hook must be invoked with the po->bind_lock held, + * or from a context in which asynchronous accesses to the packet + * socket is not possible (packet_create()). + */ +static void register_prot_hook(struct sock *sk) +{ +	struct packet_sock *po = pkt_sk(sk); +	if (!po->running) { +		if (po->fanout) +			__fanout_link(sk, po); +		else +			dev_add_pack(&po->prot_hook); +		sock_hold(sk); +		po->running = 1; +	} +} + +/* {,__}unregister_prot_hook() must be invoked with the po->bind_lock + * held.   If the sync parameter is true, we will temporarily drop + * the po->bind_lock and do a synchronize_net to make sure no + * asynchronous packet processing paths still refer to the elements + * of po->prot_hook.  If the sync parameter is false, it is the + * callers responsibility to take care of this. + */ +static void __unregister_prot_hook(struct sock *sk, bool sync) +{ +	struct packet_sock *po = pkt_sk(sk); + +	po->running = 0; +	if (po->fanout) +		__fanout_unlink(sk, po); +	else +		__dev_remove_pack(&po->prot_hook); +	__sock_put(sk); + +	if (sync) { +		spin_unlock(&po->bind_lock); +		synchronize_net(); +		spin_lock(&po->bind_lock); +	} +} + +static void unregister_prot_hook(struct sock *sk, bool sync) +{ +	struct packet_sock *po = pkt_sk(sk); + +	if (po->running) +		__unregister_prot_hook(sk, sync); +} +  static inline __pure struct page *pgv_to_page(void *addr)  {  	if (is_vmalloc_addr(addr)) @@ -324,11 +402,6 @@ static inline void packet_increment_head(struct packet_ring_buffer *buff)  	buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;  } -static inline struct packet_sock *pkt_sk(struct sock *sk) -{ -	return (struct packet_sock *)sk; -} -  static void packet_sock_destruct(struct sock *sk)  {  	skb_queue_purge(&sk->sk_error_queue); @@ -344,6 +417,240 @@ static void packet_sock_destruct(struct sock *sk)  	sk_refcnt_debug_dec(sk);  } +static int fanout_rr_next(struct packet_fanout *f, unsigned int num) +{ +	int x = atomic_read(&f->rr_cur) + 1; + +	if (x >= num) +		x = 0; + +	return x; +} + +static struct sock *fanout_demux_hash(struct packet_fanout *f, struct sk_buff *skb, unsigned int num) +{ +	u32 idx, hash = skb->rxhash; + +	idx = ((u64)hash * num) >> 32; + +	return f->arr[idx]; +} + +static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb, unsigned int num) +{ +	int cur, old; + +	cur = atomic_read(&f->rr_cur); +	while ((old = atomic_cmpxchg(&f->rr_cur, cur, +				     fanout_rr_next(f, num))) != cur) +		cur = old; +	return f->arr[cur]; +} + +static struct sock *fanout_demux_cpu(struct packet_fanout *f, struct sk_buff *skb, unsigned int num) +{ +	unsigned int cpu = smp_processor_id(); + +	return f->arr[cpu % num]; +} + +static struct sk_buff *fanout_check_defrag(struct sk_buff *skb) +{ +#ifdef CONFIG_INET +	const struct iphdr *iph; +	u32 len; + +	if (skb->protocol != htons(ETH_P_IP)) +		return skb; + +	if (!pskb_may_pull(skb, sizeof(struct iphdr))) +		return skb; + +	iph = ip_hdr(skb); +	if (iph->ihl < 5 || iph->version != 4) +		return skb; +	if (!pskb_may_pull(skb, iph->ihl*4)) +		return skb; +	iph = ip_hdr(skb); +	len = ntohs(iph->tot_len); +	if (skb->len < len || len < (iph->ihl * 4)) +		return skb; + +	if (ip_is_fragment(ip_hdr(skb))) { +		skb = skb_share_check(skb, GFP_ATOMIC); +		if (skb) { +			if (pskb_trim_rcsum(skb, len)) +				return skb; +			memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); +			if (ip_defrag(skb, IP_DEFRAG_AF_PACKET)) +				return NULL; +			skb->rxhash = 0; +		} +	} +#endif +	return skb; +} + +static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev, +			     struct packet_type *pt, struct net_device *orig_dev) +{ +	struct packet_fanout *f = pt->af_packet_priv; +	unsigned int num = f->num_members; +	struct packet_sock *po; +	struct sock *sk; + +	if (!net_eq(dev_net(dev), read_pnet(&f->net)) || +	    !num) { +		kfree_skb(skb); +		return 0; +	} + +	switch (f->type) { +	case PACKET_FANOUT_HASH: +	default: +		if (f->defrag) { +			skb = fanout_check_defrag(skb); +			if (!skb) +				return 0; +		} +		skb_get_rxhash(skb); +		sk = fanout_demux_hash(f, skb, num); +		break; +	case PACKET_FANOUT_LB: +		sk = fanout_demux_lb(f, skb, num); +		break; +	case PACKET_FANOUT_CPU: +		sk = fanout_demux_cpu(f, skb, num); +		break; +	} + +	po = pkt_sk(sk); + +	return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev); +} + +static DEFINE_MUTEX(fanout_mutex); +static LIST_HEAD(fanout_list); + +static void __fanout_link(struct sock *sk, struct packet_sock *po) +{ +	struct packet_fanout *f = po->fanout; + +	spin_lock(&f->lock); +	f->arr[f->num_members] = sk; +	smp_wmb(); +	f->num_members++; +	spin_unlock(&f->lock); +} + +static void __fanout_unlink(struct sock *sk, struct packet_sock *po) +{ +	struct packet_fanout *f = po->fanout; +	int i; + +	spin_lock(&f->lock); +	for (i = 0; i < f->num_members; i++) { +		if (f->arr[i] == sk) +			break; +	} +	BUG_ON(i >= f->num_members); +	f->arr[i] = f->arr[f->num_members - 1]; +	f->num_members--; +	spin_unlock(&f->lock); +} + +static int fanout_add(struct sock *sk, u16 id, u16 type_flags) +{ +	struct packet_sock *po = pkt_sk(sk); +	struct packet_fanout *f, *match; +	u8 type = type_flags & 0xff; +	u8 defrag = (type_flags & PACKET_FANOUT_FLAG_DEFRAG) ? 1 : 0; +	int err; + +	switch (type) { +	case PACKET_FANOUT_HASH: +	case PACKET_FANOUT_LB: +	case PACKET_FANOUT_CPU: +		break; +	default: +		return -EINVAL; +	} + +	if (!po->running) +		return -EINVAL; + +	if (po->fanout) +		return -EALREADY; + +	mutex_lock(&fanout_mutex); +	match = NULL; +	list_for_each_entry(f, &fanout_list, list) { +		if (f->id == id && +		    read_pnet(&f->net) == sock_net(sk)) { +			match = f; +			break; +		} +	} +	err = -EINVAL; +	if (match && match->defrag != defrag) +		goto out; +	if (!match) { +		err = -ENOMEM; +		match = kzalloc(sizeof(*match), GFP_KERNEL); +		if (!match) +			goto out; +		write_pnet(&match->net, sock_net(sk)); +		match->id = id; +		match->type = type; +		match->defrag = defrag; +		atomic_set(&match->rr_cur, 0); +		INIT_LIST_HEAD(&match->list); +		spin_lock_init(&match->lock); +		atomic_set(&match->sk_ref, 0); +		match->prot_hook.type = po->prot_hook.type; +		match->prot_hook.dev = po->prot_hook.dev; +		match->prot_hook.func = packet_rcv_fanout; +		match->prot_hook.af_packet_priv = match; +		dev_add_pack(&match->prot_hook); +		list_add(&match->list, &fanout_list); +	} +	err = -EINVAL; +	if (match->type == type && +	    match->prot_hook.type == po->prot_hook.type && +	    match->prot_hook.dev == po->prot_hook.dev) { +		err = -ENOSPC; +		if (atomic_read(&match->sk_ref) < PACKET_FANOUT_MAX) { +			__dev_remove_pack(&po->prot_hook); +			po->fanout = match; +			atomic_inc(&match->sk_ref); +			__fanout_link(sk, po); +			err = 0; +		} +	} +out: +	mutex_unlock(&fanout_mutex); +	return err; +} + +static void fanout_release(struct sock *sk) +{ +	struct packet_sock *po = pkt_sk(sk); +	struct packet_fanout *f; + +	f = po->fanout; +	if (!f) +		return; + +	po->fanout = NULL; + +	mutex_lock(&fanout_mutex); +	if (atomic_dec_and_test(&f->sk_ref)) { +		list_del(&f->list); +		dev_remove_pack(&f->prot_hook); +		kfree(f); +	} +	mutex_unlock(&fanout_mutex); +}  static const struct proto_ops packet_ops; @@ -822,7 +1129,6 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,  	else  		sll->sll_ifindex = dev->ifindex; -	__packet_set_status(po, h.raw, status);  	smp_mb();  #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1  	{ @@ -831,8 +1137,10 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,  		end = (u8 *)PAGE_ALIGN((unsigned long)h.raw + macoff + snaplen);  		for (start = h.raw; start < end; start += PAGE_SIZE)  			flush_dcache_page(pgv_to_page(start)); +		smp_wmb();  	}  #endif +	__packet_set_status(po, h.raw, status);  	sk->sk_data_ready(sk, 0); @@ -975,7 +1283,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)  	struct sk_buff *skb;  	struct net_device *dev;  	__be16 proto; -	int ifindex, err, reserve = 0; +	bool need_rls_dev = false; +	int err, reserve = 0;  	void *ph;  	struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;  	int tp_len, size_max; @@ -987,7 +1296,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)  	err = -EBUSY;  	if (saddr == NULL) { -		ifindex	= po->ifindex; +		dev = po->prot_hook.dev;  		proto	= po->num;  		addr	= NULL;  	} else { @@ -998,12 +1307,12 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)  					+ offsetof(struct sockaddr_ll,  						sll_addr)))  			goto out; -		ifindex	= saddr->sll_ifindex;  		proto	= saddr->sll_protocol;  		addr	= saddr->sll_addr; +		dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex); +		need_rls_dev = true;  	} -	dev = dev_get_by_index(sock_net(&po->sk), ifindex);  	err = -ENXIO;  	if (unlikely(dev == NULL))  		goto out; @@ -1089,7 +1398,8 @@ out_status:  	__packet_set_status(po, ph, status);  	kfree_skb(skb);  out_put: -	dev_put(dev); +	if (need_rls_dev) +		dev_put(dev);  out:  	mutex_unlock(&po->pg_vec_lock);  	return err; @@ -1127,8 +1437,9 @@ static int packet_snd(struct socket *sock,  	struct sk_buff *skb;  	struct net_device *dev;  	__be16 proto; +	bool need_rls_dev = false;  	unsigned char *addr; -	int ifindex, err, reserve = 0; +	int err, reserve = 0;  	struct virtio_net_hdr vnet_hdr = { 0 };  	int offset = 0;  	int vnet_hdr_len; @@ -1140,7 +1451,7 @@ static int packet_snd(struct socket *sock,  	 */  	if (saddr == NULL) { -		ifindex	= po->ifindex; +		dev = po->prot_hook.dev;  		proto	= po->num;  		addr	= NULL;  	} else { @@ -1149,13 +1460,12 @@ static int packet_snd(struct socket *sock,  			goto out;  		if (msg->msg_namelen < (saddr->sll_halen + offsetof(struct sockaddr_ll, sll_addr)))  			goto out; -		ifindex	= saddr->sll_ifindex;  		proto	= saddr->sll_protocol;  		addr	= saddr->sll_addr; +		dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex); +		need_rls_dev = true;  	} - -	dev = dev_get_by_index(sock_net(sk), ifindex);  	err = -ENXIO;  	if (dev == NULL)  		goto out_unlock; @@ -1286,14 +1596,15 @@ static int packet_snd(struct socket *sock,  	if (err > 0 && (err = net_xmit_errno(err)) != 0)  		goto out_unlock; -	dev_put(dev); +	if (need_rls_dev) +		dev_put(dev);  	return len;  out_free:  	kfree_skb(skb);  out_unlock: -	if (dev) +	if (dev && need_rls_dev)  		dev_put(dev);  out:  	return err; @@ -1334,14 +1645,10 @@ static int packet_release(struct socket *sock)  	spin_unlock_bh(&net->packet.sklist_lock);  	spin_lock(&po->bind_lock); -	if (po->running) { -		/* -		 * Remove from protocol table -		 */ -		po->running = 0; -		po->num = 0; -		__dev_remove_pack(&po->prot_hook); -		__sock_put(sk); +	unregister_prot_hook(sk, false); +	if (po->prot_hook.dev) { +		dev_put(po->prot_hook.dev); +		po->prot_hook.dev = NULL;  	}  	spin_unlock(&po->bind_lock); @@ -1355,6 +1662,8 @@ static int packet_release(struct socket *sock)  	if (po->tx_ring.pg_vec)  		packet_set_ring(sk, &req, 1, 1); +	fanout_release(sk); +  	synchronize_net();  	/*  	 *	Now the socket is dead. No more input will appear. @@ -1378,24 +1687,18 @@ static int packet_release(struct socket *sock)  static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protocol)  {  	struct packet_sock *po = pkt_sk(sk); -	/* -	 *	Detach an existing hook if present. -	 */ + +	if (po->fanout) +		return -EINVAL;  	lock_sock(sk);  	spin_lock(&po->bind_lock); -	if (po->running) { -		__sock_put(sk); -		po->running = 0; -		po->num = 0; -		spin_unlock(&po->bind_lock); -		dev_remove_pack(&po->prot_hook); -		spin_lock(&po->bind_lock); -	} - +	unregister_prot_hook(sk, true);  	po->num = protocol;  	po->prot_hook.type = protocol; +	if (po->prot_hook.dev) +		dev_put(po->prot_hook.dev);  	po->prot_hook.dev = dev;  	po->ifindex = dev ? dev->ifindex : 0; @@ -1404,9 +1707,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc  		goto out_unlock;  	if (!dev || (dev->flags & IFF_UP)) { -		dev_add_pack(&po->prot_hook); -		sock_hold(sk); -		po->running = 1; +		register_prot_hook(sk);  	} else {  		sk->sk_err = ENETDOWN;  		if (!sock_flag(sk, SOCK_DEAD)) @@ -1440,10 +1741,8 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,  	strlcpy(name, uaddr->sa_data, sizeof(name));  	dev = dev_get_by_name(sock_net(sk), name); -	if (dev) { +	if (dev)  		err = packet_do_bind(sk, dev, pkt_sk(sk)->num); -		dev_put(dev); -	}  	return err;  } @@ -1471,8 +1770,6 @@ static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len  			goto out;  	}  	err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num); -	if (dev) -		dev_put(dev);  out:  	return err; @@ -1537,9 +1834,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,  	if (proto) {  		po->prot_hook.type = proto; -		dev_add_pack(&po->prot_hook); -		sock_hold(sk); -		po->running = 1; +		register_prot_hook(sk);  	}  	spin_lock_bh(&net->packet.sklist_lock); @@ -1681,6 +1976,8 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,  			vnet_hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;  			vnet_hdr.csum_start = skb_checksum_start_offset(skb);  			vnet_hdr.csum_offset = skb->csum_offset; +		} else if (skb->ip_summed == CHECKSUM_UNNECESSARY) { +			vnet_hdr.flags = VIRTIO_NET_HDR_F_DATA_VALID;  		} /* else everything is zero */  		err = memcpy_toiovec(msg->msg_iov, (void *)&vnet_hdr, @@ -2102,6 +2399,17 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv  		po->tp_tstamp = val;  		return 0;  	} +	case PACKET_FANOUT: +	{ +		int val; + +		if (optlen != sizeof(val)) +			return -EINVAL; +		if (copy_from_user(&val, optval, sizeof(val))) +			return -EFAULT; + +		return fanout_add(sk, val & 0xffff, val >> 16); +	}  	default:  		return -ENOPROTOOPT;  	} @@ -2200,6 +2508,15 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,  		val = po->tp_tstamp;  		data = &val;  		break; +	case PACKET_FANOUT: +		if (len > sizeof(int)) +			len = sizeof(int); +		val = (po->fanout ? +		       ((u32)po->fanout->id | +			((u32)po->fanout->type << 16)) : +		       0); +		data = &val; +		break;  	default:  		return -ENOPROTOOPT;  	} @@ -2233,15 +2550,15 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void  			if (dev->ifindex == po->ifindex) {  				spin_lock(&po->bind_lock);  				if (po->running) { -					__dev_remove_pack(&po->prot_hook); -					__sock_put(sk); -					po->running = 0; +					__unregister_prot_hook(sk, false);  					sk->sk_err = ENETDOWN;  					if (!sock_flag(sk, SOCK_DEAD))  						sk->sk_error_report(sk);  				}  				if (msg == NETDEV_UNREGISTER) {  					po->ifindex = -1; +					if (po->prot_hook.dev) +						dev_put(po->prot_hook.dev);  					po->prot_hook.dev = NULL;  				}  				spin_unlock(&po->bind_lock); @@ -2250,11 +2567,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void  		case NETDEV_UP:  			if (dev->ifindex == po->ifindex) {  				spin_lock(&po->bind_lock); -				if (po->num && !po->running) { -					dev_add_pack(&po->prot_hook); -					sock_hold(sk); -					po->running = 1; -				} +				if (po->num) +					register_prot_hook(sk);  				spin_unlock(&po->bind_lock);  			}  			break; @@ -2521,10 +2835,8 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,  	was_running = po->running;  	num = po->num;  	if (was_running) { -		__dev_remove_pack(&po->prot_hook);  		po->num = 0; -		po->running = 0; -		__sock_put(sk); +		__unregister_prot_hook(sk, false);  	}  	spin_unlock(&po->bind_lock); @@ -2555,11 +2867,9 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,  	mutex_unlock(&po->pg_vec_lock);  	spin_lock(&po->bind_lock); -	if (was_running && !po->running) { -		sock_hold(sk); -		po->running = 1; +	if (was_running) {  		po->num = num; -		dev_add_pack(&po->prot_hook); +		register_prot_hook(sk);  	}  	spin_unlock(&po->bind_lock);  |