net/smc: Fix slab-out-of-bounds issue in fallback

syzbot reported a slab-out-of-bounds/use-after-free issue,
which was caused by accessing an already freed smc sock in
fallback-specific callback functions of clcsock.

This patch fixes the issue by restoring fallback-specific
callback functions to original ones and resetting clcsock
sk_user_data to NULL before freeing smc sock.

Meanwhile, this patch introduces sk_callback_lock to make
the access and assignment to sk_user_data mutually exclusive.

Reported-by: syzbot+b425899ed22c6943e00b@syzkaller.appspotmail.com
Fixes: 341adeec9a ("net/smc: Forward wakeup to smc socket waitqueue after fallback")
Link: https://lore.kernel.org/r/00000000000013ca8105d7ae3ada@google.com/
Signed-off-by: Wen Gu <guwen@linux.alibaba.com>
Acked-by: Karsten Graul <kgraul@linux.ibm.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Wen Gu 2022-04-22 15:56:19 +08:00 committed by Jakub Kicinski
parent 97b9af7a70
commit 0558226ceb
2 changed files with 59 additions and 23 deletions

View file

@ -243,11 +243,27 @@ struct proto smc_proto6 = {
}; };
EXPORT_SYMBOL_GPL(smc_proto6); EXPORT_SYMBOL_GPL(smc_proto6);
static void smc_fback_restore_callbacks(struct smc_sock *smc)
{
struct sock *clcsk = smc->clcsock->sk;
write_lock_bh(&clcsk->sk_callback_lock);
clcsk->sk_user_data = NULL;
smc_clcsock_restore_cb(&clcsk->sk_state_change, &smc->clcsk_state_change);
smc_clcsock_restore_cb(&clcsk->sk_data_ready, &smc->clcsk_data_ready);
smc_clcsock_restore_cb(&clcsk->sk_write_space, &smc->clcsk_write_space);
smc_clcsock_restore_cb(&clcsk->sk_error_report, &smc->clcsk_error_report);
write_unlock_bh(&clcsk->sk_callback_lock);
}
static void smc_restore_fallback_changes(struct smc_sock *smc) static void smc_restore_fallback_changes(struct smc_sock *smc)
{ {
if (smc->clcsock->file) { /* non-accepted sockets have no file yet */ if (smc->clcsock->file) { /* non-accepted sockets have no file yet */
smc->clcsock->file->private_data = smc->sk.sk_socket; smc->clcsock->file->private_data = smc->sk.sk_socket;
smc->clcsock->file = NULL; smc->clcsock->file = NULL;
smc_fback_restore_callbacks(smc);
} }
} }
@ -745,48 +761,57 @@ out:
static void smc_fback_state_change(struct sock *clcsk) static void smc_fback_state_change(struct sock *clcsk)
{ {
struct smc_sock *smc = struct smc_sock *smc;
smc_clcsock_user_data(clcsk);
if (!smc) read_lock_bh(&clcsk->sk_callback_lock);
return; smc = smc_clcsock_user_data(clcsk);
smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_state_change); if (smc)
smc_fback_forward_wakeup(smc, clcsk,
smc->clcsk_state_change);
read_unlock_bh(&clcsk->sk_callback_lock);
} }
static void smc_fback_data_ready(struct sock *clcsk) static void smc_fback_data_ready(struct sock *clcsk)
{ {
struct smc_sock *smc = struct smc_sock *smc;
smc_clcsock_user_data(clcsk);
if (!smc) read_lock_bh(&clcsk->sk_callback_lock);
return; smc = smc_clcsock_user_data(clcsk);
smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_data_ready); if (smc)
smc_fback_forward_wakeup(smc, clcsk,
smc->clcsk_data_ready);
read_unlock_bh(&clcsk->sk_callback_lock);
} }
static void smc_fback_write_space(struct sock *clcsk) static void smc_fback_write_space(struct sock *clcsk)
{ {
struct smc_sock *smc = struct smc_sock *smc;
smc_clcsock_user_data(clcsk);
if (!smc) read_lock_bh(&clcsk->sk_callback_lock);
return; smc = smc_clcsock_user_data(clcsk);
smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_write_space); if (smc)
smc_fback_forward_wakeup(smc, clcsk,
smc->clcsk_write_space);
read_unlock_bh(&clcsk->sk_callback_lock);
} }
static void smc_fback_error_report(struct sock *clcsk) static void smc_fback_error_report(struct sock *clcsk)
{ {
struct smc_sock *smc = struct smc_sock *smc;
smc_clcsock_user_data(clcsk);
if (!smc) read_lock_bh(&clcsk->sk_callback_lock);
return; smc = smc_clcsock_user_data(clcsk);
smc_fback_forward_wakeup(smc, clcsk, smc->clcsk_error_report); if (smc)
smc_fback_forward_wakeup(smc, clcsk,
smc->clcsk_error_report);
read_unlock_bh(&clcsk->sk_callback_lock);
} }
static void smc_fback_replace_callbacks(struct smc_sock *smc) static void smc_fback_replace_callbacks(struct smc_sock *smc)
{ {
struct sock *clcsk = smc->clcsock->sk; struct sock *clcsk = smc->clcsock->sk;
write_lock_bh(&clcsk->sk_callback_lock);
clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change, smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change,
@ -797,6 +822,8 @@ static void smc_fback_replace_callbacks(struct smc_sock *smc)
&smc->clcsk_write_space); &smc->clcsk_write_space);
smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report, smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report,
&smc->clcsk_error_report); &smc->clcsk_error_report);
write_unlock_bh(&clcsk->sk_callback_lock);
} }
static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code)
@ -2370,17 +2397,20 @@ out:
static void smc_clcsock_data_ready(struct sock *listen_clcsock) static void smc_clcsock_data_ready(struct sock *listen_clcsock)
{ {
struct smc_sock *lsmc = struct smc_sock *lsmc;
smc_clcsock_user_data(listen_clcsock);
read_lock_bh(&listen_clcsock->sk_callback_lock);
lsmc = smc_clcsock_user_data(listen_clcsock);
if (!lsmc) if (!lsmc)
return; goto out;
lsmc->clcsk_data_ready(listen_clcsock); lsmc->clcsk_data_ready(listen_clcsock);
if (lsmc->sk.sk_state == SMC_LISTEN) { if (lsmc->sk.sk_state == SMC_LISTEN) {
sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */ sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */
if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work)) if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work))
sock_put(&lsmc->sk); sock_put(&lsmc->sk);
} }
out:
read_unlock_bh(&listen_clcsock->sk_callback_lock);
} }
static int smc_listen(struct socket *sock, int backlog) static int smc_listen(struct socket *sock, int backlog)
@ -2412,10 +2442,12 @@ static int smc_listen(struct socket *sock, int backlog)
/* save original sk_data_ready function and establish /* save original sk_data_ready function and establish
* smc-specific sk_data_ready function * smc-specific sk_data_ready function
*/ */
write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
smc->clcsock->sk->sk_user_data = smc->clcsock->sk->sk_user_data =
(void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY);
smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready, smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready,
smc_clcsock_data_ready, &smc->clcsk_data_ready); smc_clcsock_data_ready, &smc->clcsk_data_ready);
write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
/* save original ops */ /* save original ops */
smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops; smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops;
@ -2430,9 +2462,11 @@ static int smc_listen(struct socket *sock, int backlog)
rc = kernel_listen(smc->clcsock, backlog); rc = kernel_listen(smc->clcsock, backlog);
if (rc) { if (rc) {
write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready,
&smc->clcsk_data_ready); &smc->clcsk_data_ready);
smc->clcsock->sk->sk_user_data = NULL; smc->clcsock->sk->sk_user_data = NULL;
write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
goto out; goto out;
} }
sk->sk_max_ack_backlog = backlog; sk->sk_max_ack_backlog = backlog;

View file

@ -214,9 +214,11 @@ again:
sk->sk_state = SMC_CLOSED; sk->sk_state = SMC_CLOSED;
sk->sk_state_change(sk); /* wake up accept */ sk->sk_state_change(sk); /* wake up accept */
if (smc->clcsock && smc->clcsock->sk) { if (smc->clcsock && smc->clcsock->sk) {
write_lock_bh(&smc->clcsock->sk->sk_callback_lock);
smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready,
&smc->clcsk_data_ready); &smc->clcsk_data_ready);
smc->clcsock->sk->sk_user_data = NULL; smc->clcsock->sk->sk_user_data = NULL;
write_unlock_bh(&smc->clcsock->sk->sk_callback_lock);
rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR);
} }
smc_close_cleanup_listen(sk); smc_close_cleanup_listen(sk);