diff --git a/fs/io_uring.c b/fs/io_uring.c index 7242cc48e97b67d35431c5d54d675ac32d09c5c7..9c77fbc0c39549ce010b54b0cdefc4ce9bc72f14 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -9083,29 +9083,39 @@ void __io_uring_files_cancel(struct files_struct *files) static s64 tctx_inflight(struct io_uring_task *tctx) { - unsigned long index; - struct file *file; - s64 inflight; - - inflight = percpu_counter_sum(&tctx->inflight); - if (!tctx->sqpoll) - return inflight; + return percpu_counter_sum(&tctx->inflight); +} - /* - * If we have SQPOLL rings, then we need to iterate and find them, and - * add the pending count for those. - */ - xa_for_each(&tctx->xa, index, file) { - struct io_ring_ctx *ctx = file->private_data; +static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx) +{ + struct io_uring_task *tctx; + s64 inflight; + DEFINE_WAIT(wait); - if (ctx->flags & IORING_SETUP_SQPOLL) { - struct io_uring_task *__tctx = ctx->sqo_task->io_uring; + if (!ctx->sq_data) + return; + tctx = ctx->sq_data->thread->io_uring; + io_disable_sqo_submit(ctx); - inflight += percpu_counter_sum(&__tctx->inflight); - } - } + atomic_inc(&tctx->in_idle); + do { + /* read completions before cancelations */ + inflight = tctx_inflight(tctx); + if (!inflight) + break; + io_uring_cancel_task_requests(ctx, NULL); - return inflight; + prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE); + /* + * If we've seen completions, retry without waiting. This + * avoids a race where a completion comes in before we did + * prepare_to_wait(). + */ + if (inflight == tctx_inflight(tctx)) + schedule(); + finish_wait(&tctx->wait, &wait); + } while (1); + atomic_dec(&tctx->in_idle); } /* @@ -9122,8 +9132,13 @@ void __io_uring_task_cancel(void) atomic_inc(&tctx->in_idle); /* trigger io_disable_sqo_submit() */ - if (tctx->sqpoll) - __io_uring_files_cancel(NULL); + if (tctx->sqpoll) { + struct file *file; + unsigned long index; + + xa_for_each(&tctx->xa, index, file) + io_uring_cancel_sqpoll(file->private_data); + } do { /* read completions before cancelations */