diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index e9fc1c1ae66c6a328700720d1151118188b9aa79..c6c61ea6bb8ca827ea308dcf71ca51d398337a0f 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -591,16 +591,29 @@ static long do_set_mempolicy(unsigned short mode, unsigned short flags,
 			     nodemask_t *nodes)
 {
 	struct mempolicy *new;
+	struct mm_struct *mm = current->mm;
 
 	new = mpol_new(mode, flags, nodes);
 	if (IS_ERR(new))
 		return PTR_ERR(new);
+
+	/*
+	 * prevent changing our mempolicy while show_numa_maps()
+	 * is using it.
+	 * Note:  do_set_mempolicy() can be called at init time
+	 * with no 'mm'.
+	 */
+	if (mm)
+		down_write(&mm->mmap_sem);
 	mpol_put(current->mempolicy);
 	current->mempolicy = new;
 	mpol_set_task_struct_flag();
 	if (new && new->policy == MPOL_INTERLEAVE &&
 	    nodes_weight(new->v.nodes))
 		current->il_next = first_node(new->v.nodes);
+	if (mm)
+		up_write(&mm->mmap_sem);
+
 	return 0;
 }