diff options
Diffstat (limited to 'net/mptcp/pm_netlink.c')
| -rw-r--r-- | net/mptcp/pm_netlink.c | 184 | 
1 files changed, 100 insertions, 84 deletions
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 287a60381eae..5c17d39146ea 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -8,19 +8,13 @@  #include <linux/inet.h>  #include <linux/kernel.h> -#include <net/tcp.h>  #include <net/inet_common.h>  #include <net/netns/generic.h>  #include <net/mptcp.h> -#include <net/genetlink.h> -#include <uapi/linux/mptcp.h>  #include "protocol.h"  #include "mib.h" -/* forward declaration */ -static struct genl_family mptcp_genl_family; -  static int pm_nl_pernet_id;  struct mptcp_pm_add_entry { @@ -396,19 +390,6 @@ void mptcp_pm_free_anno_list(struct mptcp_sock *msk)  	}  } -static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned int nr, -				  const struct mptcp_addr_info *addr) -{ -	int i; - -	for (i = 0; i < nr; i++) { -		if (addrs[i].id == addr->id) -			return true; -	} - -	return false; -} -  /* Fill all the remote addresses into the array addrs[],   * and return the array size.   */ @@ -440,18 +421,34 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk,  		msk->pm.subflows++;  		addrs[i++] = remote;  	} else { +		DECLARE_BITMAP(unavail_id, MPTCP_PM_MAX_ADDR_ID + 1); + +		/* Forbid creation of new subflows matching existing +		 * ones, possibly already created by incoming ADD_ADDR +		 */ +		bitmap_zero(unavail_id, MPTCP_PM_MAX_ADDR_ID + 1); +		mptcp_for_each_subflow(msk, subflow) +			if (READ_ONCE(subflow->local_id) == local->id) +				__set_bit(subflow->remote_id, unavail_id); +  		mptcp_for_each_subflow(msk, subflow) {  			ssk = mptcp_subflow_tcp_sock(subflow);  			remote_address((struct sock_common *)ssk, &addrs[i]); -			addrs[i].id = subflow->remote_id; +			addrs[i].id = READ_ONCE(subflow->remote_id);  			if (deny_id0 && !addrs[i].id)  				continue; +			if (test_bit(addrs[i].id, unavail_id)) +				continue; +  			if (!mptcp_pm_addr_families_match(sk, local, &addrs[i]))  				continue; -			if (!lookup_address_in_vec(addrs, i, &addrs[i]) && -			    msk->pm.subflows < subflows_max) { +			if (msk->pm.subflows < subflows_max) { +				/* forbid creating multiple address towards +				 * this id +				 */ +				__set_bit(addrs[i].id, unavail_id);  				msk->pm.subflows++;  				i++;  			} @@ -502,15 +499,12 @@ __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)  }  static struct mptcp_pm_addr_entry * -__lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info, -	      bool lookup_by_id) +__lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info)  {  	struct mptcp_pm_addr_entry *entry;  	list_for_each_entry(entry, &pernet->local_addr_list, list) { -		if ((!lookup_by_id && -		     mptcp_addresses_equal(&entry->addr, info, entry->addr.port)) || -		    (lookup_by_id && entry->addr.id == info->id)) +		if (mptcp_addresses_equal(&entry->addr, info, entry->addr.port))  			return entry;  	}  	return NULL; @@ -540,7 +534,7 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)  		mptcp_local_address((struct sock_common *)msk->first, &mpc_addr);  		rcu_read_lock(); -		entry = __lookup_addr(pernet, &mpc_addr, false); +		entry = __lookup_addr(pernet, &mpc_addr);  		if (entry) {  			__clear_bit(entry->addr.id, msk->pm.id_avail_bitmap);  			msk->mpc_endpoint_id = entry->addr.id; @@ -799,18 +793,18 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,  		mptcp_for_each_subflow_safe(msk, subflow, tmp) {  			struct sock *ssk = mptcp_subflow_tcp_sock(subflow); +			u8 remote_id = READ_ONCE(subflow->remote_id);  			int how = RCV_SHUTDOWN | SEND_SHUTDOWN; -			u8 id = subflow->local_id; +			u8 id = subflow_get_local_id(subflow); -			if (rm_type == MPTCP_MIB_RMADDR && subflow->remote_id != rm_id) +			if (rm_type == MPTCP_MIB_RMADDR && remote_id != rm_id)  				continue;  			if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id))  				continue;  			pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u",  				 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", -				 i, rm_id, subflow->local_id, subflow->remote_id, -				 msk->mpc_endpoint_id); +				 i, rm_id, id, remote_id, msk->mpc_endpoint_id);  			spin_unlock_bh(&msk->pm.lock);  			mptcp_subflow_shutdown(sk, ssk, how); @@ -901,7 +895,8 @@ static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)  }  static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, -					     struct mptcp_pm_addr_entry *entry) +					     struct mptcp_pm_addr_entry *entry, +					     bool needs_id)  {  	struct mptcp_pm_addr_entry *cur, *del_entry = NULL;  	unsigned int addr_max; @@ -949,7 +944,7 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,  		}  	} -	if (!entry->addr.id) { +	if (!entry->addr.id && needs_id) {  find_next:  		entry->addr.id = find_next_zero_bit(pernet->id_bitmap,  						    MPTCP_PM_MAX_ADDR_ID + 1, @@ -960,7 +955,7 @@ find_next:  		}  	} -	if (!entry->addr.id) +	if (!entry->addr.id && needs_id)  		goto out;  	__set_bit(entry->addr.id, pernet->id_bitmap); @@ -1092,7 +1087,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc  	entry->ifindex = 0;  	entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;  	entry->lsk = NULL; -	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); +	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry, true);  	if (ret < 0)  		kfree(entry); @@ -1285,6 +1280,18 @@ next:  	return 0;  } +static bool mptcp_pm_has_addr_attr_id(const struct nlattr *attr, +				      struct genl_info *info) +{ +	struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; + +	if (!nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, +					 mptcp_pm_address_nl_policy, info->extack) && +	    tb[MPTCP_PM_ADDR_ATTR_ID]) +		return true; +	return false; +} +  int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info)  {  	struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR]; @@ -1326,7 +1333,8 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info)  			goto out_free;  		}  	} -	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); +	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry, +						!mptcp_pm_has_addr_attr_id(attr, info));  	if (ret < 0) {  		GENL_SET_ERR_MSG_FMT(info, "too many addresses or duplicate one: %d", ret);  		goto out_free; @@ -1533,8 +1541,8 @@ void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list)  	}  } -void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, -					struct list_head *rm_list) +static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, +					       struct list_head *rm_list)  {  	struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };  	struct mptcp_pm_addr_entry *entry; @@ -1619,8 +1627,8 @@ int mptcp_pm_nl_flush_addrs_doit(struct sk_buff *skb, struct genl_info *info)  	return 0;  } -static int mptcp_nl_fill_addr(struct sk_buff *skb, -			      struct mptcp_pm_addr_entry *entry) +int mptcp_nl_fill_addr(struct sk_buff *skb, +		       struct mptcp_pm_addr_entry *entry)  {  	struct mptcp_addr_info *addr = &entry->addr;  	struct nlattr *attr; @@ -1658,7 +1666,7 @@ nla_put_failure:  	return -EMSGSIZE;  } -int mptcp_pm_nl_get_addr_doit(struct sk_buff *skb, struct genl_info *info) +int mptcp_pm_nl_get_addr(struct sk_buff *skb, struct genl_info *info)  {  	struct nlattr *attr = info->attrs[MPTCP_PM_ENDPOINT_ADDR];  	struct pm_nl_pernet *pernet = genl_info_pm_nl(info); @@ -1708,8 +1716,13 @@ fail:  	return ret;  } -int mptcp_pm_nl_get_addr_dumpit(struct sk_buff *msg, -				struct netlink_callback *cb) +int mptcp_pm_nl_get_addr_doit(struct sk_buff *skb, struct genl_info *info) +{ +	return mptcp_pm_get_addr(skb, info); +} + +int mptcp_pm_nl_dump_addr(struct sk_buff *msg, +			  struct netlink_callback *cb)  {  	struct net *net = sock_net(msg->sk);  	struct mptcp_pm_addr_entry *entry; @@ -1751,6 +1764,12 @@ int mptcp_pm_nl_get_addr_dumpit(struct sk_buff *msg,  	return msg->len;  } +int mptcp_pm_nl_get_addr_dumpit(struct sk_buff *msg, +				struct netlink_callback *cb) +{ +	return mptcp_pm_dump_addr(msg, cb); +} +  static int parse_limit(struct genl_info *info, int id, unsigned int *limit)  {  	struct nlattr *attr = info->attrs[id]; @@ -1865,66 +1884,63 @@ next:  	return ret;  } -int mptcp_pm_nl_set_flags(struct net *net, struct mptcp_pm_addr_entry *addr, u8 bkup) +int mptcp_pm_nl_set_flags(struct sk_buff *skb, struct genl_info *info)  { -	struct pm_nl_pernet *pernet = pm_nl_get_pernet(net); +	struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }; +	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];  	u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP |  			   MPTCP_PM_ADDR_FLAG_FULLMESH; +	struct net *net = sock_net(skb->sk);  	struct mptcp_pm_addr_entry *entry; +	struct pm_nl_pernet *pernet;  	u8 lookup_by_id = 0; +	u8 bkup = 0; +	int ret; -	if (addr->addr.family == AF_UNSPEC) { +	pernet = pm_nl_get_pernet(net); + +	ret = mptcp_pm_parse_entry(attr, info, false, &addr); +	if (ret < 0) +		return ret; + +	if (addr.addr.family == AF_UNSPEC) {  		lookup_by_id = 1; -		if (!addr->addr.id) +		if (!addr.addr.id) { +			GENL_SET_ERR_MSG(info, "missing required inputs");  			return -EOPNOTSUPP; +		}  	} +	if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) +		bkup = 1; +  	spin_lock_bh(&pernet->lock); -	entry = __lookup_addr(pernet, &addr->addr, lookup_by_id); +	entry = lookup_by_id ? __lookup_addr_by_id(pernet, addr.addr.id) : +			       __lookup_addr(pernet, &addr.addr);  	if (!entry) {  		spin_unlock_bh(&pernet->lock); +		GENL_SET_ERR_MSG(info, "address not found");  		return -EINVAL;  	} -	if ((addr->flags & MPTCP_PM_ADDR_FLAG_FULLMESH) && +	if ((addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) &&  	    (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {  		spin_unlock_bh(&pernet->lock); +		GENL_SET_ERR_MSG(info, "invalid addr flags");  		return -EINVAL;  	} -	changed = (addr->flags ^ entry->flags) & mask; -	entry->flags = (entry->flags & ~mask) | (addr->flags & mask); -	*addr = *entry; +	changed = (addr.flags ^ entry->flags) & mask; +	entry->flags = (entry->flags & ~mask) | (addr.flags & mask); +	addr = *entry;  	spin_unlock_bh(&pernet->lock); -	mptcp_nl_set_flags(net, &addr->addr, bkup, changed); +	mptcp_nl_set_flags(net, &addr.addr, bkup, changed);  	return 0;  }  int mptcp_pm_nl_set_flags_doit(struct sk_buff *skb, struct genl_info *info)  { -	struct mptcp_pm_addr_entry remote = { .addr = { .family = AF_UNSPEC }, }; -	struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }; -	struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; -	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; -	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; -	struct net *net = sock_net(skb->sk); -	u8 bkup = 0; -	int ret; - -	ret = mptcp_pm_parse_entry(attr, info, false, &addr); -	if (ret < 0) -		return ret; - -	if (attr_rem) { -		ret = mptcp_pm_parse_entry(attr_rem, info, false, &remote); -		if (ret < 0) -			return ret; -	} - -	if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) -		bkup = 1; - -	return mptcp_pm_set_flags(net, token, &addr, &remote, bkup); +	return mptcp_pm_set_flags(skb, info);  }  static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp) @@ -1980,7 +1996,7 @@ static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)  	if (WARN_ON_ONCE(!sf))  		return -EINVAL; -	if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id)) +	if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, subflow_get_local_id(sf)))  		return -EMSGSIZE;  	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id)) @@ -1997,7 +2013,7 @@ static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,  	const struct mptcp_subflow_context *sf;  	u8 sk_err; -	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) +	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))  		return -EMSGSIZE;  	if (mptcp_event_add_subflow(skb, ssk)) @@ -2055,7 +2071,7 @@ static int mptcp_event_created(struct sk_buff *skb,  			       const struct mptcp_sock *msk,  			       const struct sock *ssk)  { -	int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token); +	int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token));  	if (err)  		return err; @@ -2083,7 +2099,7 @@ void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)  	if (!nlh)  		goto nla_put_failure; -	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) +	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))  		goto nla_put_failure;  	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id)) @@ -2118,7 +2134,7 @@ void mptcp_event_addr_announced(const struct sock *ssk,  	if (!nlh)  		goto nla_put_failure; -	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) +	if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)))  		goto nla_put_failure;  	if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id)) @@ -2234,7 +2250,7 @@ void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,  			goto nla_put_failure;  		break;  	case MPTCP_EVENT_CLOSED: -		if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0) +		if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, READ_ONCE(msk->token)) < 0)  			goto nla_put_failure;  		break;  	case MPTCP_EVENT_ANNOUNCED: @@ -2264,7 +2280,7 @@ nla_put_failure:  	nlmsg_free(skb);  } -static struct genl_family mptcp_genl_family __ro_after_init = { +struct genl_family mptcp_genl_family __ro_after_init = {  	.name		= MPTCP_PM_NAME,  	.version	= MPTCP_PM_VER,  	.netnsok	= true,  |