diff --git a/fs/io_uring.c b/fs/io_uring.c
index 766df21769a839794c798e4c403c8760f677e3c5..62a73543ab8642c075cd9cabd9b9f7e9e905433a 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -1800,6 +1800,18 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
 	return __io_req_find_next(req);
 }
 
+static void ctx_flush_and_put(struct io_ring_ctx *ctx)
+{
+	if (!ctx)
+		return;
+	if (ctx->submit_state.comp.nr) {
+		mutex_lock(&ctx->uring_lock);
+		io_submit_flush_completions(&ctx->submit_state.comp, ctx);
+		mutex_unlock(&ctx->uring_lock);
+	}
+	percpu_ref_put(&ctx->refs);
+}
+
 static bool __tctx_task_work(struct io_uring_task *tctx)
 {
 	struct io_ring_ctx *ctx = NULL;
@@ -1817,30 +1829,20 @@ static bool __tctx_task_work(struct io_uring_task *tctx)
 	node = list.first;
 	while (node) {
 		struct io_wq_work_node *next = node->next;
-		struct io_ring_ctx *this_ctx;
 		struct io_kiocb *req;
 
 		req = container_of(node, struct io_kiocb, io_task_work.node);
-		this_ctx = req->ctx;
-		req->task_work.func(&req->task_work);
-		node = next;
-
-		if (!ctx) {
-			ctx = this_ctx;
-		} else if (ctx != this_ctx) {
-			mutex_lock(&ctx->uring_lock);
-			io_submit_flush_completions(&ctx->submit_state.comp, ctx);
-			mutex_unlock(&ctx->uring_lock);
-			ctx = this_ctx;
+		if (req->ctx != ctx) {
+			ctx_flush_and_put(ctx);
+			ctx = req->ctx;
+			percpu_ref_get(&ctx->refs);
 		}
-	}
 
-	if (ctx && ctx->submit_state.comp.nr) {
-		mutex_lock(&ctx->uring_lock);
-		io_submit_flush_completions(&ctx->submit_state.comp, ctx);
-		mutex_unlock(&ctx->uring_lock);
+		req->task_work.func(&req->task_work);
+		node = next;
 	}
 
+	ctx_flush_and_put(ctx);
 	return list.first != NULL;
 }