diff --git a/virt/kvm/vfio.c b/virt/kvm/vfio.c
index 1dd087da6f31ae2f38c70042213ee6e5159cee10..d32f239eb47133ae3f57ab0ed52a71d5c0478b8e 100644
--- a/virt/kvm/vfio.c
+++ b/virt/kvm/vfio.c
@@ -60,6 +60,19 @@ static void kvm_vfio_group_put_external_user(struct vfio_group *vfio_group)
 	symbol_put(vfio_group_put_external_user);
 }
 
+static void kvm_vfio_group_set_kvm(struct vfio_group *group, struct kvm *kvm)
+{
+	void (*fn)(struct vfio_group *, struct kvm *);
+
+	fn = symbol_get(vfio_group_set_kvm);
+	if (!fn)
+		return;
+
+	fn(group, kvm);
+
+	symbol_put(vfio_group_set_kvm);
+}
+
 static bool kvm_vfio_group_is_coherent(struct vfio_group *vfio_group)
 {
 	long (*fn)(struct vfio_group *, unsigned long);
@@ -159,6 +172,8 @@ static int kvm_vfio_set_group(struct kvm_device *dev, long attr, u64 arg)
 
 		mutex_unlock(&kv->lock);
 
+		kvm_vfio_group_set_kvm(vfio_group, dev->kvm);
+
 		kvm_vfio_update_coherency(dev);
 
 		return 0;
@@ -196,6 +211,8 @@ static int kvm_vfio_set_group(struct kvm_device *dev, long attr, u64 arg)
 
 		mutex_unlock(&kv->lock);
 
+		kvm_vfio_group_set_kvm(vfio_group, NULL);
+
 		kvm_vfio_group_put_external_user(vfio_group);
 
 		kvm_vfio_update_coherency(dev);
@@ -240,6 +257,7 @@ static void kvm_vfio_destroy(struct kvm_device *dev)
 	struct kvm_vfio_group *kvg, *tmp;
 
 	list_for_each_entry_safe(kvg, tmp, &kv->group_list, node) {
+		kvm_vfio_group_set_kvm(kvg->vfio_group, NULL);
 		kvm_vfio_group_put_external_user(kvg->vfio_group);
 		list_del(&kvg->node);
 		kfree(kvg);