diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 6d9f763a7bb9d5db422ea5625b2c28420bd14f26..427d1daf06d06a028e80417ee8c78524cc927992 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -26,6 +26,7 @@
 #include <linux/irqbypass.h>
 #include <linux/hyperv.h>
 #include <linux/kfifo.h>
+#include <linux/sched/vhost_task.h>
 
 #include <asm/apic.h>
 #include <asm/pvclock-abi.h>
@@ -1443,7 +1444,8 @@ struct kvm_arch {
 	bool sgx_provisioning_allowed;
 
 	struct kvm_x86_pmu_event_filter __rcu *pmu_event_filter;
-	struct task_struct *nx_huge_page_recovery_thread;
+	struct vhost_task *nx_huge_page_recovery_thread;
+	u64 nx_huge_page_last;
 
 #ifdef CONFIG_X86_64
 	/* The number of TDP MMU pages across all roots. */
diff --git a/arch/x86/kvm/Kconfig b/arch/x86/kvm/Kconfig
index f09f13c01c6bbd28fa37fdf50547abf4403658c9..b387d61af44f182731ab6239e64e4869bedbaaa4 100644
--- a/arch/x86/kvm/Kconfig
+++ b/arch/x86/kvm/Kconfig
@@ -29,6 +29,7 @@ config KVM_X86
 	select HAVE_KVM_IRQ_BYPASS
 	select HAVE_KVM_IRQ_ROUTING
 	select HAVE_KVM_READONLY_MEM
+	select VHOST_TASK
 	select KVM_ASYNC_PF
 	select USER_RETURN_NOTIFIER
 	select KVM_MMIO
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 8e853a5fc867b742321a78076d5bd81764abbc55..3e353ed1f767361220c85888ba1dd00afbb30d30 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -7281,7 +7281,7 @@ static int set_nx_huge_pages(const char *val, const struct kernel_param *kp)
 			kvm_mmu_zap_all_fast(kvm);
 			mutex_unlock(&kvm->slots_lock);
 
-			wake_up_process(kvm->arch.nx_huge_page_recovery_thread);
+			vhost_task_wake(kvm->arch.nx_huge_page_recovery_thread);
 		}
 		mutex_unlock(&kvm_lock);
 	}
@@ -7427,7 +7427,7 @@ static int set_nx_huge_pages_recovery_param(const char *val, const struct kernel
 		mutex_lock(&kvm_lock);
 
 		list_for_each_entry(kvm, &vm_list, vm_list)
-			wake_up_process(kvm->arch.nx_huge_page_recovery_thread);
+			vhost_task_wake(kvm->arch.nx_huge_page_recovery_thread);
 
 		mutex_unlock(&kvm_lock);
 	}
@@ -7530,62 +7530,56 @@ static void kvm_recover_nx_huge_pages(struct kvm *kvm)
 	srcu_read_unlock(&kvm->srcu, rcu_idx);
 }
 
-static long get_nx_huge_page_recovery_timeout(u64 start_time)
+static void kvm_nx_huge_page_recovery_worker_kill(void *data)
 {
-	bool enabled;
-	uint period;
-
-	enabled = calc_nx_huge_pages_recovery_period(&period);
-
-	return enabled ? start_time + msecs_to_jiffies(period) - get_jiffies_64()
-		       : MAX_SCHEDULE_TIMEOUT;
 }
 
-static int kvm_nx_huge_page_recovery_worker(struct kvm *kvm, uintptr_t data)
+static bool kvm_nx_huge_page_recovery_worker(void *data)
 {
-	u64 start_time;
+	struct kvm *kvm = data;
+	bool enabled;
+	uint period;
 	long remaining_time;
 
-	while (true) {
-		start_time = get_jiffies_64();
-		remaining_time = get_nx_huge_page_recovery_timeout(start_time);
-
-		set_current_state(TASK_INTERRUPTIBLE);
-		while (!kthread_should_stop() && remaining_time > 0) {
-			schedule_timeout(remaining_time);
-			remaining_time = get_nx_huge_page_recovery_timeout(start_time);
-			set_current_state(TASK_INTERRUPTIBLE);
-		}
-
-		set_current_state(TASK_RUNNING);
-
-		if (kthread_should_stop())
-			return 0;
+	enabled = calc_nx_huge_pages_recovery_period(&period);
+	if (!enabled)
+		return false;
 
-		kvm_recover_nx_huge_pages(kvm);
+	remaining_time = kvm->arch.nx_huge_page_last + msecs_to_jiffies(period)
+		- get_jiffies_64();
+	if (remaining_time > 0) {
+		schedule_timeout(remaining_time);
+		/* check for signals and come back */
+		return true;
 	}
+
+	__set_current_state(TASK_RUNNING);
+	kvm_recover_nx_huge_pages(kvm);
+	kvm->arch.nx_huge_page_last = get_jiffies_64();
+	return true;
 }
 
 int kvm_mmu_post_init_vm(struct kvm *kvm)
 {
-	int err;
-
 	if (nx_hugepage_mitigation_hard_disabled)
 		return 0;
 
-	err = kvm_vm_create_worker_thread(kvm, kvm_nx_huge_page_recovery_worker, 0,
-					  "kvm-nx-lpage-recovery",
-					  &kvm->arch.nx_huge_page_recovery_thread);
-	if (!err)
-		kthread_unpark(kvm->arch.nx_huge_page_recovery_thread);
+	kvm->arch.nx_huge_page_last = get_jiffies_64();
+	kvm->arch.nx_huge_page_recovery_thread = vhost_task_create(
+		kvm_nx_huge_page_recovery_worker, kvm_nx_huge_page_recovery_worker_kill,
+		kvm, "kvm-nx-lpage-recovery");
 
-	return err;
+	if (!kvm->arch.nx_huge_page_recovery_thread)
+		return -ENOMEM;
+
+	vhost_task_start(kvm->arch.nx_huge_page_recovery_thread);
+	return 0;
 }
 
 void kvm_mmu_pre_destroy_vm(struct kvm *kvm)
 {
 	if (kvm->arch.nx_huge_page_recovery_thread)
-		kthread_stop(kvm->arch.nx_huge_page_recovery_thread);
+		vhost_task_stop(kvm->arch.nx_huge_page_recovery_thread);
 }
 
 #ifdef CONFIG_KVM_GENERIC_MEMORY_ATTRIBUTES
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 45be36e5285ffb5de3c161ff450275997aac7232..85fe9d0ebb9152c02d4e253326d04b694b4ab1c1 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -2382,12 +2382,6 @@ static inline int kvm_arch_vcpu_run_pid_change(struct kvm_vcpu *vcpu)
 }
 #endif /* CONFIG_HAVE_KVM_VCPU_RUN_PID_CHANGE */
 
-typedef int (*kvm_vm_thread_fn_t)(struct kvm *kvm, uintptr_t data);
-
-int kvm_vm_create_worker_thread(struct kvm *kvm, kvm_vm_thread_fn_t thread_fn,
-				uintptr_t data, const char *name,
-				struct task_struct **thread_ptr);
-
 #ifdef CONFIG_KVM_XFER_TO_GUEST_WORK
 static inline void kvm_handle_signal_exit(struct kvm_vcpu *vcpu)
 {
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 6ca7a1045bbb75dfbd9195b348c3d03a83578a86..279e03029ce149b90745707ff3063aec50d41b4b 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -6561,106 +6561,3 @@ void kvm_exit(void)
 	kvm_irqfd_exit();
 }
 EXPORT_SYMBOL_GPL(kvm_exit);
-
-struct kvm_vm_worker_thread_context {
-	struct kvm *kvm;
-	struct task_struct *parent;
-	struct completion init_done;
-	kvm_vm_thread_fn_t thread_fn;
-	uintptr_t data;
-	int err;
-};
-
-static int kvm_vm_worker_thread(void *context)
-{
-	/*
-	 * The init_context is allocated on the stack of the parent thread, so
-	 * we have to locally copy anything that is needed beyond initialization
-	 */
-	struct kvm_vm_worker_thread_context *init_context = context;
-	struct task_struct *parent;
-	struct kvm *kvm = init_context->kvm;
-	kvm_vm_thread_fn_t thread_fn = init_context->thread_fn;
-	uintptr_t data = init_context->data;
-	int err;
-
-	err = kthread_park(current);
-	/* kthread_park(current) is never supposed to return an error */
-	WARN_ON(err != 0);
-	if (err)
-		goto init_complete;
-
-	err = cgroup_attach_task_all(init_context->parent, current);
-	if (err) {
-		kvm_err("%s: cgroup_attach_task_all failed with err %d\n",
-			__func__, err);
-		goto init_complete;
-	}
-
-	set_user_nice(current, task_nice(init_context->parent));
-
-init_complete:
-	init_context->err = err;
-	complete(&init_context->init_done);
-	init_context = NULL;
-
-	if (err)
-		goto out;
-
-	/* Wait to be woken up by the spawner before proceeding. */
-	kthread_parkme();
-
-	if (!kthread_should_stop())
-		err = thread_fn(kvm, data);
-
-out:
-	/*
-	 * Move kthread back to its original cgroup to prevent it lingering in
-	 * the cgroup of the VM process, after the latter finishes its
-	 * execution.
-	 *
-	 * kthread_stop() waits on the 'exited' completion condition which is
-	 * set in exit_mm(), via mm_release(), in do_exit(). However, the
-	 * kthread is removed from the cgroup in the cgroup_exit() which is
-	 * called after the exit_mm(). This causes the kthread_stop() to return
-	 * before the kthread actually quits the cgroup.
-	 */
-	rcu_read_lock();
-	parent = rcu_dereference(current->real_parent);
-	get_task_struct(parent);
-	rcu_read_unlock();
-	cgroup_attach_task_all(parent, current);
-	put_task_struct(parent);
-
-	return err;
-}
-
-int kvm_vm_create_worker_thread(struct kvm *kvm, kvm_vm_thread_fn_t thread_fn,
-				uintptr_t data, const char *name,
-				struct task_struct **thread_ptr)
-{
-	struct kvm_vm_worker_thread_context init_context = {};
-	struct task_struct *thread;
-
-	*thread_ptr = NULL;
-	init_context.kvm = kvm;
-	init_context.parent = current;
-	init_context.thread_fn = thread_fn;
-	init_context.data = data;
-	init_completion(&init_context.init_done);
-
-	thread = kthread_run(kvm_vm_worker_thread, &init_context,
-			     "%s-%d", name, task_pid_nr(current));
-	if (IS_ERR(thread))
-		return PTR_ERR(thread);
-
-	/* kthread_run is never supposed to return NULL */
-	WARN_ON(thread == NULL);
-
-	wait_for_completion(&init_context.init_done);
-
-	if (!init_context.err)
-		*thread_ptr = thread;
-
-	return init_context.err;
-}