diff options
Diffstat (limited to 'net/mptcp/protocol.c')
-rw-r--r-- | net/mptcp/protocol.c | 435 |
1 files changed, 225 insertions, 210 deletions
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 3980fbb6f31e..fa137a9c42d1 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -52,18 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) return msk->subflow; } -static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk) -{ - return msk->first && !sk_is_mptcp(msk->first); -} - -static struct socket *mptcp_is_tcpsk(struct sock *sk) +static bool mptcp_is_tcpsk(struct sock *sk) { struct socket *sock = sk->sk_socket; - if (sock->sk != sk) - return NULL; - if (unlikely(sk->sk_prot == &tcp_prot)) { /* we are being invoked after mptcp_accept() has * accepted a non-mp-capable flow: sk is a tcp_sk, @@ -73,59 +65,37 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk) * bypass mptcp. */ sock->ops = &inet_stream_ops; - return sock; + return true; #if IS_ENABLED(CONFIG_MPTCP_IPV6) } else if (unlikely(sk->sk_prot == &tcpv6_prot)) { sock->ops = &inet6_stream_ops; - return sock; + return true; #endif } - return NULL; + return false; } -static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk) +static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk) { - struct socket *sock; - sock_owned_by_me((const struct sock *)msk); - sock = mptcp_is_tcpsk((struct sock *)msk); - if (unlikely(sock)) - return sock; - - if (likely(!__mptcp_needs_tcp_fallback(msk))) + if (likely(!__mptcp_check_fallback(msk))) return NULL; - return msk->subflow; -} - -static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) -{ - return !msk->first; + return msk->first; } -static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) +static int __mptcp_socket_create(struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; struct sock *sk = (struct sock *)msk; struct socket *ssock; int err; - ssock = __mptcp_tcp_fallback(msk); - if (unlikely(ssock)) - return ssock; - - ssock = __mptcp_nmpc_socket(msk); - if (ssock) - goto set_state; - - if (!__mptcp_can_create_subflow(msk)) - return ERR_PTR(-EINVAL); - err = mptcp_subflow_create_socket(sk, &ssock); if (err) - return ERR_PTR(err); + return err; msk->first = ssock->sk; msk->subflow = ssock; @@ -133,10 +103,12 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) list_add(&subflow->node, &msk->conn_list); subflow->request_mptcp = 1; -set_state: - if (state != MPTCP_SAME_STATE) - inet_sk_state_store(sk, state); - return ssock; + /* accept() will wait on first subflow sk_wq, and we always wakes up + * via msk->sk_socket + */ + RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq); + + return 0; } static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk, @@ -207,13 +179,6 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk, return false; } - if (!(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { - int rcvbuf = max(ssk->sk_rcvbuf, sk->sk_rcvbuf); - - if (rcvbuf > sk->sk_rcvbuf) - sk->sk_rcvbuf = rcvbuf; - } - tp = tcp_sk(ssk); do { u32 map_remaining, offset; @@ -229,6 +194,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk, if (!skb) break; + if (__mptcp_check_fallback(msk)) { + /* if we are running under the workqueue, TCP could have + * collapsed skbs between dummy map creation and now + * be sure to adjust the size + */ + map_remaining = skb->len; + subflow->map_data_len = skb->len; + } + offset = seq - TCP_SKB_CB(skb)->seq; fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; if (fin) { @@ -466,8 +440,15 @@ static void mptcp_clean_una(struct sock *sk) { struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_data_frag *dtmp, *dfrag; - u64 snd_una = atomic64_read(&msk->snd_una); bool cleaned = false; + u64 snd_una; + + /* on fallback we just need to ignore snd_una, as this is really + * plain TCP + */ + if (__mptcp_check_fallback(msk)) + atomic64_set(&msk->snd_una, msk->write_seq); + snd_una = atomic64_read(&msk->snd_una); list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) { if (after64(dfrag->data_seq + dfrag->data_len, snd_una)) @@ -740,7 +721,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) int mss_now = 0, size_goal = 0, ret = 0; struct mptcp_sock *msk = mptcp_sk(sk); struct page_frag *pfrag; - struct socket *ssock; size_t copied = 0; struct sock *ssk; bool tx_ok; @@ -759,15 +739,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) goto out; } -fallback: - ssock = __mptcp_tcp_fallback(msk); - if (unlikely(ssock)) { - release_sock(sk); - pr_debug("fallback passthrough"); - ret = sock_sendmsg(ssock, msg); - return ret >= 0 ? ret + copied : (copied ? copied : ret); - } - pfrag = sk_page_frag(sk); restart: mptcp_clean_una(sk); @@ -819,17 +790,6 @@ wait_for_sndbuf: } break; } - if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) { - /* Can happen for passive sockets: - * 3WHS negotiated MPTCP, but first packet after is - * plain TCP (e.g. due to middlebox filtering unknown - * options). - * - * Fall back to TCP. - */ - release_sock(ssk); - goto fallback; - } copied += ret; @@ -949,6 +909,100 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk, return copied; } +/* receive buffer autotuning. See tcp_rcv_space_adjust for more information. + * + * Only difference: Use highest rtt estimate of the subflows in use. + */ +static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + u32 time, advmss = 1; + u64 rtt_us, mstamp; + + sock_owned_by_me(sk); + + if (copied <= 0) + return; + + msk->rcvq_space.copied += copied; + + mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC); + time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time); + + rtt_us = msk->rcvq_space.rtt_us; + if (rtt_us && time < (rtt_us >> 3)) + return; + + rtt_us = 0; + mptcp_for_each_subflow(msk, subflow) { + const struct tcp_sock *tp; + u64 sf_rtt_us; + u32 sf_advmss; + + tp = tcp_sk(mptcp_subflow_tcp_sock(subflow)); + + sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us); + sf_advmss = READ_ONCE(tp->advmss); + + rtt_us = max(sf_rtt_us, rtt_us); + advmss = max(sf_advmss, advmss); + } + + msk->rcvq_space.rtt_us = rtt_us; + if (time < (rtt_us >> 3) || rtt_us == 0) + return; + + if (msk->rcvq_space.copied <= msk->rcvq_space.space) + goto new_measure; + + if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf && + !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { + int rcvmem, rcvbuf; + u64 rcvwin, grow; + + rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss; + + grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space); + + do_div(grow, msk->rcvq_space.space); + rcvwin += (grow << 1); + + rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER); + while (tcp_win_from_space(sk, rcvmem) < advmss) + rcvmem += 128; + + do_div(rcvwin, advmss); + rcvbuf = min_t(u64, rcvwin * rcvmem, + sock_net(sk)->ipv4.sysctl_tcp_rmem[2]); + + if (rcvbuf > sk->sk_rcvbuf) { + u32 window_clamp; + + window_clamp = tcp_win_from_space(sk, rcvbuf); + WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); + + /* Make subflows follow along. If we do not do this, we + * get drops at subflow level if skbs can't be moved to + * the mptcp rx queue fast enough (announced rcv_win can + * exceed ssk->sk_rcvbuf). + */ + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk; + + ssk = mptcp_subflow_tcp_sock(subflow); + WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf); + tcp_sk(ssk)->window_clamp = window_clamp; + } + } + } + + msk->rcvq_space.space = msk->rcvq_space.copied; +new_measure: + msk->rcvq_space.copied = 0; + msk->rcvq_space.time = mstamp; +} + static bool __mptcp_move_skbs(struct mptcp_sock *msk) { unsigned int moved = 0; @@ -972,7 +1026,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, int flags, int *addr_len) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; int copied = 0; int target; long timeo; @@ -981,16 +1034,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, return -EOPNOTSUPP; lock_sock(sk); - ssock = __mptcp_tcp_fallback(msk); - if (unlikely(ssock)) { -fallback: - release_sock(sk); - pr_debug("fallback-read subflow=%p", - mptcp_subflow_ctx(ssock->sk)); - copied = sock_recvmsg(ssock, msg, flags); - return copied; - } - timeo = sock_rcvtimeo(sk, nonblock); len = min_t(size_t, len, INT_MAX); @@ -1056,9 +1099,6 @@ fallback: pr_debug("block timeout %ld", timeo); mptcp_wait_data(sk, &timeo); - ssock = __mptcp_tcp_fallback(msk); - if (unlikely(ssock)) - goto fallback; } if (skb_queue_empty(&sk->sk_receive_queue)) { @@ -1075,6 +1115,8 @@ fallback: set_bit(MPTCP_DATA_READY, &msk->flags); } out_err: + mptcp_rcv_space_adjust(msk, copied); + release_sock(sk); return copied; } @@ -1283,7 +1325,12 @@ static int mptcp_init_sock(struct sock *sk) if (ret) return ret; + ret = __mptcp_socket_create(mptcp_sk(sk)); + if (ret) + return ret; + sk_sockets_allocated_inc(sk); + sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1]; sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2]; return 0; @@ -1335,8 +1382,6 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how, break; } - /* Wake up anyone sleeping in poll. */ - ssk->sk_state_change(ssk); release_sock(ssk); } @@ -1448,20 +1493,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk, msk->token = subflow_req->token; msk->subflow = NULL; - if (unlikely(mptcp_token_new_accept(subflow_req->token, nsk))) { - nsk->sk_state = TCP_CLOSE; - bh_unlock_sock(nsk); - - /* we can't call into mptcp_close() here - possible BH context - * free the sock directly. - * sk_clone_lock() sets nsk refcnt to two, hence call sk_free() - * too. - */ - sk_common_release(nsk); - sk_free(nsk); - return NULL; - } - msk->write_seq = subflow_req->idsn + 1; atomic64_set(&msk->snd_una, msk->write_seq); if (mp_opt->mp_capable) { @@ -1482,6 +1513,22 @@ struct sock *mptcp_sk_clone(const struct sock *sk, return nsk; } +void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk) +{ + const struct tcp_sock *tp = tcp_sk(ssk); + + msk->rcvq_space.copied = 0; + msk->rcvq_space.rtt_us = 0; + + msk->rcvq_space.time = tp->tcp_mstamp; + + /* initial rcv_space offering made to peer */ + msk->rcvq_space.space = min_t(u32, tp->rcv_wnd, + TCP_INIT_CWND * tp->advmss); + if (msk->rcvq_space.space == 0) + msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT; +} + static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, bool kern) { @@ -1501,7 +1548,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, return NULL; pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); - if (sk_is_mptcp(newsk)) { struct mptcp_subflow_context *subflow; struct sock *new_mptcp_sock; @@ -1531,6 +1577,7 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, list_add(&subflow->node, &msk->conn_list); inet_sk_state_store(newsk, TCP_ESTABLISHED); + mptcp_rcv_space_init(msk, ssk); bh_unlock_sock(new_mptcp_sock); __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK); @@ -1547,7 +1594,7 @@ static void mptcp_destroy(struct sock *sk) { struct mptcp_sock *msk = mptcp_sk(sk); - mptcp_token_destroy(msk->token); + mptcp_token_destroy(msk); if (msk->cached_ext) __skb_ext_put(msk->cached_ext); @@ -1558,7 +1605,7 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname, char __user *optval, unsigned int optlen) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; + struct sock *ssk; pr_debug("msk=%p", msk); @@ -1569,11 +1616,10 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname, * to the one remaining subflow. */ lock_sock(sk); - ssock = __mptcp_tcp_fallback(msk); + ssk = __mptcp_tcp_fallback(msk); release_sock(sk); - if (ssock) - return tcp_setsockopt(ssock->sk, level, optname, optval, - optlen); + if (ssk) + return tcp_setsockopt(ssk, level, optname, optval, optlen); return -EOPNOTSUPP; } @@ -1582,7 +1628,7 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, int __user *option) { struct mptcp_sock *msk = mptcp_sk(sk); - struct socket *ssock; + struct sock *ssk; pr_debug("msk=%p", msk); @@ -1593,11 +1639,10 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname, * to the one remaining subflow. */ lock_sock(sk); - ssock = __mptcp_tcp_fallback(msk); + ssk = __mptcp_tcp_fallback(msk); release_sock(sk); - if (ssock) - return tcp_getsockopt(ssock->sk, level, optname, optval, - option); + if (ssk) + return tcp_getsockopt(ssk, level, optname, optval, option); return -EOPNOTSUPP; } @@ -1636,6 +1681,20 @@ static void mptcp_release_cb(struct sock *sk) } } +static int mptcp_hash(struct sock *sk) +{ + /* should never be called, + * we hash the TCP subflows not the master socket + */ + WARN_ON_ONCE(1); + return 0; +} + +static void mptcp_unhash(struct sock *sk) +{ + /* called from sk_common_release(), but nothing to do here */ +} + static int mptcp_get_port(struct sock *sk, unsigned short snum) { struct mptcp_sock *msk = mptcp_sk(sk); @@ -1660,12 +1719,6 @@ void mptcp_finish_connect(struct sock *ssk) sk = subflow->conn; msk = mptcp_sk(sk); - if (!subflow->mp_capable) { - MPTCP_INC_STATS(sock_net(sk), - MPTCP_MIB_MPCAPABLEACTIVEFALLBACK); - return; - } - pr_debug("msk=%p, token=%u", sk, subflow->token); mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq); @@ -1679,13 +1732,14 @@ void mptcp_finish_connect(struct sock *ssk) */ WRITE_ONCE(msk->remote_key, subflow->remote_key); WRITE_ONCE(msk->local_key, subflow->local_key); - WRITE_ONCE(msk->token, subflow->token); WRITE_ONCE(msk->write_seq, subflow->idsn + 1); WRITE_ONCE(msk->ack_seq, ack_seq); WRITE_ONCE(msk->can_ack, 1); atomic64_set(&msk->snd_una, msk->write_seq); mptcp_pm_new_connection(msk, 0); + + mptcp_rcv_space_init(msk, ssk); } static void mptcp_sock_graft(struct sock *sk, struct socket *parent) @@ -1761,8 +1815,8 @@ static struct proto mptcp_prot = { .sendmsg = mptcp_sendmsg, .recvmsg = mptcp_recvmsg, .release_cb = mptcp_release_cb, - .hash = inet_hash, - .unhash = inet_unhash, + .hash = mptcp_hash, + .unhash = mptcp_unhash, .get_port = mptcp_get_port, .sockets_allocated = &mptcp_sockets_allocated, .memory_allocated = &tcp_memory_allocated, @@ -1771,6 +1825,7 @@ static struct proto mptcp_prot = { .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), .sysctl_mem = sysctl_tcp_mem, .obj_size = sizeof(struct mptcp_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, .no_autobind = true, }; @@ -1781,9 +1836,9 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) int err; lock_sock(sock->sk); - ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); - if (IS_ERR(ssock)) { - err = PTR_ERR(ssock); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; goto unlock; } @@ -1800,6 +1855,7 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags) { struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct mptcp_subflow_context *subflow; struct socket *ssock; int err; @@ -1812,19 +1868,24 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, goto do_connect; } - ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); - if (IS_ERR(ssock)) { - err = PTR_ERR(ssock); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; goto unlock; } + mptcp_token_destroy(msk); + inet_sk_state_store(sock->sk, TCP_SYN_SENT); + subflow = mptcp_subflow_ctx(ssock->sk); #ifdef CONFIG_TCP_MD5SIG /* no MPTCP if MD5SIG is enabled on this socket or we may run out of * TCP option space. */ if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) - mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0; + subflow->request_mptcp = 0; #endif + if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk)) + subflow->request_mptcp = 0; do_connect: err = ssock->ops->connect(ssock, uaddr, addr_len, flags); @@ -1843,42 +1904,6 @@ unlock: return err; } -static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr, - int peer) -{ - if (sock->sk->sk_prot == &tcp_prot) { - /* we are being invoked from __sys_accept4, after - * mptcp_accept() has just accepted a non-mp-capable - * flow: sk is a tcp_sk, not an mptcp one. - * - * Hand the socket over to tcp so all further socket ops - * bypass mptcp. - */ - sock->ops = &inet_stream_ops; - } - - return inet_getname(sock, uaddr, peer); -} - -#if IS_ENABLED(CONFIG_MPTCP_IPV6) -static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr, - int peer) -{ - if (sock->sk->sk_prot == &tcpv6_prot) { - /* we are being invoked from __sys_accept4 after - * mptcp_accept() has accepted a non-mp-capable - * subflow: sk is a tcp_sk, not mptcp. - * - * Hand the socket over to tcp so all further - * socket ops bypass mptcp. - */ - sock->ops = &inet6_stream_ops; - } - - return inet6_getname(sock, uaddr, peer); -} -#endif - static int mptcp_listen(struct socket *sock, int backlog) { struct mptcp_sock *msk = mptcp_sk(sock->sk); @@ -1888,12 +1913,14 @@ static int mptcp_listen(struct socket *sock, int backlog) pr_debug("msk=%p", msk); lock_sock(sock->sk); - ssock = __mptcp_socket_create(msk, TCP_LISTEN); - if (IS_ERR(ssock)) { - err = PTR_ERR(ssock); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; goto unlock; } + mptcp_token_destroy(msk); + inet_sk_state_store(sock->sk, TCP_LISTEN); sock_set_flag(sock->sk, SOCK_RCU_FREE); err = ssock->ops->listen(ssock, backlog); @@ -1906,15 +1933,6 @@ unlock: return err; } -static bool is_tcp_proto(const struct proto *p) -{ -#if IS_ENABLED(CONFIG_MPTCP_IPV6) - return p == &tcp_prot || p == &tcpv6_prot; -#else - return p == &tcp_prot; -#endif -} - static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, int flags, bool kern) { @@ -1932,11 +1950,12 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, if (!ssock) goto unlock_fail; + clear_bit(MPTCP_DATA_READY, &msk->flags); sock_hold(ssock->sk); release_sock(sock->sk); err = ssock->ops->accept(sock, newsock, flags, kern); - if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { + if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) { struct mptcp_sock *msk = mptcp_sk(newsock->sk); struct mptcp_subflow_context *subflow; @@ -1952,6 +1971,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, } } + if (inet_csk_listen_poll(ssock->sk)) + set_bit(MPTCP_DATA_READY, &msk->flags); sock_put(ssock->sk); return err; @@ -1960,39 +1981,36 @@ unlock_fail: return -EINVAL; } +static __poll_t mptcp_check_readable(struct mptcp_sock *msk) +{ + return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM : + 0; +} + static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { struct sock *sk = sock->sk; struct mptcp_sock *msk; - struct socket *ssock; __poll_t mask = 0; + int state; msk = mptcp_sk(sk); - lock_sock(sk); - ssock = __mptcp_tcp_fallback(msk); - if (!ssock) - ssock = __mptcp_nmpc_socket(msk); - if (ssock) { - mask = ssock->ops->poll(file, ssock, wait); - release_sock(sk); - return mask; - } - - release_sock(sk); sock_poll_wait(file, sock, wait); - lock_sock(sk); - if (test_bit(MPTCP_DATA_READY, &msk->flags)) - mask = EPOLLIN | EPOLLRDNORM; - if (sk_stream_is_writeable(sk) && - test_bit(MPTCP_SEND_SPACE, &msk->flags)) - mask |= EPOLLOUT | EPOLLWRNORM; + state = inet_sk_state_load(sk); + if (state == TCP_LISTEN) + return mptcp_check_readable(msk); + + if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) { + mask |= mptcp_check_readable(msk); + if (sk_stream_is_writeable(sk) && + test_bit(MPTCP_SEND_SPACE, &msk->flags)) + mask |= EPOLLOUT | EPOLLWRNORM; + } if (sk->sk_shutdown & RCV_SHUTDOWN) mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; - release_sock(sk); - return mask; } @@ -2000,18 +2018,11 @@ static int mptcp_shutdown(struct socket *sock, int how) { struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_subflow_context *subflow; - struct socket *ssock; int ret = 0; pr_debug("sk=%p, how=%d", msk, how); lock_sock(sock->sk); - ssock = __mptcp_tcp_fallback(msk); - if (ssock) { - release_sock(sock->sk); - return inet_shutdown(ssock, how); - } - if (how == SHUT_WR || how == SHUT_RDWR) inet_sk_state_store(sock->sk, TCP_FIN_WAIT1); @@ -2037,6 +2048,9 @@ static int mptcp_shutdown(struct socket *sock, int how) mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq); } + /* Wake up anyone sleeping in poll. */ + sock->sk->sk_state_change(sock->sk); + out_unlock: release_sock(sock->sk); @@ -2051,7 +2065,7 @@ static const struct proto_ops mptcp_stream_ops = { .connect = mptcp_stream_connect, .socketpair = sock_no_socketpair, .accept = mptcp_stream_accept, - .getname = mptcp_v4_getname, + .getname = inet_getname, .poll = mptcp_poll, .ioctl = inet_ioctl, .gettstamp = sock_gettstamp, @@ -2077,7 +2091,7 @@ static struct inet_protosw mptcp_protosw = { .flags = INET_PROTOSW_ICSK, }; -void mptcp_proto_init(void) +void __init mptcp_proto_init(void) { mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; @@ -2086,6 +2100,7 @@ void mptcp_proto_init(void) mptcp_subflow_init(); mptcp_pm_init(); + mptcp_token_init(); if (proto_register(&mptcp_prot, 1) != 0) panic("Failed to register MPTCP proto.\n"); @@ -2104,7 +2119,7 @@ static const struct proto_ops mptcp_v6_stream_ops = { .connect = mptcp_stream_connect, .socketpair = sock_no_socketpair, .accept = mptcp_stream_accept, - .getname = mptcp_v6_getname, + .getname = inet6_getname, .poll = mptcp_poll, .ioctl = inet6_ioctl, .gettstamp = sock_gettstamp, @@ -2139,7 +2154,7 @@ static struct inet_protosw mptcp_v6_protosw = { .flags = INET_PROTOSW_ICSK, }; -int mptcp_proto_v6_init(void) +int __init mptcp_proto_v6_init(void) { int err; |