aboutsummaryrefslogtreecommitdiff
path: root/net/l2tp/l2tp_core.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/l2tp/l2tp_core.c')
-rw-r--r--net/l2tp/l2tp_core.c240
1 files changed, 175 insertions, 65 deletions
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index 69f8c9f5cdc7..d6bffdb16466 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -107,11 +107,17 @@ struct l2tp_net {
/* Lock for write access to l2tp_tunnel_idr */
spinlock_t l2tp_tunnel_idr_lock;
struct idr l2tp_tunnel_idr;
- struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2];
- /* Lock for write access to l2tp_session_hlist */
- spinlock_t l2tp_session_hlist_lock;
+ /* Lock for write access to l2tp_v3_session_idr/htable */
+ spinlock_t l2tp_session_idr_lock;
+ struct idr l2tp_v3_session_idr;
+ struct hlist_head l2tp_v3_session_htable[16];
};
+static inline unsigned long l2tp_v3_session_hashkey(struct sock *sk, u32 session_id)
+{
+ return ((unsigned long)sk) + session_id;
+}
+
#if IS_ENABLED(CONFIG_IPV6)
static bool l2tp_sk_is_v6(struct sock *sk)
{
@@ -125,17 +131,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net)
return net_generic(net, l2tp_net_id);
}
-/* Session hash global list for L2TPv3.
- * The session_id SHOULD be random according to RFC3931, but several
- * L2TP implementations use incrementing session_ids. So we do a real
- * hash on the session_id, rather than a simple bitmask.
- */
-static inline struct hlist_head *
-l2tp_session_id_hash_2(struct l2tp_net *pn, u32 session_id)
-{
- return &pn->l2tp_session_hlist[hash_32(session_id, L2TP_HASH_BITS_2)];
-}
-
/* Session hash list.
* The session_id SHOULD be random according to RFC2661, but several
* L2TP implementations (Cisco and Microsoft) use incrementing
@@ -262,26 +257,40 @@ struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel,
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session);
-struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id)
+struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id)
{
- struct hlist_head *session_list;
+ const struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_session *session;
- session_list = l2tp_session_id_hash_2(l2tp_pernet(net), session_id);
-
rcu_read_lock_bh();
- hlist_for_each_entry_rcu(session, session_list, global_hlist)
- if (session->session_id == session_id) {
- l2tp_session_inc_refcount(session);
- rcu_read_unlock_bh();
+ session = idr_find(&pn->l2tp_v3_session_idr, session_id);
+ if (session && !hash_hashed(&session->hlist) &&
+ refcount_inc_not_zero(&session->ref_count)) {
+ rcu_read_unlock_bh();
+ return session;
+ }
- return session;
+ /* If we get here and session is non-NULL, the session_id
+ * collides with one in another tunnel. If sk is non-NULL,
+ * find the session matching sk.
+ */
+ if (session && sk) {
+ unsigned long key = l2tp_v3_session_hashkey(sk, session->session_id);
+
+ hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session,
+ hlist, key) {
+ if (session->tunnel->sock == sk &&
+ refcount_inc_not_zero(&session->ref_count)) {
+ rcu_read_unlock_bh();
+ return session;
+ }
}
+ }
rcu_read_unlock_bh();
return NULL;
}
-EXPORT_SYMBOL_GPL(l2tp_session_get);
+EXPORT_SYMBOL_GPL(l2tp_v3_session_get);
struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth)
{
@@ -313,12 +322,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
const char *ifname)
{
struct l2tp_net *pn = l2tp_pernet(net);
- int hash;
+ unsigned long session_id, tmp;
struct l2tp_session *session;
rcu_read_lock_bh();
- for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
- hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
+ idr_for_each_entry_ul(&pn->l2tp_v3_session_idr, session, tmp, session_id) {
+ if (session) {
if (!strcmp(session->ifname, ifname)) {
l2tp_session_inc_refcount(session);
rcu_read_unlock_bh();
@@ -334,13 +343,106 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
}
EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
+static void l2tp_session_coll_list_add(struct l2tp_session_coll_list *clist,
+ struct l2tp_session *session)
+{
+ l2tp_session_inc_refcount(session);
+ WARN_ON_ONCE(session->coll_list);
+ session->coll_list = clist;
+ spin_lock(&clist->lock);
+ list_add(&session->clist, &clist->list);
+ spin_unlock(&clist->lock);
+}
+
+static int l2tp_session_collision_add(struct l2tp_net *pn,
+ struct l2tp_session *session1,
+ struct l2tp_session *session2)
+{
+ struct l2tp_session_coll_list *clist;
+
+ lockdep_assert_held(&pn->l2tp_session_idr_lock);
+
+ if (!session2)
+ return -EEXIST;
+
+ /* If existing session is in IP-encap tunnel, refuse new session */
+ if (session2->tunnel->encap == L2TP_ENCAPTYPE_IP)
+ return -EEXIST;
+
+ clist = session2->coll_list;
+ if (!clist) {
+ /* First collision. Allocate list to manage the collided sessions
+ * and add the existing session to the list.
+ */
+ clist = kmalloc(sizeof(*clist), GFP_ATOMIC);
+ if (!clist)
+ return -ENOMEM;
+
+ spin_lock_init(&clist->lock);
+ INIT_LIST_HEAD(&clist->list);
+ refcount_set(&clist->ref_count, 1);
+ l2tp_session_coll_list_add(clist, session2);
+ }
+
+ /* If existing session isn't already in the session hlist, add it. */
+ if (!hash_hashed(&session2->hlist))
+ hash_add(pn->l2tp_v3_session_htable, &session2->hlist,
+ session2->hlist_key);
+
+ /* Add new session to the hlist and collision list */
+ hash_add(pn->l2tp_v3_session_htable, &session1->hlist,
+ session1->hlist_key);
+ refcount_inc(&clist->ref_count);
+ l2tp_session_coll_list_add(clist, session1);
+
+ return 0;
+}
+
+static void l2tp_session_collision_del(struct l2tp_net *pn,
+ struct l2tp_session *session)
+{
+ struct l2tp_session_coll_list *clist = session->coll_list;
+ unsigned long session_key = session->session_id;
+ struct l2tp_session *session2;
+
+ lockdep_assert_held(&pn->l2tp_session_idr_lock);
+
+ hash_del(&session->hlist);
+
+ if (clist) {
+ /* Remove session from its collision list. If there
+ * are other sessions with the same ID, replace this
+ * session's IDR entry with that session, otherwise
+ * remove the IDR entry. If this is the last session,
+ * the collision list data is freed.
+ */
+ spin_lock(&clist->lock);
+ list_del_init(&session->clist);
+ session2 = list_first_entry_or_null(&clist->list, struct l2tp_session, clist);
+ if (session2) {
+ void *old = idr_replace(&pn->l2tp_v3_session_idr, session2, session_key);
+
+ WARN_ON_ONCE(IS_ERR_VALUE(old));
+ } else {
+ void *removed = idr_remove(&pn->l2tp_v3_session_idr, session_key);
+
+ WARN_ON_ONCE(removed != session);
+ }
+ session->coll_list = NULL;
+ spin_unlock(&clist->lock);
+ if (refcount_dec_and_test(&clist->ref_count))
+ kfree(clist);
+ l2tp_session_dec_refcount(session);
+ }
+}
+
int l2tp_session_register(struct l2tp_session *session,
struct l2tp_tunnel *tunnel)
{
+ struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
struct l2tp_session *session_walk;
- struct hlist_head *g_head;
struct hlist_head *head;
- struct l2tp_net *pn;
+ u32 session_key;
int err;
head = l2tp_session_id_hash(tunnel, session->session_id);
@@ -358,39 +460,45 @@ int l2tp_session_register(struct l2tp_session *session,
}
if (tunnel->version == L2TP_HDR_VER_3) {
- pn = l2tp_pernet(tunnel->l2tp_net);
- g_head = l2tp_session_id_hash_2(pn, session->session_id);
-
- spin_lock_bh(&pn->l2tp_session_hlist_lock);
-
+ session_key = session->session_id;
+ spin_lock_bh(&pn->l2tp_session_idr_lock);
+ err = idr_alloc_u32(&pn->l2tp_v3_session_idr, NULL,
+ &session_key, session_key, GFP_ATOMIC);
/* IP encap expects session IDs to be globally unique, while
- * UDP encap doesn't.
+ * UDP encap doesn't. This isn't per the RFC, which says that
+ * sessions are identified only by the session ID, but is to
+ * support existing userspace which depends on it.
*/
- hlist_for_each_entry(session_walk, g_head, global_hlist)
- if (session_walk->session_id == session->session_id &&
- (session_walk->tunnel->encap == L2TP_ENCAPTYPE_IP ||
- tunnel->encap == L2TP_ENCAPTYPE_IP)) {
- err = -EEXIST;
- goto err_tlock_pnlock;
- }
+ if (err == -ENOSPC && tunnel->encap == L2TP_ENCAPTYPE_UDP) {
+ struct l2tp_session *session2;
- l2tp_tunnel_inc_refcount(tunnel);
- hlist_add_head_rcu(&session->global_hlist, g_head);
-
- spin_unlock_bh(&pn->l2tp_session_hlist_lock);
- } else {
- l2tp_tunnel_inc_refcount(tunnel);
+ session2 = idr_find(&pn->l2tp_v3_session_idr,
+ session_key);
+ err = l2tp_session_collision_add(pn, session, session2);
+ }
+ spin_unlock_bh(&pn->l2tp_session_idr_lock);
+ if (err == -ENOSPC)
+ err = -EEXIST;
}
+ if (err)
+ goto err_tlock;
+
+ l2tp_tunnel_inc_refcount(tunnel);
+
hlist_add_head_rcu(&session->hlist, head);
spin_unlock_bh(&tunnel->hlist_lock);
+ if (tunnel->version == L2TP_HDR_VER_3) {
+ spin_lock_bh(&pn->l2tp_session_idr_lock);
+ idr_replace(&pn->l2tp_v3_session_idr, session, session_key);
+ spin_unlock_bh(&pn->l2tp_session_idr_lock);
+ }
+
trace_register_session(session);
return 0;
-err_tlock_pnlock:
- spin_unlock_bh(&pn->l2tp_session_hlist_lock);
err_tlock:
spin_unlock_bh(&tunnel->hlist_lock);
@@ -1218,13 +1326,19 @@ static void l2tp_session_unhash(struct l2tp_session *session)
hlist_del_init_rcu(&session->hlist);
spin_unlock_bh(&tunnel->hlist_lock);
- /* For L2TPv3 we have a per-net hash: remove from there, too */
- if (tunnel->version != L2TP_HDR_VER_2) {
+ /* For L2TPv3 we have a per-net IDR: remove from there, too */
+ if (tunnel->version == L2TP_HDR_VER_3) {
struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
-
- spin_lock_bh(&pn->l2tp_session_hlist_lock);
- hlist_del_init_rcu(&session->global_hlist);
- spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+ struct l2tp_session *removed = session;
+
+ spin_lock_bh(&pn->l2tp_session_idr_lock);
+ if (hash_hashed(&session->hlist))
+ l2tp_session_collision_del(pn, session);
+ else
+ removed = idr_remove(&pn->l2tp_v3_session_idr,
+ session->session_id);
+ WARN_ON_ONCE(removed && removed != session);
+ spin_unlock_bh(&pn->l2tp_session_idr_lock);
}
synchronize_rcu();
@@ -1649,8 +1763,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
skb_queue_head_init(&session->reorder_q);
+ session->hlist_key = l2tp_v3_session_hashkey(tunnel->sock, session->session_id);
INIT_HLIST_NODE(&session->hlist);
- INIT_HLIST_NODE(&session->global_hlist);
+ INIT_LIST_HEAD(&session->clist);
if (cfg) {
session->pwtype = cfg->pw_type;
@@ -1683,15 +1798,12 @@ EXPORT_SYMBOL_GPL(l2tp_session_create);
static __net_init int l2tp_init_net(struct net *net)
{
struct l2tp_net *pn = net_generic(net, l2tp_net_id);
- int hash;
idr_init(&pn->l2tp_tunnel_idr);
spin_lock_init(&pn->l2tp_tunnel_idr_lock);
- for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
- INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]);
-
- spin_lock_init(&pn->l2tp_session_hlist_lock);
+ idr_init(&pn->l2tp_v3_session_idr);
+ spin_lock_init(&pn->l2tp_session_idr_lock);
return 0;
}
@@ -1701,7 +1813,6 @@ static __net_exit void l2tp_exit_net(struct net *net)
struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_tunnel *tunnel = NULL;
unsigned long tunnel_id, tmp;
- int hash;
rcu_read_lock_bh();
idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) {
@@ -1714,8 +1825,7 @@ static __net_exit void l2tp_exit_net(struct net *net)
flush_workqueue(l2tp_wq);
rcu_barrier();
- for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
- WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
+ idr_destroy(&pn->l2tp_v3_session_idr);
idr_destroy(&pn->l2tp_tunnel_idr);
}