diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device.c | 19 | ||||
-rw-r--r-- | net/tls/tls_main.c | 17 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 8 |
3 files changed, 29 insertions, 15 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index ec6f4b699a2b..9975df34d9c2 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -97,13 +97,16 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx) unsigned long flags; spin_lock_irqsave(&tls_device_lock, flags); + if (unlikely(!refcount_dec_and_test(&ctx->refcount))) + goto unlock; + list_move_tail(&ctx->list, &tls_device_gc_list); /* schedule_work inside the spinlock * to make sure tls_device_down waits for that work. */ schedule_work(&tls_device_gc_work); - +unlock: spin_unlock_irqrestore(&tls_device_lock, flags); } @@ -194,8 +197,7 @@ void tls_device_sk_destruct(struct sock *sk) clean_acked_data_disable(inet_csk(sk)); } - if (refcount_dec_and_test(&tls_ctx->refcount)) - tls_device_queue_ctx_destruction(tls_ctx); + tls_device_queue_ctx_destruction(tls_ctx); } EXPORT_SYMBOL_GPL(tls_device_sk_destruct); @@ -1374,8 +1376,13 @@ static int tls_device_down(struct net_device *netdev) * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW. * Now release the ref taken above. */ - if (refcount_dec_and_test(&ctx->refcount)) + if (refcount_dec_and_test(&ctx->refcount)) { + /* sk_destruct ran after tls_device_down took a ref, and + * it returned early. Complete the destruction here. + */ + list_del(&ctx->list); tls_device_free_ctx(ctx); + } } up_write(&device_offload_lock); @@ -1419,9 +1426,9 @@ static struct notifier_block tls_dev_notifier = { .notifier_call = tls_dev_event, }; -void __init tls_device_init(void) +int __init tls_device_init(void) { - register_netdevice_notifier(&tls_dev_notifier); + return register_netdevice_notifier(&tls_dev_notifier); } void __exit tls_device_cleanup(void) diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index b91ddc110786..d80ab3d1764e 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -544,7 +544,7 @@ static int do_tls_getsockopt(struct sock *sk, int optname, rc = do_tls_getsockopt_conf(sk, optval, optlen, optname == TLS_TX); break; - case TLS_TX_ZEROCOPY_SENDFILE: + case TLS_TX_ZEROCOPY_RO: rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); break; default: @@ -731,7 +731,7 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, optname == TLS_TX); release_sock(sk); break; - case TLS_TX_ZEROCOPY_SENDFILE: + case TLS_TX_ZEROCOPY_RO: lock_sock(sk); rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); release_sock(sk); @@ -921,6 +921,8 @@ static void tls_update(struct sock *sk, struct proto *p, { struct tls_context *ctx; + WARN_ON_ONCE(sk->sk_prot == p); + ctx = tls_get_ctx(sk); if (likely(ctx)) { ctx->sk_write_space = write_space; @@ -970,7 +972,7 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb) goto nla_failure; if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { - err = nla_put_flag(skb, TLS_INFO_ZC_SENDFILE); + err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); if (err) goto nla_failure; } @@ -994,7 +996,7 @@ static size_t tls_get_info_size(const struct sock *sk) nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ - nla_total_size(0) + /* TLS_INFO_ZC_SENDFILE */ + nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 0; return size; @@ -1046,7 +1048,12 @@ static int __init tls_register(void) if (err) return err; - tls_device_init(); + err = tls_device_init(); + if (err) { + unregister_pernet_subsys(&tls_proc_ops); + return err; + } + tcp_register_ulp(&tcp_tls_ulp_ops); return 0; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 0513f82b8537..e30649f6dde5 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -267,9 +267,6 @@ static int tls_do_decryption(struct sock *sk, } darg->async = false; - if (ret == -EBADMSG) - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); - return ret; } @@ -1579,8 +1576,11 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, } err = decrypt_internal(sk, skb, dest, NULL, darg); - if (err < 0) + if (err < 0) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return err; + } if (darg->async) goto decrypt_next; |