diff --git a/tools/testing/selftests/vm/hugepage-mremap.c b/tools/testing/selftests/vm/hugepage-mremap.c
index 257df94697a52a8c02ea7a706ae14602b17dee08..2a7c33631a298874e2e8991a72ac3998d9317c83 100644
--- a/tools/testing/selftests/vm/hugepage-mremap.c
+++ b/tools/testing/selftests/vm/hugepage-mremap.c
@@ -4,7 +4,11 @@
  *
  * Example of remapping huge page memory in a user application using the
  * mremap system call.  Code assumes a hugetlbfs filesystem is mounted
- * at './huge'.  The code will use 10MB worth of huge pages.
+ * at './huge'.  The amount of memory used by this test is decided by a command
+ * line argument in MBs. If missing, the default amount is 10MB.
+ *
+ * To make sure the test triggers pmd sharing and goes through the 'unshare'
+ * path in the mremap code use 1GB (1024) or more.
  */
 
 #define _GNU_SOURCE
@@ -18,8 +22,10 @@
 #include <linux/userfaultfd.h>
 #include <sys/ioctl.h>
 
-#define LENGTH (1UL * 1024 * 1024 * 1024)
+#define DEFAULT_LENGTH_MB 10UL
+#define MB_TO_BYTES(x) (x * 1024 * 1024)
 
+#define FILE_NAME "huge/hugepagefile"
 #define PROTECTION (PROT_READ | PROT_WRITE | PROT_EXEC)
 #define FLAGS (MAP_SHARED | MAP_ANONYMOUS)
 
@@ -28,20 +34,20 @@ static void check_bytes(char *addr)
 	printf("First hex is %x\n", *((unsigned int *)addr));
 }
 
-static void write_bytes(char *addr)
+static void write_bytes(char *addr, size_t len)
 {
 	unsigned long i;
 
-	for (i = 0; i < LENGTH; i++)
+	for (i = 0; i < len; i++)
 		*(addr + i) = (char)i;
 }
 
-static int read_bytes(char *addr)
+static int read_bytes(char *addr, size_t len)
 {
 	unsigned long i;
 
 	check_bytes(addr);
-	for (i = 0; i < LENGTH; i++)
+	for (i = 0; i < len; i++)
 		if (*(addr + i) != (char)i) {
 			printf("Mismatch at %lu\n", i);
 			return 1;
@@ -99,11 +105,19 @@ static void register_region_with_uffd(char *addr, size_t len)
 	}
 }
 
-int main(void)
+int main(int argc, char *argv[])
 {
+	/* Read memory length as the first arg if valid, otherwise fallback to
+	 * the default length. Any additional args are ignored.
+	 */
+	size_t length = argc > 1 ? (size_t)atoi(argv[1]) : 0UL;
+
+	length = length > 0 ? length : DEFAULT_LENGTH_MB;
+	length = MB_TO_BYTES(length);
+
 	int ret = 0;
 
-	int fd = open("/huge/test", O_CREAT | O_RDWR, 0755);
+	int fd = open(FILE_NAME, O_CREAT | O_RDWR, 0755);
 
 	if (fd < 0) {
 		perror("Open failed");
@@ -112,7 +126,7 @@ int main(void)
 
 	/* mmap to a PUD aligned address to hopefully trigger pmd sharing. */
 	unsigned long suggested_addr = 0x7eaa40000000;
-	void *haddr = mmap((void *)suggested_addr, LENGTH, PROTECTION,
+	void *haddr = mmap((void *)suggested_addr, length, PROTECTION,
 			   MAP_HUGETLB | MAP_SHARED | MAP_POPULATE, fd, 0);
 	printf("Map haddr: Returned address is %p\n", haddr);
 	if (haddr == MAP_FAILED) {
@@ -122,7 +136,7 @@ int main(void)
 
 	/* mmap again to a dummy address to hopefully trigger pmd sharing. */
 	suggested_addr = 0x7daa40000000;
-	void *daddr = mmap((void *)suggested_addr, LENGTH, PROTECTION,
+	void *daddr = mmap((void *)suggested_addr, length, PROTECTION,
 			   MAP_HUGETLB | MAP_SHARED | MAP_POPULATE, fd, 0);
 	printf("Map daddr: Returned address is %p\n", daddr);
 	if (daddr == MAP_FAILED) {
@@ -132,16 +146,16 @@ int main(void)
 
 	suggested_addr = 0x7faa40000000;
 	void *vaddr =
-		mmap((void *)suggested_addr, LENGTH, PROTECTION, FLAGS, -1, 0);
+		mmap((void *)suggested_addr, length, PROTECTION, FLAGS, -1, 0);
 	printf("Map vaddr: Returned address is %p\n", vaddr);
 	if (vaddr == MAP_FAILED) {
 		perror("mmap2");
 		exit(1);
 	}
 
-	register_region_with_uffd(haddr, LENGTH);
+	register_region_with_uffd(haddr, length);
 
-	void *addr = mremap(haddr, LENGTH, LENGTH,
+	void *addr = mremap(haddr, length, length,
 			    MREMAP_MAYMOVE | MREMAP_FIXED, vaddr);
 	if (addr == MAP_FAILED) {
 		perror("mremap");
@@ -150,10 +164,10 @@ int main(void)
 
 	printf("Mremap: Returned address is %p\n", addr);
 	check_bytes(addr);
-	write_bytes(addr);
-	ret = read_bytes(addr);
+	write_bytes(addr, length);
+	ret = read_bytes(addr, length);
 
-	munmap(addr, LENGTH);
+	munmap(addr, length);
 
 	return ret;
 }
diff --git a/tools/testing/selftests/vm/run_vmtests.sh b/tools/testing/selftests/vm/run_vmtests.sh
index a24d30af30949a1a9dbc21a7efe4f6a6d0c01e38..75d4017413944660dcc213afe7a69d71da4d7e5d 100755
--- a/tools/testing/selftests/vm/run_vmtests.sh
+++ b/tools/testing/selftests/vm/run_vmtests.sh
@@ -111,7 +111,7 @@ fi
 echo "-----------------------"
 echo "running hugepage-mremap"
 echo "-----------------------"
-./hugepage-mremap
+./hugepage-mremap 256
 if [ $? -ne 0 ]; then
 	echo "[FAIL]"
 	exitcode=1