diff options
Diffstat (limited to 'net/tls')
| -rw-r--r-- | net/tls/tls_main.c | 47 | ||||
| -rw-r--r-- | net/tls/tls_sw.c | 82 |
2 files changed, 88 insertions, 41 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index acfba9f1ba72..6bc2879ba637 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -61,7 +61,7 @@ static DEFINE_MUTEX(tcpv6_prot_mutex); static const struct proto *saved_tcpv4_prot; static DEFINE_MUTEX(tcpv4_prot_mutex); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; -static struct proto_ops tls_sw_proto_ops; +static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], const struct proto *base); @@ -71,6 +71,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx) WRITE_ONCE(sk->sk_prot, &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); + WRITE_ONCE(sk->sk_socket->ops, + &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); } int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -669,8 +671,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, if (tx) { ctx->sk_write_space = sk->sk_write_space; sk->sk_write_space = tls_write_space; - } else { - sk->sk_socket->ops = &tls_sw_proto_ops; } goto out; @@ -728,6 +728,39 @@ struct tls_context *tls_ctx_create(struct sock *sk) return ctx; } +static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + const struct proto_ops *base) +{ + ops[TLS_BASE][TLS_BASE] = *base; + + ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked; + + ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; + + ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; + ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; + +#ifdef CONFIG_TLS_DEVICE + ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL; + + ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; + ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL; + + ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; + + ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; + + ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; + ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL; +#endif +#ifdef CONFIG_TLS_TOE + ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; +#endif +} + static void tls_build_proto(struct sock *sk) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; @@ -739,6 +772,8 @@ static void tls_build_proto(struct sock *sk) mutex_lock(&tcpv6_prot_mutex); if (likely(prot != saved_tcpv6_prot)) { build_protos(tls_prots[TLSV6], prot); + build_proto_ops(tls_proto_ops[TLSV6], + sk->sk_socket->ops); smp_store_release(&saved_tcpv6_prot, prot); } mutex_unlock(&tcpv6_prot_mutex); @@ -749,6 +784,8 @@ static void tls_build_proto(struct sock *sk) mutex_lock(&tcpv4_prot_mutex); if (likely(prot != saved_tcpv4_prot)) { build_protos(tls_prots[TLSV4], prot); + build_proto_ops(tls_proto_ops[TLSV4], + sk->sk_socket->ops); smp_store_release(&saved_tcpv4_prot, prot); } mutex_unlock(&tcpv4_prot_mutex); @@ -959,10 +996,6 @@ static int __init tls_register(void) if (err) return err; - tls_sw_proto_ops = inet_stream_ops; - tls_sw_proto_ops.splice_read = tls_sw_splice_read; - tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked; - tls_device_init(); tcp_register_ulp(&tcp_tls_ulp_ops); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index d81564078557..efc84845bb6b 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -521,7 +521,7 @@ static int tls_do_encryption(struct sock *sk, memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, prot->iv_size + prot->salt_size); - xor_iv_with_seq(prot, rec->iv_data, tls_ctx->tx.rec_seq); + xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq); sge->offset += prot->prepend_size; sge->length -= prot->prepend_size; @@ -1499,7 +1499,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, else memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); - xor_iv_with_seq(prot, iv, tls_ctx->rx.rec_seq); + xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); /* Prepare AAD */ tls_make_aad(aad, rxm->full_len - prot->overhead_size + @@ -1990,6 +1990,7 @@ recv_end: end: release_sock(sk); + sk_defer_free_flush(sk); if (psock) sk_psock_put(sk, psock); return copied ? : err; @@ -2005,6 +2006,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct sock *sk = sock->sk; struct sk_buff *skb; ssize_t copied = 0; + bool from_queue; int err = 0; long timeo; int chunk; @@ -2014,25 +2016,28 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); - skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err); - if (!skb) - goto splice_read_end; - - if (!ctx->decrypted) { - err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); - - /* splice does not support reading control messages */ - if (ctx->control != TLS_RECORD_TYPE_DATA) { - err = -EINVAL; + from_queue = !skb_queue_empty(&ctx->rx_list); + if (from_queue) { + skb = __skb_dequeue(&ctx->rx_list); + } else { + skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, + &err); + if (!skb) goto splice_read_end; - } + err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; } - ctx->decrypted = 1; } + + /* splice does not support reading control messages */ + if (ctx->control != TLS_RECORD_TYPE_DATA) { + err = -EINVAL; + goto splice_read_end; + } + rxm = strp_msg(skb); chunk = min_t(unsigned int, rxm->full_len, len); @@ -2040,10 +2045,21 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (copied < 0) goto splice_read_end; - tls_sw_advance_skb(sk, skb, copied); + if (!from_queue) { + ctx->recv_pkt = NULL; + __strp_unpause(&ctx->strp); + } + if (chunk < rxm->full_len) { + __skb_queue_head(&ctx->rx_list, skb); + rxm->offset += len; + rxm->full_len -= len; + } else { + consume_skb(skb); + } splice_read_end: release_sock(sk); + sk_defer_free_flush(sk); return copied ? : err; } @@ -2314,10 +2330,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_crypto_info *crypto_info; - struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; - struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; - struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; - struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; struct cipher_context *cctx; @@ -2380,15 +2392,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: { + struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; + + gcm_128_info = (void *)crypto_info; nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; - iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; + iv = gcm_128_info->iv; rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; - rec_seq = - ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; - gcm_128_info = - (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; + rec_seq = gcm_128_info->rec_seq; keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; key = gcm_128_info->key; salt = gcm_128_info->salt; @@ -2397,15 +2409,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) break; } case TLS_CIPHER_AES_GCM_256: { + struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; + + gcm_256_info = (void *)crypto_info; nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; - iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv; + iv = gcm_256_info->iv; rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; - rec_seq = - ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq; - gcm_256_info = - (struct tls12_crypto_info_aes_gcm_256 *)crypto_info; + rec_seq = gcm_256_info->rec_seq; keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; key = gcm_256_info->key; salt = gcm_256_info->salt; @@ -2414,15 +2426,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) break; } case TLS_CIPHER_AES_CCM_128: { + struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; + + ccm_128_info = (void *)crypto_info; nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; - iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv; + iv = ccm_128_info->iv; rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; - rec_seq = - ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq; - ccm_128_info = - (struct tls12_crypto_info_aes_ccm_128 *)crypto_info; + rec_seq = ccm_128_info->rec_seq; keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; key = ccm_128_info->key; salt = ccm_128_info->salt; @@ -2431,6 +2443,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) break; } case TLS_CIPHER_CHACHA20_POLY1305: { + struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; + chacha20_poly1305_info = (void *)crypto_info; nonce_size = 0; tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE; |