aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h28
-rw-r--r--include/net/tls.h3
-rw-r--r--net/core/filter.c8
-rw-r--r--net/core/skmsg.c2
-rw-r--r--net/ipv4/tcp_bpf.c2
-rw-r--r--net/tls/tls_main.c13
-rw-r--r--net/tls/tls_sw.c32
-rw-r--r--tools/testing/selftests/bpf/test_sockmap.c47
-rw-r--r--tools/testing/selftests/bpf/xdping.c2
-rw-r--r--tools/testing/selftests/net/tls.c60
10 files changed, 131 insertions, 66 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 6cb077b646a5..ef7031f8a304 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -14,6 +14,7 @@
#include <net/strparser.h>
#define MAX_MSG_FRAGS MAX_SKB_FRAGS
+#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
enum __sk_action {
__SK_DROP = 0,
@@ -29,13 +30,15 @@ struct sk_msg_sg {
u32 size;
u32 copybreak;
unsigned long copy;
- /* The extra element is used for chaining the front and sections when
- * the list becomes partitioned (e.g. end < start). The crypto APIs
- * require the chaining.
+ /* The extra two elements:
+ * 1) used for chaining the front and sections when the list becomes
+ * partitioned (e.g. end < start). The crypto APIs require the
+ * chaining;
+ * 2) to chain tailer SG entries after the message.
*/
- struct scatterlist data[MAX_MSG_FRAGS + 1];
+ struct scatterlist data[MAX_MSG_FRAGS + 2];
};
-static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS);
+static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
struct sk_msg {
@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
static inline u32 sk_msg_iter_dist(u32 start, u32 end)
{
- return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
+ return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
}
#define sk_msg_iter_var_prev(var) \
do { \
if (var == 0) \
- var = MAX_MSG_FRAGS - 1; \
+ var = NR_MSG_FRAG_IDS - 1; \
else \
var--; \
} while (0)
@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end)
#define sk_msg_iter_var_next(var) \
do { \
var++; \
- if (var == MAX_MSG_FRAGS) \
+ if (var == NR_MSG_FRAG_IDS) \
var = 0; \
} while (0)
@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg)
{
- BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
+ BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
memset(msg, 0, sizeof(*msg));
- sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
+ sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
}
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
static inline bool sk_msg_full(const struct sk_msg *msg)
{
- return (msg->sg.end == msg->sg.start) && msg->sg.size;
+ return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
- if (sk_msg_full(msg))
- return MAX_MSG_FRAGS;
-
return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
}
diff --git a/include/net/tls.h b/include/net/tls.h
index 6ed91e82edd0..df630f5fc723 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -100,7 +100,6 @@ struct tls_rec {
struct list_head list;
int tx_ready;
int tx_flags;
- int inplace_crypto;
struct sk_msg msg_plaintext;
struct sk_msg msg_encrypted;
@@ -377,7 +376,7 @@ int tls_push_sg(struct sock *sk, struct tls_context *ctx,
int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags);
-bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
+void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
diff --git a/net/core/filter.c b/net/core/filter.c
index b0ed048585ba..f1e703eed3d2 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
WARN_ON_ONCE(last_sge == first_sge);
shift = last_sge > first_sge ?
last_sge - first_sge - 1 :
- MAX_SKB_FRAGS - first_sge + last_sge - 1;
+ NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
if (!shift)
goto out;
@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
do {
u32 move_from;
- if (i + shift >= MAX_MSG_FRAGS)
- move_from = i + shift - MAX_MSG_FRAGS;
+ if (i + shift >= NR_MSG_FRAG_IDS)
+ move_from = i + shift - NR_MSG_FRAG_IDS;
else
move_from = i + shift;
if (move_from == msg->sg.end)
@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
} while (1);
msg->sg.end = msg->sg.end - shift > msg->sg.end ?
- msg->sg.end - shift + MAX_MSG_FRAGS :
+ msg->sg.end - shift + NR_MSG_FRAG_IDS :
msg->sg.end - shift;
out:
msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index a469d2124f3f..ded2d5227678 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
copied = skb->len;
msg->sg.start = 0;
msg->sg.size = copied;
- msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
+ msg->sg.end = num_sge;
msg->skb = skb;
sk_psock_queue_msg(psock, msg);
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a56e09cfb0e..e38705165ac9 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, int *copied, int flags)
{
- bool cork = false, enospc = msg->sg.start == msg->sg.end;
+ bool cork = false, enospc = sk_msg_full(msg);
struct sock *sk_redir;
u32 tosend, delta = 0;
int ret;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index bdca31ffe6da..b3da6c5ab999 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -209,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags);
}
-bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
+void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
{
struct scatterlist *sg;
- sg = ctx->partially_sent_record;
- if (!sg)
- return false;
-
- while (1) {
+ for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
put_page(sg_page(sg));
sk_mem_uncharge(sk, sg->length);
-
- if (sg_is_last(sg))
- break;
- sg++;
}
ctx->partially_sent_record = NULL;
- return true;
}
static void tls_write_space(struct sock *sk)
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index da9f9ce51e7b..2b2d0bae14a9 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -710,8 +710,7 @@ static int tls_push_record(struct sock *sk, int flags,
}
i = msg_pl->sg.start;
- sg_chain(rec->sg_aead_in, 2, rec->inplace_crypto ?
- &msg_en->sg.data[i] : &msg_pl->sg.data[i]);
+ sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
i = msg_en->sg.end;
sk_msg_iter_var_prev(i);
@@ -771,8 +770,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
policy = !(flags & MSG_SENDPAGE_NOPOLICY);
psock = sk_psock_get(sk);
- if (!psock || !policy)
- return tls_push_record(sk, flags, record_type);
+ if (!psock || !policy) {
+ err = tls_push_record(sk, flags, record_type);
+ if (err) {
+ *copied -= sk_msg_free(sk, msg);
+ tls_free_open_rec(sk);
+ }
+ return err;
+ }
more_data:
enospc = sk_msg_full(msg);
if (psock->eval == __SK_NONE) {
@@ -970,8 +975,6 @@ alloc_encrypted:
if (ret)
goto fallback_to_reg_send;
- rec->inplace_crypto = 0;
-
num_zc++;
copied += try_to_copy;
@@ -984,7 +987,7 @@ alloc_encrypted:
num_async++;
else if (ret == -ENOMEM)
goto wait_for_memory;
- else if (ret == -ENOSPC)
+ else if (ctx->open_rec && ret == -ENOSPC)
goto rollback_iter;
else if (ret != -EAGAIN)
goto send_end;
@@ -1053,11 +1056,12 @@ wait_for_memory:
ret = sk_stream_wait_memory(sk, &timeo);
if (ret) {
trim_sgl:
- tls_trim_both_msgs(sk, orig_size);
+ if (ctx->open_rec)
+ tls_trim_both_msgs(sk, orig_size);
goto send_end;
}
- if (msg_en->sg.size < required_size)
+ if (ctx->open_rec && msg_en->sg.size < required_size)
goto alloc_encrypted;
}
@@ -1169,7 +1173,6 @@ alloc_payload:
tls_ctx->pending_open_record_frags = true;
if (full_record || eor || sk_msg_full(msg_pl)) {
- rec->inplace_crypto = 0;
ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
record_type, &copied, flags);
if (ret) {
@@ -1190,11 +1193,13 @@ wait_for_sndbuf:
wait_for_memory:
ret = sk_stream_wait_memory(sk, &timeo);
if (ret) {
- tls_trim_both_msgs(sk, msg_pl->sg.size);
+ if (ctx->open_rec)
+ tls_trim_both_msgs(sk, msg_pl->sg.size);
goto sendpage_end;
}
- goto alloc_payload;
+ if (ctx->open_rec)
+ goto alloc_payload;
}
if (num_async) {
@@ -2084,7 +2089,8 @@ void tls_sw_release_resources_tx(struct sock *sk)
/* Free up un-sent records in tx_list. First, free
* the partially sent record if any at head of tx_list.
*/
- if (tls_free_partial_record(sk, tls_ctx)) {
+ if (tls_ctx->partially_sent_record) {
+ tls_free_partial_record(sk, tls_ctx);
rec = list_first_entry(&ctx->tx_list,
struct tls_rec, list);
list_del(&rec->list);
diff --git a/tools/testing/selftests/bpf/test_sockmap.c b/tools/testing/selftests/bpf/test_sockmap.c
index 3845144e2c91..4a851513c842 100644
--- a/tools/testing/selftests/bpf/test_sockmap.c
+++ b/tools/testing/selftests/bpf/test_sockmap.c
@@ -240,14 +240,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) {
- perror("bind s1 failed()\n");
+ perror("bind s1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0) {
- perror("bind s2 failed()\n");
+ perror("bind s2 failed()");
return errno;
}
@@ -255,14 +255,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = listen(s1, 32);
if (err < 0) {
- perror("listen s1 failed()\n");
+ perror("listen s1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = listen(s2, 32);
if (err < 0) {
- perror("listen s1 failed()\n");
+ perror("listen s1 failed()");
return errno;
}
@@ -270,14 +270,14 @@ static int sockmap_init_sockets(int verbose)
addr.sin_port = htons(S1_PORT);
err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) {
- perror("connect c1 failed()\n");
+ perror("connect c1 failed()");
return errno;
}
addr.sin_port = htons(S2_PORT);
err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
if (err < 0 && errno != EINPROGRESS) {
- perror("connect c2 failed()\n");
+ perror("connect c2 failed()");
return errno;
} else if (err < 0) {
err = 0;
@@ -286,13 +286,13 @@ static int sockmap_init_sockets(int verbose)
/* Accept Connecrtions */
p1 = accept(s1, NULL, NULL);
if (p1 < 0) {
- perror("accept s1 failed()\n");
+ perror("accept s1 failed()");
return errno;
}
p2 = accept(s2, NULL, NULL);
if (p2 < 0) {
- perror("accept s1 failed()\n");
+ perror("accept s1 failed()");
return errno;
}
@@ -332,6 +332,10 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
int i, fp;
file = fopen(".sendpage_tst.tmp", "w+");
+ if (!file) {
+ perror("create file for sendpage");
+ return 1;
+ }
for (i = 0; i < iov_length * cnt; i++, k++)
fwrite(&k, sizeof(char), 1, file);
fflush(file);
@@ -339,12 +343,17 @@ static int msg_loop_sendpage(int fd, int iov_length, int cnt,
fclose(file);
fp = open(".sendpage_tst.tmp", O_RDONLY);
+ if (fp < 0) {
+ perror("reopen file for sendpage");
+ return 1;
+ }
+
clock_gettime(CLOCK_MONOTONIC, &s->start);
for (i = 0; i < cnt; i++) {
int sent = sendfile(fd, fp, NULL, iov_length);
if (!drop && sent < 0) {
- perror("send loop error:");
+ perror("send loop error");
close(fp);
return sent;
} else if (drop && sent >= 0) {
@@ -463,7 +472,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
int sent = sendmsg(fd, &msg, flags);
if (!drop && sent < 0) {
- perror("send loop error:");
+ perror("send loop error");
goto out_errno;
} else if (drop && sent >= 0) {
printf("send loop error expected: %i\n", sent);
@@ -499,7 +508,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
total_bytes -= txmsg_pop_total;
err = clock_gettime(CLOCK_MONOTONIC, &s->start);
if (err < 0)
- perror("recv start time: ");
+ perror("recv start time");
while (s->bytes_recvd < total_bytes) {
if (txmsg_cork) {
timeout.tv_sec = 0;
@@ -543,7 +552,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
if (recv < 0) {
if (errno != EWOULDBLOCK) {
clock_gettime(CLOCK_MONOTONIC, &s->end);
- perror("recv failed()\n");
+ perror("recv failed()");
goto out_errno;
}
}
@@ -557,7 +566,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
errno = msg_verify_data(&msg, recv, chunk_sz);
if (errno) {
- perror("data verify msg failed\n");
+ perror("data verify msg failed");
goto out_errno;
}
if (recvp) {
@@ -565,7 +574,7 @@ static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
recvp,
chunk_sz);
if (errno) {
- perror("data verify msg_peek failed\n");
+ perror("data verify msg_peek failed");
goto out_errno;
}
}
@@ -654,7 +663,7 @@ static int sendmsg_test(struct sockmap_options *opt)
err = 0;
exit(err ? 1 : 0);
} else if (rxpid == -1) {
- perror("msg_loop_rx: ");
+ perror("msg_loop_rx");
return errno;
}
@@ -681,7 +690,7 @@ static int sendmsg_test(struct sockmap_options *opt)
s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
exit(err ? 1 : 0);
} else if (txpid == -1) {
- perror("msg_loop_tx: ");
+ perror("msg_loop_tx");
return errno;
}
@@ -715,7 +724,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
/* Ping/Pong data from client to server */
sc = send(c1, buf, sizeof(buf), 0);
if (sc < 0) {
- perror("send failed()\n");
+ perror("send failed()");
return sc;
}
@@ -748,7 +757,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
rc = recv(i, buf, sizeof(buf), 0);
if (rc < 0) {
if (errno != EWOULDBLOCK) {
- perror("recv failed()\n");
+ perror("recv failed()");
return rc;
}
}
@@ -760,7 +769,7 @@ static int forever_ping_pong(int rate, struct sockmap_options *opt)
sc = send(i, buf, rc, 0);
if (sc < 0) {
- perror("send failed()\n");
+ perror("send failed()");
return sc;
}
}
diff --git a/tools/testing/selftests/bpf/xdping.c b/tools/testing/selftests/bpf/xdping.c
index d60a343b1371..842d9155d36c 100644
--- a/tools/testing/selftests/bpf/xdping.c
+++ b/tools/testing/selftests/bpf/xdping.c
@@ -45,7 +45,7 @@ static int get_stats(int fd, __u16 count, __u32 raddr)
printf("\nXDP RTT data:\n");
if (bpf_map_lookup_elem(fd, &raddr, &pinginfo)) {
- perror("bpf_map_lookup elem: ");
+ perror("bpf_map_lookup elem");
return 1;
}
diff --git a/tools/testing/selftests/net/tls.c b/tools/testing/selftests/net/tls.c
index 1c8f194d6556..46abcae47dee 100644
--- a/tools/testing/selftests/net/tls.c
+++ b/tools/testing/selftests/net/tls.c
@@ -268,6 +268,38 @@ TEST_F(tls, sendmsg_single)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
}
+#define MAX_FRAGS 64
+#define SEND_LEN 13
+TEST_F(tls, sendmsg_fragmented)
+{
+ char const *test_str = "test_sendmsg";
+ char buf[SEND_LEN * MAX_FRAGS];
+ struct iovec vec[MAX_FRAGS];
+ struct msghdr msg;
+ int i, frags;
+
+ for (frags = 1; frags <= MAX_FRAGS; frags++) {
+ for (i = 0; i < frags; i++) {
+ vec[i].iov_base = (char *)test_str;
+ vec[i].iov_len = SEND_LEN;
+ }
+
+ memset(&msg, 0, sizeof(struct msghdr));
+ msg.msg_iov = vec;
+ msg.msg_iovlen = frags;
+
+ EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
+ EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
+ SEND_LEN * frags);
+
+ for (i = 0; i < frags; i++)
+ EXPECT_EQ(memcmp(buf + SEND_LEN * i,
+ test_str, SEND_LEN), 0);
+ }
+}
+#undef MAX_FRAGS
+#undef SEND_LEN
+
TEST_F(tls, sendmsg_large)
{
void *mem = malloc(16384);
@@ -694,6 +726,34 @@ TEST_F(tls, recv_lowat)
EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
}
+TEST_F(tls, recv_rcvbuf)
+{
+ char send_mem[4096];
+ char recv_mem[4096];
+ int rcv_buf = 1024;
+
+ memset(send_mem, 0x1c, sizeof(send_mem));
+
+ EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVBUF,
+ &rcv_buf, sizeof(rcv_buf)), 0);
+
+ EXPECT_EQ(send(self->fd, send_mem, 512, 0), 512);
+ memset(recv_mem, 0, sizeof(recv_mem));
+ EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), 512);
+ EXPECT_EQ(memcmp(send_mem, recv_mem, 512), 0);
+
+ if (self->notls)
+ return;
+
+ EXPECT_EQ(send(self->fd, send_mem, 4096, 0), 4096);
+ memset(recv_mem, 0, sizeof(recv_mem));
+ EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
+ EXPECT_EQ(errno, EMSGSIZE);
+
+ EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0), -1);
+ EXPECT_EQ(errno, EMSGSIZE);
+}
+
TEST_F(tls, bidir)
{
char const *test_str = "test_read";