diff --git a/fs/io_uring.c b/fs/io_uring.c index 7426e4f23f9b..65a6978e1795 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -1714,7 +1714,6 @@ static void __io_req_task_submit(struct io_kiocb *req) { struct io_ring_ctx *ctx = req->ctx; - __set_current_state(TASK_RUNNING); if (!__io_sq_thread_acquire_mm(ctx)) { mutex_lock(&ctx->uring_lock); __io_queue_sqe(req, NULL, NULL); @@ -1899,6 +1898,17 @@ static int io_put_kbuf(struct io_kiocb *req) return cflags; } +static inline bool io_run_task_work(void) +{ + if (current->task_works) { + __set_current_state(TASK_RUNNING); + task_work_run(); + return true; + } + + return false; +} + static void io_iopoll_queue(struct list_head *again) { struct io_kiocb *req; @@ -2079,8 +2089,7 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, unsigned *nr_events, */ if (!(++iters & 7)) { mutex_unlock(&ctx->uring_lock); - if (current->task_works) - task_work_run(); + io_run_task_work(); mutex_lock(&ctx->uring_lock); } @@ -2176,8 +2185,6 @@ static void io_rw_resubmit(struct callback_head *cb) struct io_ring_ctx *ctx = req->ctx; int err; - __set_current_state(TASK_RUNNING); - err = io_sq_thread_acquire_mm(ctx, req); if (io_resubmit_prep(req, err)) { @@ -6361,8 +6368,7 @@ static int io_sq_thread(void *data) if (!list_empty(&ctx->poll_list) || need_resched() || (!time_after(jiffies, timeout) && ret != -EBUSY && !percpu_ref_is_dying(&ctx->refs))) { - if (current->task_works) - task_work_run(); + io_run_task_work(); cond_resched(); continue; } @@ -6394,8 +6400,7 @@ static int io_sq_thread(void *data) finish_wait(&ctx->sqo_wait, &wait); break; } - if (current->task_works) { - task_work_run(); + if (io_run_task_work()) { finish_wait(&ctx->sqo_wait, &wait); continue; } @@ -6420,8 +6425,7 @@ static int io_sq_thread(void *data) timeout = jiffies + ctx->sq_thread_idle; } - if (current->task_works) - task_work_run(); + io_run_task_work(); io_sq_thread_drop_mm(ctx); revert_creds(old_cred); @@ -6486,9 +6490,8 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, do { if (io_cqring_events(ctx, false) >= min_events) return 0; - if (!current->task_works) + if (!io_run_task_work()) break; - task_work_run(); } while (1); if (sig) { @@ -6510,8 +6513,8 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, prepare_to_wait_exclusive(&ctx->wait, &iowq.wq, TASK_INTERRUPTIBLE); /* make sure we run task_work before checking for signals */ - if (current->task_works) - task_work_run(); + if (io_run_task_work()) + continue; if (signal_pending(current)) { if (current->jobctl & JOBCTL_TASK_WORK) { spin_lock_irq(¤t->sighand->siglock); @@ -7953,8 +7956,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit, int submitted = 0; struct fd f; - if (current->task_works) - task_work_run(); + io_run_task_work(); if (flags & ~(IORING_ENTER_GETEVENTS | IORING_ENTER_SQ_WAKEUP)) return -EINVAL;