diff --git a/drivers/mfd/qcom-spmi-pmic.c b/drivers/mfd/qcom-spmi-pmic.c
index 47738f7e492ca989c291833c85fa5e91eac0b274..8e449cff5cec40d74edc9c59861552bb1979e2eb 100644
--- a/drivers/mfd/qcom-spmi-pmic.c
+++ b/drivers/mfd/qcom-spmi-pmic.c
@@ -30,6 +30,8 @@ struct qcom_spmi_dev {
 	struct qcom_spmi_pmic pmic;
 };
 
+static DEFINE_MUTEX(pmic_spmi_revid_lock);
+
 #define N_USIDS(n)		((void *)n)
 
 static const struct of_device_id pmic_spmi_id_table[] = {
@@ -76,24 +78,21 @@ static const struct of_device_id pmic_spmi_id_table[] = {
  *
  * This only supports PMICs with 1 or 2 USIDs.
  */
-static struct spmi_device *qcom_pmic_get_base_usid(struct device *dev)
+static struct spmi_device *qcom_pmic_get_base_usid(struct spmi_device *sdev, struct qcom_spmi_dev *ctx)
 {
-	struct spmi_device *sdev;
-	struct qcom_spmi_dev *ctx;
 	struct device_node *spmi_bus;
 	struct device_node *child;
 	int function_parent_usid, ret;
 	u32 pmic_addr;
 
-	sdev = to_spmi_device(dev);
-	ctx = dev_get_drvdata(&sdev->dev);
-
 	/*
 	 * Quick return if the function device is already in the base
 	 * USID. This will always be hit for PMICs with only 1 USID.
 	 */
-	if (sdev->usid % ctx->num_usids == 0)
+	if (sdev->usid % ctx->num_usids == 0) {
+		get_device(&sdev->dev);
 		return sdev;
+	}
 
 	function_parent_usid = sdev->usid;
 
@@ -118,10 +117,8 @@ static struct spmi_device *qcom_pmic_get_base_usid(struct device *dev)
 			sdev = spmi_device_from_of(child);
 			if (!sdev) {
 				/*
-				 * If the base USID for this PMIC hasn't probed yet
-				 * but the secondary USID has, then we need to defer
-				 * the function driver so that it will attempt to
-				 * probe again when the base USID is ready.
+				 * If the base USID for this PMIC hasn't been
+				 * registered yet then we need to defer.
 				 */
 				sdev = ERR_PTR(-EPROBE_DEFER);
 			}
@@ -135,6 +132,35 @@ static struct spmi_device *qcom_pmic_get_base_usid(struct device *dev)
 	return sdev;
 }
 
+static int pmic_spmi_get_base_revid(struct spmi_device *sdev, struct qcom_spmi_dev *ctx)
+{
+	struct qcom_spmi_dev *base_ctx;
+	struct spmi_device *base;
+	int ret = 0;
+
+	base = qcom_pmic_get_base_usid(sdev, ctx);
+	if (IS_ERR(base))
+		return PTR_ERR(base);
+
+	/*
+	 * Copy revid info from base device if it has probed and is still
+	 * bound to its driver.
+	 */
+	mutex_lock(&pmic_spmi_revid_lock);
+	base_ctx = spmi_device_get_drvdata(base);
+	if (!base_ctx) {
+		ret = -EPROBE_DEFER;
+		goto out_unlock;
+	}
+	memcpy(&ctx->pmic, &base_ctx->pmic, sizeof(ctx->pmic));
+out_unlock:
+	mutex_unlock(&pmic_spmi_revid_lock);
+
+	put_device(&base->dev);
+
+	return ret;
+}
+
 static int pmic_spmi_load_revid(struct regmap *map, struct device *dev,
 				 struct qcom_spmi_pmic *pmic)
 {
@@ -210,11 +236,7 @@ const struct qcom_spmi_pmic *qcom_pmic_get(struct device *dev)
 	if (!of_match_device(pmic_spmi_id_table, dev->parent))
 		return ERR_PTR(-EINVAL);
 
-	sdev = qcom_pmic_get_base_usid(dev->parent);
-
-	if (IS_ERR(sdev))
-		return ERR_CAST(sdev);
-
+	sdev = to_spmi_device(dev->parent);
 	spmi = dev_get_drvdata(&sdev->dev);
 
 	return &spmi->pmic;
@@ -249,16 +271,31 @@ static int pmic_spmi_probe(struct spmi_device *sdev)
 		ret = pmic_spmi_load_revid(regmap, &sdev->dev, &ctx->pmic);
 		if (ret < 0)
 			return ret;
+	} else {
+		ret = pmic_spmi_get_base_revid(sdev, ctx);
+		if (ret)
+			return ret;
 	}
+
+	mutex_lock(&pmic_spmi_revid_lock);
 	spmi_device_set_drvdata(sdev, ctx);
+	mutex_unlock(&pmic_spmi_revid_lock);
 
 	return devm_of_platform_populate(&sdev->dev);
 }
 
+static void pmic_spmi_remove(struct spmi_device *sdev)
+{
+	mutex_lock(&pmic_spmi_revid_lock);
+	spmi_device_set_drvdata(sdev, NULL);
+	mutex_unlock(&pmic_spmi_revid_lock);
+}
+
 MODULE_DEVICE_TABLE(of, pmic_spmi_id_table);
 
 static struct spmi_driver pmic_spmi_driver = {
 	.probe = pmic_spmi_probe,
+	.remove = pmic_spmi_remove,
 	.driver = {
 		.name = "pmic-spmi",
 		.of_match_table = pmic_spmi_id_table,