diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index 06fd598ee8a6fa6c3ddd2e4b4def4266653ec54b..935b84f40f65d9198766ff4ff8eee8a4b4f53642 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -804,18 +804,22 @@ struct vcpu_vmx {
 	/*
 	 * loaded_vmcs points to the VMCS currently used in this vcpu. For a
 	 * non-nested (L1) guest, it always points to vmcs01. For a nested
-	 * guest (L2), it points to a different VMCS.
+	 * guest (L2), it points to a different VMCS.  loaded_cpu_state points
+	 * to the VMCS whose state is loaded into the CPU registers that only
+	 * need to be switched when transitioning to/from the kernel; a NULL
+	 * value indicates that host state is loaded.
 	 */
 	struct loaded_vmcs    vmcs01;
 	struct loaded_vmcs   *loaded_vmcs;
+	struct loaded_vmcs   *loaded_cpu_state;
 	bool                  __launched; /* temporary, used in vmx_vcpu_run */
 	struct msr_autoload {
 		unsigned nr;
 		struct vmx_msr_entry guest[NR_AUTOLOAD_MSRS];
 		struct vmx_msr_entry host[NR_AUTOLOAD_MSRS];
 	} msr_autoload;
+
 	struct {
-		int           loaded;
 		u16           fs_sel, gs_sel, ldt_sel;
 #ifdef CONFIG_X86_64
 		u16           ds_sel, es_sel;
@@ -2667,10 +2671,11 @@ static void vmx_save_host_state(struct kvm_vcpu *vcpu)
 	u16 fs_sel, gs_sel;
 	int i;
 
-	if (vmx->host_state.loaded)
+	if (vmx->loaded_cpu_state)
 		return;
 
-	vmx->host_state.loaded = 1;
+	vmx->loaded_cpu_state = vmx->loaded_vmcs;
+
 	/*
 	 * Set host fs and gs selectors.  Unfortunately, 22.2.3 does not
 	 * allow segment selectors with cpl > 0 or ti == 1.
@@ -2732,11 +2737,14 @@ static void vmx_save_host_state(struct kvm_vcpu *vcpu)
 
 static void __vmx_load_host_state(struct vcpu_vmx *vmx)
 {
-	if (!vmx->host_state.loaded)
+	if (!vmx->loaded_cpu_state)
 		return;
 
+	WARN_ON_ONCE(vmx->loaded_cpu_state != vmx->loaded_vmcs);
+
 	++vmx->vcpu.stat.host_state_reload;
-	vmx->host_state.loaded = 0;
+	vmx->loaded_cpu_state = NULL;
+
 #ifdef CONFIG_X86_64
 	if (is_long_mode(&vmx->vcpu))
 		rdmsrl(MSR_KERNEL_GS_BASE, vmx->msr_guest_kernel_gs_base);
@@ -10596,8 +10604,8 @@ static void vmx_switch_vmcs(struct kvm_vcpu *vcpu, struct loaded_vmcs *vmcs)
 		return;
 
 	cpu = get_cpu();
-	vmx->loaded_vmcs = vmcs;
 	vmx_vcpu_put(vcpu);
+	vmx->loaded_vmcs = vmcs;
 	vmx_vcpu_load(vcpu, cpu);
 	put_cpu();
 }