diff --git a/fs/aio.c b/fs/aio.c
index 99bd790e8cd253479befd5013d6b38a5781d3927..976e33d9741303d0d1643ddaf865918ebeca27c9 100644
--- a/fs/aio.c
+++ b/fs/aio.c
@@ -92,8 +92,10 @@ static void aio_free_ring(struct kioctx *ctx)
 	for (i=0; i<info->nr_pages; i++)
 		put_page(info->ring_pages[i]);
 
-	if (info->mmap_size)
+	if (info->mmap_size) {
+		BUG_ON(ctx->mm != current->mm);
 		vm_munmap(ctx->mm, info->mmap_base, info->mmap_size);
+	}
 
 	if (info->ring_pages && info->ring_pages != info->internal_pages)
 		kfree(info->ring_pages);
@@ -386,6 +388,17 @@ void exit_aio(struct mm_struct *mm)
 				"exit_aio:ioctx still alive: %d %d %d\n",
 				atomic_read(&ctx->users), ctx->dead,
 				ctx->reqs_active);
+		/*
+		 * We don't need to bother with munmap() here -
+		 * exit_mmap(mm) is coming and it'll unmap everything.
+		 * Since aio_free_ring() uses non-zero ->mmap_size
+		 * as indicator that it needs to unmap the area,
+		 * just set it to 0; aio_free_ring() is the only
+		 * place that uses ->mmap_size, so it's safe.
+		 * That way we get all munmap done to current->mm -
+		 * all other callers have ctx->mm == current->mm.
+		 */
+		ctx->ring_info.mmap_size = 0;
 		put_ioctx(ctx);
 	}
 }