diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 300 |
1 files changed, 183 insertions, 117 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 20cd93be6236..78cb4a584080 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,6 +45,7 @@ MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS_TCP_ULP("tls"); enum { TLSV4, @@ -52,27 +53,20 @@ enum { TLS_NUM_PROTS, }; -enum { - TLS_BASE, - TLS_SW_TX, - TLS_SW_RX, - TLS_SW_RXTX, - TLS_HW_RECORD, - TLS_NUM_CONFIG, -}; - static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex); static LIST_HEAD(device_list); -static DEFINE_MUTEX(device_mutex); -static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; +static DEFINE_SPINLOCK(device_spinlock); +static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; -static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) +static void update_sk_prot(struct sock *sk, struct tls_context *ctx) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; - sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; + sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]; } int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -149,7 +143,6 @@ retry: size = sg->length; } - clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); ctx->in_tcp_sendpages = false; ctx->sk_write_space(sk); @@ -201,15 +194,12 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, return rc; } -int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, - int flags, long *timeo) +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags) { struct scatterlist *sg; u16 offset; - if (!tls_is_partially_sent_record(ctx)) - return ctx->push_pending_record(sk, flags); - sg = ctx->partially_sent_record; offset = ctx->partially_sent_offset; @@ -217,33 +207,53 @@ int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } +int tls_push_pending_closed_record(struct sock *sk, + struct tls_context *tls_ctx, + int flags, long *timeo) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + if (tls_is_partially_sent_record(tls_ctx) || + !list_empty(&ctx->tx_list)) + return tls_tx_records(sk, flags); + else + return tls_ctx->push_pending_record(sk, flags); +} + static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); - /* We are already sending pages, ignore notification */ - if (ctx->in_tcp_sendpages) + /* If in_tcp_sendpages call lower protocol write space handler + * to ensure we wake up any waiting operations there. For example + * if do_tcp_sendpages where to call sk_wait_event. + */ + if (ctx->in_tcp_sendpages) { + ctx->sk_write_space(sk); return; + } - if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { - gfp_t sk_allocation = sk->sk_allocation; - int rc; - long timeo = 0; - - sk->sk_allocation = GFP_ATOMIC; - rc = tls_push_pending_closed_record(sk, ctx, - MSG_DONTWAIT | - MSG_NOSIGNAL, - &timeo); - sk->sk_allocation = sk_allocation; - - if (rc < 0) - return; + /* Schedule the transmission if tx list is ready */ + if (is_tx_ready(tx_ctx) && !sk->sk_write_pending) { + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); } ctx->sk_write_space(sk); } +static void tls_ctx_free(struct tls_context *ctx) +{ + if (!ctx) + return; + + memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); + memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); + kfree(ctx); +} + static void tls_sk_proto_close(struct sock *sk, long timeout) { struct tls_context *ctx = tls_get_ctx(sk); @@ -254,7 +264,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) { + if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || + (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) { free_ctx = true; goto skip_tx_cleanup; } @@ -262,28 +273,29 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) tls_handle_open_record(sk, 0); - if (ctx->partially_sent_record) { - struct scatterlist *sg = ctx->partially_sent_record; - - while (1) { - put_page(sg_page(sg)); - sk_mem_uncharge(sk, sg->length); + /* We need these for tls_sw_fallback handling of other packets */ + if (ctx->tx_conf == TLS_SW) { + kfree(ctx->tx.rec_seq); + kfree(ctx->tx.iv); + tls_sw_free_resources_tx(sk); + } - if (sg_is_last(sg)) - break; - sg++; - } + if (ctx->rx_conf == TLS_SW) { + kfree(ctx->rx.rec_seq); + kfree(ctx->rx.iv); + tls_sw_free_resources_rx(sk); } - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); - kfree(ctx->rx.rec_seq); - kfree(ctx->rx.iv); +#ifdef CONFIG_TLS_DEVICE + if (ctx->rx_conf == TLS_HW) + tls_device_offload_cleanup_rx(sk); - if (ctx->conf == TLS_SW_TX || - ctx->conf == TLS_SW_RX || - ctx->conf == TLS_SW_RXTX) { - tls_sw_free_resources(sk); + if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) { +#else + { +#endif + tls_ctx_free(ctx); + ctx = NULL; } skip_tx_cleanup: @@ -293,7 +305,7 @@ skip_tx_cleanup: * for sk->sk_prot->unhash [tls_hw_unhash] */ if (free_ctx) - kfree(ctx); + tls_ctx_free(ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -318,7 +330,7 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, } /* get user crypto info */ - crypto_info = &ctx->crypto_send; + crypto_info = &ctx->crypto_send.info; if (!TLS_CRYPTO_INFO_READY(crypto_info)) { rc = -EBUSY; @@ -405,9 +417,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } if (tx) - crypto_info = &ctx->crypto_send; + crypto_info = &ctx->crypto_send.info; else - crypto_info = &ctx->crypto_recv; + crypto_info = &ctx->crypto_recv.info; /* Currently we don't support set crypto info more than one time */ if (TLS_CRYPTO_INFO_READY(crypto_info)) { @@ -446,25 +458,37 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto err_crypto_info; } - /* currently SW is default, we will have ethtool in future */ if (tx) { - rc = tls_set_sw_offload(sk, ctx, 1); - if (ctx->conf == TLS_SW_RX) - conf = TLS_SW_RXTX; - else - conf = TLS_SW_TX; +#ifdef CONFIG_TLS_DEVICE + rc = tls_set_device_offload(sk, ctx); + conf = TLS_HW; + if (rc) { +#else + { +#endif + rc = tls_set_sw_offload(sk, ctx, 1); + conf = TLS_SW; + } } else { - rc = tls_set_sw_offload(sk, ctx, 0); - if (ctx->conf == TLS_SW_TX) - conf = TLS_SW_RXTX; - else - conf = TLS_SW_RX; +#ifdef CONFIG_TLS_DEVICE + rc = tls_set_device_offload_rx(sk, ctx); + conf = TLS_HW; + if (rc) { +#else + { +#endif + rc = tls_set_sw_offload(sk, ctx, 0); + conf = TLS_SW; + } } if (rc) goto err_crypto_info; - ctx->conf = conf; + if (tx) + ctx->tx_conf = conf; + else + ctx->rx_conf = conf; update_sk_prot(sk, ctx); if (tx) { ctx->sk_write_space = sk->sk_write_space; @@ -475,7 +499,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, goto out; err_crypto_info: - memset(crypto_info, 0, sizeof(*crypto_info)); + memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); out: return rc; } @@ -516,11 +540,14 @@ static struct tls_context *create_ctx(struct sock *sk) struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; - ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); if (!ctx) return NULL; icsk->icsk_ulp_data = ctx; + ctx->setsockopt = sk->sk_prot->setsockopt; + ctx->getsockopt = sk->sk_prot->getsockopt; + ctx->sk_proto_close = sk->sk_prot->close; return ctx; } @@ -530,7 +557,7 @@ static int tls_hw_prot(struct sock *sk) struct tls_device *dev; int rc = 0; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { if (dev->feature && dev->feature(dev)) { ctx = create_ctx(sk); @@ -540,14 +567,15 @@ static int tls_hw_prot(struct sock *sk) ctx->hash = sk->sk_prot->hash; ctx->unhash = sk->sk_prot->unhash; ctx->sk_proto_close = sk->sk_prot->close; - ctx->conf = TLS_HW_RECORD; + ctx->rx_conf = TLS_HW_RECORD; + ctx->tx_conf = TLS_HW_RECORD; update_sk_prot(sk, ctx); rc = 1; break; } } out: - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); return rc; } @@ -556,12 +584,17 @@ static void tls_hw_unhash(struct sock *sk) struct tls_context *ctx = tls_get_ctx(sk); struct tls_device *dev; - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->unhash) + if (dev->unhash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); dev->unhash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); ctx->unhash(sk); } @@ -572,41 +605,65 @@ static int tls_hw_hash(struct sock *sk) int err; err = ctx->hash(sk); - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { - if (dev->hash) + if (dev->hash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); err |= dev->hash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } } - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); if (err) tls_hw_unhash(sk); return err; } -static void build_protos(struct proto *prot, struct proto *base) +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + struct proto *base) { - prot[TLS_BASE] = *base; - prot[TLS_BASE].setsockopt = tls_setsockopt; - prot[TLS_BASE].getsockopt = tls_getsockopt; - prot[TLS_BASE].close = tls_sk_proto_close; - - prot[TLS_SW_TX] = prot[TLS_BASE]; - prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; - prot[TLS_SW_TX].sendpage = tls_sw_sendpage; - - prot[TLS_SW_RX] = prot[TLS_BASE]; - prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; - prot[TLS_SW_RX].close = tls_sk_proto_close; - - prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; - prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; - prot[TLS_SW_RXTX].close = tls_sk_proto_close; - - prot[TLS_HW_RECORD] = *base; - prot[TLS_HW_RECORD].hash = tls_hw_hash; - prot[TLS_HW_RECORD].unhash = tls_hw_unhash; - prot[TLS_HW_RECORD].close = tls_sk_proto_close; + prot[TLS_BASE][TLS_BASE] = *base; + prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; + prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; + prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; + prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; + + prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read; + prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; + prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read; + prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; + +#ifdef CONFIG_TLS_DEVICE + prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; + + prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; + prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; + + prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; + + prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; + + prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; +#endif + + prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close; } static int tls_init(struct sock *sk) @@ -633,11 +690,8 @@ static int tls_init(struct sock *sk) rc = -ENOMEM; goto out; } - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; - /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ if (ip_ver == TLSV6 && unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { mutex_lock(&tcpv6_prot_mutex); @@ -648,7 +702,18 @@ static int tls_init(struct sock *sk) mutex_unlock(&tcpv6_prot_mutex); } - ctx->conf = TLS_BASE; + if (ip_ver == TLSV4 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], sk->sk_prot); + smp_store_release(&saved_tcpv4_prot, sk->sk_prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } + + ctx->tx_conf = TLS_BASE; + ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); out: return rc; @@ -656,36 +721,34 @@ out: void tls_register_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_add_tail(&device->dev_list, &device_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_register_device); void tls_unregister_device(struct tls_device *device) { - mutex_lock(&device_mutex); + spin_lock_bh(&device_spinlock); list_del(&device->dev_list); - mutex_unlock(&device_mutex); + spin_unlock_bh(&device_spinlock); } EXPORT_SYMBOL(tls_unregister_device); static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", - .uid = TCP_ULP_TLS, - .user_visible = true, .owner = THIS_MODULE, .init = tls_init, }; static int __init tls_register(void) { - build_protos(tls_prots[TLSV4], &tcp_prot); - tls_sw_proto_ops = inet_stream_ops; - tls_sw_proto_ops.poll = tls_sw_poll; tls_sw_proto_ops.splice_read = tls_sw_splice_read; +#ifdef CONFIG_TLS_DEVICE + tls_device_init(); +#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -694,6 +757,9 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); +#ifdef CONFIG_TLS_DEVICE + tls_device_cleanup(); +#endif } module_init(tls_register); |