aboutsummaryrefslogtreecommitdiff
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c158
1 files changed, 99 insertions, 59 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 0d379970960e..a127d61e8af9 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -51,12 +51,12 @@ enum {
TLSV6,
TLS_NUM_PROTS,
};
-
enum {
TLS_BASE,
- TLS_SW_TX,
- TLS_SW_RX,
- TLS_SW_RXTX,
+ TLS_SW,
+#ifdef CONFIG_TLS_DEVICE
+ TLS_HW,
+#endif
TLS_HW_RECORD,
TLS_NUM_CONFIG,
};
@@ -65,14 +65,14 @@ static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
static LIST_HEAD(device_list);
static DEFINE_MUTEX(device_mutex);
-static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
+static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops;
-static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
+static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
- sk->sk_prot = &tls_prots[ip_ver][ctx->conf];
+ sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
}
int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -114,6 +114,7 @@ int tls_push_sg(struct sock *sk,
size = sg->length - offset;
offset += sg->offset;
+ ctx->in_tcp_sendpages = true;
while (1) {
if (sg_is_last(sg))
sendpage_flags = flags;
@@ -134,6 +135,7 @@ retry:
offset -= sg->offset;
ctx->partially_sent_offset = offset;
ctx->partially_sent_record = (void *)sg;
+ ctx->in_tcp_sendpages = false;
return ret;
}
@@ -148,6 +150,8 @@ retry:
}
clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+ ctx->in_tcp_sendpages = false;
+ ctx->sk_write_space(sk);
return 0;
}
@@ -217,6 +221,10 @@ static void tls_write_space(struct sock *sk)
{
struct tls_context *ctx = tls_get_ctx(sk);
+ /* We are already sending pages, ignore notification */
+ if (ctx->in_tcp_sendpages)
+ return;
+
if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
gfp_t sk_allocation = sk->sk_allocation;
int rc;
@@ -241,16 +249,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
struct tls_context *ctx = tls_get_ctx(sk);
long timeo = sock_sndtimeo(sk, 0);
void (*sk_proto_close)(struct sock *sk, long timeout);
+ bool free_ctx = false;
lock_sock(sk);
sk_proto_close = ctx->sk_proto_close;
- if (ctx->conf == TLS_HW_RECORD)
- goto skip_tx_cleanup;
-
- if (ctx->conf == TLS_BASE) {
- kfree(ctx);
- ctx = NULL;
+ if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) ||
+ (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) {
+ free_ctx = true;
goto skip_tx_cleanup;
}
@@ -270,15 +276,26 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
}
}
- kfree(ctx->tx.rec_seq);
- kfree(ctx->tx.iv);
- kfree(ctx->rx.rec_seq);
- kfree(ctx->rx.iv);
+ /* We need these for tls_sw_fallback handling of other packets */
+ if (ctx->tx_conf == TLS_SW) {
+ kfree(ctx->tx.rec_seq);
+ kfree(ctx->tx.iv);
+ tls_sw_free_resources_tx(sk);
+ }
+
+ if (ctx->rx_conf == TLS_SW) {
+ kfree(ctx->rx.rec_seq);
+ kfree(ctx->rx.iv);
+ tls_sw_free_resources_rx(sk);
+ }
- if (ctx->conf == TLS_SW_TX ||
- ctx->conf == TLS_SW_RX ||
- ctx->conf == TLS_SW_RXTX) {
- tls_sw_free_resources(sk);
+#ifdef CONFIG_TLS_DEVICE
+ if (ctx->tx_conf != TLS_HW) {
+#else
+ {
+#endif
+ kfree(ctx);
+ ctx = NULL;
}
skip_tx_cleanup:
@@ -287,7 +304,7 @@ skip_tx_cleanup:
/* free ctx for TLS_HW_RECORD, used by tcp_set_state
* for sk->sk_prot->unhash [tls_hw_unhash]
*/
- if (ctx && ctx->conf == TLS_HW_RECORD)
+ if (free_ctx)
kfree(ctx);
}
@@ -441,25 +458,29 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info;
}
- /* currently SW is default, we will have ethtool in future */
if (tx) {
- rc = tls_set_sw_offload(sk, ctx, 1);
- if (ctx->conf == TLS_SW_RX)
- conf = TLS_SW_RXTX;
- else
- conf = TLS_SW_TX;
+#ifdef CONFIG_TLS_DEVICE
+ rc = tls_set_device_offload(sk, ctx);
+ conf = TLS_HW;
+ if (rc) {
+#else
+ {
+#endif
+ rc = tls_set_sw_offload(sk, ctx, 1);
+ conf = TLS_SW;
+ }
} else {
rc = tls_set_sw_offload(sk, ctx, 0);
- if (ctx->conf == TLS_SW_TX)
- conf = TLS_SW_RXTX;
- else
- conf = TLS_SW_RX;
+ conf = TLS_SW;
}
if (rc)
goto err_crypto_info;
- ctx->conf = conf;
+ if (tx)
+ ctx->tx_conf = conf;
+ else
+ ctx->rx_conf = conf;
update_sk_prot(sk, ctx);
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
@@ -535,7 +556,8 @@ static int tls_hw_prot(struct sock *sk)
ctx->hash = sk->sk_prot->hash;
ctx->unhash = sk->sk_prot->unhash;
ctx->sk_proto_close = sk->sk_prot->close;
- ctx->conf = TLS_HW_RECORD;
+ ctx->rx_conf = TLS_HW_RECORD;
+ ctx->tx_conf = TLS_HW_RECORD;
update_sk_prot(sk, ctx);
rc = 1;
break;
@@ -579,29 +601,40 @@ static int tls_hw_hash(struct sock *sk)
return err;
}
-static void build_protos(struct proto *prot, struct proto *base)
+static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+ struct proto *base)
{
- prot[TLS_BASE] = *base;
- prot[TLS_BASE].setsockopt = tls_setsockopt;
- prot[TLS_BASE].getsockopt = tls_getsockopt;
- prot[TLS_BASE].close = tls_sk_proto_close;
-
- prot[TLS_SW_TX] = prot[TLS_BASE];
- prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg;
- prot[TLS_SW_TX].sendpage = tls_sw_sendpage;
-
- prot[TLS_SW_RX] = prot[TLS_BASE];
- prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg;
- prot[TLS_SW_RX].close = tls_sk_proto_close;
-
- prot[TLS_SW_RXTX] = prot[TLS_SW_TX];
- prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg;
- prot[TLS_SW_RXTX].close = tls_sk_proto_close;
-
- prot[TLS_HW_RECORD] = *base;
- prot[TLS_HW_RECORD].hash = tls_hw_hash;
- prot[TLS_HW_RECORD].unhash = tls_hw_unhash;
- prot[TLS_HW_RECORD].close = tls_sk_proto_close;
+ prot[TLS_BASE][TLS_BASE] = *base;
+ prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
+ prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
+ prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
+
+ prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
+ prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
+
+ prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
+ prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
+
+ prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
+ prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
+ prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
+
+#ifdef CONFIG_TLS_DEVICE
+ prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
+ prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
+
+ prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
+ prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
+#endif
+
+ prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
}
static int tls_init(struct sock *sk)
@@ -632,7 +665,7 @@ static int tls_init(struct sock *sk)
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
- /* Build IPv6 TLS whenever the address of tcpv6_prot changes */
+ /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
if (ip_ver == TLSV6 &&
unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
mutex_lock(&tcpv6_prot_mutex);
@@ -643,7 +676,8 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex);
}
- ctx->conf = TLS_BASE;
+ ctx->tx_conf = TLS_BASE;
+ ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx);
out:
return rc;
@@ -678,9 +712,12 @@ static int __init tls_register(void)
build_protos(tls_prots[TLSV4], &tcp_prot);
tls_sw_proto_ops = inet_stream_ops;
- tls_sw_proto_ops.poll = tls_sw_poll;
+ tls_sw_proto_ops.poll_mask = tls_sw_poll_mask;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
+#ifdef CONFIG_TLS_DEVICE
+ tls_device_init();
+#endif
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
@@ -689,6 +726,9 @@ static int __init tls_register(void)
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
+#ifdef CONFIG_TLS_DEVICE
+ tls_device_cleanup();
+#endif
}
module_init(tls_register);