diff --git a/include/linux/mempolicy.h b/include/linux/mempolicy.h
index 4429d255c8ab6c7524436d2ee36fdcea57304447..5e5b2969d93167a8337823dcc74b15ce53870307 100644
--- a/include/linux/mempolicy.h
+++ b/include/linux/mempolicy.h
@@ -195,6 +195,7 @@ static inline bool vma_migratable(struct vm_area_struct *vma)
 }
 
 extern int mpol_misplaced(struct page *, struct vm_area_struct *, unsigned long);
+extern void mpol_put_task_policy(struct task_struct *);
 
 #else
 
@@ -297,5 +298,8 @@ static inline int mpol_misplaced(struct page *page, struct vm_area_struct *vma,
 	return -1; /* no node preference */
 }
 
+static inline void mpol_put_task_policy(struct task_struct *task)
+{
+}
 #endif /* CONFIG_NUMA */
 #endif
diff --git a/kernel/exit.c b/kernel/exit.c
index 2f974ae042a677a90c7d34a78b66793cbe8584c5..091a78be3b09d5669d9c10b98f6300e4171d2413 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -848,12 +848,7 @@ void do_exit(long code)
 	TASKS_RCU(preempt_enable());
 	exit_notify(tsk, group_dead);
 	proc_exit_connector(tsk);
-#ifdef CONFIG_NUMA
-	task_lock(tsk);
-	mpol_put(tsk->mempolicy);
-	tsk->mempolicy = NULL;
-	task_unlock(tsk);
-#endif
+	mpol_put_task_policy(tsk);
 #ifdef CONFIG_FUTEX
 	if (unlikely(current->pi_state_cache))
 		kfree(current->pi_state_cache);
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index d8c4e38fb5f4be1d9748dc77f214c8a285374f8d..2da72a5b6ecc04f87bd9168fc31d613360e40f2e 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -2336,6 +2336,23 @@ int mpol_misplaced(struct page *page, struct vm_area_struct *vma, unsigned long
 	return ret;
 }
 
+/*
+ * Drop the (possibly final) reference to task->mempolicy.  It needs to be
+ * dropped after task->mempolicy is set to NULL so that any allocation done as
+ * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
+ * policy.
+ */
+void mpol_put_task_policy(struct task_struct *task)
+{
+	struct mempolicy *pol;
+
+	task_lock(task);
+	pol = task->mempolicy;
+	task->mempolicy = NULL;
+	task_unlock(task);
+	mpol_put(pol);
+}
+
 static void sp_delete(struct shared_policy *sp, struct sp_node *n)
 {
 	pr_debug("deleting %lx-l%lx\n", n->start, n->end);