diff options
Diffstat (limited to 'net/mptcp/protocol.c')
| -rw-r--r-- | net/mptcp/protocol.c | 868 | 
1 files changed, 592 insertions, 276 deletions
| diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index c0abe738e7d3..49b815023986 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -16,6 +16,7 @@  #include <net/inet_hashtables.h>  #include <net/protocol.h>  #include <net/tcp.h> +#include <net/tcp_states.h>  #if IS_ENABLED(CONFIG_MPTCP_IPV6)  #include <net/transp_v6.h>  #endif @@ -52,18 +53,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)  	return msk->subflow;  } -static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk) -{ -	return msk->first && !sk_is_mptcp(msk->first); -} - -static struct socket *mptcp_is_tcpsk(struct sock *sk) +static bool mptcp_is_tcpsk(struct sock *sk)  {  	struct socket *sock = sk->sk_socket; -	if (sock->sk != sk) -		return NULL; -  	if (unlikely(sk->sk_prot == &tcp_prot)) {  		/* we are being invoked after mptcp_accept() has  		 * accepted a non-mp-capable flow: sk is a tcp_sk, @@ -73,59 +66,37 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)  		 * bypass mptcp.  		 */  		sock->ops = &inet_stream_ops; -		return sock; +		return true;  #if IS_ENABLED(CONFIG_MPTCP_IPV6)  	} else if (unlikely(sk->sk_prot == &tcpv6_prot)) {  		sock->ops = &inet6_stream_ops; -		return sock; +		return true;  #endif  	} -	return NULL; +	return false;  } -static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk) +static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)  { -	struct socket *sock; -  	sock_owned_by_me((const struct sock *)msk); -	sock = mptcp_is_tcpsk((struct sock *)msk); -	if (unlikely(sock)) -		return sock; - -	if (likely(!__mptcp_needs_tcp_fallback(msk))) +	if (likely(!__mptcp_check_fallback(msk)))  		return NULL; -	return msk->subflow; -} - -static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) -{ -	return !msk->first; +	return msk->first;  } -static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) +static int __mptcp_socket_create(struct mptcp_sock *msk)  {  	struct mptcp_subflow_context *subflow;  	struct sock *sk = (struct sock *)msk;  	struct socket *ssock;  	int err; -	ssock = __mptcp_tcp_fallback(msk); -	if (unlikely(ssock)) -		return ssock; - -	ssock = __mptcp_nmpc_socket(msk); -	if (ssock) -		goto set_state; - -	if (!__mptcp_can_create_subflow(msk)) -		return ERR_PTR(-EINVAL); -  	err = mptcp_subflow_create_socket(sk, &ssock);  	if (err) -		return ERR_PTR(err); +		return err;  	msk->first = ssock->sk;  	msk->subflow = ssock; @@ -133,10 +104,12 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)  	list_add(&subflow->node, &msk->conn_list);  	subflow->request_mptcp = 1; -set_state: -	if (state != MPTCP_SAME_STATE) -		inet_sk_state_store(sk, state); -	return ssock; +	/* accept() will wait on first subflow sk_wq, and we always wakes up +	 * via msk->sk_socket +	 */ +	RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq); + +	return 0;  }  static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk, @@ -170,6 +143,14 @@ static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,  	MPTCP_SKB_CB(skb)->offset = offset;  } +static void mptcp_stop_timer(struct sock *sk) +{ +	struct inet_connection_sock *icsk = inet_csk(sk); + +	sk_stop_timer(sk, &icsk->icsk_retransmit_timer); +	mptcp_sk(sk)->timer_ival = 0; +} +  /* both sockets must be locked */  static bool mptcp_subflow_dsn_valid(const struct mptcp_sock *msk,  				    struct sock *ssk) @@ -191,6 +172,138 @@ static bool mptcp_subflow_dsn_valid(const struct mptcp_sock *msk,  	return mptcp_subflow_data_available(ssk);  } +static void mptcp_check_data_fin_ack(struct sock *sk) +{ +	struct mptcp_sock *msk = mptcp_sk(sk); + +	if (__mptcp_check_fallback(msk)) +		return; + +	/* Look for an acknowledged DATA_FIN */ +	if (((1 << sk->sk_state) & +	     (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) && +	    msk->write_seq == atomic64_read(&msk->snd_una)) { +		mptcp_stop_timer(sk); + +		WRITE_ONCE(msk->snd_data_fin_enable, 0); + +		switch (sk->sk_state) { +		case TCP_FIN_WAIT1: +			inet_sk_state_store(sk, TCP_FIN_WAIT2); +			sk->sk_state_change(sk); +			break; +		case TCP_CLOSING: +		case TCP_LAST_ACK: +			inet_sk_state_store(sk, TCP_CLOSE); +			sk->sk_state_change(sk); +			break; +		} + +		if (sk->sk_shutdown == SHUTDOWN_MASK || +		    sk->sk_state == TCP_CLOSE) +			sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); +		else +			sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); +	} +} + +static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq) +{ +	struct mptcp_sock *msk = mptcp_sk(sk); + +	if (READ_ONCE(msk->rcv_data_fin) && +	    ((1 << sk->sk_state) & +	     (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) { +		u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq); + +		if (msk->ack_seq == rcv_data_fin_seq) { +			if (seq) +				*seq = rcv_data_fin_seq; + +			return true; +		} +	} + +	return false; +} + +static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk) +{ +	long tout = ssk && inet_csk(ssk)->icsk_pending ? +				      inet_csk(ssk)->icsk_timeout - jiffies : 0; + +	if (tout <= 0) +		tout = mptcp_sk(sk)->timer_ival; +	mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN; +} + +static void mptcp_check_data_fin(struct sock *sk) +{ +	struct mptcp_sock *msk = mptcp_sk(sk); +	u64 rcv_data_fin_seq; + +	if (__mptcp_check_fallback(msk) || !msk->first) +		return; + +	/* Need to ack a DATA_FIN received from a peer while this side +	 * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2. +	 * msk->rcv_data_fin was set when parsing the incoming options +	 * at the subflow level and the msk lock was not held, so this +	 * is the first opportunity to act on the DATA_FIN and change +	 * the msk state. +	 * +	 * If we are caught up to the sequence number of the incoming +	 * DATA_FIN, send the DATA_ACK now and do state transition.  If +	 * not caught up, do nothing and let the recv code send DATA_ACK +	 * when catching up. +	 */ + +	if (mptcp_pending_data_fin(sk, &rcv_data_fin_seq)) { +		struct mptcp_subflow_context *subflow; + +		msk->ack_seq++; +		WRITE_ONCE(msk->rcv_data_fin, 0); + +		sk->sk_shutdown |= RCV_SHUTDOWN; +		smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ +		set_bit(MPTCP_DATA_READY, &msk->flags); + +		switch (sk->sk_state) { +		case TCP_ESTABLISHED: +			inet_sk_state_store(sk, TCP_CLOSE_WAIT); +			break; +		case TCP_FIN_WAIT1: +			inet_sk_state_store(sk, TCP_CLOSING); +			break; +		case TCP_FIN_WAIT2: +			inet_sk_state_store(sk, TCP_CLOSE); +			// @@ Close subflows now? +			break; +		default: +			/* Other states not expected */ +			WARN_ON_ONCE(1); +			break; +		} + +		mptcp_set_timeout(sk, NULL); +		mptcp_for_each_subflow(msk, subflow) { +			struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + +			lock_sock(ssk); +			tcp_send_ack(ssk); +			release_sock(ssk); +		} + +		sk->sk_state_change(sk); + +		if (sk->sk_shutdown == SHUTDOWN_MASK || +		    sk->sk_state == TCP_CLOSE) +			sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); +		else +			sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); +	} +} +  static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,  					   struct sock *ssk,  					   unsigned int *bytes) @@ -207,13 +320,6 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,  		return false;  	} -	if (!(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { -		int rcvbuf = max(ssk->sk_rcvbuf, sk->sk_rcvbuf); - -		if (rcvbuf > sk->sk_rcvbuf) -			sk->sk_rcvbuf = rcvbuf; -	} -  	tp = tcp_sk(ssk);  	do {  		u32 map_remaining, offset; @@ -229,6 +335,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,  		if (!skb)  			break; +		if (__mptcp_check_fallback(msk)) { +			/* if we are running under the workqueue, TCP could have +			 * collapsed skbs between dummy map creation and now +			 * be sure to adjust the size +			 */ +			map_remaining = skb->len; +			subflow->map_data_len = skb->len; +		} +  		offset = seq - TCP_SKB_CB(skb)->seq;  		fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN;  		if (fin) { @@ -265,6 +380,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,  	*bytes = moved; +	/* If the moves have caught up with the DATA_FIN sequence number +	 * it's time to ack the DATA_FIN and change socket state, but +	 * this is not a good place to change state. Let the workqueue +	 * do it. +	 */ +	if (mptcp_pending_data_fin(sk, NULL) && +	    schedule_work(&msk->work)) +		sock_hold(sk); +  	return done;  } @@ -329,16 +453,6 @@ static void __mptcp_flush_join_list(struct mptcp_sock *msk)  	spin_unlock_bh(&msk->join_list_lock);  } -static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk) -{ -	long tout = ssk && inet_csk(ssk)->icsk_pending ? -				      inet_csk(ssk)->icsk_timeout - jiffies : 0; - -	if (tout <= 0) -		tout = mptcp_sk(sk)->timer_ival; -	mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN; -} -  static bool mptcp_timer_pending(struct sock *sk)  {  	return timer_pending(&inet_csk(sk)->icsk_retransmit_timer); @@ -360,7 +474,8 @@ void mptcp_data_acked(struct sock *sk)  {  	mptcp_reset_timer(sk); -	if (!sk_stream_is_writeable(sk) && +	if ((!sk_stream_is_writeable(sk) || +	     (inet_sk_state_load(sk) != TCP_ESTABLISHED)) &&  	    schedule_work(&mptcp_sk(sk)->work))  		sock_hold(sk);  } @@ -395,14 +510,6 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)  	}  } -static void mptcp_stop_timer(struct sock *sk) -{ -	struct inet_connection_sock *icsk = inet_csk(sk); - -	sk_stop_timer(sk, &icsk->icsk_retransmit_timer); -	mptcp_sk(sk)->timer_ival = 0; -} -  static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)  {  	const struct sock *sk = (const struct sock *)msk; @@ -466,8 +573,15 @@ static void mptcp_clean_una(struct sock *sk)  {  	struct mptcp_sock *msk = mptcp_sk(sk);  	struct mptcp_data_frag *dtmp, *dfrag; -	u64 snd_una = atomic64_read(&msk->snd_una);  	bool cleaned = false; +	u64 snd_una; + +	/* on fallback we just need to ignore snd_una, as this is really +	 * plain TCP +	 */ +	if (__mptcp_check_fallback(msk)) +		atomic64_set(&msk->snd_una, msk->write_seq); +	snd_una = atomic64_read(&msk->snd_una);  	list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {  		if (after64(dfrag->data_seq + dfrag->data_len, snd_una)) @@ -479,15 +593,20 @@ static void mptcp_clean_una(struct sock *sk)  	dfrag = mptcp_rtx_head(sk);  	if (dfrag && after64(snd_una, dfrag->data_seq)) { -		u64 delta = dfrag->data_seq + dfrag->data_len - snd_una; +		u64 delta = snd_una - dfrag->data_seq; + +		if (WARN_ON_ONCE(delta > dfrag->data_len)) +			goto out;  		dfrag->data_seq += delta; +		dfrag->offset += delta;  		dfrag->data_len -= delta;  		dfrag_uncharge(sk, delta);  		cleaned = true;  	} +out:  	if (cleaned) {  		sk_mem_reclaim_partial(sk); @@ -605,8 +724,10 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,  		if (!psize)  			return -EINVAL; -		if (!sk_wmem_schedule(sk, psize + dfrag->overhead)) +		if (!sk_wmem_schedule(sk, psize + dfrag->overhead)) { +			iov_iter_revert(&msg->msg_iter, psize);  			return -ENOMEM; +		}  	} else {  		offset = dfrag->offset;  		psize = min_t(size_t, dfrag->data_len, avail_size); @@ -617,8 +738,11 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,  	 */  	ret = do_tcp_sendpages(ssk, page, offset, psize,  			       msg->msg_flags | MSG_SENDPAGE_NOTLAST | MSG_DONTWAIT); -	if (ret <= 0) +	if (ret <= 0) { +		if (!retransmission) +			iov_iter_revert(&msg->msg_iter, psize);  		return ret; +	}  	frag_truesize += ret;  	if (!retransmission) { @@ -673,7 +797,7 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,  out:  	if (!retransmission)  		pfrag->offset += frag_truesize; -	*write_seq += ret; +	WRITE_ONCE(*write_seq, *write_seq + ret);  	mptcp_subflow_ctx(ssk)->rel_write_seq += ret;  	return ret; @@ -740,7 +864,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)  	int mss_now = 0, size_goal = 0, ret = 0;  	struct mptcp_sock *msk = mptcp_sk(sk);  	struct page_frag *pfrag; -	struct socket *ssock;  	size_t copied = 0;  	struct sock *ssk;  	bool tx_ok; @@ -759,19 +882,15 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)  			goto out;  	} -fallback: -	ssock = __mptcp_tcp_fallback(msk); -	if (unlikely(ssock)) { -		release_sock(sk); -		pr_debug("fallback passthrough"); -		ret = sock_sendmsg(ssock, msg); -		return ret >= 0 ? ret + copied : (copied ? copied : ret); -	} -  	pfrag = sk_page_frag(sk);  restart:  	mptcp_clean_una(sk); +	if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) { +		ret = -EPIPE; +		goto out; +	} +  wait_for_sndbuf:  	__mptcp_flush_join_list(msk);  	ssk = mptcp_subflow_get_send(msk); @@ -819,17 +938,6 @@ wait_for_sndbuf:  			}  			break;  		} -		if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) { -			/* Can happen for passive sockets: -			 * 3WHS negotiated MPTCP, but first packet after is -			 * plain TCP (e.g. due to middlebox filtering unknown -			 * options). -			 * -			 * Fall back to TCP. -			 */ -			release_sock(ssk); -			goto fallback; -		}  		copied += ret; @@ -880,7 +988,6 @@ wait_for_sndbuf:  	mptcp_set_timeout(sk, ssk);  	if (copied) { -		ret = copied;  		tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,  			 size_goal); @@ -893,7 +1000,7 @@ wait_for_sndbuf:  	release_sock(ssk);  out:  	release_sock(sk); -	return ret; +	return copied ? : ret;  }  static void mptcp_wait_data(struct sock *sk, long *timeo) @@ -949,6 +1056,100 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,  	return copied;  } +/* receive buffer autotuning.  See tcp_rcv_space_adjust for more information. + * + * Only difference: Use highest rtt estimate of the subflows in use. + */ +static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied) +{ +	struct mptcp_subflow_context *subflow; +	struct sock *sk = (struct sock *)msk; +	u32 time, advmss = 1; +	u64 rtt_us, mstamp; + +	sock_owned_by_me(sk); + +	if (copied <= 0) +		return; + +	msk->rcvq_space.copied += copied; + +	mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC); +	time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time); + +	rtt_us = msk->rcvq_space.rtt_us; +	if (rtt_us && time < (rtt_us >> 3)) +		return; + +	rtt_us = 0; +	mptcp_for_each_subflow(msk, subflow) { +		const struct tcp_sock *tp; +		u64 sf_rtt_us; +		u32 sf_advmss; + +		tp = tcp_sk(mptcp_subflow_tcp_sock(subflow)); + +		sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us); +		sf_advmss = READ_ONCE(tp->advmss); + +		rtt_us = max(sf_rtt_us, rtt_us); +		advmss = max(sf_advmss, advmss); +	} + +	msk->rcvq_space.rtt_us = rtt_us; +	if (time < (rtt_us >> 3) || rtt_us == 0) +		return; + +	if (msk->rcvq_space.copied <= msk->rcvq_space.space) +		goto new_measure; + +	if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf && +	    !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { +		int rcvmem, rcvbuf; +		u64 rcvwin, grow; + +		rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss; + +		grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space); + +		do_div(grow, msk->rcvq_space.space); +		rcvwin += (grow << 1); + +		rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER); +		while (tcp_win_from_space(sk, rcvmem) < advmss) +			rcvmem += 128; + +		do_div(rcvwin, advmss); +		rcvbuf = min_t(u64, rcvwin * rcvmem, +			       sock_net(sk)->ipv4.sysctl_tcp_rmem[2]); + +		if (rcvbuf > sk->sk_rcvbuf) { +			u32 window_clamp; + +			window_clamp = tcp_win_from_space(sk, rcvbuf); +			WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); + +			/* Make subflows follow along.  If we do not do this, we +			 * get drops at subflow level if skbs can't be moved to +			 * the mptcp rx queue fast enough (announced rcv_win can +			 * exceed ssk->sk_rcvbuf). +			 */ +			mptcp_for_each_subflow(msk, subflow) { +				struct sock *ssk; + +				ssk = mptcp_subflow_tcp_sock(subflow); +				WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf); +				tcp_sk(ssk)->window_clamp = window_clamp; +			} +		} +	} + +	msk->rcvq_space.space = msk->rcvq_space.copied; +new_measure: +	msk->rcvq_space.copied = 0; +	msk->rcvq_space.time = mstamp; +} +  static bool __mptcp_move_skbs(struct mptcp_sock *msk)  {  	unsigned int moved = 0; @@ -972,7 +1173,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,  			 int nonblock, int flags, int *addr_len)  {  	struct mptcp_sock *msk = mptcp_sk(sk); -	struct socket *ssock;  	int copied = 0;  	int target;  	long timeo; @@ -981,16 +1181,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,  		return -EOPNOTSUPP;  	lock_sock(sk); -	ssock = __mptcp_tcp_fallback(msk); -	if (unlikely(ssock)) { -fallback: -		release_sock(sk); -		pr_debug("fallback-read subflow=%p", -			 mptcp_subflow_ctx(ssock->sk)); -		copied = sock_recvmsg(ssock, msg, flags); -		return copied; -	} -  	timeo = sock_rcvtimeo(sk, nonblock);  	len = min_t(size_t, len, INT_MAX); @@ -1056,9 +1246,6 @@ fallback:  		pr_debug("block timeout %ld", timeo);  		mptcp_wait_data(sk, &timeo); -		ssock = __mptcp_tcp_fallback(msk); -		if (unlikely(ssock)) -			goto fallback;  	}  	if (skb_queue_empty(&sk->sk_receive_queue)) { @@ -1075,6 +1262,8 @@ fallback:  		set_bit(MPTCP_DATA_READY, &msk->flags);  	}  out_err: +	mptcp_rcv_space_adjust(msk, copied); +  	release_sock(sk);  	return copied;  } @@ -1083,7 +1272,7 @@ static void mptcp_retransmit_handler(struct sock *sk)  {  	struct mptcp_sock *msk = mptcp_sk(sk); -	if (atomic64_read(&msk->snd_una) == msk->write_seq) { +	if (atomic64_read(&msk->snd_una) == READ_ONCE(msk->write_seq)) {  		mptcp_stop_timer(sk);  	} else {  		set_bit(MPTCP_WORK_RTX, &msk->flags); @@ -1172,6 +1361,29 @@ static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)  	return 0;  } +static void pm_work(struct mptcp_sock *msk) +{ +	struct mptcp_pm_data *pm = &msk->pm; + +	spin_lock_bh(&msk->pm.lock); + +	pr_debug("msk=%p status=%x", msk, pm->status); +	if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { +		pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); +		mptcp_pm_nl_add_addr_received(msk); +	} +	if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) { +		pm->status &= ~BIT(MPTCP_PM_ESTABLISHED); +		mptcp_pm_nl_fully_established(msk); +	} +	if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) { +		pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED); +		mptcp_pm_nl_subflow_established(msk); +	} + +	spin_unlock_bh(&msk->pm.lock); +} +  static void mptcp_worker(struct work_struct *work)  {  	struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work); @@ -1180,17 +1392,25 @@ static void mptcp_worker(struct work_struct *work)  	struct mptcp_data_frag *dfrag;  	u64 orig_write_seq;  	size_t copied = 0; -	struct msghdr msg; +	struct msghdr msg = { +		.msg_flags = MSG_DONTWAIT, +	};  	long timeo = 0;  	lock_sock(sk);  	mptcp_clean_una(sk); +	mptcp_check_data_fin_ack(sk);  	__mptcp_flush_join_list(msk);  	__mptcp_move_skbs(msk); +	if (msk->pm.status) +		pm_work(msk); +  	if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))  		mptcp_check_for_eof(msk); +	mptcp_check_data_fin(sk); +  	if (!test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))  		goto unlock; @@ -1207,7 +1427,6 @@ static void mptcp_worker(struct work_struct *work)  	lock_sock(ssk); -	msg.msg_flags = MSG_DONTWAIT;  	orig_len = dfrag->data_len;  	orig_offset = dfrag->offset;  	orig_write_seq = dfrag->data_seq; @@ -1283,7 +1502,12 @@ static int mptcp_init_sock(struct sock *sk)  	if (ret)  		return ret; +	ret = __mptcp_socket_create(mptcp_sk(sk)); +	if (ret) +		return ret; +  	sk_sockets_allocated_inc(sk); +	sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1];  	sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2];  	return 0; @@ -1308,8 +1532,7 @@ static void mptcp_cancel_work(struct sock *sk)  		sock_put(sk);  } -static void mptcp_subflow_shutdown(struct sock *ssk, int how, -				   bool data_fin_tx_enable, u64 data_fin_tx_seq) +static void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)  {  	lock_sock(ssk); @@ -1317,41 +1540,89 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how,  	case TCP_LISTEN:  		if (!(how & RCV_SHUTDOWN))  			break; -		/* fall through */ +		fallthrough;  	case TCP_SYN_SENT:  		tcp_disconnect(ssk, O_NONBLOCK);  		break;  	default: -		if (data_fin_tx_enable) { -			struct mptcp_subflow_context *subflow; - -			subflow = mptcp_subflow_ctx(ssk); -			subflow->data_fin_tx_seq = data_fin_tx_seq; -			subflow->data_fin_tx_enable = 1; +		if (__mptcp_check_fallback(mptcp_sk(sk))) { +			pr_debug("Fallback"); +			ssk->sk_shutdown |= how; +			tcp_shutdown(ssk, how); +		} else { +			pr_debug("Sending DATA_FIN on subflow %p", ssk); +			mptcp_set_timeout(sk, ssk); +			tcp_send_ack(ssk);  		} - -		ssk->sk_shutdown |= how; -		tcp_shutdown(ssk, how);  		break;  	} -	/* Wake up anyone sleeping in poll. */ -	ssk->sk_state_change(ssk);  	release_sock(ssk);  } -/* Called with msk lock held, releases such lock before returning */ +static const unsigned char new_state[16] = { +	/* current state:     new state:      action:	*/ +	[0 /* (Invalid) */] = TCP_CLOSE, +	[TCP_ESTABLISHED]   = TCP_FIN_WAIT1 | TCP_ACTION_FIN, +	[TCP_SYN_SENT]      = TCP_CLOSE, +	[TCP_SYN_RECV]      = TCP_FIN_WAIT1 | TCP_ACTION_FIN, +	[TCP_FIN_WAIT1]     = TCP_FIN_WAIT1, +	[TCP_FIN_WAIT2]     = TCP_FIN_WAIT2, +	[TCP_TIME_WAIT]     = TCP_CLOSE,	/* should not happen ! */ +	[TCP_CLOSE]         = TCP_CLOSE, +	[TCP_CLOSE_WAIT]    = TCP_LAST_ACK  | TCP_ACTION_FIN, +	[TCP_LAST_ACK]      = TCP_LAST_ACK, +	[TCP_LISTEN]        = TCP_CLOSE, +	[TCP_CLOSING]       = TCP_CLOSING, +	[TCP_NEW_SYN_RECV]  = TCP_CLOSE,	/* should not happen ! */ +}; + +static int mptcp_close_state(struct sock *sk) +{ +	int next = (int)new_state[sk->sk_state]; +	int ns = next & TCP_STATE_MASK; + +	inet_sk_state_store(sk, ns); + +	return next & TCP_ACTION_FIN; +} +  static void mptcp_close(struct sock *sk, long timeout)  {  	struct mptcp_subflow_context *subflow, *tmp;  	struct mptcp_sock *msk = mptcp_sk(sk);  	LIST_HEAD(conn_list); -	u64 data_fin_tx_seq;  	lock_sock(sk); +	sk->sk_shutdown = SHUTDOWN_MASK; + +	if (sk->sk_state == TCP_LISTEN) { +		inet_sk_state_store(sk, TCP_CLOSE); +		goto cleanup; +	} else if (sk->sk_state == TCP_CLOSE) { +		goto cleanup; +	} + +	if (__mptcp_check_fallback(msk)) { +		goto update_state; +	} else if (mptcp_close_state(sk)) { +		pr_debug("Sending DATA_FIN sk=%p", sk); +		WRITE_ONCE(msk->write_seq, msk->write_seq + 1); +		WRITE_ONCE(msk->snd_data_fin_enable, 1); + +		mptcp_for_each_subflow(msk, subflow) { +			struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); + +			mptcp_subflow_shutdown(sk, tcp_sk, SHUTDOWN_MASK); +		} +	} +	sk_stream_wait_close(sk, timeout); + +update_state:  	inet_sk_state_store(sk, TCP_CLOSE); +cleanup:  	/* be sure to always acquire the join list lock, to sync vs  	 * mptcp_finish_join().  	 */ @@ -1360,22 +1631,16 @@ static void mptcp_close(struct sock *sk, long timeout)  	spin_unlock_bh(&msk->join_list_lock);  	list_splice_init(&msk->conn_list, &conn_list); -	data_fin_tx_seq = msk->write_seq; -  	__mptcp_clear_xmit(sk);  	release_sock(sk);  	list_for_each_entry_safe(subflow, tmp, &conn_list, node) {  		struct sock *ssk = mptcp_subflow_tcp_sock(subflow); - -		subflow->data_fin_tx_seq = data_fin_tx_seq; -		subflow->data_fin_tx_enable = 1;  		__mptcp_close_ssk(sk, ssk, subflow, timeout);  	}  	mptcp_cancel_work(sk); -	mptcp_pm_close(msk);  	__skb_queue_purge(&sk->sk_receive_queue); @@ -1447,20 +1712,7 @@ struct sock *mptcp_sk_clone(const struct sock *sk,  	msk->local_key = subflow_req->local_key;  	msk->token = subflow_req->token;  	msk->subflow = NULL; - -	if (unlikely(mptcp_token_new_accept(subflow_req->token, nsk))) { -		nsk->sk_state = TCP_CLOSE; -		bh_unlock_sock(nsk); - -		/* we can't call into mptcp_close() here - possible BH context -		 * free the sock directly. -		 * sk_clone_lock() sets nsk refcnt to two, hence call sk_free() -		 * too. -		 */ -		sk_common_release(nsk); -		sk_free(nsk); -		return NULL; -	} +	WRITE_ONCE(msk->fully_established, false);  	msk->write_seq = subflow_req->idsn + 1;  	atomic64_set(&msk->snd_una, msk->write_seq); @@ -1482,6 +1734,22 @@ struct sock *mptcp_sk_clone(const struct sock *sk,  	return nsk;  } +void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk) +{ +	const struct tcp_sock *tp = tcp_sk(ssk); + +	msk->rcvq_space.copied = 0; +	msk->rcvq_space.rtt_us = 0; + +	msk->rcvq_space.time = tp->tcp_mstamp; + +	/* initial rcv_space offering made to peer */ +	msk->rcvq_space.space = min_t(u32, tp->rcv_wnd, +				      TCP_INIT_CWND * tp->advmss); +	if (msk->rcvq_space.space == 0) +		msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT; +} +  static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,  				 bool kern)  { @@ -1501,7 +1769,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,  		return NULL;  	pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); -  	if (sk_is_mptcp(newsk)) {  		struct mptcp_subflow_context *subflow;  		struct sock *new_mptcp_sock; @@ -1529,8 +1796,8 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,  		newsk = new_mptcp_sock;  		mptcp_copy_inaddrs(newsk, ssk);  		list_add(&subflow->node, &msk->conn_list); -		inet_sk_state_store(newsk, TCP_ESTABLISHED); +		mptcp_rcv_space_init(msk, ssk);  		bh_unlock_sock(new_mptcp_sock);  		__MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK); @@ -1547,21 +1814,82 @@ static void mptcp_destroy(struct sock *sk)  {  	struct mptcp_sock *msk = mptcp_sk(sk); -	mptcp_token_destroy(msk->token); +	mptcp_token_destroy(msk);  	if (msk->cached_ext)  		__skb_ext_put(msk->cached_ext);  	sk_sockets_allocated_dec(sk);  } +static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, +				       sockptr_t optval, unsigned int optlen) +{ +	struct sock *sk = (struct sock *)msk; +	struct socket *ssock; +	int ret; + +	switch (optname) { +	case SO_REUSEPORT: +	case SO_REUSEADDR: +		lock_sock(sk); +		ssock = __mptcp_nmpc_socket(msk); +		if (!ssock) { +			release_sock(sk); +			return -EINVAL; +		} + +		ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen); +		if (ret == 0) { +			if (optname == SO_REUSEPORT) +				sk->sk_reuseport = ssock->sk->sk_reuseport; +			else if (optname == SO_REUSEADDR) +				sk->sk_reuse = ssock->sk->sk_reuse; +		} +		release_sock(sk); +		return ret; +	} + +	return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen); +} + +static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, +			       sockptr_t optval, unsigned int optlen) +{ +	struct sock *sk = (struct sock *)msk; +	int ret = -EOPNOTSUPP; +	struct socket *ssock; + +	switch (optname) { +	case IPV6_V6ONLY: +		lock_sock(sk); +		ssock = __mptcp_nmpc_socket(msk); +		if (!ssock) { +			release_sock(sk); +			return -EINVAL; +		} + +		ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen); +		if (ret == 0) +			sk->sk_ipv6only = ssock->sk->sk_ipv6only; + +		release_sock(sk); +		break; +	} + +	return ret; +} +  static int mptcp_setsockopt(struct sock *sk, int level, int optname, -			    char __user *optval, unsigned int optlen) +			    sockptr_t optval, unsigned int optlen)  {  	struct mptcp_sock *msk = mptcp_sk(sk); -	struct socket *ssock; +	struct sock *ssk;  	pr_debug("msk=%p", msk); +	if (level == SOL_SOCKET) +		return mptcp_setsockopt_sol_socket(msk, optname, optval, optlen); +  	/* @@ the meaning of setsockopt() when the socket is connected and  	 * there are multiple subflows is not yet defined. It is up to the  	 * MPTCP-level socket to configure the subflows until the subflow @@ -1569,11 +1897,13 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname,  	 * to the one remaining subflow.  	 */  	lock_sock(sk); -	ssock = __mptcp_tcp_fallback(msk); +	ssk = __mptcp_tcp_fallback(msk);  	release_sock(sk); -	if (ssock) -		return tcp_setsockopt(ssock->sk, level, optname, optval, -				      optlen); +	if (ssk) +		return tcp_setsockopt(ssk, level, optname, optval, optlen); + +	if (level == SOL_IPV6) +		return mptcp_setsockopt_v6(msk, optname, optval, optlen);  	return -EOPNOTSUPP;  } @@ -1582,7 +1912,7 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,  			    char __user *optval, int __user *option)  {  	struct mptcp_sock *msk = mptcp_sk(sk); -	struct socket *ssock; +	struct sock *ssk;  	pr_debug("msk=%p", msk); @@ -1593,11 +1923,10 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,  	 * to the one remaining subflow.  	 */  	lock_sock(sk); -	ssock = __mptcp_tcp_fallback(msk); +	ssk = __mptcp_tcp_fallback(msk);  	release_sock(sk); -	if (ssock) -		return tcp_getsockopt(ssock->sk, level, optname, optval, -				      option); +	if (ssk) +		return tcp_getsockopt(ssk, level, optname, optval, option);  	return -EOPNOTSUPP;  } @@ -1636,6 +1965,20 @@ static void mptcp_release_cb(struct sock *sk)  	}  } +static int mptcp_hash(struct sock *sk) +{ +	/* should never be called, +	 * we hash the TCP subflows not the master socket +	 */ +	WARN_ON_ONCE(1); +	return 0; +} + +static void mptcp_unhash(struct sock *sk) +{ +	/* called from sk_common_release(), but nothing to do here */ +} +  static int mptcp_get_port(struct sock *sk, unsigned short snum)  {  	struct mptcp_sock *msk = mptcp_sk(sk); @@ -1660,32 +2003,26 @@ void mptcp_finish_connect(struct sock *ssk)  	sk = subflow->conn;  	msk = mptcp_sk(sk); -	if (!subflow->mp_capable) { -		MPTCP_INC_STATS(sock_net(sk), -				MPTCP_MIB_MPCAPABLEACTIVEFALLBACK); -		return; -	} -  	pr_debug("msk=%p, token=%u", sk, subflow->token);  	mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);  	ack_seq++;  	subflow->map_seq = ack_seq;  	subflow->map_subflow_seq = 1; -	subflow->rel_write_seq = 1;  	/* the socket is not connected yet, no msk/subflow ops can access/race  	 * accessing the field below  	 */  	WRITE_ONCE(msk->remote_key, subflow->remote_key);  	WRITE_ONCE(msk->local_key, subflow->local_key); -	WRITE_ONCE(msk->token, subflow->token);  	WRITE_ONCE(msk->write_seq, subflow->idsn + 1);  	WRITE_ONCE(msk->ack_seq, ack_seq);  	WRITE_ONCE(msk->can_ack, 1);  	atomic64_set(&msk->snd_una, msk->write_seq);  	mptcp_pm_new_connection(msk, 0); + +	mptcp_rcv_space_init(msk, ssk);  }  static void mptcp_sock_graft(struct sock *sk, struct socket *parent) @@ -1708,7 +2045,7 @@ bool mptcp_finish_join(struct sock *sk)  	pr_debug("msk=%p, subflow=%p", msk, subflow);  	/* mptcp socket already closing? */ -	if (inet_sk_state_load(parent) != TCP_ESTABLISHED) +	if (!mptcp_is_fully_established(parent))  		return false;  	if (!msk->pm.server_side) @@ -1761,8 +2098,8 @@ static struct proto mptcp_prot = {  	.sendmsg	= mptcp_sendmsg,  	.recvmsg	= mptcp_recvmsg,  	.release_cb	= mptcp_release_cb, -	.hash		= inet_hash, -	.unhash		= inet_unhash, +	.hash		= mptcp_hash, +	.unhash		= mptcp_unhash,  	.get_port	= mptcp_get_port,  	.sockets_allocated	= &mptcp_sockets_allocated,  	.memory_allocated	= &tcp_memory_allocated, @@ -1771,6 +2108,7 @@ static struct proto mptcp_prot = {  	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_tcp_wmem),  	.sysctl_mem	= sysctl_tcp_mem,  	.obj_size	= sizeof(struct mptcp_sock), +	.slab_flags	= SLAB_TYPESAFE_BY_RCU,  	.no_autobind	= true,  }; @@ -1781,9 +2119,9 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)  	int err;  	lock_sock(sock->sk); -	ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); -	if (IS_ERR(ssock)) { -		err = PTR_ERR(ssock); +	ssock = __mptcp_nmpc_socket(msk); +	if (!ssock) { +		err = -EINVAL;  		goto unlock;  	} @@ -1796,10 +2134,18 @@ unlock:  	return err;  } +static void mptcp_subflow_early_fallback(struct mptcp_sock *msk, +					 struct mptcp_subflow_context *subflow) +{ +	subflow->request_mptcp = 0; +	__mptcp_do_fallback(msk); +} +  static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,  				int addr_len, int flags)  {  	struct mptcp_sock *msk = mptcp_sk(sock->sk); +	struct mptcp_subflow_context *subflow;  	struct socket *ssock;  	int err; @@ -1812,19 +2158,24 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,  		goto do_connect;  	} -	ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); -	if (IS_ERR(ssock)) { -		err = PTR_ERR(ssock); +	ssock = __mptcp_nmpc_socket(msk); +	if (!ssock) { +		err = -EINVAL;  		goto unlock;  	} +	mptcp_token_destroy(msk); +	inet_sk_state_store(sock->sk, TCP_SYN_SENT); +	subflow = mptcp_subflow_ctx(ssock->sk);  #ifdef CONFIG_TCP_MD5SIG  	/* no MPTCP if MD5SIG is enabled on this socket or we may run out of  	 * TCP option space.  	 */  	if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) -		mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0; +		mptcp_subflow_early_fallback(msk, subflow);  #endif +	if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk)) +		mptcp_subflow_early_fallback(msk, subflow);  do_connect:  	err = ssock->ops->connect(ssock, uaddr, addr_len, flags); @@ -1843,42 +2194,6 @@ unlock:  	return err;  } -static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr, -			    int peer) -{ -	if (sock->sk->sk_prot == &tcp_prot) { -		/* we are being invoked from __sys_accept4, after -		 * mptcp_accept() has just accepted a non-mp-capable -		 * flow: sk is a tcp_sk, not an mptcp one. -		 * -		 * Hand the socket over to tcp so all further socket ops -		 * bypass mptcp. -		 */ -		sock->ops = &inet_stream_ops; -	} - -	return inet_getname(sock, uaddr, peer); -} - -#if IS_ENABLED(CONFIG_MPTCP_IPV6) -static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr, -			    int peer) -{ -	if (sock->sk->sk_prot == &tcpv6_prot) { -		/* we are being invoked from __sys_accept4 after -		 * mptcp_accept() has accepted a non-mp-capable -		 * subflow: sk is a tcp_sk, not mptcp. -		 * -		 * Hand the socket over to tcp so all further -		 * socket ops bypass mptcp. -		 */ -		sock->ops = &inet6_stream_ops; -	} - -	return inet6_getname(sock, uaddr, peer); -} -#endif -  static int mptcp_listen(struct socket *sock, int backlog)  {  	struct mptcp_sock *msk = mptcp_sk(sock->sk); @@ -1888,12 +2203,14 @@ static int mptcp_listen(struct socket *sock, int backlog)  	pr_debug("msk=%p", msk);  	lock_sock(sock->sk); -	ssock = __mptcp_socket_create(msk, TCP_LISTEN); -	if (IS_ERR(ssock)) { -		err = PTR_ERR(ssock); +	ssock = __mptcp_nmpc_socket(msk); +	if (!ssock) { +		err = -EINVAL;  		goto unlock;  	} +	mptcp_token_destroy(msk); +	inet_sk_state_store(sock->sk, TCP_LISTEN);  	sock_set_flag(sock->sk, SOCK_RCU_FREE);  	err = ssock->ops->listen(ssock, backlog); @@ -1906,15 +2223,6 @@ unlock:  	return err;  } -static bool is_tcp_proto(const struct proto *p) -{ -#if IS_ENABLED(CONFIG_MPTCP_IPV6) -	return p == &tcp_prot || p == &tcpv6_prot; -#else -	return p == &tcp_prot; -#endif -} -  static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,  			       int flags, bool kern)  { @@ -1932,11 +2240,12 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,  	if (!ssock)  		goto unlock_fail; +	clear_bit(MPTCP_DATA_READY, &msk->flags);  	sock_hold(ssock->sk);  	release_sock(sock->sk);  	err = ssock->ops->accept(sock, newsock, flags, kern); -	if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { +	if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {  		struct mptcp_sock *msk = mptcp_sk(newsock->sk);  		struct mptcp_subflow_context *subflow; @@ -1944,7 +2253,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,  		 * This is needed so NOSPACE flag can be set from tcp stack.  		 */  		__mptcp_flush_join_list(msk); -		list_for_each_entry(subflow, &msk->conn_list, node) { +		mptcp_for_each_subflow(msk, subflow) {  			struct sock *ssk = mptcp_subflow_tcp_sock(subflow);  			if (!ssk->sk_socket) @@ -1952,6 +2261,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,  		}  	} +	if (inet_csk_listen_poll(ssock->sk)) +		set_bit(MPTCP_DATA_READY, &msk->flags);  	sock_put(ssock->sk);  	return err; @@ -1960,39 +2271,36 @@ unlock_fail:  	return -EINVAL;  } +static __poll_t mptcp_check_readable(struct mptcp_sock *msk) +{ +	return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM : +	       0; +} +  static __poll_t mptcp_poll(struct file *file, struct socket *sock,  			   struct poll_table_struct *wait)  {  	struct sock *sk = sock->sk;  	struct mptcp_sock *msk; -	struct socket *ssock;  	__poll_t mask = 0; +	int state;  	msk = mptcp_sk(sk); -	lock_sock(sk); -	ssock = __mptcp_tcp_fallback(msk); -	if (!ssock) -		ssock = __mptcp_nmpc_socket(msk); -	if (ssock) { -		mask = ssock->ops->poll(file, ssock, wait); -		release_sock(sk); -		return mask; -	} - -	release_sock(sk);  	sock_poll_wait(file, sock, wait); -	lock_sock(sk); -	if (test_bit(MPTCP_DATA_READY, &msk->flags)) -		mask = EPOLLIN | EPOLLRDNORM; -	if (sk_stream_is_writeable(sk) && -	    test_bit(MPTCP_SEND_SPACE, &msk->flags)) -		mask |= EPOLLOUT | EPOLLWRNORM; +	state = inet_sk_state_load(sk); +	if (state == TCP_LISTEN) +		return mptcp_check_readable(msk); + +	if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) { +		mask |= mptcp_check_readable(msk); +		if (sk_stream_is_writeable(sk) && +		    test_bit(MPTCP_SEND_SPACE, &msk->flags)) +			mask |= EPOLLOUT | EPOLLWRNORM; +	}  	if (sk->sk_shutdown & RCV_SHUTDOWN)  		mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; -	release_sock(sk); -  	return mask;  } @@ -2000,23 +2308,13 @@ static int mptcp_shutdown(struct socket *sock, int how)  {  	struct mptcp_sock *msk = mptcp_sk(sock->sk);  	struct mptcp_subflow_context *subflow; -	struct socket *ssock;  	int ret = 0;  	pr_debug("sk=%p, how=%d", msk, how);  	lock_sock(sock->sk); -	ssock = __mptcp_tcp_fallback(msk); -	if (ssock) { -		release_sock(sock->sk); -		return inet_shutdown(ssock, how); -	} - -	if (how == SHUT_WR || how == SHUT_RDWR) -		inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);  	how++; -  	if ((how & ~SHUTDOWN_MASK) || !how) {  		ret = -EINVAL;  		goto out_unlock; @@ -2030,13 +2328,36 @@ static int mptcp_shutdown(struct socket *sock, int how)  			sock->state = SS_CONNECTED;  	} -	__mptcp_flush_join_list(msk); -	mptcp_for_each_subflow(msk, subflow) { -		struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); +	/* If we've already sent a FIN, or it's a closed state, skip this. */ +	if (__mptcp_check_fallback(msk)) { +		if (how == SHUT_WR || how == SHUT_RDWR) +			inet_sk_state_store(sock->sk, TCP_FIN_WAIT1); -		mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq); +		mptcp_for_each_subflow(msk, subflow) { +			struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); + +			mptcp_subflow_shutdown(sock->sk, tcp_sk, how); +		} +	} else if ((how & SEND_SHUTDOWN) && +		   ((1 << sock->sk->sk_state) & +		    (TCPF_ESTABLISHED | TCPF_SYN_SENT | +		     TCPF_SYN_RECV | TCPF_CLOSE_WAIT)) && +		   mptcp_close_state(sock->sk)) { +		__mptcp_flush_join_list(msk); + +		WRITE_ONCE(msk->write_seq, msk->write_seq + 1); +		WRITE_ONCE(msk->snd_data_fin_enable, 1); + +		mptcp_for_each_subflow(msk, subflow) { +			struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); + +			mptcp_subflow_shutdown(sock->sk, tcp_sk, how); +		}  	} +	/* Wake up anyone sleeping in poll. */ +	sock->sk->sk_state_change(sock->sk); +  out_unlock:  	release_sock(sock->sk); @@ -2051,7 +2372,7 @@ static const struct proto_ops mptcp_stream_ops = {  	.connect	   = mptcp_stream_connect,  	.socketpair	   = sock_no_socketpair,  	.accept		   = mptcp_stream_accept, -	.getname	   = mptcp_v4_getname, +	.getname	   = inet_getname,  	.poll		   = mptcp_poll,  	.ioctl		   = inet_ioctl,  	.gettstamp	   = sock_gettstamp, @@ -2063,10 +2384,6 @@ static const struct proto_ops mptcp_stream_ops = {  	.recvmsg	   = inet_recvmsg,  	.mmap		   = sock_no_mmap,  	.sendpage	   = inet_sendpage, -#ifdef CONFIG_COMPAT -	.compat_setsockopt = compat_sock_common_setsockopt, -	.compat_getsockopt = compat_sock_common_getsockopt, -#endif  };  static struct inet_protosw mptcp_protosw = { @@ -2077,7 +2394,7 @@ static struct inet_protosw mptcp_protosw = {  	.flags		= INET_PROTOSW_ICSK,  }; -void mptcp_proto_init(void) +void __init mptcp_proto_init(void)  {  	mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; @@ -2086,6 +2403,7 @@ void mptcp_proto_init(void)  	mptcp_subflow_init();  	mptcp_pm_init(); +	mptcp_token_init();  	if (proto_register(&mptcp_prot, 1) != 0)  		panic("Failed to register MPTCP proto.\n"); @@ -2104,7 +2422,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {  	.connect	   = mptcp_stream_connect,  	.socketpair	   = sock_no_socketpair,  	.accept		   = mptcp_stream_accept, -	.getname	   = mptcp_v6_getname, +	.getname	   = inet6_getname,  	.poll		   = mptcp_poll,  	.ioctl		   = inet6_ioctl,  	.gettstamp	   = sock_gettstamp, @@ -2118,8 +2436,6 @@ static const struct proto_ops mptcp_v6_stream_ops = {  	.sendpage	   = inet_sendpage,  #ifdef CONFIG_COMPAT  	.compat_ioctl	   = inet6_compat_ioctl, -	.compat_setsockopt = compat_sock_common_setsockopt, -	.compat_getsockopt = compat_sock_common_getsockopt,  #endif  }; @@ -2139,7 +2455,7 @@ static struct inet_protosw mptcp_v6_protosw = {  	.flags		= INET_PROTOSW_ICSK,  }; -int mptcp_proto_v6_init(void) +int __init mptcp_proto_v6_init(void)  {  	int err; |