diff --git a/fs/io-wq.c b/fs/io-wq.c
index 5361a9b4b47b5f25904f0dbaa8112b79ec6611a8..b3e8624a37d0904c73151fc00a4b3bb64b60dc70 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -979,13 +979,16 @@ static bool io_task_work_match(struct callback_head *cb, void *data)
 	return cwd->wqe->wq == data;
 }
 
+void io_wq_exit_start(struct io_wq *wq)
+{
+	set_bit(IO_WQ_BIT_EXIT, &wq->state);
+}
+
 static void io_wq_exit_workers(struct io_wq *wq)
 {
 	struct callback_head *cb;
 	int node;
 
-	set_bit(IO_WQ_BIT_EXIT, &wq->state);
-
 	if (!wq->task)
 		return;
 
@@ -1003,13 +1006,16 @@ static void io_wq_exit_workers(struct io_wq *wq)
 		struct io_wqe *wqe = wq->wqes[node];
 
 		io_wq_for_each_worker(wqe, io_wq_worker_wake, NULL);
-		spin_lock_irq(&wq->hash->wait.lock);
-		list_del_init(&wq->wqes[node]->wait.entry);
-		spin_unlock_irq(&wq->hash->wait.lock);
 	}
 	rcu_read_unlock();
 	io_worker_ref_put(wq);
 	wait_for_completion(&wq->worker_done);
+
+	for_each_node(node) {
+		spin_lock_irq(&wq->hash->wait.lock);
+		list_del_init(&wq->wqes[node]->wait.entry);
+		spin_unlock_irq(&wq->hash->wait.lock);
+	}
 	put_task_struct(wq->task);
 	wq->task = NULL;
 }
@@ -1020,8 +1026,6 @@ static void io_wq_destroy(struct io_wq *wq)
 
 	cpuhp_state_remove_instance_nocalls(io_wq_online, &wq->cpuhp_node);
 
-	io_wq_exit_workers(wq);
-
 	for_each_node(node) {
 		struct io_wqe *wqe = wq->wqes[node];
 		struct io_cb_cancel_data match = {
@@ -1036,16 +1040,13 @@ static void io_wq_destroy(struct io_wq *wq)
 	kfree(wq);
 }
 
-void io_wq_put(struct io_wq *wq)
-{
-	if (refcount_dec_and_test(&wq->refs))
-		io_wq_destroy(wq);
-}
-
 void io_wq_put_and_exit(struct io_wq *wq)
 {
+	WARN_ON_ONCE(!test_bit(IO_WQ_BIT_EXIT, &wq->state));
+
 	io_wq_exit_workers(wq);
-	io_wq_put(wq);
+	if (refcount_dec_and_test(&wq->refs))
+		io_wq_destroy(wq);
 }
 
 static bool io_wq_worker_affinity(struct io_worker *worker, void *data)
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 0e6d310999e899a9d416459138be22a23f02bf56..af2df0680ee22b2d6363df48ee0016d439f96056 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -122,7 +122,7 @@ struct io_wq_data {
 };
 
 struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data);
-void io_wq_put(struct io_wq *wq);
+void io_wq_exit_start(struct io_wq *wq);
 void io_wq_put_and_exit(struct io_wq *wq);
 
 void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work);
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 5f82954004f613ee95a6cf08acc62af8dfcde2a6..903458afd56c17de177e29e58974aac45b7dbcba 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -9039,11 +9039,16 @@ static void io_uring_clean_tctx(struct io_uring_task *tctx)
 	struct io_tctx_node *node;
 	unsigned long index;
 
-	tctx->io_wq = NULL;
 	xa_for_each(&tctx->xa, index, node)
 		io_uring_del_task_file(index);
-	if (wq)
+	if (wq) {
+		/*
+		 * Must be after io_uring_del_task_file() (removes nodes under
+		 * uring_lock) to avoid race with io_uring_try_cancel_iowq().
+		 */
+		tctx->io_wq = NULL;
 		io_wq_put_and_exit(wq);
+	}
 }
 
 static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
@@ -9078,6 +9083,9 @@ static void io_uring_cancel_sqpoll(struct io_sq_data *sqd)
 
 	if (!current->io_uring)
 		return;
+	if (tctx->io_wq)
+		io_wq_exit_start(tctx->io_wq);
+
 	WARN_ON_ONCE(!sqd || sqd->thread != current);
 
 	atomic_inc(&tctx->in_idle);
@@ -9112,6 +9120,9 @@ void __io_uring_cancel(struct files_struct *files)
 	DEFINE_WAIT(wait);
 	s64 inflight;
 
+	if (tctx->io_wq)
+		io_wq_exit_start(tctx->io_wq);
+
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
 	do {