diff --git a/drivers/scsi/hisi_sas/hisi_sas.h b/drivers/scsi/hisi_sas/hisi_sas.h
index 912d2342a5fe3d531d5ca6d0e2e582b5debe367f..af291947a54d86809b7a202abac074e9f164e9c7 100644
--- a/drivers/scsi/hisi_sas/hisi_sas.h
+++ b/drivers/scsi/hisi_sas/hisi_sas.h
@@ -69,6 +69,12 @@
 #define HISI_SAS_SATA_PROTOCOL_FPDMA		0x8
 #define HISI_SAS_SATA_PROTOCOL_ATAPI		0x10
 
+#define HISI_SAS_DIF_PROT_MASK (SHOST_DIF_TYPE1_PROTECTION | \
+				SHOST_DIF_TYPE2_PROTECTION | \
+				SHOST_DIF_TYPE3_PROTECTION)
+
+#define HISI_SAS_PROT_MASK (HISI_SAS_DIF_PROT_MASK)
+
 struct hisi_hba;
 
 enum {
@@ -268,6 +274,8 @@ struct hisi_hba {
 	struct pci_dev *pci_dev;
 	struct device *dev;
 
+	int prot_mask;
+
 	void __iomem *regs;
 	void __iomem *sgpio_regs;
 	struct regmap *ctrl;
diff --git a/drivers/scsi/hisi_sas/hisi_sas_v3_hw.c b/drivers/scsi/hisi_sas/hisi_sas_v3_hw.c
index 6acca892d95fe0ca86d4ec78d6071d3985d629e4..add761648fcccf9b47a1f710c36206becabef0e2 100644
--- a/drivers/scsi/hisi_sas/hisi_sas_v3_hw.c
+++ b/drivers/scsi/hisi_sas/hisi_sas_v3_hw.c
@@ -127,6 +127,8 @@
 #define PHY_CTRL			(PORT_BASE + 0x14)
 #define PHY_CTRL_RESET_OFF		0
 #define PHY_CTRL_RESET_MSK		(0x1 << PHY_CTRL_RESET_OFF)
+#define CMD_HDR_PIR_OFF			8
+#define CMD_HDR_PIR_MSK			(0x1 << CMD_HDR_PIR_OFF)
 #define SL_CFG				(PORT_BASE + 0x84)
 #define AIP_LIMIT			(PORT_BASE + 0x90)
 #define SL_CONTROL			(PORT_BASE + 0x94)
@@ -333,6 +335,16 @@
 #define ITCT_HDR_RTOLT_OFF		48
 #define ITCT_HDR_RTOLT_MSK		(0xffffULL << ITCT_HDR_RTOLT_OFF)
 
+struct hisi_sas_protect_iu_v3_hw {
+	u32 dw0;
+	u32 lbrtcv;
+	u32 lbrtgv;
+	u32 dw3;
+	u32 dw4;
+	u32 dw5;
+	u32 rsv;
+};
+
 struct hisi_sas_complete_v3_hdr {
 	__le32 dw0;
 	__le32 dw1;
@@ -372,9 +384,28 @@ struct hisi_sas_err_record_v3 {
 	((fis.command == ATA_CMD_DEV_RESET) && \
 	((fis.control & ATA_SRST) != 0)))
 
+#define T10_INSRT_EN_OFF    0
+#define T10_INSRT_EN_MSK    (1 << T10_INSRT_EN_OFF)
+#define T10_RMV_EN_OFF	    1
+#define T10_RMV_EN_MSK	    (1 << T10_RMV_EN_OFF)
+#define T10_RPLC_EN_OFF	    2
+#define T10_RPLC_EN_MSK	    (1 << T10_RPLC_EN_OFF)
+#define T10_CHK_EN_OFF	    3
+#define T10_CHK_EN_MSK	    (1 << T10_CHK_EN_OFF)
+#define INCR_LBRT_OFF	    5
+#define INCR_LBRT_MSK	    (1 << INCR_LBRT_OFF)
+#define USR_DATA_BLOCK_SZ_OFF	20
+#define USR_DATA_BLOCK_SZ_MSK	(0x3 << USR_DATA_BLOCK_SZ_OFF)
+#define T10_CHK_MSK_OFF	    16
+
 static bool hisi_sas_intr_conv;
 MODULE_PARM_DESC(intr_conv, "interrupt converge enable (0-1)");
 
+/* permit overriding the host protection capabilities mask (EEDP/T10 PI) */
+static int prot_mask;
+module_param(prot_mask, int, 0);
+MODULE_PARM_DESC(prot_mask, " host protection capabilities mask, def=0x0 ");
+
 static u32 hisi_sas_read32(struct hisi_hba *hisi_hba, u32 off)
 {
 	void __iomem *regs = hisi_hba->regs + off;
@@ -941,6 +972,58 @@ static void prep_prd_sge_v3_hw(struct hisi_hba *hisi_hba,
 	hdr->sg_len = cpu_to_le32(n_elem << CMD_HDR_DATA_SGL_LEN_OFF);
 }
 
+static u32 get_prot_chk_msk_v3_hw(struct scsi_cmnd *scsi_cmnd)
+{
+	unsigned char prot_flags = scsi_cmnd->prot_flags;
+
+	if (prot_flags & SCSI_PROT_TRANSFER_PI) {
+		if (prot_flags & SCSI_PROT_REF_CHECK)
+			return 0xc << 16;
+		return 0xfc << 16;
+	}
+	return 0;
+}
+
+static void fill_prot_v3_hw(struct scsi_cmnd *scsi_cmnd,
+			    struct hisi_sas_protect_iu_v3_hw *prot)
+{
+	unsigned char prot_op = scsi_get_prot_op(scsi_cmnd);
+	unsigned int interval = scsi_prot_interval(scsi_cmnd);
+	u32 lbrt_chk_val = t10_pi_ref_tag(scsi_cmnd->request);
+
+	switch (prot_op) {
+	case SCSI_PROT_READ_STRIP:
+		prot->dw0 |= (T10_RMV_EN_MSK | T10_CHK_EN_MSK);
+		prot->lbrtcv = lbrt_chk_val;
+		prot->dw4 |= get_prot_chk_msk_v3_hw(scsi_cmnd);
+		break;
+	case SCSI_PROT_WRITE_INSERT:
+		prot->dw0 |= T10_INSRT_EN_MSK;
+		prot->lbrtgv = lbrt_chk_val;
+		break;
+	default:
+		WARN(1, "prot_op(0x%x) is not valid\n", prot_op);
+		break;
+	}
+
+	switch (interval) {
+	case 512:
+		break;
+	case 4096:
+		prot->dw0 |= (0x1 << USR_DATA_BLOCK_SZ_OFF);
+		break;
+	case 520:
+		prot->dw0 |= (0x2 << USR_DATA_BLOCK_SZ_OFF);
+		break;
+	default:
+		WARN(1, "protection interval (0x%x) invalid\n",
+		     interval);
+		break;
+	}
+
+	prot->dw0 |= INCR_LBRT_MSK;
+}
+
 static void prep_ssp_v3_hw(struct hisi_hba *hisi_hba,
 			  struct hisi_sas_slot *slot)
 {
@@ -952,9 +1035,10 @@ static void prep_ssp_v3_hw(struct hisi_hba *hisi_hba,
 	struct sas_ssp_task *ssp_task = &task->ssp_task;
 	struct scsi_cmnd *scsi_cmnd = ssp_task->cmd;
 	struct hisi_sas_tmf_task *tmf = slot->tmf;
+	unsigned char prot_op = scsi_get_prot_op(scsi_cmnd);
 	int has_data = 0, priority = !!tmf;
 	u8 *buf_cmd;
-	u32 dw1 = 0, dw2 = 0;
+	u32 dw1 = 0, dw2 = 0, len = 0;
 
 	hdr->dw0 = cpu_to_le32((1 << CMD_HDR_RESP_REPORT_OFF) |
 			       (2 << CMD_HDR_TLR_CTRL_OFF) |
@@ -984,7 +1068,6 @@ static void prep_ssp_v3_hw(struct hisi_hba *hisi_hba,
 
 	/* map itct entry */
 	dw1 |= sas_dev->device_id << CMD_HDR_DEV_ID_OFF;
-	hdr->dw1 = cpu_to_le32(dw1);
 
 	dw2 = (((sizeof(struct ssp_command_iu) + sizeof(struct ssp_frame_hdr)
 	      + 3) / 4) << CMD_HDR_CFL_OFF) |
@@ -997,7 +1080,6 @@ static void prep_ssp_v3_hw(struct hisi_hba *hisi_hba,
 		prep_prd_sge_v3_hw(hisi_hba, slot, hdr, task->scatter,
 					slot->n_elem);
 
-	hdr->data_transfer_len = cpu_to_le32(task->total_xfer_len);
 	hdr->cmd_table_addr = cpu_to_le64(hisi_sas_cmd_hdr_addr_dma(slot));
 	hdr->sts_buffer_addr = cpu_to_le64(hisi_sas_status_buf_addr_dma(slot));
 
@@ -1022,6 +1104,38 @@ static void prep_ssp_v3_hw(struct hisi_hba *hisi_hba,
 			break;
 		}
 	}
+
+	if (has_data && (prot_op != SCSI_PROT_NORMAL)) {
+		struct hisi_sas_protect_iu_v3_hw prot;
+		u8 *buf_cmd_prot;
+
+		hdr->dw7 |= cpu_to_le32(1 << CMD_HDR_ADDR_MODE_SEL_OFF);
+		dw1 |= CMD_HDR_PIR_MSK;
+		buf_cmd_prot = hisi_sas_cmd_hdr_addr_mem(slot) +
+			       sizeof(struct ssp_frame_hdr) +
+			       sizeof(struct ssp_command_iu);
+
+		memset(&prot, 0, sizeof(struct hisi_sas_protect_iu_v3_hw));
+		fill_prot_v3_hw(scsi_cmnd, &prot);
+		memcpy(buf_cmd_prot, &prot,
+		       sizeof(struct hisi_sas_protect_iu_v3_hw));
+
+		/*
+		 * For READ, we need length of info read to memory, while for
+		 * WRITE we need length of data written to the disk.
+		 */
+		if (prot_op == SCSI_PROT_WRITE_INSERT) {
+			unsigned int interval = scsi_prot_interval(scsi_cmnd);
+			unsigned int ilog2_interval = ilog2(interval);
+
+			len = (task->total_xfer_len >> ilog2_interval) * 8;
+		}
+
+	}
+
+	hdr->dw1 = cpu_to_le32(dw1);
+
+	hdr->data_transfer_len = cpu_to_le32(task->total_xfer_len + len);
 }
 
 static void prep_smp_v3_hw(struct hisi_hba *hisi_hba,
@@ -2291,6 +2405,12 @@ hisi_sas_shost_alloc_pci(struct pci_dev *pdev)
 	hisi_hba->shost = shost;
 	SHOST_TO_SAS_HA(shost) = &hisi_hba->sha;
 
+	if (prot_mask & ~HISI_SAS_PROT_MASK)
+		dev_err(dev, "unsupported protection mask 0x%x, using default (0x0)\n",
+			prot_mask);
+	else
+		hisi_hba->prot_mask = prot_mask;
+
 	timer_setup(&hisi_hba->timer, NULL, 0);
 
 	if (hisi_sas_get_fw_info(hisi_hba) < 0)
@@ -2401,6 +2521,12 @@ hisi_sas_v3_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 	if (rc)
 		goto err_out_register_ha;
 
+	if (hisi_hba->prot_mask) {
+		dev_info(dev, "Registering for DIF/DIX prot_mask=0x%x\n",
+			 prot_mask);
+		scsi_host_set_prot(hisi_hba->shost, prot_mask);
+	}
+
 	scsi_scan_host(shost);
 
 	return 0;