diff --git a/fs/dax.c b/fs/dax.c
index 9bcce89ea18ef458b25e786ce21b6f1c6122a178..48132eca3761de2b4cdf7c6c75ab8efda8cf7a26 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -232,6 +232,34 @@ static void *get_unlocked_entry(struct xa_state *xas)
 	}
 }
 
+/*
+ * The only thing keeping the address space around is the i_pages lock
+ * (it's cycled in clear_inode() after removing the entries from i_pages)
+ * After we call xas_unlock_irq(), we cannot touch xas->xa.
+ */
+static void wait_entry_unlocked(struct xa_state *xas, void *entry)
+{
+	struct wait_exceptional_entry_queue ewait;
+	wait_queue_head_t *wq;
+
+	init_wait(&ewait.wait);
+	ewait.wait.func = wake_exceptional_entry_func;
+
+	wq = dax_entry_waitqueue(xas, entry, &ewait.key);
+	prepare_to_wait_exclusive(wq, &ewait.wait, TASK_UNINTERRUPTIBLE);
+	xas_unlock_irq(xas);
+	schedule();
+	finish_wait(wq, &ewait.wait);
+
+	/*
+	 * Entry lock waits are exclusive. Wake up the next waiter since
+	 * we aren't sure we will acquire the entry lock and thus wake
+	 * the next waiter up on unlock.
+	 */
+	if (waitqueue_active(wq))
+		__wake_up(wq, TASK_NORMAL, 1, &ewait.key);
+}
+
 static void put_unlocked_entry(struct xa_state *xas, void *entry)
 {
 	/* If we were the only waiter woken, wake the next one */
@@ -351,21 +379,21 @@ static struct page *dax_busy_page(void *entry)
  * @page: The page whose entry we want to lock
  *
  * Context: Process context.
- * Return: %true if the entry was locked or does not need to be locked.
+ * Return: A cookie to pass to dax_unlock_page() or 0 if the entry could
+ * not be locked.
  */
-bool dax_lock_mapping_entry(struct page *page)
+dax_entry_t dax_lock_page(struct page *page)
 {
 	XA_STATE(xas, NULL, 0);
 	void *entry;
-	bool locked;
 
 	/* Ensure page->mapping isn't freed while we look at it */
 	rcu_read_lock();
 	for (;;) {
 		struct address_space *mapping = READ_ONCE(page->mapping);
 
-		locked = false;
-		if (!dax_mapping(mapping))
+		entry = NULL;
+		if (!mapping || !dax_mapping(mapping))
 			break;
 
 		/*
@@ -375,7 +403,7 @@ bool dax_lock_mapping_entry(struct page *page)
 		 * otherwise we would not have a valid pfn_to_page()
 		 * translation.
 		 */
-		locked = true;
+		entry = (void *)~0UL;
 		if (S_ISCHR(mapping->host->i_mode))
 			break;
 
@@ -389,9 +417,7 @@ bool dax_lock_mapping_entry(struct page *page)
 		entry = xas_load(&xas);
 		if (dax_is_locked(entry)) {
 			rcu_read_unlock();
-			entry = get_unlocked_entry(&xas);
-			xas_unlock_irq(&xas);
-			put_unlocked_entry(&xas, entry);
+			wait_entry_unlocked(&xas, entry);
 			rcu_read_lock();
 			continue;
 		}
@@ -400,23 +426,18 @@ bool dax_lock_mapping_entry(struct page *page)
 		break;
 	}
 	rcu_read_unlock();
-	return locked;
+	return (dax_entry_t)entry;
 }
 
-void dax_unlock_mapping_entry(struct page *page)
+void dax_unlock_page(struct page *page, dax_entry_t cookie)
 {
 	struct address_space *mapping = page->mapping;
 	XA_STATE(xas, &mapping->i_pages, page->index);
-	void *entry;
 
 	if (S_ISCHR(mapping->host->i_mode))
 		return;
 
-	rcu_read_lock();
-	entry = xas_load(&xas);
-	rcu_read_unlock();
-	entry = dax_make_entry(page_to_pfn_t(page), dax_is_pmd_entry(entry));
-	dax_unlock_entry(&xas, entry);
+	dax_unlock_entry(&xas, (void *)cookie);
 }
 
 /*
diff --git a/include/linux/dax.h b/include/linux/dax.h
index 450b28db95331ffbe19963c804391d7249cfb2c3..0dd316a74a295132ea6b6c04f914356c5c4064d6 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -7,6 +7,8 @@
 #include <linux/radix-tree.h>
 #include <asm/pgtable.h>
 
+typedef unsigned long dax_entry_t;
+
 struct iomap_ops;
 struct dax_device;
 struct dax_operations {
@@ -88,8 +90,8 @@ int dax_writeback_mapping_range(struct address_space *mapping,
 		struct block_device *bdev, struct writeback_control *wbc);
 
 struct page *dax_layout_busy_page(struct address_space *mapping);
-bool dax_lock_mapping_entry(struct page *page);
-void dax_unlock_mapping_entry(struct page *page);
+dax_entry_t dax_lock_page(struct page *page);
+void dax_unlock_page(struct page *page, dax_entry_t cookie);
 #else
 static inline bool bdev_dax_supported(struct block_device *bdev,
 		int blocksize)
@@ -122,14 +124,14 @@ static inline int dax_writeback_mapping_range(struct address_space *mapping,
 	return -EOPNOTSUPP;
 }
 
-static inline bool dax_lock_mapping_entry(struct page *page)
+static inline dax_entry_t dax_lock_page(struct page *page)
 {
 	if (IS_DAX(page->mapping->host))
-		return true;
-	return false;
+		return ~0UL;
+	return 0;
 }
 
-static inline void dax_unlock_mapping_entry(struct page *page)
+static inline void dax_unlock_page(struct page *page, dax_entry_t cookie)
 {
 }
 #endif
diff --git a/mm/memory-failure.c b/mm/memory-failure.c
index 0cd3de3550f0830f507d286b0499789d7961171e..7c72f2a95785e0d3d5df615ea33477b0bdcc5278 100644
--- a/mm/memory-failure.c
+++ b/mm/memory-failure.c
@@ -1161,6 +1161,7 @@ static int memory_failure_dev_pagemap(unsigned long pfn, int flags,
 	LIST_HEAD(tokill);
 	int rc = -EBUSY;
 	loff_t start;
+	dax_entry_t cookie;
 
 	/*
 	 * Prevent the inode from being freed while we are interrogating
@@ -1169,7 +1170,8 @@ static int memory_failure_dev_pagemap(unsigned long pfn, int flags,
 	 * also prevents changes to the mapping of this pfn until
 	 * poison signaling is complete.
 	 */
-	if (!dax_lock_mapping_entry(page))
+	cookie = dax_lock_page(page);
+	if (!cookie)
 		goto out;
 
 	if (hwpoison_filter(page)) {
@@ -1220,7 +1222,7 @@ static int memory_failure_dev_pagemap(unsigned long pfn, int flags,
 	kill_procs(&tokill, flags & MF_MUST_KILL, !unmap_success, pfn, flags);
 	rc = 0;
 unlock:
-	dax_unlock_mapping_entry(page);
+	dax_unlock_page(page, cookie);
 out:
 	/* drop pgmap ref acquired in caller */
 	put_dev_pagemap(pgmap);