diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 908ea5464a518097958083156eb5b0b94f4d021c..fb8b376bf28cb3e04a6bb903900f32838ab02a14 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -720,7 +720,7 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
 	}
 }
 
-static void set_spte_track_bits(u64 *sptep, u64 new_spte)
+static int set_spte_track_bits(u64 *sptep, u64 new_spte)
 {
 	pfn_t pfn;
 	u64 old_spte = *sptep;
@@ -731,19 +731,20 @@ static void set_spte_track_bits(u64 *sptep, u64 new_spte)
 		old_spte = __xchg_spte(sptep, new_spte);
 
 	if (!is_rmap_spte(old_spte))
-		return;
+		return 0;
 
 	pfn = spte_to_pfn(old_spte);
 	if (!shadow_accessed_mask || old_spte & shadow_accessed_mask)
 		kvm_set_pfn_accessed(pfn);
 	if (!shadow_dirty_mask || (old_spte & shadow_dirty_mask))
 		kvm_set_pfn_dirty(pfn);
+	return 1;
 }
 
 static void drop_spte(struct kvm *kvm, u64 *sptep, u64 new_spte)
 {
-	set_spte_track_bits(sptep, new_spte);
-	rmap_remove(kvm, sptep);
+	if (set_spte_track_bits(sptep, new_spte))
+		rmap_remove(kvm, sptep);
 }
 
 static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)