diff --git a/drivers/vfio/vfio_main.c b/drivers/vfio/vfio_main.c
index 9f830d0a25b7a9a3c2fccb71652e52a2ab48fc27..911ee1abdff0749589b3b1e5cf4afcabb30fc2c3 100644
--- a/drivers/vfio/vfio_main.c
+++ b/drivers/vfio/vfio_main.c
@@ -1552,8 +1552,9 @@ static const struct file_operations vfio_device_fops = {
  * vfio_file_iommu_group - Return the struct iommu_group for the vfio group file
  * @file: VFIO group file
  *
- * The returned iommu_group is valid as long as a ref is held on the file.
- * This function is deprecated, only the SPAPR path in kvm should call it.
+ * The returned iommu_group is valid as long as a ref is held on the file. This
+ * returns a reference on the group. This function is deprecated, only the SPAPR
+ * path in kvm should call it.
  */
 struct iommu_group *vfio_file_iommu_group(struct file *file)
 {
@@ -1564,6 +1565,7 @@ struct iommu_group *vfio_file_iommu_group(struct file *file)
 
 	if (!vfio_file_is_group(file))
 		return NULL;
+	iommu_group_ref_get(group->iommu_group);
 	return group->iommu_group;
 }
 EXPORT_SYMBOL_GPL(vfio_file_iommu_group);
diff --git a/virt/kvm/vfio.c b/virt/kvm/vfio.c
index 54aec3b0559c706996d664b381de3be52df143eb..495ceabffe88bb7270083e404d6576c7258634a0 100644
--- a/virt/kvm/vfio.c
+++ b/virt/kvm/vfio.c
@@ -24,6 +24,9 @@
 struct kvm_vfio_group {
 	struct list_head node;
 	struct file *file;
+#ifdef CONFIG_SPAPR_TCE_IOMMU
+	struct iommu_group *iommu_group;
+#endif
 };
 
 struct kvm_vfio {
@@ -97,12 +100,12 @@ static struct iommu_group *kvm_vfio_file_iommu_group(struct file *file)
 static void kvm_spapr_tce_release_vfio_group(struct kvm *kvm,
 					     struct kvm_vfio_group *kvg)
 {
-	struct iommu_group *grp = kvm_vfio_file_iommu_group(kvg->file);
-
-	if (WARN_ON_ONCE(!grp))
+	if (WARN_ON_ONCE(!kvg->iommu_group))
 		return;
 
-	kvm_spapr_tce_release_iommu_group(kvm, grp);
+	kvm_spapr_tce_release_iommu_group(kvm, kvg->iommu_group);
+	iommu_group_put(kvg->iommu_group);
+	kvg->iommu_group = NULL;
 }
 #endif
 
@@ -252,19 +255,19 @@ static int kvm_vfio_group_set_spapr_tce(struct kvm_device *dev,
 	mutex_lock(&kv->lock);
 
 	list_for_each_entry(kvg, &kv->group_list, node) {
-		struct iommu_group *grp;
-
 		if (kvg->file != f.file)
 			continue;
 
-		grp = kvm_vfio_file_iommu_group(kvg->file);
-		if (WARN_ON_ONCE(!grp)) {
-			ret = -EIO;
-			goto err_fdput;
+		if (!kvg->iommu_group) {
+			kvg->iommu_group = kvm_vfio_file_iommu_group(kvg->file);
+			if (WARN_ON_ONCE(!kvg->iommu_group)) {
+				ret = -EIO;
+				goto err_fdput;
+			}
 		}
 
 		ret = kvm_spapr_tce_attach_iommu_group(dev->kvm, param.tablefd,
-						       grp);
+						       kvg->iommu_group);
 		break;
 	}