diff options
Diffstat (limited to 'net/packet/af_packet.c')
| -rw-r--r-- | net/packet/af_packet.c | 52 | 
1 files changed, 35 insertions, 17 deletions
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index c26172995511..2986941164b1 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -1684,10 +1684,6 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  	mutex_lock(&fanout_mutex); -	err = -EINVAL; -	if (!po->running) -		goto out; -  	err = -EALREADY;  	if (po->fanout)  		goto out; @@ -1749,7 +1745,10 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  		list_add(&match->list, &fanout_list);  	}  	err = -EINVAL; -	if (match->type == type && + +	spin_lock(&po->bind_lock); +	if (po->running && +	    match->type == type &&  	    match->prot_hook.type == po->prot_hook.type &&  	    match->prot_hook.dev == po->prot_hook.dev) {  		err = -ENOSPC; @@ -1761,9 +1760,16 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  			err = 0;  		}  	} +	spin_unlock(&po->bind_lock); + +	if (err && !refcount_read(&match->sk_ref)) { +		list_del(&match->list); +		kfree(match); +	} +  out:  	if (err && rollover) { -		kfree(rollover); +		kfree_rcu(rollover, rcu);  		po->rollover = NULL;  	}  	mutex_unlock(&fanout_mutex); @@ -1790,8 +1796,10 @@ static struct packet_fanout *fanout_release(struct sock *sk)  		else  			f = NULL; -		if (po->rollover) +		if (po->rollover) {  			kfree_rcu(po->rollover, rcu); +			po->rollover = NULL; +		}  	}  	mutex_unlock(&fanout_mutex); @@ -2834,6 +2842,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)  	struct virtio_net_hdr vnet_hdr = { 0 };  	int offset = 0;  	struct packet_sock *po = pkt_sk(sk); +	bool has_vnet_hdr = false;  	int hlen, tlen, linear;  	int extra_len = 0; @@ -2877,6 +2886,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)  		err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);  		if (err)  			goto out_unlock; +		has_vnet_hdr = true;  	}  	if (unlikely(sock_flag(sk, SOCK_NOFCS))) { @@ -2935,7 +2945,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)  	skb->priority = sk->sk_priority;  	skb->mark = sockc.mark; -	if (po->has_vnet_hdr) { +	if (has_vnet_hdr) {  		err = virtio_net_hdr_to_skb(skb, &vnet_hdr, vio_le());  		if (err)  			goto out_free; @@ -3063,13 +3073,15 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,  	int ret = 0;  	bool unlisted = false; -	if (po->fanout) -		return -EINVAL; -  	lock_sock(sk);  	spin_lock(&po->bind_lock);  	rcu_read_lock(); +	if (po->fanout) { +		ret = -EINVAL; +		goto out_unlock; +	} +  	if (name) {  		dev = dev_get_by_name_rcu(sock_net(sk), name);  		if (!dev) { @@ -3841,6 +3853,7 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,  	void *data = &val;  	union tpacket_stats_u st;  	struct tpacket_rollover_stats rstats; +	struct packet_rollover *rollover;  	if (level != SOL_PACKET)  		return -ENOPROTOOPT; @@ -3919,13 +3932,18 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,  		       0);  		break;  	case PACKET_ROLLOVER_STATS: -		if (!po->rollover) +		rcu_read_lock(); +		rollover = rcu_dereference(po->rollover); +		if (rollover) { +			rstats.tp_all = atomic_long_read(&rollover->num); +			rstats.tp_huge = atomic_long_read(&rollover->num_huge); +			rstats.tp_failed = atomic_long_read(&rollover->num_failed); +			data = &rstats; +			lv = sizeof(rstats); +		} +		rcu_read_unlock(); +		if (!rollover)  			return -EINVAL; -		rstats.tp_all = atomic_long_read(&po->rollover->num); -		rstats.tp_huge = atomic_long_read(&po->rollover->num_huge); -		rstats.tp_failed = atomic_long_read(&po->rollover->num_failed); -		data = &rstats; -		lv = sizeof(rstats);  		break;  	case PACKET_TX_HAS_OFF:  		val = po->tp_tx_has_off;  |