aboutsummaryrefslogtreecommitdiff
path: root/net/ipv4/tcp.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/tcp.c')
-rw-r--r--net/ipv4/tcp.c307
1 files changed, 286 insertions, 21 deletions
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index e03a342c9162..4f77bd862e95 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -285,6 +285,8 @@
#include <trace/events/tcp.h>
#include <net/rps.h>
+#include "../core/devmem.h"
+
/* Track pending CMSGs. */
enum {
TCP_CMSG_INQ = 1,
@@ -471,6 +473,7 @@ void tcp_init_sock(struct sock *sk)
set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
sk_sockets_allocated_inc(sk);
+ xa_init_flags(&sk->sk_user_frags, XA_FLAGS_ALLOC1);
}
EXPORT_SYMBOL(tcp_init_sock);
@@ -2160,6 +2163,9 @@ static int tcp_zerocopy_receive(struct sock *sk,
skb = tcp_recv_skb(sk, seq, &offset);
}
+ if (!skb_frags_readable(skb))
+ break;
+
if (TCP_SKB_CB(skb)->has_rxtstamp) {
tcp_update_recv_tstamps(skb, tss);
zc->msg_flags |= TCP_CMSG_TS;
@@ -2177,6 +2183,9 @@ static int tcp_zerocopy_receive(struct sock *sk,
break;
}
page = skb_frag_page(frags);
+ if (WARN_ON_ONCE(!page))
+ break;
+
prefetchw(page);
pages[pages_to_map++] = page;
length += PAGE_SIZE;
@@ -2235,6 +2244,7 @@ void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
struct scm_timestamping_internal *tss)
{
int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW);
+ u32 tsflags = READ_ONCE(sk->sk_tsflags);
bool has_timestamping = false;
if (tss->ts[0].tv_sec || tss->ts[0].tv_nsec) {
@@ -2274,14 +2284,18 @@ void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
}
}
- if (READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_SOFTWARE)
+ if (tsflags & SOF_TIMESTAMPING_SOFTWARE &&
+ (tsflags & SOF_TIMESTAMPING_RX_SOFTWARE ||
+ !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER)))
has_timestamping = true;
else
tss->ts[0] = (struct timespec64) {0};
}
if (tss->ts[2].tv_sec || tss->ts[2].tv_nsec) {
- if (READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_RAW_HARDWARE)
+ if (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE &&
+ (tsflags & SOF_TIMESTAMPING_RX_HARDWARE ||
+ !(tsflags & SOF_TIMESTAMPING_OPT_RX_FILTER)))
has_timestamping = true;
else
tss->ts[2] = (struct timespec64) {0};
@@ -2317,6 +2331,220 @@ static int tcp_inq_hint(struct sock *sk)
return inq;
}
+/* batch __xa_alloc() calls and reduce xa_lock()/xa_unlock() overhead. */
+struct tcp_xa_pool {
+ u8 max; /* max <= MAX_SKB_FRAGS */
+ u8 idx; /* idx <= max */
+ __u32 tokens[MAX_SKB_FRAGS];
+ netmem_ref netmems[MAX_SKB_FRAGS];
+};
+
+static void tcp_xa_pool_commit_locked(struct sock *sk, struct tcp_xa_pool *p)
+{
+ int i;
+
+ /* Commit part that has been copied to user space. */
+ for (i = 0; i < p->idx; i++)
+ __xa_cmpxchg(&sk->sk_user_frags, p->tokens[i], XA_ZERO_ENTRY,
+ (__force void *)p->netmems[i], GFP_KERNEL);
+ /* Rollback what has been pre-allocated and is no longer needed. */
+ for (; i < p->max; i++)
+ __xa_erase(&sk->sk_user_frags, p->tokens[i]);
+
+ p->max = 0;
+ p->idx = 0;
+}
+
+static void tcp_xa_pool_commit(struct sock *sk, struct tcp_xa_pool *p)
+{
+ if (!p->max)
+ return;
+
+ xa_lock_bh(&sk->sk_user_frags);
+
+ tcp_xa_pool_commit_locked(sk, p);
+
+ xa_unlock_bh(&sk->sk_user_frags);
+}
+
+static int tcp_xa_pool_refill(struct sock *sk, struct tcp_xa_pool *p,
+ unsigned int max_frags)
+{
+ int err, k;
+
+ if (p->idx < p->max)
+ return 0;
+
+ xa_lock_bh(&sk->sk_user_frags);
+
+ tcp_xa_pool_commit_locked(sk, p);
+
+ for (k = 0; k < max_frags; k++) {
+ err = __xa_alloc(&sk->sk_user_frags, &p->tokens[k],
+ XA_ZERO_ENTRY, xa_limit_31b, GFP_KERNEL);
+ if (err)
+ break;
+ }
+
+ xa_unlock_bh(&sk->sk_user_frags);
+
+ p->max = k;
+ p->idx = 0;
+ return k ? 0 : err;
+}
+
+/* On error, returns the -errno. On success, returns number of bytes sent to the
+ * user. May not consume all of @remaining_len.
+ */
+static int tcp_recvmsg_dmabuf(struct sock *sk, const struct sk_buff *skb,
+ unsigned int offset, struct msghdr *msg,
+ int remaining_len)
+{
+ struct dmabuf_cmsg dmabuf_cmsg = { 0 };
+ struct tcp_xa_pool tcp_xa_pool;
+ unsigned int start;
+ int i, copy, n;
+ int sent = 0;
+ int err = 0;
+
+ tcp_xa_pool.max = 0;
+ tcp_xa_pool.idx = 0;
+ do {
+ start = skb_headlen(skb);
+
+ if (skb_frags_readable(skb)) {
+ err = -ENODEV;
+ goto out;
+ }
+
+ /* Copy header. */
+ copy = start - offset;
+ if (copy > 0) {
+ copy = min(copy, remaining_len);
+
+ n = copy_to_iter(skb->data + offset, copy,
+ &msg->msg_iter);
+ if (n != copy) {
+ err = -EFAULT;
+ goto out;
+ }
+
+ offset += copy;
+ remaining_len -= copy;
+
+ /* First a dmabuf_cmsg for # bytes copied to user
+ * buffer.
+ */
+ memset(&dmabuf_cmsg, 0, sizeof(dmabuf_cmsg));
+ dmabuf_cmsg.frag_size = copy;
+ err = put_cmsg(msg, SOL_SOCKET, SO_DEVMEM_LINEAR,
+ sizeof(dmabuf_cmsg), &dmabuf_cmsg);
+ if (err || msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~MSG_CTRUNC;
+ if (!err)
+ err = -ETOOSMALL;
+ goto out;
+ }
+
+ sent += copy;
+
+ if (remaining_len == 0)
+ goto out;
+ }
+
+ /* after that, send information of dmabuf pages through a
+ * sequence of cmsg
+ */
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+ skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
+ struct net_iov *niov;
+ u64 frag_offset;
+ int end;
+
+ /* !skb_frags_readable() should indicate that ALL the
+ * frags in this skb are dmabuf net_iovs. We're checking
+ * for that flag above, but also check individual frags
+ * here. If the tcp stack is not setting
+ * skb_frags_readable() correctly, we still don't want
+ * to crash here.
+ */
+ if (!skb_frag_net_iov(frag)) {
+ net_err_ratelimited("Found non-dmabuf skb with net_iov");
+ err = -ENODEV;
+ goto out;
+ }
+
+ niov = skb_frag_net_iov(frag);
+ end = start + skb_frag_size(frag);
+ copy = end - offset;
+
+ if (copy > 0) {
+ copy = min(copy, remaining_len);
+
+ frag_offset = net_iov_virtual_addr(niov) +
+ skb_frag_off(frag) + offset -
+ start;
+ dmabuf_cmsg.frag_offset = frag_offset;
+ dmabuf_cmsg.frag_size = copy;
+ err = tcp_xa_pool_refill(sk, &tcp_xa_pool,
+ skb_shinfo(skb)->nr_frags - i);
+ if (err)
+ goto out;
+
+ /* Will perform the exchange later */
+ dmabuf_cmsg.frag_token = tcp_xa_pool.tokens[tcp_xa_pool.idx];
+ dmabuf_cmsg.dmabuf_id = net_iov_binding_id(niov);
+
+ offset += copy;
+ remaining_len -= copy;
+
+ err = put_cmsg(msg, SOL_SOCKET,
+ SO_DEVMEM_DMABUF,
+ sizeof(dmabuf_cmsg),
+ &dmabuf_cmsg);
+ if (err || msg->msg_flags & MSG_CTRUNC) {
+ msg->msg_flags &= ~MSG_CTRUNC;
+ if (!err)
+ err = -ETOOSMALL;
+ goto out;
+ }
+
+ atomic_long_inc(&niov->pp_ref_count);
+ tcp_xa_pool.netmems[tcp_xa_pool.idx++] = skb_frag_netmem(frag);
+
+ sent += copy;
+
+ if (remaining_len == 0)
+ goto out;
+ }
+ start = end;
+ }
+
+ tcp_xa_pool_commit(sk, &tcp_xa_pool);
+ if (!remaining_len)
+ goto out;
+
+ /* if remaining_len is not satisfied yet, we need to go to the
+ * next frag in the frag_list to satisfy remaining_len.
+ */
+ skb = skb_shinfo(skb)->frag_list ?: skb->next;
+
+ offset = offset - start;
+ } while (skb);
+
+ if (remaining_len) {
+ err = -EFAULT;
+ goto out;
+ }
+
+out:
+ tcp_xa_pool_commit(sk, &tcp_xa_pool);
+ if (!sent)
+ sent = err;
+
+ return sent;
+}
+
/*
* This routine copies from a sock struct into the user buffer.
*
@@ -2330,6 +2558,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
int *cmsg_flags)
{
struct tcp_sock *tp = tcp_sk(sk);
+ int last_copied_dmabuf = -1; /* uninitialized */
int copied = 0;
u32 peek_seq;
u32 *seq;
@@ -2509,15 +2738,44 @@ found_ok_skb:
}
if (!(flags & MSG_TRUNC)) {
- err = skb_copy_datagram_msg(skb, offset, msg, used);
- if (err) {
- /* Exception. Bailout! */
- if (!copied)
- copied = -EFAULT;
+ if (last_copied_dmabuf != -1 &&
+ last_copied_dmabuf != !skb_frags_readable(skb))
break;
+
+ if (skb_frags_readable(skb)) {
+ err = skb_copy_datagram_msg(skb, offset, msg,
+ used);
+ if (err) {
+ /* Exception. Bailout! */
+ if (!copied)
+ copied = -EFAULT;
+ break;
+ }
+ } else {
+ if (!(flags & MSG_SOCK_DEVMEM)) {
+ /* dmabuf skbs can only be received
+ * with the MSG_SOCK_DEVMEM flag.
+ */
+ if (!copied)
+ copied = -EFAULT;
+
+ break;
+ }
+
+ err = tcp_recvmsg_dmabuf(sk, skb, offset, msg,
+ used);
+ if (err <= 0) {
+ if (!copied)
+ copied = -EFAULT;
+
+ break;
+ }
+ used = err;
}
}
+ last_copied_dmabuf = !skb_frags_readable(skb);
+
WRITE_ONCE(*seq, *seq + used);
copied += used;
len -= used;
@@ -2833,7 +3091,7 @@ void __tcp_close(struct sock *sk, long timeout)
NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONCLOSE);
tcp_set_state(sk, TCP_CLOSE);
tcp_send_active_reset(sk, sk->sk_allocation,
- SK_RST_REASON_NOT_SPECIFIED);
+ SK_RST_REASON_TCP_ABORT_ON_CLOSE);
} else if (sock_flag(sk, SOCK_LINGER) && !sk->sk_lingertime) {
/* Check zero linger _after_ checking for unread data. */
sk->sk_prot->disconnect(sk, 0);
@@ -2908,7 +3166,7 @@ adjudge_to_death:
if (READ_ONCE(tp->linger2) < 0) {
tcp_set_state(sk, TCP_CLOSE);
tcp_send_active_reset(sk, GFP_ATOMIC,
- SK_RST_REASON_NOT_SPECIFIED);
+ SK_RST_REASON_TCP_ABORT_ON_LINGER);
__NET_INC_STATS(sock_net(sk),
LINUX_MIB_TCPABORTONLINGER);
} else {
@@ -2927,7 +3185,7 @@ adjudge_to_death:
if (tcp_check_oom(sk, 0)) {
tcp_set_state(sk, TCP_CLOSE);
tcp_send_active_reset(sk, GFP_ATOMIC,
- SK_RST_REASON_NOT_SPECIFIED);
+ SK_RST_REASON_TCP_ABORT_ON_MEMORY);
__NET_INC_STATS(sock_net(sk),
LINUX_MIB_TCPABORTONMEMORY);
} else if (!check_net(sock_net(sk))) {
@@ -3025,13 +3283,16 @@ int tcp_disconnect(struct sock *sk, int flags)
inet_csk_listen_stop(sk);
} else if (unlikely(tp->repair)) {
WRITE_ONCE(sk->sk_err, ECONNABORTED);
- } else if (tcp_need_reset(old_state) ||
- (tp->snd_nxt != tp->write_seq &&
- (1 << old_state) & (TCPF_CLOSING | TCPF_LAST_ACK))) {
+ } else if (tcp_need_reset(old_state)) {
+ tcp_send_active_reset(sk, gfp_any(), SK_RST_REASON_TCP_STATE);
+ WRITE_ONCE(sk->sk_err, ECONNRESET);
+ } else if (tp->snd_nxt != tp->write_seq &&
+ (1 << old_state) & (TCPF_CLOSING | TCPF_LAST_ACK)) {
/* The last check adjusts for discrepancy of Linux wrt. RFC
* states
*/
- tcp_send_active_reset(sk, gfp_any(), SK_RST_REASON_NOT_SPECIFIED);
+ tcp_send_active_reset(sk, gfp_any(),
+ SK_RST_REASON_TCP_DISCONNECT_WITH_DATA);
WRITE_ONCE(sk->sk_err, ECONNRESET);
} else if (old_state == TCP_SYN_SENT)
WRITE_ONCE(sk->sk_err, ECONNRESET);
@@ -4637,6 +4898,13 @@ int tcp_abort(struct sock *sk, int err)
/* Don't race with userspace socket closes such as tcp_close. */
lock_sock(sk);
+ /* Avoid closing the same socket twice. */
+ if (sk->sk_state == TCP_CLOSE) {
+ if (!has_current_bpf_ctx())
+ release_sock(sk);
+ return -ENOENT;
+ }
+
if (sk->sk_state == TCP_LISTEN) {
tcp_set_state(sk, TCP_CLOSE);
inet_csk_listen_stop(sk);
@@ -4646,16 +4914,13 @@ int tcp_abort(struct sock *sk, int err)
local_bh_disable();
bh_lock_sock(sk);
- if (!sock_flag(sk, SOCK_DEAD)) {
- if (tcp_need_reset(sk->sk_state))
- tcp_send_active_reset(sk, GFP_ATOMIC,
- SK_RST_REASON_NOT_SPECIFIED);
- tcp_done_with_error(sk, err);
- }
+ if (tcp_need_reset(sk->sk_state))
+ tcp_send_active_reset(sk, GFP_ATOMIC,
+ SK_RST_REASON_TCP_STATE);
+ tcp_done_with_error(sk, err);
bh_unlock_sock(sk);
local_bh_enable();
- tcp_write_queue_purge(sk);
if (!has_current_bpf_ctx())
release_sock(sk);
return 0;