diff --git a/fs/io_uring.c b/fs/io_uring.c
index e3eb37304e241e6ae9db01c4e4099ab6b3bdca61..d6c2ff6124fd0d42c2812f383413983796173add 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -8105,12 +8105,9 @@ static struct io_wq_work *io_free_work(struct io_wq_work *work)
 	return req ? &req->work : NULL;
 }
 
-static int io_init_wq_offload(struct io_ring_ctx *ctx,
-			      struct io_uring_params *p)
+static int io_init_wq_offload(struct io_ring_ctx *ctx)
 {
 	struct io_wq_data data;
-	struct fd f;
-	struct io_ring_ctx *ctx_attach;
 	unsigned int concurrency;
 	int ret = 0;
 
@@ -8118,37 +8115,15 @@ static int io_init_wq_offload(struct io_ring_ctx *ctx,
 	data.free_work = io_free_work;
 	data.do_work = io_wq_submit_work;
 
-	if (!(p->flags & IORING_SETUP_ATTACH_WQ)) {
-		/* Do QD, or 4 * CPUS, whatever is smallest */
-		concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
-
-		ctx->io_wq = io_wq_create(concurrency, &data);
-		if (IS_ERR(ctx->io_wq)) {
-			ret = PTR_ERR(ctx->io_wq);
-			ctx->io_wq = NULL;
-		}
-		return ret;
-	}
-
-	f = fdget(p->wq_fd);
-	if (!f.file)
-		return -EBADF;
-
-	if (f.file->f_op != &io_uring_fops) {
-		ret = -EINVAL;
-		goto out_fput;
-	}
+	/* Do QD, or 4 * CPUS, whatever is smallest */
+	concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
 
-	ctx_attach = f.file->private_data;
-	/* @io_wq is protected by holding the fd */
-	if (!io_wq_get(ctx_attach->io_wq, &data)) {
-		ret = -EINVAL;
-		goto out_fput;
+	ctx->io_wq = io_wq_create(concurrency, &data);
+	if (IS_ERR(ctx->io_wq)) {
+		ret = PTR_ERR(ctx->io_wq);
+		ctx->io_wq = NULL;
 	}
 
-	ctx->io_wq = ctx_attach->io_wq;
-out_fput:
-	fdput(f);
 	return ret;
 }
 
@@ -8200,6 +8175,20 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 {
 	int ret;
 
+	/* Retain compatibility with failing for an invalid attach attempt */
+	if ((ctx->flags & (IORING_SETUP_ATTACH_WQ | IORING_SETUP_SQPOLL)) ==
+				IORING_SETUP_ATTACH_WQ) {
+		struct fd f;
+
+		f = fdget(p->wq_fd);
+		if (!f.file)
+			return -ENXIO;
+		if (f.file->f_op != &io_uring_fops) {
+			fdput(f);
+			return -EINVAL;
+		}
+		fdput(f);
+	}
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
 		struct io_sq_data *sqd;
 
@@ -8257,7 +8246,7 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 	}
 
 done:
-	ret = io_init_wq_offload(ctx, p);
+	ret = io_init_wq_offload(ctx);
 	if (ret)
 		goto err;