diff --git a/Documentation/virtual/kvm/mmu.txt b/Documentation/virtual/kvm/mmu.txt
index 290894176142f452ba794eab9f0d739901516865..53838d9c6295792501f2825175ac9c314c1a2fc7 100644
--- a/Documentation/virtual/kvm/mmu.txt
+++ b/Documentation/virtual/kvm/mmu.txt
@@ -425,6 +425,20 @@ fault through the slow path.
 Since only 19 bits are used to store generation-number on mmio spte, all
 pages are zapped when there is an overflow.
 
+Unfortunately, a single memory access might access kvm_memslots(kvm) multiple
+times, the last one happening when the generation number is retrieved and
+stored into the MMIO spte.  Thus, the MMIO spte might be created based on
+out-of-date information, but with an up-to-date generation number.
+
+To avoid this, the generation number is incremented again after synchronize_srcu
+returns; thus, the low bit of kvm_memslots(kvm)->generation is only 1 during a
+memslot update, while some SRCU readers might be using the old copy.  We do not
+want to use an MMIO sptes created with an odd generation number, and we can do
+this without losing a bit in the MMIO spte.  The low bit of the generation
+is not stored in MMIO spte, and presumed zero when it is extracted out of the
+spte.  If KVM is unlucky and creates an MMIO spte while the low bit is 1,
+the next access to the spte will always be a cache miss.
+
 
 Further reading
 ===============
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 323c3f5f5c84968eb35561978392286405323afe..96515957ba824c0bcecaa3756679258102ae7366 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -199,16 +199,20 @@ void kvm_mmu_set_mmio_spte_mask(u64 mmio_mask)
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mmio_spte_mask);
 
 /*
- * spte bits of bit 3 ~ bit 11 are used as low 9 bits of generation number,
- * the bits of bits 52 ~ bit 61 are used as high 10 bits of generation
- * number.
+ * the low bit of the generation number is always presumed to be zero.
+ * This disables mmio caching during memslot updates.  The concept is
+ * similar to a seqcount but instead of retrying the access we just punt
+ * and ignore the cache.
+ *
+ * spte bits 3-11 are used as bits 1-9 of the generation number,
+ * the bits 52-61 are used as bits 10-19 of the generation number.
  */
-#define MMIO_SPTE_GEN_LOW_SHIFT		3
+#define MMIO_SPTE_GEN_LOW_SHIFT		2
 #define MMIO_SPTE_GEN_HIGH_SHIFT	52
 
-#define MMIO_GEN_SHIFT			19
-#define MMIO_GEN_LOW_SHIFT		9
-#define MMIO_GEN_LOW_MASK		((1 << MMIO_GEN_LOW_SHIFT) - 1)
+#define MMIO_GEN_SHIFT			20
+#define MMIO_GEN_LOW_SHIFT		10
+#define MMIO_GEN_LOW_MASK		((1 << MMIO_GEN_LOW_SHIFT) - 2)
 #define MMIO_GEN_MASK			((1 << MMIO_GEN_SHIFT) - 1)
 #define MMIO_MAX_GEN			((1 << MMIO_GEN_SHIFT) - 1)
 
@@ -4428,7 +4432,7 @@ void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm)
 	 * The very rare case: if the generation-number is round,
 	 * zap all shadow pages.
 	 */
-	if (unlikely(kvm_current_mmio_generation(kvm) >= MMIO_MAX_GEN)) {
+	if (unlikely(kvm_current_mmio_generation(kvm) == 0)) {
 		printk_ratelimited(KERN_INFO "kvm: zapping shadow pages for mmio generation wraparound\n");
 		kvm_mmu_invalidate_zap_all_pages(kvm);
 	}
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 0bfdb673db264498de7fa5af73c9af1e45057b23..bb8641b5d83ba2b80dc344a165da4db265f08efe 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -95,8 +95,6 @@ static int hardware_enable_all(void);
 static void hardware_disable_all(void);
 
 static void kvm_io_bus_destroy(struct kvm_io_bus *bus);
-static void update_memslots(struct kvm_memslots *slots,
-			    struct kvm_memory_slot *new, u64 last_generation);
 
 static void kvm_release_pfn_dirty(pfn_t pfn);
 static void mark_page_dirty_in_slot(struct kvm *kvm,
@@ -695,8 +693,7 @@ static void sort_memslots(struct kvm_memslots *slots)
 }
 
 static void update_memslots(struct kvm_memslots *slots,
-			    struct kvm_memory_slot *new,
-			    u64 last_generation)
+			    struct kvm_memory_slot *new)
 {
 	if (new) {
 		int id = new->id;
@@ -707,8 +704,6 @@ static void update_memslots(struct kvm_memslots *slots,
 		if (new->npages != npages)
 			sort_memslots(slots);
 	}
-
-	slots->generation = last_generation + 1;
 }
 
 static int check_memory_region_flags(struct kvm_userspace_memory_region *mem)
@@ -730,10 +725,24 @@ static struct kvm_memslots *install_new_memslots(struct kvm *kvm,
 {
 	struct kvm_memslots *old_memslots = kvm->memslots;
 
-	update_memslots(slots, new, kvm->memslots->generation);
+	/*
+	 * Set the low bit in the generation, which disables SPTE caching
+	 * until the end of synchronize_srcu_expedited.
+	 */
+	WARN_ON(old_memslots->generation & 1);
+	slots->generation = old_memslots->generation + 1;
+
+	update_memslots(slots, new);
 	rcu_assign_pointer(kvm->memslots, slots);
 	synchronize_srcu_expedited(&kvm->srcu);
 
+	/*
+	 * Increment the new memslot generation a second time. This prevents
+	 * vm exits that race with memslot updates from caching a memslot
+	 * generation that will (potentially) be valid forever.
+	 */
+	slots->generation++;
+
 	kvm_arch_memslots_updated(kvm);
 
 	return old_memslots;