diff options
Diffstat (limited to 'net/mptcp/protocol.c')
| -rw-r--r-- | net/mptcp/protocol.c | 50 | 
1 files changed, 34 insertions, 16 deletions
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index b7ad030dfe89..3ad9c46202fc 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -98,7 +98,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)  	struct socket *ssock;  	int err; -	err = mptcp_subflow_create_socket(sk, &ssock); +	err = mptcp_subflow_create_socket(sk, sk->sk_family, &ssock);  	if (err)  		return err; @@ -923,9 +923,8 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)  static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)  {  	struct mptcp_subflow_context *subflow; -	struct sock *sk = (struct sock *)msk; -	sock_owned_by_me(sk); +	msk_owned_by_me(msk);  	mptcp_for_each_subflow(msk, subflow) {  		if (READ_ONCE(subflow->data_avail)) @@ -1408,7 +1407,7 @@ static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk)  	u64 linger_time;  	long tout = 0; -	sock_owned_by_me(sk); +	msk_owned_by_me(msk);  	if (__mptcp_check_fallback(msk)) {  		if (!msk->first) @@ -1890,7 +1889,7 @@ static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)  	u32 time, advmss = 1;  	u64 rtt_us, mstamp; -	sock_owned_by_me(sk); +	msk_owned_by_me(msk);  	if (copied <= 0)  		return; @@ -2217,7 +2216,7 @@ static struct sock *mptcp_subflow_get_retrans(struct mptcp_sock *msk)  	struct mptcp_subflow_context *subflow;  	int min_stale_count = INT_MAX; -	sock_owned_by_me((const struct sock *)msk); +	msk_owned_by_me(msk);  	if (__mptcp_check_fallback(msk))  		return NULL; @@ -2724,8 +2723,8 @@ static int mptcp_init_sock(struct sock *sk)  	mptcp_ca_reset(sk);  	sk_sockets_allocated_inc(sk); -	sk->sk_rcvbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1]); -	sk->sk_sndbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1]); +	sk->sk_rcvbuf = READ_ONCE(net->ipv4.sysctl_tcp_rmem[1]); +	sk->sk_sndbuf = READ_ONCE(net->ipv4.sysctl_tcp_wmem[1]);  	return 0;  } @@ -2876,7 +2875,6 @@ static void __mptcp_destroy_sock(struct sock *sk)  	sk_stream_kill_queues(sk);  	xfrm_sk_free_policy(sk); -	sk_refcnt_debug_release(sk);  	sock_put(sk);  } @@ -2892,15 +2890,23 @@ static __poll_t mptcp_check_readable(struct mptcp_sock *msk)  	return EPOLLIN | EPOLLRDNORM;  } +static void mptcp_listen_inuse_dec(struct sock *sk) +{ +	if (inet_sk_state_load(sk) == TCP_LISTEN) +		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); +} +  bool __mptcp_close(struct sock *sk, long timeout)  {  	struct mptcp_subflow_context *subflow;  	struct mptcp_sock *msk = mptcp_sk(sk);  	bool do_cancel_work = false; +	int subflows_alive = 0;  	sk->sk_shutdown = SHUTDOWN_MASK;  	if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) { +		mptcp_listen_inuse_dec(sk);  		inet_sk_state_store(sk, TCP_CLOSE);  		goto cleanup;  	} @@ -2922,6 +2928,8 @@ cleanup:  		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);  		bool slow = lock_sock_fast_nested(ssk); +		subflows_alive += ssk->sk_state != TCP_CLOSE; +  		/* since the close timeout takes precedence on the fail one,  		 * cancel the latter  		 */ @@ -2937,6 +2945,12 @@ cleanup:  	}  	sock_orphan(sk); +	/* all the subflows are closed, only timeout can change the msk +	 * state, let's not keep resources busy for no reasons +	 */ +	if (subflows_alive == 0) +		inet_sk_state_store(sk, TCP_CLOSE); +  	sock_hold(sk);  	pr_debug("msk=%p state=%d", sk, sk->sk_state);  	if (msk->token) @@ -3001,6 +3015,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)  	if (msk->fastopening)  		return 0; +	mptcp_listen_inuse_dec(sk);  	inet_sk_state_store(sk, TCP_CLOSE);  	mptcp_stop_timer(sk); @@ -3639,12 +3654,13 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,  static int mptcp_listen(struct socket *sock, int backlog)  {  	struct mptcp_sock *msk = mptcp_sk(sock->sk); +	struct sock *sk = sock->sk;  	struct socket *ssock;  	int err;  	pr_debug("msk=%p", msk); -	lock_sock(sock->sk); +	lock_sock(sk);  	ssock = __mptcp_nmpc_socket(msk);  	if (!ssock) {  		err = -EINVAL; @@ -3652,18 +3668,20 @@ static int mptcp_listen(struct socket *sock, int backlog)  	}  	mptcp_token_destroy(msk); -	inet_sk_state_store(sock->sk, TCP_LISTEN); -	sock_set_flag(sock->sk, SOCK_RCU_FREE); +	inet_sk_state_store(sk, TCP_LISTEN); +	sock_set_flag(sk, SOCK_RCU_FREE);  	err = ssock->ops->listen(ssock, backlog); -	inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); -	if (!err) -		mptcp_copy_inaddrs(sock->sk, ssock->sk); +	inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); +	if (!err) { +		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); +		mptcp_copy_inaddrs(sk, ssock->sk); +	}  	mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED);  unlock: -	release_sock(sock->sk); +	release_sock(sk);  	return err;  }  |