diff options
Diffstat (limited to 'net/tls/tls_sw.c')
| -rw-r--r-- | net/tls/tls_sw.c | 180 | 
1 files changed, 104 insertions, 76 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index e1c93ce74e0f..4618f1c31137 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -52,7 +52,7 @@ static int tls_do_decryption(struct sock *sk,  			     gfp_t flags)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	struct strp_msg *rxm = strp_msg(skb);  	struct aead_request *aead_req; @@ -122,7 +122,7 @@ out:  static void trim_both_sgl(struct sock *sk, int target_size)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	trim_sg(sk, ctx->sg_plaintext_data,  		&ctx->sg_plaintext_num_elem, @@ -141,7 +141,7 @@ static void trim_both_sgl(struct sock *sk, int target_size)  static int alloc_encrypted_sg(struct sock *sk, int len)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	int rc = 0;  	rc = sk_alloc_sg(sk, len, @@ -155,7 +155,7 @@ static int alloc_encrypted_sg(struct sock *sk, int len)  static int alloc_plaintext_sg(struct sock *sk, int len)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	int rc = 0;  	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, @@ -181,7 +181,7 @@ static void free_sg(struct sock *sk, struct scatterlist *sg,  static void tls_free_both_sg(struct sock *sk)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,  		&ctx->sg_encrypted_size); @@ -191,18 +191,12 @@ static void tls_free_both_sg(struct sock *sk)  }  static int tls_do_encryption(struct tls_context *tls_ctx, -			     struct tls_sw_context *ctx, size_t data_len, -			     gfp_t flags) +			     struct tls_sw_context_tx *ctx, +			     struct aead_request *aead_req, +			     size_t data_len)  { -	unsigned int req_size = sizeof(struct aead_request) + -		crypto_aead_reqsize(ctx->aead_send); -	struct aead_request *aead_req;  	int rc; -	aead_req = kzalloc(req_size, flags); -	if (!aead_req) -		return -ENOMEM; -  	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;  	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; @@ -219,7 +213,6 @@ static int tls_do_encryption(struct tls_context *tls_ctx,  	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;  	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; -	kfree(aead_req);  	return rc;  } @@ -227,9 +220,15 @@ static int tls_push_record(struct sock *sk, int flags,  			   unsigned char record_type)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); +	struct aead_request *req;  	int rc; +	req = kzalloc(sizeof(struct aead_request) + +		      crypto_aead_reqsize(ctx->aead_send), sk->sk_allocation); +	if (!req) +		return -ENOMEM; +  	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);  	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); @@ -245,15 +244,14 @@ static int tls_push_record(struct sock *sk, int flags,  	tls_ctx->pending_open_record_frags = 0;  	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags); -	rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size, -			       sk->sk_allocation); +	rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);  	if (rc < 0) {  		/* If we are called from write_space and  		 * we fail, we need to set this SOCK_NOSPACE  		 * to trigger another write_space in the future.  		 */  		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); -		return rc; +		goto out_req;  	}  	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, @@ -268,6 +266,8 @@ static int tls_push_record(struct sock *sk, int flags,  		tls_err_abort(sk, EBADMSG);  	tls_advance_record_sn(sk, &tls_ctx->tx); +out_req: +	kfree(req);  	return rc;  } @@ -339,7 +339,7 @@ static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,  			     int bytes)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	struct scatterlist *sg = ctx->sg_plaintext_data;  	int copy, i, rc = 0; @@ -367,7 +367,7 @@ out:  int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	int ret = 0;  	int required_size;  	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); @@ -440,7 +440,7 @@ alloc_encrypted:  			ret = tls_push_record(sk, msg->msg_flags, record_type);  			if (!ret)  				continue; -			if (ret == -EAGAIN) +			if (ret < 0)  				goto send_end;  			copied -= try_to_copy; @@ -522,7 +522,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,  		    int offset, size_t size, int flags)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	int ret = 0;  	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);  	bool eor; @@ -636,7 +636,7 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,  				     long timeo, int *err)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	struct sk_buff *skb;  	DEFINE_WAIT_FUNC(wait, woken_wake_function); @@ -674,7 +674,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,  		       struct scatterlist *sgout)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];  	struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];  	struct scatterlist *sgin = &sgin_arr[0]; @@ -692,8 +692,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,  	if (!sgout) {  		nsg = skb_cow_data(skb, 0, &unused) + 1;  		sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); -		if (!sgout) -			sgout = sgin; +		sgout = sgin;  	}  	sg_init_table(sgin, nsg); @@ -702,6 +701,10 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,  	nsg = skb_to_sgvec(skb, &sgin[1],  			   rxm->offset + tls_ctx->rx.prepend_size,  			   rxm->full_len - tls_ctx->rx.prepend_size); +	if (nsg < 0) { +		ret = nsg; +		goto out; +	}  	tls_make_aad(ctx->rx_aad_ciphertext,  		     rxm->full_len - tls_ctx->rx.overhead_size, @@ -713,6 +716,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,  				rxm->full_len - tls_ctx->rx.overhead_size,  				skb, sk->sk_allocation); +out:  	if (sgin != &sgin_arr[0])  		kfree(sgin); @@ -723,7 +727,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,  			       unsigned int len)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	struct strp_msg *rxm = strp_msg(skb);  	if (len < rxm->full_len) { @@ -736,7 +740,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,  	/* Finished with message */  	ctx->recv_pkt = NULL;  	kfree_skb(skb); -	strp_unpause(&ctx->strp); +	__strp_unpause(&ctx->strp);  	return true;  } @@ -749,13 +753,13 @@ int tls_sw_recvmsg(struct sock *sk,  		   int *addr_len)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	unsigned char control;  	struct strp_msg *rxm;  	struct sk_buff *skb;  	ssize_t copied = 0;  	bool cmsg = false; -	int err = 0; +	int target, err = 0;  	long timeo;  	flags |= nonblock; @@ -765,6 +769,7 @@ int tls_sw_recvmsg(struct sock *sk,  	lock_sock(sk); +	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);  	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);  	do {  		bool zc = false; @@ -857,6 +862,9 @@ fallback_to_reg_recv:  					goto recv_end;  			}  		} +		/* If we have a new message from strparser, continue now. */ +		if (copied >= target && !ctx->recv_pkt) +			break;  	} while (len);  recv_end: @@ -869,7 +877,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,  			   size_t len, unsigned int flags)  {  	struct tls_context *tls_ctx = tls_get_ctx(sock->sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	struct strp_msg *rxm = NULL;  	struct sock *sk = sock->sk;  	struct sk_buff *skb; @@ -922,7 +930,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock,  	unsigned int ret;  	struct sock *sk = sock->sk;  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	/* Grab POLLOUT and POLLHUP from the underlying socket */  	ret = ctx->sk_poll(file, sock, wait); @@ -938,7 +946,7 @@ unsigned int tls_sw_poll(struct file *file, struct socket *sock,  static int tls_read_size(struct strparser *strp, struct sk_buff *skb)  {  	struct tls_context *tls_ctx = tls_get_ctx(strp->sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	char header[tls_ctx->rx.prepend_size];  	struct strp_msg *rxm = strp_msg(skb);  	size_t cipher_overhead; @@ -987,7 +995,7 @@ read_failure:  static void tls_queue(struct strparser *strp, struct sk_buff *skb)  {  	struct tls_context *tls_ctx = tls_get_ctx(strp->sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	struct strp_msg *rxm;  	rxm = strp_msg(skb); @@ -1003,18 +1011,28 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)  static void tls_data_ready(struct sock *sk)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);  	strp_data_ready(&ctx->strp);  } -void tls_sw_free_resources(struct sock *sk) +void tls_sw_free_resources_tx(struct sock *sk)  {  	struct tls_context *tls_ctx = tls_get_ctx(sk); -	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); +	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);  	if (ctx->aead_send)  		crypto_free_aead(ctx->aead_send); +	tls_free_both_sg(sk); + +	kfree(ctx); +} + +void tls_sw_free_resources_rx(struct sock *sk) +{ +	struct tls_context *tls_ctx = tls_get_ctx(sk); +	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); +  	if (ctx->aead_recv) {  		if (ctx->recv_pkt) {  			kfree_skb(ctx->recv_pkt); @@ -1030,10 +1048,7 @@ void tls_sw_free_resources(struct sock *sk)  		lock_sock(sk);  	} -	tls_free_both_sg(sk); -  	kfree(ctx); -	kfree(tls_ctx);  }  int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) @@ -1041,7 +1056,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)  	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];  	struct tls_crypto_info *crypto_info;  	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; -	struct tls_sw_context *sw_ctx; +	struct tls_sw_context_tx *sw_ctx_tx = NULL; +	struct tls_sw_context_rx *sw_ctx_rx = NULL;  	struct cipher_context *cctx;  	struct crypto_aead **aead;  	struct strp_callbacks cb; @@ -1054,27 +1070,32 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)  		goto out;  	} -	if (!ctx->priv_ctx) { -		sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); -		if (!sw_ctx) { +	if (tx) { +		sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); +		if (!sw_ctx_tx) {  			rc = -ENOMEM;  			goto out;  		} -		crypto_init_wait(&sw_ctx->async_wait); +		crypto_init_wait(&sw_ctx_tx->async_wait); +		ctx->priv_ctx_tx = sw_ctx_tx;  	} else { -		sw_ctx = ctx->priv_ctx; +		sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); +		if (!sw_ctx_rx) { +			rc = -ENOMEM; +			goto out; +		} +		crypto_init_wait(&sw_ctx_rx->async_wait); +		ctx->priv_ctx_rx = sw_ctx_rx;  	} -	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; -  	if (tx) {  		crypto_info = &ctx->crypto_send;  		cctx = &ctx->tx; -		aead = &sw_ctx->aead_send; +		aead = &sw_ctx_tx->aead_send;  	} else {  		crypto_info = &ctx->crypto_recv;  		cctx = &ctx->rx; -		aead = &sw_ctx->aead_recv; +		aead = &sw_ctx_rx->aead_recv;  	}  	switch (crypto_info->cipher_type) { @@ -1121,22 +1142,24 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)  	}  	memcpy(cctx->rec_seq, rec_seq, rec_seq_size); -	if (tx) { -		sg_init_table(sw_ctx->sg_encrypted_data, -			      ARRAY_SIZE(sw_ctx->sg_encrypted_data)); -		sg_init_table(sw_ctx->sg_plaintext_data, -			      ARRAY_SIZE(sw_ctx->sg_plaintext_data)); - -		sg_init_table(sw_ctx->sg_aead_in, 2); -		sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, -			   sizeof(sw_ctx->aad_space)); -		sg_unmark_end(&sw_ctx->sg_aead_in[1]); -		sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); -		sg_init_table(sw_ctx->sg_aead_out, 2); -		sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, -			   sizeof(sw_ctx->aad_space)); -		sg_unmark_end(&sw_ctx->sg_aead_out[1]); -		sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); +	if (sw_ctx_tx) { +		sg_init_table(sw_ctx_tx->sg_encrypted_data, +			      ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); +		sg_init_table(sw_ctx_tx->sg_plaintext_data, +			      ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); + +		sg_init_table(sw_ctx_tx->sg_aead_in, 2); +		sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, +			   sizeof(sw_ctx_tx->aad_space)); +		sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); +		sg_chain(sw_ctx_tx->sg_aead_in, 2, +			 sw_ctx_tx->sg_plaintext_data); +		sg_init_table(sw_ctx_tx->sg_aead_out, 2); +		sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, +			   sizeof(sw_ctx_tx->aad_space)); +		sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); +		sg_chain(sw_ctx_tx->sg_aead_out, 2, +			 sw_ctx_tx->sg_encrypted_data);  	}  	if (!*aead) { @@ -1161,22 +1184,22 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)  	if (rc)  		goto free_aead; -	if (!tx) { +	if (sw_ctx_rx) {  		/* Set up strparser */  		memset(&cb, 0, sizeof(cb));  		cb.rcv_msg = tls_queue;  		cb.parse_msg = tls_read_size; -		strp_init(&sw_ctx->strp, sk, &cb); +		strp_init(&sw_ctx_rx->strp, sk, &cb);  		write_lock_bh(&sk->sk_callback_lock); -		sw_ctx->saved_data_ready = sk->sk_data_ready; +		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;  		sk->sk_data_ready = tls_data_ready;  		write_unlock_bh(&sk->sk_callback_lock); -		sw_ctx->sk_poll = sk->sk_socket->ops->poll; +		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; -		strp_check_rcv(&sw_ctx->strp); +		strp_check_rcv(&sw_ctx_rx->strp);  	}  	goto out; @@ -1188,11 +1211,16 @@ free_rec_seq:  	kfree(cctx->rec_seq);  	cctx->rec_seq = NULL;  free_iv: -	kfree(ctx->tx.iv); -	ctx->tx.iv = NULL; +	kfree(cctx->iv); +	cctx->iv = NULL;  free_priv: -	kfree(ctx->priv_ctx); -	ctx->priv_ctx = NULL; +	if (tx) { +		kfree(ctx->priv_ctx_tx); +		ctx->priv_ctx_tx = NULL; +	} else { +		kfree(ctx->priv_ctx_rx); +		ctx->priv_ctx_rx = NULL; +	}  out:  	return rc;  }  |