diff options
Diffstat (limited to 'net/tls/tls_main.c')
| -rw-r--r-- | net/tls/tls_main.c | 139 | 
1 files changed, 87 insertions, 52 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 20cd93be6236..301f22430469 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -51,12 +51,12 @@ enum {  	TLSV6,  	TLS_NUM_PROTS,  }; -  enum {  	TLS_BASE, -	TLS_SW_TX, -	TLS_SW_RX, -	TLS_SW_RXTX, +	TLS_SW, +#ifdef CONFIG_TLS_DEVICE +	TLS_HW, +#endif  	TLS_HW_RECORD,  	TLS_NUM_CONFIG,  }; @@ -65,14 +65,14 @@ static struct proto *saved_tcpv6_prot;  static DEFINE_MUTEX(tcpv6_prot_mutex);  static LIST_HEAD(device_list);  static DEFINE_MUTEX(device_mutex); -static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; +static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];  static struct proto_ops tls_sw_proto_ops; -static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) +static void update_sk_prot(struct sock *sk, struct tls_context *ctx)  {  	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; -	sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; +	sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];  }  int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -254,7 +254,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)  	lock_sock(sk);  	sk_proto_close = ctx->sk_proto_close; -	if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) { +	if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) || +	    (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) {  		free_ctx = true;  		goto skip_tx_cleanup;  	} @@ -275,15 +276,26 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)  		}  	} -	kfree(ctx->tx.rec_seq); -	kfree(ctx->tx.iv); -	kfree(ctx->rx.rec_seq); -	kfree(ctx->rx.iv); +	/* We need these for tls_sw_fallback handling of other packets */ +	if (ctx->tx_conf == TLS_SW) { +		kfree(ctx->tx.rec_seq); +		kfree(ctx->tx.iv); +		tls_sw_free_resources_tx(sk); +	} -	if (ctx->conf == TLS_SW_TX || -	    ctx->conf == TLS_SW_RX || -	    ctx->conf == TLS_SW_RXTX) { -		tls_sw_free_resources(sk); +	if (ctx->rx_conf == TLS_SW) { +		kfree(ctx->rx.rec_seq); +		kfree(ctx->rx.iv); +		tls_sw_free_resources_rx(sk); +	} + +#ifdef CONFIG_TLS_DEVICE +	if (ctx->tx_conf != TLS_HW) { +#else +	{ +#endif +		kfree(ctx); +		ctx = NULL;  	}  skip_tx_cleanup: @@ -446,25 +458,29 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,  		goto err_crypto_info;  	} -	/* currently SW is default, we will have ethtool in future */  	if (tx) { -		rc = tls_set_sw_offload(sk, ctx, 1); -		if (ctx->conf == TLS_SW_RX) -			conf = TLS_SW_RXTX; -		else -			conf = TLS_SW_TX; +#ifdef CONFIG_TLS_DEVICE +		rc = tls_set_device_offload(sk, ctx); +		conf = TLS_HW; +		if (rc) { +#else +		{ +#endif +			rc = tls_set_sw_offload(sk, ctx, 1); +			conf = TLS_SW; +		}  	} else {  		rc = tls_set_sw_offload(sk, ctx, 0); -		if (ctx->conf == TLS_SW_TX) -			conf = TLS_SW_RXTX; -		else -			conf = TLS_SW_RX; +		conf = TLS_SW;  	}  	if (rc)  		goto err_crypto_info; -	ctx->conf = conf; +	if (tx) +		ctx->tx_conf = conf; +	else +		ctx->rx_conf = conf;  	update_sk_prot(sk, ctx);  	if (tx) {  		ctx->sk_write_space = sk->sk_write_space; @@ -540,7 +556,8 @@ static int tls_hw_prot(struct sock *sk)  			ctx->hash = sk->sk_prot->hash;  			ctx->unhash = sk->sk_prot->unhash;  			ctx->sk_proto_close = sk->sk_prot->close; -			ctx->conf = TLS_HW_RECORD; +			ctx->rx_conf = TLS_HW_RECORD; +			ctx->tx_conf = TLS_HW_RECORD;  			update_sk_prot(sk, ctx);  			rc = 1;  			break; @@ -584,29 +601,40 @@ static int tls_hw_hash(struct sock *sk)  	return err;  } -static void build_protos(struct proto *prot, struct proto *base) +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], +			 struct proto *base)  { -	prot[TLS_BASE] = *base; -	prot[TLS_BASE].setsockopt	= tls_setsockopt; -	prot[TLS_BASE].getsockopt	= tls_getsockopt; -	prot[TLS_BASE].close		= tls_sk_proto_close; - -	prot[TLS_SW_TX] = prot[TLS_BASE]; -	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg; -	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage; - -	prot[TLS_SW_RX] = prot[TLS_BASE]; -	prot[TLS_SW_RX].recvmsg		= tls_sw_recvmsg; -	prot[TLS_SW_RX].close		= tls_sk_proto_close; - -	prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; -	prot[TLS_SW_RXTX].recvmsg	= tls_sw_recvmsg; -	prot[TLS_SW_RXTX].close		= tls_sk_proto_close; - -	prot[TLS_HW_RECORD] = *base; -	prot[TLS_HW_RECORD].hash	= tls_hw_hash; -	prot[TLS_HW_RECORD].unhash	= tls_hw_unhash; -	prot[TLS_HW_RECORD].close	= tls_sk_proto_close; +	prot[TLS_BASE][TLS_BASE] = *base; +	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt; +	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt; +	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close; + +	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; +	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg; +	prot[TLS_SW][TLS_BASE].sendpage		= tls_sw_sendpage; + +	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; +	prot[TLS_BASE][TLS_SW].recvmsg		= tls_sw_recvmsg; +	prot[TLS_BASE][TLS_SW].close		= tls_sk_proto_close; + +	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; +	prot[TLS_SW][TLS_SW].recvmsg	= tls_sw_recvmsg; +	prot[TLS_SW][TLS_SW].close	= tls_sk_proto_close; + +#ifdef CONFIG_TLS_DEVICE +	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; +	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg; +	prot[TLS_HW][TLS_BASE].sendpage		= tls_device_sendpage; + +	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; +	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg; +	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage; +#endif + +	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; +	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_hw_hash; +	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_hw_unhash; +	prot[TLS_HW_RECORD][TLS_HW_RECORD].close	= tls_sk_proto_close;  }  static int tls_init(struct sock *sk) @@ -637,7 +665,7 @@ static int tls_init(struct sock *sk)  	ctx->getsockopt = sk->sk_prot->getsockopt;  	ctx->sk_proto_close = sk->sk_prot->close; -	/* Build IPv6 TLS whenever the address of tcpv6_prot changes */ +	/* Build IPv6 TLS whenever the address of tcpv6	_prot changes */  	if (ip_ver == TLSV6 &&  	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {  		mutex_lock(&tcpv6_prot_mutex); @@ -648,7 +676,8 @@ static int tls_init(struct sock *sk)  		mutex_unlock(&tcpv6_prot_mutex);  	} -	ctx->conf = TLS_BASE; +	ctx->tx_conf = TLS_BASE; +	ctx->rx_conf = TLS_BASE;  	update_sk_prot(sk, ctx);  out:  	return rc; @@ -686,6 +715,9 @@ static int __init tls_register(void)  	tls_sw_proto_ops.poll = tls_sw_poll;  	tls_sw_proto_ops.splice_read = tls_sw_splice_read; +#ifdef CONFIG_TLS_DEVICE +	tls_device_init(); +#endif  	tcp_register_ulp(&tcp_tls_ulp_ops);  	return 0; @@ -694,6 +726,9 @@ static int __init tls_register(void)  static void __exit tls_unregister(void)  {  	tcp_unregister_ulp(&tcp_tls_ulp_ops); +#ifdef CONFIG_TLS_DEVICE +	tls_device_cleanup(); +#endif  }  module_init(tls_register);  |