diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 372e4f626138cf644927f59d510e9de2e0f99059..8549cb111627f0b4bc64956df7c2b1e945bbfdc2 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -380,10 +380,10 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
  * first page and all consecutive pages with the same locking.
  */
 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
-				  long npage, unsigned long *pfn_base)
+				  long npage, unsigned long *pfn_base,
+				  bool lock_cap, unsigned long limit)
 {
-	unsigned long pfn = 0, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
-	bool lock_cap = capable(CAP_IPC_LOCK);
+	unsigned long pfn = 0;
 	long ret, pinned = 0, lock_acct = 0;
 	bool rsvd;
 	dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
@@ -924,13 +924,15 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 	unsigned long vaddr = dma->vaddr;
 	size_t size = map_size;
 	long npage;
-	unsigned long pfn;
+	unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+	bool lock_cap = capable(CAP_IPC_LOCK);
 	int ret = 0;
 
 	while (size) {
 		/* Pin a contiguous chunk of memory */
 		npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
-					      size >> PAGE_SHIFT, &pfn);
+					      size >> PAGE_SHIFT, &pfn,
+					      lock_cap, limit);
 		if (npage <= 0) {
 			WARN_ON(!npage);
 			ret = (int)npage;
@@ -1040,6 +1042,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 {
 	struct vfio_domain *d;
 	struct rb_node *n;
+	unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+	bool lock_cap = capable(CAP_IPC_LOCK);
 	int ret;
 
 	/* Arbitrarily pick the first domain in the list for lookups */
@@ -1086,7 +1090,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 
 				npage = vfio_pin_pages_remote(dma, vaddr,
 							      n >> PAGE_SHIFT,
-							      &pfn);
+							      &pfn, lock_cap,
+							      limit);
 				if (npage <= 0) {
 					WARN_ON(!npage);
 					ret = (int)npage;