diff options
Diffstat (limited to 'net/ipv4/bpf_tcp_ca.c')
| -rw-r--r-- | net/ipv4/bpf_tcp_ca.c | 40 | 
1 files changed, 35 insertions, 5 deletions
diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c index 574972bc7299..e3939f76b024 100644 --- a/net/ipv4/bpf_tcp_ca.c +++ b/net/ipv4/bpf_tcp_ca.c @@ -7,6 +7,7 @@  #include <linux/btf.h>  #include <linux/filter.h>  #include <net/tcp.h> +#include <net/bpf_sk_storage.h>  static u32 optional_ops[] = {  	offsetof(struct tcp_congestion_ops, init), @@ -27,6 +28,27 @@ static u32 unsupported_ops[] = {  static const struct btf_type *tcp_sock_type;  static u32 tcp_sock_id, sock_id; +static int btf_sk_storage_get_ids[5]; +static struct bpf_func_proto btf_sk_storage_get_proto __read_mostly; + +static int btf_sk_storage_delete_ids[5]; +static struct bpf_func_proto btf_sk_storage_delete_proto __read_mostly; + +static void convert_sk_func_proto(struct bpf_func_proto *to, int *to_btf_ids, +				  const struct bpf_func_proto *from) +{ +	int i; + +	*to = *from; +	to->btf_id = to_btf_ids; +	for (i = 0; i < ARRAY_SIZE(to->arg_type); i++) { +		if (to->arg_type[i] == ARG_PTR_TO_SOCKET) { +			to->arg_type[i] = ARG_PTR_TO_BTF_ID; +			to->btf_id[i] = tcp_sock_id; +		} +	} +} +  static int bpf_tcp_ca_init(struct btf *btf)  {  	s32 type_id; @@ -42,6 +64,13 @@ static int bpf_tcp_ca_init(struct btf *btf)  	tcp_sock_id = type_id;  	tcp_sock_type = btf_type_by_id(btf, tcp_sock_id); +	convert_sk_func_proto(&btf_sk_storage_get_proto, +			      btf_sk_storage_get_ids, +			      &bpf_sk_storage_get_proto); +	convert_sk_func_proto(&btf_sk_storage_delete_proto, +			      btf_sk_storage_delete_ids, +			      &bpf_sk_storage_delete_proto); +  	return 0;  } @@ -167,6 +196,10 @@ bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,  	switch (func_id) {  	case BPF_FUNC_tcp_send_ack:  		return &bpf_tcp_send_ack_proto; +	case BPF_FUNC_sk_storage_get: +		return &btf_sk_storage_get_proto; +	case BPF_FUNC_sk_storage_delete: +		return &btf_sk_storage_delete_proto;  	default:  		return bpf_base_func_proto(func_id);  	} @@ -184,7 +217,6 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t,  {  	const struct tcp_congestion_ops *utcp_ca;  	struct tcp_congestion_ops *tcp_ca; -	size_t tcp_ca_name_len;  	int prog_fd;  	u32 moff; @@ -199,13 +231,11 @@ static int bpf_tcp_ca_init_member(const struct btf_type *t,  		tcp_ca->flags = utcp_ca->flags;  		return 1;  	case offsetof(struct tcp_congestion_ops, name): -		tcp_ca_name_len = strnlen(utcp_ca->name, sizeof(utcp_ca->name)); -		if (!tcp_ca_name_len || -		    tcp_ca_name_len == sizeof(utcp_ca->name)) +		if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name, +				     sizeof(tcp_ca->name)) <= 0)  			return -EINVAL;  		if (tcp_ca_find(utcp_ca->name))  			return -EEXIST; -		memcpy(tcp_ca->name, utcp_ca->name, sizeof(tcp_ca->name));  		return 1;  	}  |