diff --git a/arch/s390/include/asm/set_memory.h b/arch/s390/include/asm/set_memory.h
index a22a5a81811cf289449512f250dc446787831fcf..950d87bd997a7b32a0d055038ad567b5a0418be5 100644
--- a/arch/s390/include/asm/set_memory.h
+++ b/arch/s390/include/asm/set_memory.h
@@ -10,6 +10,7 @@ extern struct mutex cpa_mutex;
 #define SET_MEMORY_RW	2UL
 #define SET_MEMORY_NX	4UL
 #define SET_MEMORY_X	8UL
+#define SET_MEMORY_4K  16UL
 
 int __set_memory(unsigned long addr, int numpages, unsigned long flags);
 
@@ -33,4 +34,9 @@ static inline int set_memory_x(unsigned long addr, int numpages)
 	return __set_memory(addr, numpages, SET_MEMORY_X);
 }
 
+static inline int set_memory_4k(unsigned long addr, int numpages)
+{
+	return __set_memory(addr, numpages, SET_MEMORY_4K);
+}
+
 #endif
diff --git a/arch/s390/mm/pageattr.c b/arch/s390/mm/pageattr.c
index ed8e5b3575d59dc709cfa091d0713ff5519d491f..2ad95e5c79c1bf34781eec8398e11c921a605729 100644
--- a/arch/s390/mm/pageattr.c
+++ b/arch/s390/mm/pageattr.c
@@ -85,6 +85,8 @@ static int walk_pte_level(pmd_t *pmdp, unsigned long addr, unsigned long end,
 {
 	pte_t *ptep, new;
 
+	if ((flags & SET_MEMORY_4K) == SET_MEMORY_4K)
+		return 0;
 	ptep = pte_offset_kernel(pmdp, addr);
 	do {
 		new = *ptep;
@@ -155,6 +157,7 @@ static int walk_pmd_level(pud_t *pudp, unsigned long addr, unsigned long end,
 			  unsigned long flags)
 {
 	unsigned long next;
+	int need_split;
 	pmd_t *pmdp;
 	int rc = 0;
 
@@ -164,7 +167,10 @@ static int walk_pmd_level(pud_t *pudp, unsigned long addr, unsigned long end,
 			return -EINVAL;
 		next = pmd_addr_end(addr, end);
 		if (pmd_large(*pmdp)) {
-			if (addr & ~PMD_MASK || addr + PMD_SIZE > next) {
+			need_split  = !!(flags & SET_MEMORY_4K);
+			need_split |= !!(addr & ~PMD_MASK);
+			need_split |= !!(addr + PMD_SIZE > next);
+			if (need_split) {
 				rc = split_pmd_page(pmdp, addr);
 				if (rc)
 					return rc;
@@ -232,6 +238,7 @@ static int walk_pud_level(p4d_t *p4d, unsigned long addr, unsigned long end,
 			  unsigned long flags)
 {
 	unsigned long next;
+	int need_split;
 	pud_t *pudp;
 	int rc = 0;
 
@@ -241,7 +248,10 @@ static int walk_pud_level(p4d_t *p4d, unsigned long addr, unsigned long end,
 			return -EINVAL;
 		next = pud_addr_end(addr, end);
 		if (pud_large(*pudp)) {
-			if (addr & ~PUD_MASK || addr + PUD_SIZE > next) {
+			need_split  = !!(flags & SET_MEMORY_4K);
+			need_split |= !!(addr & ~PUD_MASK);
+			need_split |= !!(addr + PUD_SIZE > next);
+			if (need_split) {
 				rc = split_pud_page(pudp, addr);
 				if (rc)
 					break;