aboutsummaryrefslogtreecommitdiff
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c300
1 files changed, 183 insertions, 117 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 20cd93be6236..78cb4a584080 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -45,6 +45,7 @@
MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL");
+MODULE_ALIAS_TCP_ULP("tls");
enum {
TLSV4,
@@ -52,27 +53,20 @@ enum {
TLS_NUM_PROTS,
};
-enum {
- TLS_BASE,
- TLS_SW_TX,
- TLS_SW_RX,
- TLS_SW_RXTX,
- TLS_HW_RECORD,
- TLS_NUM_CONFIG,
-};
-
static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
+static struct proto *saved_tcpv4_prot;
+static DEFINE_MUTEX(tcpv4_prot_mutex);
static LIST_HEAD(device_list);
-static DEFINE_MUTEX(device_mutex);
-static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
+static DEFINE_SPINLOCK(device_spinlock);
+static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static struct proto_ops tls_sw_proto_ops;
-static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
+static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
- sk->sk_prot = &tls_prots[ip_ver][ctx->conf];
+ sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
}
int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -149,7 +143,6 @@ retry:
size = sg->length;
}
- clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
ctx->in_tcp_sendpages = false;
ctx->sk_write_space(sk);
@@ -201,15 +194,12 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
return rc;
}
-int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
- int flags, long *timeo)
+int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
+ int flags)
{
struct scatterlist *sg;
u16 offset;
- if (!tls_is_partially_sent_record(ctx))
- return ctx->push_pending_record(sk, flags);
-
sg = ctx->partially_sent_record;
offset = ctx->partially_sent_offset;
@@ -217,33 +207,53 @@ int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
return tls_push_sg(sk, ctx, sg, offset, flags);
}
+int tls_push_pending_closed_record(struct sock *sk,
+ struct tls_context *tls_ctx,
+ int flags, long *timeo)
+{
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+
+ if (tls_is_partially_sent_record(tls_ctx) ||
+ !list_empty(&ctx->tx_list))
+ return tls_tx_records(sk, flags);
+ else
+ return tls_ctx->push_pending_record(sk, flags);
+}
+
static void tls_write_space(struct sock *sk)
{
struct tls_context *ctx = tls_get_ctx(sk);
+ struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
- /* We are already sending pages, ignore notification */
- if (ctx->in_tcp_sendpages)
+ /* If in_tcp_sendpages call lower protocol write space handler
+ * to ensure we wake up any waiting operations there. For example
+ * if do_tcp_sendpages where to call sk_wait_event.
+ */
+ if (ctx->in_tcp_sendpages) {
+ ctx->sk_write_space(sk);
return;
+ }
- if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
- gfp_t sk_allocation = sk->sk_allocation;
- int rc;
- long timeo = 0;
-
- sk->sk_allocation = GFP_ATOMIC;
- rc = tls_push_pending_closed_record(sk, ctx,
- MSG_DONTWAIT |
- MSG_NOSIGNAL,
- &timeo);
- sk->sk_allocation = sk_allocation;
-
- if (rc < 0)
- return;
+ /* Schedule the transmission if tx list is ready */
+ if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) {
+ /* Schedule the transmission */
+ if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
+ schedule_delayed_work(&tx_ctx->tx_work.work, 0);
}
ctx->sk_write_space(sk);
}
+static void tls_ctx_free(struct tls_context *ctx)
+{
+ if (!ctx)
+ return;
+
+ memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
+ memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
+ kfree(ctx);
+}
+
static void tls_sk_proto_close(struct sock *sk, long timeout)
{
struct tls_context *ctx = tls_get_ctx(sk);
@@ -254,7 +264,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
lock_sock(sk);
sk_proto_close = ctx->sk_proto_close;
- if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) {
+ if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) ||
+ (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) {
free_ctx = true;
goto skip_tx_cleanup;
}
@@ -262,28 +273,29 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
tls_handle_open_record(sk, 0);
- if (ctx->partially_sent_record) {
- struct scatterlist *sg = ctx->partially_sent_record;
-
- while (1) {
- put_page(sg_page(sg));
- sk_mem_uncharge(sk, sg->length);
+ /* We need these for tls_sw_fallback handling of other packets */
+ if (ctx->tx_conf == TLS_SW) {
+ kfree(ctx->tx.rec_seq);
+ kfree(ctx->tx.iv);
+ tls_sw_free_resources_tx(sk);
+ }
- if (sg_is_last(sg))
- break;
- sg++;
- }
+ if (ctx->rx_conf == TLS_SW) {
+ kfree(ctx->rx.rec_seq);
+ kfree(ctx->rx.iv);
+ tls_sw_free_resources_rx(sk);
}
- kfree(ctx->tx.rec_seq);
- kfree(ctx->tx.iv);
- kfree(ctx->rx.rec_seq);
- kfree(ctx->rx.iv);
+#ifdef CONFIG_TLS_DEVICE
+ if (ctx->rx_conf == TLS_HW)
+ tls_device_offload_cleanup_rx(sk);
- if (ctx->conf == TLS_SW_TX ||
- ctx->conf == TLS_SW_RX ||
- ctx->conf == TLS_SW_RXTX) {
- tls_sw_free_resources(sk);
+ if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
+#else
+ {
+#endif
+ tls_ctx_free(ctx);
+ ctx = NULL;
}
skip_tx_cleanup:
@@ -293,7 +305,7 @@ skip_tx_cleanup:
* for sk->sk_prot->unhash [tls_hw_unhash]
*/
if (free_ctx)
- kfree(ctx);
+ tls_ctx_free(ctx);
}
static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
@@ -318,7 +330,7 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
}
/* get user crypto info */
- crypto_info = &ctx->crypto_send;
+ crypto_info = &ctx->crypto_send.info;
if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
rc = -EBUSY;
@@ -405,9 +417,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
}
if (tx)
- crypto_info = &ctx->crypto_send;
+ crypto_info = &ctx->crypto_send.info;
else
- crypto_info = &ctx->crypto_recv;
+ crypto_info = &ctx->crypto_recv.info;
/* Currently we don't support set crypto info more than one time */
if (TLS_CRYPTO_INFO_READY(crypto_info)) {
@@ -446,25 +458,37 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info;
}
- /* currently SW is default, we will have ethtool in future */
if (tx) {
- rc = tls_set_sw_offload(sk, ctx, 1);
- if (ctx->conf == TLS_SW_RX)
- conf = TLS_SW_RXTX;
- else
- conf = TLS_SW_TX;
+#ifdef CONFIG_TLS_DEVICE
+ rc = tls_set_device_offload(sk, ctx);
+ conf = TLS_HW;
+ if (rc) {
+#else
+ {
+#endif
+ rc = tls_set_sw_offload(sk, ctx, 1);
+ conf = TLS_SW;
+ }
} else {
- rc = tls_set_sw_offload(sk, ctx, 0);
- if (ctx->conf == TLS_SW_TX)
- conf = TLS_SW_RXTX;
- else
- conf = TLS_SW_RX;
+#ifdef CONFIG_TLS_DEVICE
+ rc = tls_set_device_offload_rx(sk, ctx);
+ conf = TLS_HW;
+ if (rc) {
+#else
+ {
+#endif
+ rc = tls_set_sw_offload(sk, ctx, 0);
+ conf = TLS_SW;
+ }
}
if (rc)
goto err_crypto_info;
- ctx->conf = conf;
+ if (tx)
+ ctx->tx_conf = conf;
+ else
+ ctx->rx_conf = conf;
update_sk_prot(sk, ctx);
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
@@ -475,7 +499,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto out;
err_crypto_info:
- memset(crypto_info, 0, sizeof(*crypto_info));
+ memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
out:
return rc;
}
@@ -516,11 +540,14 @@ static struct tls_context *create_ctx(struct sock *sk)
struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx;
- ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+ ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
if (!ctx)
return NULL;
icsk->icsk_ulp_data = ctx;
+ ctx->setsockopt = sk->sk_prot->setsockopt;
+ ctx->getsockopt = sk->sk_prot->getsockopt;
+ ctx->sk_proto_close = sk->sk_prot->close;
return ctx;
}
@@ -530,7 +557,7 @@ static int tls_hw_prot(struct sock *sk)
struct tls_device *dev;
int rc = 0;
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
if (dev->feature && dev->feature(dev)) {
ctx = create_ctx(sk);
@@ -540,14 +567,15 @@ static int tls_hw_prot(struct sock *sk)
ctx->hash = sk->sk_prot->hash;
ctx->unhash = sk->sk_prot->unhash;
ctx->sk_proto_close = sk->sk_prot->close;
- ctx->conf = TLS_HW_RECORD;
+ ctx->rx_conf = TLS_HW_RECORD;
+ ctx->tx_conf = TLS_HW_RECORD;
update_sk_prot(sk, ctx);
rc = 1;
break;
}
}
out:
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
return rc;
}
@@ -556,12 +584,17 @@ static void tls_hw_unhash(struct sock *sk)
struct tls_context *ctx = tls_get_ctx(sk);
struct tls_device *dev;
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
- if (dev->unhash)
+ if (dev->unhash) {
+ kref_get(&dev->kref);
+ spin_unlock_bh(&device_spinlock);
dev->unhash(dev, sk);
+ kref_put(&dev->kref, dev->release);
+ spin_lock_bh(&device_spinlock);
+ }
}
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
ctx->unhash(sk);
}
@@ -572,41 +605,65 @@ static int tls_hw_hash(struct sock *sk)
int err;
err = ctx->hash(sk);
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
- if (dev->hash)
+ if (dev->hash) {
+ kref_get(&dev->kref);
+ spin_unlock_bh(&device_spinlock);
err |= dev->hash(dev, sk);
+ kref_put(&dev->kref, dev->release);
+ spin_lock_bh(&device_spinlock);
+ }
}
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
if (err)
tls_hw_unhash(sk);
return err;
}
-static void build_protos(struct proto *prot, struct proto *base)
+static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+ struct proto *base)
{
- prot[TLS_BASE] = *base;
- prot[TLS_BASE].setsockopt = tls_setsockopt;
- prot[TLS_BASE].getsockopt = tls_getsockopt;
- prot[TLS_BASE].close = tls_sk_proto_close;
-
- prot[TLS_SW_TX] = prot[TLS_BASE];
- prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg;
- prot[TLS_SW_TX].sendpage = tls_sw_sendpage;
-
- prot[TLS_SW_RX] = prot[TLS_BASE];
- prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg;
- prot[TLS_SW_RX].close = tls_sk_proto_close;
-
- prot[TLS_SW_RXTX] = prot[TLS_SW_TX];
- prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg;
- prot[TLS_SW_RXTX].close = tls_sk_proto_close;
-
- prot[TLS_HW_RECORD] = *base;
- prot[TLS_HW_RECORD].hash = tls_hw_hash;
- prot[TLS_HW_RECORD].unhash = tls_hw_unhash;
- prot[TLS_HW_RECORD].close = tls_sk_proto_close;
+ prot[TLS_BASE][TLS_BASE] = *base;
+ prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
+ prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
+ prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
+
+ prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
+ prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
+
+ prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
+ prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
+ prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
+
+ prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
+ prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
+ prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read;
+ prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
+
+#ifdef CONFIG_TLS_DEVICE
+ prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
+ prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
+
+ prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
+ prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
+
+ prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+
+ prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+
+ prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
+#endif
+
+ prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
+ prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
}
static int tls_init(struct sock *sk)
@@ -633,11 +690,8 @@ static int tls_init(struct sock *sk)
rc = -ENOMEM;
goto out;
}
- ctx->setsockopt = sk->sk_prot->setsockopt;
- ctx->getsockopt = sk->sk_prot->getsockopt;
- ctx->sk_proto_close = sk->sk_prot->close;
- /* Build IPv6 TLS whenever the address of tcpv6_prot changes */
+ /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
if (ip_ver == TLSV6 &&
unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
mutex_lock(&tcpv6_prot_mutex);
@@ -648,7 +702,18 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex);
}
- ctx->conf = TLS_BASE;
+ if (ip_ver == TLSV4 &&
+ unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+ mutex_lock(&tcpv4_prot_mutex);
+ if (likely(sk->sk_prot != saved_tcpv4_prot)) {
+ build_protos(tls_prots[TLSV4], sk->sk_prot);
+ smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+ }
+ mutex_unlock(&tcpv4_prot_mutex);
+ }
+
+ ctx->tx_conf = TLS_BASE;
+ ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx);
out:
return rc;
@@ -656,36 +721,34 @@ out:
void tls_register_device(struct tls_device *device)
{
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_add_tail(&device->dev_list, &device_list);
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
}
EXPORT_SYMBOL(tls_register_device);
void tls_unregister_device(struct tls_device *device)
{
- mutex_lock(&device_mutex);
+ spin_lock_bh(&device_spinlock);
list_del(&device->dev_list);
- mutex_unlock(&device_mutex);
+ spin_unlock_bh(&device_spinlock);
}
EXPORT_SYMBOL(tls_unregister_device);
static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls",
- .uid = TCP_ULP_TLS,
- .user_visible = true,
.owner = THIS_MODULE,
.init = tls_init,
};
static int __init tls_register(void)
{
- build_protos(tls_prots[TLSV4], &tcp_prot);
-
tls_sw_proto_ops = inet_stream_ops;
- tls_sw_proto_ops.poll = tls_sw_poll;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
+#ifdef CONFIG_TLS_DEVICE
+ tls_device_init();
+#endif
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
@@ -694,6 +757,9 @@ static int __init tls_register(void)
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
+#ifdef CONFIG_TLS_DEVICE
+ tls_device_cleanup();
+#endif
}
module_init(tls_register);