diff --git a/fs/dax.c b/fs/dax.c
index 32f020c9cedf05106c8324771cdd600174508fe7..93ae87297ffa09ed6e25b04da9f985ea52ba5d73 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -1387,6 +1387,16 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 	if ((pgoff | PG_PMD_COLOUR) > max_pgoff)
 		goto fallback;
 
+	/*
+	 * grab_mapping_entry() will make sure we get a 2M empty entry, a DAX
+	 * PMD or a HZP entry.  If it can't (because a 4k page is already in
+	 * the tree, for instance), it will return -EEXIST and we just fall
+	 * back to 4k entries.
+	 */
+	entry = grab_mapping_entry(mapping, pgoff, RADIX_DAX_PMD);
+	if (IS_ERR(entry))
+		goto fallback;
+
 	/*
 	 * Note that we don't use iomap_apply here.  We aren't doing I/O, only
 	 * setting up a mapping, so really we're using iomap_begin() as a way
@@ -1395,21 +1405,11 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 	pos = (loff_t)pgoff << PAGE_SHIFT;
 	error = ops->iomap_begin(inode, pos, PMD_SIZE, iomap_flags, &iomap);
 	if (error)
-		goto fallback;
+		goto unlock_entry;
 
 	if (iomap.offset + iomap.length < pos + PMD_SIZE)
 		goto finish_iomap;
 
-	/*
-	 * grab_mapping_entry() will make sure we get a 2M empty entry, a DAX
-	 * PMD or a HZP entry.  If it can't (because a 4k page is already in
-	 * the tree, for instance), it will return -EEXIST and we just fall
-	 * back to 4k entries.
-	 */
-	entry = grab_mapping_entry(mapping, pgoff, RADIX_DAX_PMD);
-	if (IS_ERR(entry))
-		goto finish_iomap;
-
 	switch (iomap.type) {
 	case IOMAP_MAPPED:
 		result = dax_pmd_insert_mapping(vmf, &iomap, pos, &entry);
@@ -1417,7 +1417,7 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 	case IOMAP_UNWRITTEN:
 	case IOMAP_HOLE:
 		if (WARN_ON_ONCE(write))
-			goto unlock_entry;
+			break;
 		result = dax_pmd_load_hole(vmf, &iomap, &entry);
 		break;
 	default:
@@ -1425,8 +1425,6 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 		break;
 	}
 
- unlock_entry:
-	put_locked_mapping_entry(mapping, pgoff, entry);
  finish_iomap:
 	if (ops->iomap_end) {
 		int copied = PMD_SIZE;
@@ -1442,6 +1440,8 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 		ops->iomap_end(inode, pos, PMD_SIZE, copied, iomap_flags,
 				&iomap);
 	}
+ unlock_entry:
+	put_locked_mapping_entry(mapping, pgoff, entry);
  fallback:
 	if (result == VM_FAULT_FALLBACK) {
 		split_huge_pmd(vma, vmf->pmd, vmf->address);