diff --git a/fs/io_uring.c b/fs/io_uring.c
index 4a088581b0f2dfca2845eb23e72125349d1c4eea..d55c9ab6314a15157c186f551e69aac462b9411d 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -6793,9 +6793,9 @@ static int io_sq_thread(void *data)
 		ctx->sqo_exec = 1;
 		io_ring_set_wakeup_flag(ctx);
 	}
-	mutex_unlock(&sqd->lock);
 
 	complete(&sqd->exited);
+	mutex_unlock(&sqd->lock);
 	do_exit(0);
 }
 
@@ -7118,13 +7118,19 @@ static bool io_sq_thread_park(struct io_sq_data *sqd)
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-	if (!sqd->thread)
+	if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
 		return;
-
-	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-	WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state));
-	wake_up_process(sqd->thread);
-	wait_for_completion(&sqd->exited);
+	mutex_lock(&sqd->lock);
+	if (sqd->thread) {
+		set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+		WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state));
+		wake_up_process(sqd->thread);
+		mutex_unlock(&sqd->lock);
+		wait_for_completion(&sqd->exited);
+		WARN_ON_ONCE(sqd->thread);
+	} else {
+		mutex_unlock(&sqd->lock);
+	}
 }
 
 static void io_put_sq_data(struct io_sq_data *sqd)
@@ -8867,6 +8873,11 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 	if (!io_sq_thread_park(sqd))
 		return;
 	tctx = ctx->sq_data->thread->io_uring;
+	/* can happen on fork/alloc failure, just ignore that state */
+	if (!tctx) {
+		io_sq_thread_unpark(sqd);
+		return;
+	}
 
 	atomic_inc(&tctx->in_idle);
 	do {