diff --git a/mm/memory.c b/mm/memory.c
index 75e0abda86427c918e6e7c7c5f3a33104d49d0eb..230bcd12a3362dc06d7ca1d2dbf060e10ca10c56 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -5127,18 +5127,18 @@ static vm_fault_t do_numa_page(struct vm_fault *vmf)
 	int flags = 0;
 
 	/*
-	 * The "pte" at this point cannot be used safely without
-	 * validation through pte_unmap_same(). It's of NUMA type but
-	 * the pfn may be screwed if the read is non atomic.
+	 * The pte cannot be used safely until we verify, while holding the page
+	 * table lock, that its contents have not changed during fault handling.
 	 */
 	spin_lock(vmf->ptl);
-	if (unlikely(!pte_same(ptep_get(vmf->pte), vmf->orig_pte))) {
+	/* Read the live PTE from the page tables: */
+	old_pte = ptep_get(vmf->pte);
+
+	if (unlikely(!pte_same(old_pte, vmf->orig_pte))) {
 		pte_unmap_unlock(vmf->pte, vmf->ptl);
 		goto out;
 	}
 
-	/* Get the normal PTE  */
-	old_pte = ptep_get(vmf->pte);
 	pte = pte_modify(old_pte, vma->vm_page_prot);
 
 	/*