diff --git a/drivers/iommu/amd_iommu_types.h b/drivers/iommu/amd_iommu_types.h
index 294a409e283b7ae4b52c59350756983711bba9d1..d6b873b57054b44d2f0227630798cda83a637256 100644
--- a/drivers/iommu/amd_iommu_types.h
+++ b/drivers/iommu/amd_iommu_types.h
@@ -574,7 +574,9 @@ struct amd_iommu {
 
 static inline struct amd_iommu *dev_to_amd_iommu(struct device *dev)
 {
-	return container_of(dev, struct amd_iommu, iommu.dev);
+	struct iommu_device *iommu = dev_to_iommu_device(dev);
+
+	return container_of(iommu, struct amd_iommu, iommu);
 }
 
 #define ACPIHID_UID_LEN 256
diff --git a/drivers/iommu/intel-iommu.c b/drivers/iommu/intel-iommu.c
index 687f18f65cea58d2a5f22725a2c36c78621dd3cd..3e8636f1220ea9bdbcf668bac193e21cee408063 100644
--- a/drivers/iommu/intel-iommu.c
+++ b/drivers/iommu/intel-iommu.c
@@ -4736,7 +4736,9 @@ static void intel_disable_iommus(void)
 
 static inline struct intel_iommu *dev_to_intel_iommu(struct device *dev)
 {
-	return container_of(dev, struct intel_iommu, iommu.dev);
+	struct iommu_device *iommu_dev = dev_to_iommu_device(dev);
+
+	return container_of(iommu_dev, struct intel_iommu, iommu);
 }
 
 static ssize_t intel_iommu_show_version(struct device *dev,
diff --git a/drivers/iommu/iommu-sysfs.c b/drivers/iommu/iommu-sysfs.c
index c58351ed61c14309c7a72346bd56f659a7039637..36d1a7ce7fc4cc922cade9642b5d64c05e5d8e19 100644
--- a/drivers/iommu/iommu-sysfs.c
+++ b/drivers/iommu/iommu-sysfs.c
@@ -62,32 +62,40 @@ int iommu_device_sysfs_add(struct iommu_device *iommu,
 	va_list vargs;
 	int ret;
 
-	device_initialize(&iommu->dev);
+	iommu->dev = kzalloc(sizeof(*iommu->dev), GFP_KERNEL);
+	if (!iommu->dev)
+		return -ENOMEM;
 
-	iommu->dev.class = &iommu_class;
-	iommu->dev.parent = parent;
-	iommu->dev.groups = groups;
+	device_initialize(iommu->dev);
+
+	iommu->dev->class = &iommu_class;
+	iommu->dev->parent = parent;
+	iommu->dev->groups = groups;
 
 	va_start(vargs, fmt);
-	ret = kobject_set_name_vargs(&iommu->dev.kobj, fmt, vargs);
+	ret = kobject_set_name_vargs(&iommu->dev->kobj, fmt, vargs);
 	va_end(vargs);
 	if (ret)
 		goto error;
 
-	ret = device_add(&iommu->dev);
+	ret = device_add(iommu->dev);
 	if (ret)
 		goto error;
 
+	dev_set_drvdata(iommu->dev, iommu);
+
 	return 0;
 
 error:
-	put_device(&iommu->dev);
+	put_device(iommu->dev);
 	return ret;
 }
 
 void iommu_device_sysfs_remove(struct iommu_device *iommu)
 {
-	device_unregister(&iommu->dev);
+	dev_set_drvdata(iommu->dev, NULL);
+	device_unregister(iommu->dev);
+	iommu->dev = NULL;
 }
 /*
  * IOMMU drivers can indicate a device is managed by a given IOMMU using
@@ -102,14 +110,14 @@ int iommu_device_link(struct iommu_device *iommu, struct device *link)
 	if (!iommu || IS_ERR(iommu))
 		return -ENODEV;
 
-	ret = sysfs_add_link_to_group(&iommu->dev.kobj, "devices",
+	ret = sysfs_add_link_to_group(&iommu->dev->kobj, "devices",
 				      &link->kobj, dev_name(link));
 	if (ret)
 		return ret;
 
-	ret = sysfs_create_link_nowarn(&link->kobj, &iommu->dev.kobj, "iommu");
+	ret = sysfs_create_link_nowarn(&link->kobj, &iommu->dev->kobj, "iommu");
 	if (ret)
-		sysfs_remove_link_from_group(&iommu->dev.kobj, "devices",
+		sysfs_remove_link_from_group(&iommu->dev->kobj, "devices",
 					     dev_name(link));
 
 	return ret;
@@ -121,5 +129,5 @@ void iommu_device_unlink(struct iommu_device *iommu, struct device *link)
 		return;
 
 	sysfs_remove_link(&link->kobj, "iommu");
-	sysfs_remove_link_from_group(&iommu->dev.kobj, "devices", dev_name(link));
+	sysfs_remove_link_from_group(&iommu->dev->kobj, "devices", dev_name(link));
 }
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 2cb54adc4a334aa3f3a1732c35eaf9749d64b9e8..176f7569d87408c1e57bf09846809bfef5b0b7c8 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -240,7 +240,7 @@ struct iommu_device {
 	struct list_head list;
 	const struct iommu_ops *ops;
 	struct fwnode_handle *fwnode;
-	struct device dev;
+	struct device *dev;
 };
 
 int  iommu_device_register(struct iommu_device *iommu);
@@ -265,6 +265,11 @@ static inline void iommu_device_set_fwnode(struct iommu_device *iommu,
 	iommu->fwnode = fwnode;
 }
 
+static inline struct iommu_device *dev_to_iommu_device(struct device *dev)
+{
+	return (struct iommu_device *)dev_get_drvdata(dev);
+}
+
 #define IOMMU_GROUP_NOTIFY_ADD_DEVICE		1 /* Device added */
 #define IOMMU_GROUP_NOTIFY_DEL_DEVICE		2 /* Pre Device removed */
 #define IOMMU_GROUP_NOTIFY_BIND_DRIVER		3 /* Pre Driver bind */
@@ -589,6 +594,11 @@ static inline void iommu_device_set_fwnode(struct iommu_device *iommu,
 {
 }
 
+static inline struct iommu_device *dev_to_iommu_device(struct device *dev)
+{
+	return NULL;
+}
+
 static inline void iommu_device_unregister(struct iommu_device *iommu)
 {
 }