diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 3d822a97dfa1660ccf7f5196b322c006a8761066..9b73cfcef917a745093b245b3031f5a030bd2944 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -4054,7 +4054,7 @@ static bool fast_cr3_switch(struct kvm_vcpu *vcpu, gpa_t new_cr3,
 			return false;
 
 		swap(mmu->root_hpa, mmu->prev_root.hpa);
-		mmu->prev_root.cr3 = kvm_read_cr3(vcpu);
+		mmu->prev_root.cr3 = mmu->get_cr3(vcpu);
 
 		if (new_cr3 == prev_cr3 &&
 		    VALID_PAGE(mmu->root_hpa) &&
@@ -4091,6 +4091,7 @@ void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3)
 {
 	__kvm_mmu_new_cr3(vcpu, new_cr3, kvm_mmu_calc_root_page_role(vcpu));
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_new_cr3);
 
 static unsigned long get_cr3(struct kvm_vcpu *vcpu)
 {
@@ -4725,12 +4726,13 @@ kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty)
 }
 
 void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
-			     bool accessed_dirty)
+			     bool accessed_dirty, gpa_t new_eptp)
 {
 	struct kvm_mmu *context = &vcpu->arch.mmu;
 	union kvm_mmu_page_role root_page_role =
 		kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty);
 
+	__kvm_mmu_new_cr3(vcpu, new_eptp, root_page_role);
 	context->shadow_root_level = PT64_ROOT_4LEVEL;
 
 	context->nx = true;
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index 11ab3d62ad65a0f7f5156aae2afa27ccc84092e1..3177127cee96f89635176e550aa10f26646f0b2c 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -64,7 +64,7 @@ reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context);
 void kvm_init_mmu(struct kvm_vcpu *vcpu, bool reset_roots);
 void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu);
 void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
-			     bool accessed_dirty);
+			     bool accessed_dirty, gpa_t new_eptp);
 bool kvm_can_do_async_pf(struct kvm_vcpu *vcpu);
 int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 				u64 fault_address, char *insn, int insn_len);
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index cbe31b9d2b81b47053da3fe2d858d5a8fb91289b..c6ea892a610e8aa235a6bf615cc4fbc124b74229 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -10741,11 +10741,11 @@ static int nested_ept_init_mmu_context(struct kvm_vcpu *vcpu)
 	if (!valid_ept_address(vcpu, nested_ept_get_cr3(vcpu)))
 		return 1;
 
-	kvm_mmu_unload(vcpu);
 	kvm_init_shadow_ept_mmu(vcpu,
 			to_vmx(vcpu)->nested.msrs.ept_caps &
 			VMX_EPT_EXECUTE_ONLY_BIT,
-			nested_ept_ad_enabled(vcpu));
+			nested_ept_ad_enabled(vcpu),
+			nested_ept_get_cr3(vcpu));
 	vcpu->arch.mmu.set_cr3           = vmx_set_cr3;
 	vcpu->arch.mmu.get_cr3           = nested_ept_get_cr3;
 	vcpu->arch.mmu.inject_page_fault = nested_ept_inject_page_fault;
@@ -11342,12 +11342,16 @@ static int nested_vmx_load_cr3(struct kvm_vcpu *vcpu, unsigned long cr3, bool ne
 				return 1;
 			}
 		}
-
-		vcpu->arch.cr3 = cr3;
-		__set_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail);
 	}
 
-	kvm_mmu_reset_context(vcpu);
+	if (!nested_ept)
+		kvm_mmu_new_cr3(vcpu, cr3);
+
+	vcpu->arch.cr3 = cr3;
+	__set_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail);
+
+	kvm_init_mmu(vcpu, false);
+
 	return 0;
 }
 
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index f519eb8d06b142240b511a247a9c7a0b47a3393b..8f461e0ed38219a5e953c4dc9412f877dbafd5f6 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -2127,16 +2127,22 @@ static void shrink_halt_poll_ns(struct kvm_vcpu *vcpu)
 
 static int kvm_vcpu_check_block(struct kvm_vcpu *vcpu)
 {
+	int ret = -EINTR;
+	int idx = srcu_read_lock(&vcpu->kvm->srcu);
+
 	if (kvm_arch_vcpu_runnable(vcpu)) {
 		kvm_make_request(KVM_REQ_UNHALT, vcpu);
-		return -EINTR;
+		goto out;
 	}
 	if (kvm_cpu_has_pending_timer(vcpu))
-		return -EINTR;
+		goto out;
 	if (signal_pending(current))
-		return -EINTR;
+		goto out;
 
-	return 0;
+	ret = 0;
+out:
+	srcu_read_unlock(&vcpu->kvm->srcu, idx);
+	return ret;
 }
 
 /*