diff --git a/mm/mmap.c b/mm/mmap.c
index f7cd9cb966c0fb5892e28bb24ec98bc3d85d9768..330f12c17fa11d835afd80a17e7980f62e707637 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2687,8 +2687,8 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
  * work.  This now handles partial unmappings.
  * Jeremy Fitzhardinge <jeremy@goop.org>
  */
-int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
-	      struct list_head *uf)
+static int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
+		       struct list_head *uf, bool downgrade)
 {
 	unsigned long end;
 	struct vm_area_struct *vma, *prev, *last;
@@ -2770,25 +2770,47 @@ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 				mm->locked_vm -= vma_pages(tmp);
 				munlock_vma_pages_all(tmp);
 			}
+
+			/*
+			 * Unmapping vmas, which have VM_HUGETLB or VM_PFNMAP,
+			 * need get done with write mmap_sem held since they may
+			 * update vm_flags.
+			 */
+			if (downgrade &&
+			    (tmp->vm_flags & (VM_HUGETLB | VM_PFNMAP)))
+				downgrade = false;
+
 			tmp = tmp->vm_next;
 		}
 	}
 
-	/*
-	 * Remove the vma's, and unmap the actual pages
-	 */
+	/* Detach vmas from rbtree */
 	detach_vmas_to_be_unmapped(mm, vma, prev, end);
-	unmap_region(mm, vma, prev, start, end);
 
+	/*
+	 * mpx unmap needs to be called with mmap_sem held for write.
+	 * It is safe to call it before unmap_region().
+	 */
 	arch_unmap(mm, vma, start, end);
 
+	if (downgrade)
+		downgrade_write(&mm->mmap_sem);
+
+	unmap_region(mm, vma, prev, start, end);
+
 	/* Fix up all other VM information */
 	remove_vma_list(mm, vma);
 
-	return 0;
+	return downgrade ? 1 : 0;
 }
 
-int vm_munmap(unsigned long start, size_t len)
+int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
+	      struct list_head *uf)
+{
+	return __do_munmap(mm, start, len, uf, false);
+}
+
+static int __vm_munmap(unsigned long start, size_t len, bool downgrade)
 {
 	int ret;
 	struct mm_struct *mm = current->mm;
@@ -2797,17 +2819,32 @@ int vm_munmap(unsigned long start, size_t len)
 	if (down_write_killable(&mm->mmap_sem))
 		return -EINTR;
 
-	ret = do_munmap(mm, start, len, &uf);
-	up_write(&mm->mmap_sem);
+	ret = __do_munmap(mm, start, len, &uf, downgrade);
+	/*
+	 * Returning 1 indicates mmap_sem is downgraded.
+	 * But 1 is not legal return value of vm_munmap() and munmap(), reset
+	 * it to 0 before return.
+	 */
+	if (ret == 1) {
+		up_read(&mm->mmap_sem);
+		ret = 0;
+	} else
+		up_write(&mm->mmap_sem);
+
 	userfaultfd_unmap_complete(mm, &uf);
 	return ret;
 }
+
+int vm_munmap(unsigned long start, size_t len)
+{
+	return __vm_munmap(start, len, false);
+}
 EXPORT_SYMBOL(vm_munmap);
 
 SYSCALL_DEFINE2(munmap, unsigned long, addr, size_t, len)
 {
 	profile_munmap(addr);
-	return vm_munmap(addr, len);
+	return __vm_munmap(addr, len, true);
 }