diff options
Diffstat (limited to 'net/packet/af_packet.c')
| -rw-r--r-- | net/packet/af_packet.c | 37 | 
1 files changed, 15 insertions, 22 deletions
| diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c index 737092ca9b4e..da215e5c1399 100644 --- a/net/packet/af_packet.c +++ b/net/packet/af_packet.c @@ -1687,7 +1687,6 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  		atomic_long_set(&rollover->num, 0);  		atomic_long_set(&rollover->num_huge, 0);  		atomic_long_set(&rollover->num_failed, 0); -		po->rollover = rollover;  	}  	if (type_flags & PACKET_FANOUT_FLAG_UNIQUEID) { @@ -1745,6 +1744,8 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  		if (refcount_read(&match->sk_ref) < PACKET_FANOUT_MAX) {  			__dev_remove_pack(&po->prot_hook);  			po->fanout = match; +			po->rollover = rollover; +			rollover = NULL;  			refcount_set(&match->sk_ref, refcount_read(&match->sk_ref) + 1);  			__fanout_link(sk, po);  			err = 0; @@ -1758,10 +1759,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)  	}  out: -	if (err && rollover) { -		kfree_rcu(rollover, rcu); -		po->rollover = NULL; -	} +	kfree(rollover);  	mutex_unlock(&fanout_mutex);  	return err;  } @@ -1785,11 +1783,6 @@ static struct packet_fanout *fanout_release(struct sock *sk)  			list_del(&f->list);  		else  			f = NULL; - -		if (po->rollover) { -			kfree_rcu(po->rollover, rcu); -			po->rollover = NULL; -		}  	}  	mutex_unlock(&fanout_mutex); @@ -3029,6 +3022,7 @@ static int packet_release(struct socket *sock)  	synchronize_net();  	if (f) { +		kfree(po->rollover);  		fanout_release_data(f);  		kfree(f);  	} @@ -3097,6 +3091,10 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,  	if (need_rehook) {  		if (po->running) {  			rcu_read_unlock(); +			/* prevents packet_notifier() from calling +			 * register_prot_hook() +			 */ +			po->num = 0;  			__unregister_prot_hook(sk, true);  			rcu_read_lock();  			dev_curr = po->prot_hook.dev; @@ -3105,6 +3103,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,  								 dev->ifindex);  		} +		BUG_ON(po->running);  		po->num = proto;  		po->prot_hook.type = proto; @@ -3843,7 +3842,6 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,  	void *data = &val;  	union tpacket_stats_u st;  	struct tpacket_rollover_stats rstats; -	struct packet_rollover *rollover;  	if (level != SOL_PACKET)  		return -ENOPROTOOPT; @@ -3922,18 +3920,13 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,  		       0);  		break;  	case PACKET_ROLLOVER_STATS: -		rcu_read_lock(); -		rollover = rcu_dereference(po->rollover); -		if (rollover) { -			rstats.tp_all = atomic_long_read(&rollover->num); -			rstats.tp_huge = atomic_long_read(&rollover->num_huge); -			rstats.tp_failed = atomic_long_read(&rollover->num_failed); -			data = &rstats; -			lv = sizeof(rstats); -		} -		rcu_read_unlock(); -		if (!rollover) +		if (!po->rollover)  			return -EINVAL; +		rstats.tp_all = atomic_long_read(&po->rollover->num); +		rstats.tp_huge = atomic_long_read(&po->rollover->num_huge); +		rstats.tp_failed = atomic_long_read(&po->rollover->num_failed); +		data = &rstats; +		lv = sizeof(rstats);  		break;  	case PACKET_TX_HAS_OFF:  		val = po->tp_tx_has_off; |