diff options
Diffstat (limited to 'net/mptcp/pm_netlink.c')
-rw-r--r-- | net/mptcp/pm_netlink.c | 198 |
1 files changed, 113 insertions, 85 deletions
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 3e4ad801786f..64fe0e7d87d7 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -130,12 +130,15 @@ static bool lookup_subflow_by_daddr(const struct list_head *list, { struct mptcp_subflow_context *subflow; struct mptcp_addr_info cur; - struct sock_common *skc; list_for_each_entry(subflow, list, node) { - skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); - remote_address(skc, &cur); + if (!((1 << inet_sk_state_load(ssk)) & + (TCPF_ESTABLISHED | TCPF_SYN_SENT | TCPF_SYN_RECV))) + continue; + + remote_address((struct sock_common *)ssk, &cur); if (mptcp_addresses_equal(&cur, daddr, daddr->port)) return true; } @@ -146,7 +149,7 @@ static bool lookup_subflow_by_daddr(const struct list_head *list, static bool select_local_address(const struct pm_nl_pernet *pernet, const struct mptcp_sock *msk, - struct mptcp_pm_addr_entry *new_entry) + struct mptcp_pm_local *new_local) { struct mptcp_pm_addr_entry *entry; bool found = false; @@ -161,7 +164,9 @@ select_local_address(const struct pm_nl_pernet *pernet, if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap)) continue; - *new_entry = *entry; + new_local->addr = entry->addr; + new_local->flags = entry->flags; + new_local->ifindex = entry->ifindex; found = true; break; } @@ -172,7 +177,7 @@ select_local_address(const struct pm_nl_pernet *pernet, static bool select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk, - struct mptcp_pm_addr_entry *new_entry) + struct mptcp_pm_local *new_local) { struct mptcp_pm_addr_entry *entry; bool found = false; @@ -190,7 +195,9 @@ select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk, if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) continue; - *new_entry = *entry; + new_local->addr = entry->addr; + new_local->flags = entry->flags; + new_local->ifindex = entry->ifindex; found = true; break; } @@ -287,7 +294,7 @@ static void mptcp_pm_add_timer(struct timer_list *timer) struct mptcp_sock *msk = entry->sock; struct sock *sk = (struct sock *)msk; - pr_debug("msk=%p", msk); + pr_debug("msk=%p\n", msk); if (!msk) return; @@ -306,7 +313,7 @@ static void mptcp_pm_add_timer(struct timer_list *timer) spin_lock_bh(&msk->pm.lock); if (!mptcp_pm_should_add_signal_addr(msk)) { - pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); + pr_debug("retransmit ADD_ADDR id=%d\n", entry->addr.id); mptcp_pm_announce_addr(msk, &entry->addr, false); mptcp_pm_add_addr_send_ack(msk); entry->retrans_times++; @@ -331,15 +338,21 @@ mptcp_pm_del_add_timer(struct mptcp_sock *msk, { struct mptcp_pm_add_entry *entry; struct sock *sk = (struct sock *)msk; + struct timer_list *add_timer = NULL; spin_lock_bh(&msk->pm.lock); entry = mptcp_lookup_anno_list_by_saddr(msk, addr); - if (entry && (!check_id || entry->addr.id == addr->id)) + if (entry && (!check_id || entry->addr.id == addr->id)) { entry->retrans_times = ADD_ADDR_RETRANS_MAX; + add_timer = &entry->add_timer; + } + if (!check_id && entry) + list_del(&entry->list); spin_unlock_bh(&msk->pm.lock); - if (entry && (!check_id || entry->addr.id == addr->id)) - sk_stop_timer_sync(sk, &entry->add_timer); + /* no lock, because sk_stop_timer_sync() is calling del_timer_sync() */ + if (add_timer) + sk_stop_timer_sync(sk, add_timer); return entry; } @@ -387,7 +400,7 @@ void mptcp_pm_free_anno_list(struct mptcp_sock *msk) struct sock *sk = (struct sock *)msk; LIST_HEAD(free_list); - pr_debug("msk=%p", msk); + pr_debug("msk=%p\n", msk); spin_lock_bh(&msk->pm.lock); list_splice_init(&msk->pm.anno_list, &free_list); @@ -473,7 +486,7 @@ static void __mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_con struct sock *ssk = mptcp_subflow_tcp_sock(subflow); bool slow; - pr_debug("send ack for %s", + pr_debug("send ack for %s\n", prio ? "mp_prio" : (mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr")); slow = lock_sock_fast(ssk); @@ -521,11 +534,11 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info) static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) { struct sock *sk = (struct sock *)msk; - struct mptcp_pm_addr_entry local; unsigned int add_addr_signal_max; bool signal_and_subflow = false; unsigned int local_addr_max; struct pm_nl_pernet *pernet; + struct mptcp_pm_local local; unsigned int subflows_max; pernet = pm_nl_get_pernet(sock_net(sk)); @@ -585,6 +598,11 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) __clear_bit(local.addr.id, msk->pm.id_avail_bitmap); msk->pm.add_addr_signaled++; + + /* Special case for ID0: set the correct ID */ + if (local.addr.id == msk->mpc_endpoint_id) + local.addr.id = 0; + mptcp_pm_announce_addr(msk, &local.addr, false); mptcp_pm_nl_addr_send_ack(msk); @@ -607,15 +625,21 @@ subflow: fullmesh = !!(local.flags & MPTCP_PM_ADDR_FLAG_FULLMESH); - msk->pm.local_addr_used++; __clear_bit(local.addr.id, msk->pm.id_avail_bitmap); + + /* Special case for ID0: set the correct ID */ + if (local.addr.id == msk->mpc_endpoint_id) + local.addr.id = 0; + else /* local_addr_used is not decr for ID 0 */ + msk->pm.local_addr_used++; + nr = fill_remote_addresses_vec(msk, &local.addr, fullmesh, addrs); if (nr == 0) continue; spin_unlock_bh(&msk->pm.lock); for (i = 0; i < nr; i++) - __mptcp_subflow_connect(sk, &local.addr, &addrs[i]); + __mptcp_subflow_connect(sk, &local, &addrs[i]); spin_lock_bh(&msk->pm.lock); } mptcp_pm_nl_check_work_pending(msk); @@ -636,7 +660,7 @@ static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) */ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, struct mptcp_addr_info *remote, - struct mptcp_addr_info *addrs) + struct mptcp_pm_local *locals) { struct sock *sk = (struct sock *)msk; struct mptcp_pm_addr_entry *entry; @@ -659,13 +683,15 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, continue; if (msk->pm.subflows < subflows_max) { - msk->pm.subflows++; - addrs[i] = entry->addr; + locals[i].addr = entry->addr; + locals[i].flags = entry->flags; + locals[i].ifindex = entry->ifindex; /* Special case for ID0: set the correct ID */ - if (mptcp_addresses_equal(&entry->addr, &mpc_addr, entry->addr.port)) - addrs[i].id = 0; + if (mptcp_addresses_equal(&locals[i].addr, &mpc_addr, locals[i].addr.port)) + locals[i].addr.id = 0; + msk->pm.subflows++; i++; } } @@ -675,21 +701,19 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, * 'IPADDRANY' local address */ if (!i) { - struct mptcp_addr_info local; - - memset(&local, 0, sizeof(local)); - local.family = + memset(&locals[i], 0, sizeof(locals[i])); + locals[i].addr.family = #if IS_ENABLED(CONFIG_MPTCP_IPV6) remote->family == AF_INET6 && ipv6_addr_v4mapped(&remote->addr6) ? AF_INET : #endif remote->family; - if (!mptcp_pm_addr_families_match(sk, &local, remote)) + if (!mptcp_pm_addr_families_match(sk, &locals[i].addr, remote)) return 0; msk->pm.subflows++; - addrs[i++] = local; + i++; } return i; @@ -697,7 +721,7 @@ static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) { - struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX]; + struct mptcp_pm_local locals[MPTCP_PM_ADDR_MAX]; struct sock *sk = (struct sock *)msk; unsigned int add_addr_accept_max; struct mptcp_addr_info remote; @@ -708,7 +732,7 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); subflows_max = mptcp_pm_get_subflows_max(msk); - pr_debug("accepted %d:%d remote family %d", + pr_debug("accepted %d:%d remote family %d\n", msk->pm.add_addr_accepted, add_addr_accept_max, msk->pm.remote.family); @@ -726,24 +750,35 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) /* connect to the specified remote address, using whatever * local address the routing configuration will pick. */ - nr = fill_local_addresses_vec(msk, &remote, addrs); + nr = fill_local_addresses_vec(msk, &remote, locals); if (nr == 0) return; spin_unlock_bh(&msk->pm.lock); for (i = 0; i < nr; i++) - if (__mptcp_subflow_connect(sk, &addrs[i], &remote) == 0) + if (__mptcp_subflow_connect(sk, &locals[i], &remote) == 0) sf_created = true; spin_lock_bh(&msk->pm.lock); if (sf_created) { - msk->pm.add_addr_accepted++; + /* add_addr_accepted is not decr for ID 0 */ + if (remote.id) + msk->pm.add_addr_accepted++; if (msk->pm.add_addr_accepted >= add_addr_accept_max || msk->pm.subflows >= subflows_max) WRITE_ONCE(msk->pm.accept_addr, false); } } +bool mptcp_pm_nl_is_init_remote_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *remote) +{ + struct mptcp_addr_info mpc_remote; + + remote_address((struct sock_common *)msk, &mpc_remote); + return mptcp_addresses_equal(&mpc_remote, remote, remote->port); +} + void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; @@ -755,9 +790,12 @@ void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) !mptcp_pm_should_rm_signal(msk)) return; - subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); - if (subflow) - mptcp_pm_send_ack(msk, subflow, false, false); + mptcp_for_each_subflow(msk, subflow) { + if (__mptcp_subflow_active(subflow)) { + mptcp_pm_send_ack(msk, subflow, false, false); + break; + } + } } int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, @@ -767,7 +805,7 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, { struct mptcp_subflow_context *subflow; - pr_debug("bkup=%d", bkup); + pr_debug("bkup=%d\n", bkup); mptcp_for_each_subflow(msk, subflow) { struct sock *ssk = mptcp_subflow_tcp_sock(subflow); @@ -790,11 +828,6 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, return -EINVAL; } -static bool mptcp_local_id_match(const struct mptcp_sock *msk, u8 local_id, u8 id) -{ - return local_id == id || (!local_id && msk->mpc_endpoint_id == id); -} - static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list, enum linux_mptcp_mib_field rm_type) @@ -803,7 +836,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, struct sock *sk = (struct sock *)msk; u8 i; - pr_debug("%s rm_list_nr %d", + pr_debug("%s rm_list_nr %d\n", rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr); msk_owned_by_me(msk); @@ -827,12 +860,14 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, int how = RCV_SHUTDOWN | SEND_SHUTDOWN; u8 id = subflow_get_local_id(subflow); + if (inet_sk_state_load(ssk) == TCP_CLOSE) + continue; if (rm_type == MPTCP_MIB_RMADDR && remote_id != rm_id) continue; - if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id)) + if (rm_type == MPTCP_MIB_RMSUBFLOW && id != rm_id) continue; - pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u", + pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u\n", rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", i, rm_id, id, remote_id, msk->mpc_endpoint_id); spin_unlock_bh(&msk->pm.lock); @@ -889,7 +924,7 @@ void mptcp_pm_nl_work(struct mptcp_sock *msk) spin_lock_bh(&msk->pm.lock); - pr_debug("msk=%p status=%x", msk, pm->status); + pr_debug("msk=%p status=%x\n", msk, pm->status); if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); mptcp_pm_nl_add_addr_received(msk); @@ -1307,20 +1342,27 @@ static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) return pm_nl_get_pernet(genl_info_net(info)); } -static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) +static int mptcp_nl_add_subflow_or_signal_addr(struct net *net, + struct mptcp_addr_info *addr) { struct mptcp_sock *msk; long s_slot = 0, s_num = 0; while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { struct sock *sk = (struct sock *)msk; + struct mptcp_addr_info mpc_addr; if (!READ_ONCE(msk->fully_established) || mptcp_pm_is_userspace(msk)) goto next; + /* if the endp linked to the init sf is re-added with a != ID */ + mptcp_local_address((struct sock_common *)msk, &mpc_addr); + lock_sock(sk); spin_lock_bh(&msk->pm.lock); + if (mptcp_addresses_equal(addr, &mpc_addr, addr->port)) + msk->mpc_endpoint_id = addr->id; mptcp_pm_create_subflow_or_signal_addr(msk); spin_unlock_bh(&msk->pm.lock); release_sock(sk); @@ -1393,7 +1435,7 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) goto out_free; } - mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); + mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk), &entry->addr); return 0; out_free: @@ -1401,28 +1443,6 @@ out_free: return ret; } -int mptcp_pm_nl_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id, - u8 *flags, int *ifindex) -{ - struct mptcp_pm_addr_entry *entry; - struct sock *sk = (struct sock *)msk; - struct net *net = sock_net(sk); - - /* No entries with ID 0 */ - if (id == 0) - return 0; - - rcu_read_lock(); - entry = __lookup_addr_by_id(pm_nl_get_pernet(net), id); - if (entry) { - *flags = entry->flags; - *ifindex = entry->ifindex; - } - rcu_read_unlock(); - - return 0; -} - static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, const struct mptcp_addr_info *addr) { @@ -1430,7 +1450,6 @@ static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, entry = mptcp_pm_del_add_timer(msk, addr, false); if (entry) { - list_del(&entry->list); kfree(entry); return true; } @@ -1438,6 +1457,12 @@ static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, return false; } +static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + return msk->mpc_endpoint_id == addr->id ? 0 : addr->id; +} + static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, const struct mptcp_addr_info *addr, bool force) @@ -1445,7 +1470,7 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, struct mptcp_rm_list list = { .nr = 0 }; bool ret; - list.ids[list.nr++] = addr->id; + list.ids[list.nr++] = mptcp_endp_get_local_id(msk, addr); ret = remove_anno_list_by_saddr(msk, addr); if (ret || force) { @@ -1472,13 +1497,11 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, const struct mptcp_pm_addr_entry *entry) { const struct mptcp_addr_info *addr = &entry->addr; - struct mptcp_rm_list list = { .nr = 0 }; + struct mptcp_rm_list list = { .nr = 1 }; long s_slot = 0, s_num = 0; struct mptcp_sock *msk; - pr_debug("remove_id=%d", addr->id); - - list.ids[list.nr++] = addr->id; + pr_debug("remove_id=%d\n", addr->id); while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { struct sock *sk = (struct sock *)msk; @@ -1497,6 +1520,7 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, mptcp_pm_remove_anno_addr(msk, addr, remove_subflow && !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)); + list.ids[0] = mptcp_endp_get_local_id(msk, addr); if (remove_subflow) { spin_lock_bh(&msk->pm.lock); mptcp_pm_nl_rm_subflow_received(msk, &list); @@ -1509,6 +1533,8 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, spin_unlock_bh(&msk->pm.lock); } + if (msk->mpc_endpoint_id == entry->addr.id) + msk->mpc_endpoint_id = 0; release_sock(sk); next: @@ -1603,6 +1629,7 @@ int mptcp_pm_nl_del_addr_doit(struct sk_buff *skb, struct genl_info *info) return ret; } +/* Called from the userspace PM only */ void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list) { struct mptcp_rm_list alist = { .nr = 0 }; @@ -1631,8 +1658,9 @@ void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list) } } -static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, - struct list_head *rm_list) +/* Called from the in-kernel PM only */ +static void mptcp_pm_flush_addrs_and_subflows(struct mptcp_sock *msk, + struct list_head *rm_list) { struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 }; struct mptcp_pm_addr_entry *entry; @@ -1640,11 +1668,11 @@ static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, list_for_each_entry(entry, rm_list, list) { if (slist.nr < MPTCP_RM_IDS_MAX && lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) - slist.ids[slist.nr++] = entry->addr.id; + slist.ids[slist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr); if (alist.nr < MPTCP_RM_IDS_MAX && remove_anno_list_by_saddr(msk, &entry->addr)) - alist.ids[alist.nr++] = entry->addr.id; + alist.ids[alist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr); } spin_lock_bh(&msk->pm.lock); @@ -1660,8 +1688,8 @@ static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, spin_unlock_bh(&msk->pm.lock); } -static void mptcp_nl_remove_addrs_list(struct net *net, - struct list_head *rm_list) +static void mptcp_nl_flush_addrs_list(struct net *net, + struct list_head *rm_list) { long s_slot = 0, s_num = 0; struct mptcp_sock *msk; @@ -1674,7 +1702,7 @@ static void mptcp_nl_remove_addrs_list(struct net *net, if (!mptcp_pm_is_userspace(msk)) { lock_sock(sk); - mptcp_pm_remove_addrs_and_subflows(msk, rm_list); + mptcp_pm_flush_addrs_and_subflows(msk, rm_list); release_sock(sk); } @@ -1715,7 +1743,7 @@ int mptcp_pm_nl_flush_addrs_doit(struct sk_buff *skb, struct genl_info *info) pernet->next_id = 1; bitmap_zero(pernet->id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); spin_unlock_bh(&pernet->lock); - mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list); + mptcp_nl_flush_addrs_list(sock_net(skb->sk), &free_list); synchronize_rcu(); __flush_addrs(&free_list); return 0; @@ -1941,7 +1969,7 @@ static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk, { struct mptcp_rm_list list = { .nr = 0 }; - list.ids[list.nr++] = addr->id; + list.ids[list.nr++] = mptcp_endp_get_local_id(msk, addr); spin_lock_bh(&msk->pm.lock); mptcp_pm_nl_rm_subflow_received(msk, &list); |