diff --git a/drivers/pci/msi.c b/drivers/pci/msi.c
index 8c61304cbb37e8b9f32343c0fb93cd73bd276d27..ccb1974464fb11011b8cbbbb788ec4dc3bc2737d 100644
--- a/drivers/pci/msi.c
+++ b/drivers/pci/msi.c
@@ -70,12 +70,10 @@ arch_teardown_msi_irqs(struct pci_dev *dev)
 	}
 }
 
-static void msi_set_enable(struct pci_dev *dev, int enable)
+static void __msi_set_enable(struct pci_dev *dev, int pos, int enable)
 {
-	int pos;
 	u16 control;
 
-	pos = pci_find_capability(dev, PCI_CAP_ID_MSI);
 	if (pos) {
 		pci_read_config_word(dev, pos + PCI_MSI_FLAGS, &control);
 		control &= ~PCI_MSI_FLAGS_ENABLE;
@@ -85,6 +83,11 @@ static void msi_set_enable(struct pci_dev *dev, int enable)
 	}
 }
 
+static void msi_set_enable(struct pci_dev *dev, int enable)
+{
+	__msi_set_enable(dev, pci_find_capability(dev, PCI_CAP_ID_MSI), enable);
+}
+
 static void msix_set_enable(struct pci_dev *dev, int enable)
 {
 	int pos;
@@ -141,7 +144,8 @@ static void msi_set_mask_bits(unsigned int irq, u32 mask, u32 flag)
 			mask_bits |= flag & mask;
 			pci_write_config_dword(entry->dev, pos, mask_bits);
 		} else {
-			msi_set_enable(entry->dev, !flag);
+			__msi_set_enable(entry->dev, entry->msi_attrib.pos,
+					 !flag);
 		}
 		break;
 	case PCI_CAP_ID_MSIX: