diff --git a/drivers/xen/gntdev.c b/drivers/xen/gntdev.c
index 57390c7666e5dd8d44bfe9bdf1e503afb13de189..b0b02a5011672b6670e136728b2c2a8a8f2ee68e 100644
--- a/drivers/xen/gntdev.c
+++ b/drivers/xen/gntdev.c
@@ -492,12 +492,19 @@ static bool in_range(struct gntdev_grant_map *map,
 	return true;
 }
 
-static void unmap_if_in_range(struct gntdev_grant_map *map,
-			      unsigned long start, unsigned long end)
+static int unmap_if_in_range(struct gntdev_grant_map *map,
+			      unsigned long start, unsigned long end,
+			      bool blockable)
 {
 	unsigned long mstart, mend;
 	int err;
 
+	if (!in_range(map, start, end))
+		return 0;
+
+	if (!blockable)
+		return -EAGAIN;
+
 	mstart = max(start, map->vma->vm_start);
 	mend   = min(end,   map->vma->vm_end);
 	pr_debug("map %d+%d (%lx %lx), range %lx %lx, mrange %lx %lx\n",
@@ -508,6 +515,8 @@ static void unmap_if_in_range(struct gntdev_grant_map *map,
 				(mstart - map->vma->vm_start) >> PAGE_SHIFT,
 				(mend - mstart) >> PAGE_SHIFT);
 	WARN_ON(err);
+
+	return 0;
 }
 
 static int mn_invl_range_start(struct mmu_notifier *mn,
@@ -519,25 +528,20 @@ static int mn_invl_range_start(struct mmu_notifier *mn,
 	struct gntdev_grant_map *map;
 	int ret = 0;
 
-	/* TODO do we really need a mutex here? */
 	if (blockable)
 		mutex_lock(&priv->lock);
 	else if (!mutex_trylock(&priv->lock))
 		return -EAGAIN;
 
 	list_for_each_entry(map, &priv->maps, next) {
-		if (in_range(map, start, end)) {
-			ret = -EAGAIN;
+		ret = unmap_if_in_range(map, start, end, blockable);
+		if (ret)
 			goto out_unlock;
-		}
-		unmap_if_in_range(map, start, end);
 	}
 	list_for_each_entry(map, &priv->freeable_maps, next) {
-		if (in_range(map, start, end)) {
-			ret = -EAGAIN;
+		ret = unmap_if_in_range(map, start, end, blockable);
+		if (ret)
 			goto out_unlock;
-		}
-		unmap_if_in_range(map, start, end);
 	}
 
 out_unlock: