diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 0729973d486a6062985a04de86066a109c70a8f1..b8d654963df87b7b851d5f88abf8b67675964e65 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -2004,7 +2004,7 @@ static void __split_huge_zero_page_pmd(struct vm_area_struct *vma,
 {
 	struct mm_struct *mm = vma->vm_mm;
 	pgtable_t pgtable;
-	pmd_t _pmd;
+	pmd_t _pmd, old_pmd;
 	int i;
 
 	/*
@@ -2015,7 +2015,7 @@ static void __split_huge_zero_page_pmd(struct vm_area_struct *vma,
 	 *
 	 * See Documentation/mm/mmu_notifier.rst
 	 */
-	pmdp_huge_clear_flush(vma, haddr, pmd);
+	old_pmd = pmdp_huge_clear_flush(vma, haddr, pmd);
 
 	pgtable = pgtable_trans_huge_withdraw(mm, pmd);
 	pmd_populate(mm, &_pmd, pgtable);
@@ -2024,6 +2024,8 @@ static void __split_huge_zero_page_pmd(struct vm_area_struct *vma,
 		pte_t *pte, entry;
 		entry = pfn_pte(my_zero_pfn(haddr), vma->vm_page_prot);
 		entry = pte_mkspecial(entry);
+		if (pmd_uffd_wp(old_pmd))
+			entry = pte_mkuffd_wp(entry);
 		pte = pte_offset_map(&_pmd, haddr);
 		VM_BUG_ON(!pte_none(*pte));
 		set_pte_at(mm, haddr, pte, entry);