diff --git a/Documentation/admin-guide/cgroup-v2.rst b/Documentation/admin-guide/cgroup-v2.rst
index 86311c2907cd3afdc54e47e34fa0f1536aa65cd8..f0499884124d2dacf920b0eb4d71b18e5e9df834 100644
--- a/Documentation/admin-guide/cgroup-v2.rst
+++ b/Documentation/admin-guide/cgroup-v2.rst
@@ -1333,11 +1333,14 @@ The following nested keys are defined.
 	all the existing limitations and potential future extensions.
 
   memory.peak
-	A read-only single value file which exists on non-root
-	cgroups.
+	A read-write single value file which exists on non-root cgroups.
+
+	The max memory usage recorded for the cgroup and its descendants since
+	either the creation of the cgroup or the most recent reset for that FD.
 
-	The max memory usage recorded for the cgroup and its
-	descendants since the creation of the cgroup.
+	A write of any non-empty string to this file resets it to the
+	current memory usage for subsequent reads through the same
+	file descriptor.
 
   memory.oom.group
 	A read-write single value file which exists on non-root
@@ -1663,11 +1666,14 @@ The following nested keys are defined.
 	Healthy workloads are not expected to reach this limit.
 
   memory.swap.peak
-	A read-only single value file which exists on non-root
-	cgroups.
+	A read-write single value file which exists on non-root cgroups.
+
+	The max swap usage recorded for the cgroup and its descendants since
+	the creation of the cgroup or the most recent reset for that FD.
 
-	The max swap usage recorded for the cgroup and its
-	descendants since the creation of the cgroup.
+	A write of any non-empty string to this file resets it to the
+	current memory usage for subsequent reads through the same
+	file descriptor.
 
   memory.swap.max
 	A read-write single value file which exists on non-root
diff --git a/include/linux/cgroup-defs.h b/include/linux/cgroup-defs.h
index ae04035b6cbe58453760fb175fe2be0f87a37ff0..7fc2d0195f5603c77863f83a953503bf320b2f88 100644
--- a/include/linux/cgroup-defs.h
+++ b/include/linux/cgroup-defs.h
@@ -775,6 +775,11 @@ struct cgroup_subsys {
 
 extern struct percpu_rw_semaphore cgroup_threadgroup_rwsem;
 
+struct cgroup_of_peak {
+	unsigned long		value;
+	struct list_head	list;
+};
+
 /**
  * cgroup_threadgroup_change_begin - threadgroup exclusion for cgroups
  * @tsk: target task
diff --git a/include/linux/cgroup.h b/include/linux/cgroup.h
index c60ba0ab14627ede2a541cca11c9f8b1c3327621..3e0563753cc3e576bb678e7e4b9e0d3a84838cf4 100644
--- a/include/linux/cgroup.h
+++ b/include/linux/cgroup.h
@@ -11,6 +11,7 @@
 
 #include <linux/sched.h>
 #include <linux/nodemask.h>
+#include <linux/list.h>
 #include <linux/rculist.h>
 #include <linux/cgroupstats.h>
 #include <linux/fs.h>
@@ -854,4 +855,6 @@ static inline void cgroup_bpf_put(struct cgroup *cgrp) {}
 
 struct cgroup *task_get_cgroup1(struct task_struct *tsk, int hierarchy_id);
 
+struct cgroup_of_peak *of_peak(struct kernfs_open_file *of);
+
 #endif /* _LINUX_CGROUP_H */
diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index af7da7bd00af9795fc381cff0f81491d32b088a2..1b79760af68529a5d99a5af6f835a943685d2aeb 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -193,6 +193,11 @@ struct mem_cgroup {
 		struct page_counter memsw;	/* v1 only */
 	};
 
+	/* registered local peak watchers */
+	struct list_head memory_peaks;
+	struct list_head swap_peaks;
+	spinlock_t	 peaks_lock;
+
 	/* Range enforcement for interrupt charges */
 	struct work_struct high_work;
 
diff --git a/include/linux/page_counter.h b/include/linux/page_counter.h
index 66ebf9a73158da1729c477d1e70581109273ceb5..79dbd8bc35a72f072fc6ff954551be279ae05e33 100644
--- a/include/linux/page_counter.h
+++ b/include/linux/page_counter.h
@@ -26,6 +26,8 @@ struct page_counter {
 	atomic_long_t children_low_usage;
 
 	unsigned long watermark;
+	/* Latest cg2 reset watermark */
+	unsigned long local_watermark;
 	unsigned long failcnt;
 
 	/* Keep all the read most fields in a separete cacheline. */
@@ -84,7 +86,14 @@ int page_counter_memparse(const char *buf, const char *max,
 
 static inline void page_counter_reset_watermark(struct page_counter *counter)
 {
-	counter->watermark = page_counter_read(counter);
+	unsigned long usage = page_counter_read(counter);
+
+	/*
+	 * Update local_watermark first, so it's always <= watermark
+	 * (modulo CPU/compiler re-ordering)
+	 */
+	counter->local_watermark = usage;
+	counter->watermark = usage;
 }
 
 #ifdef CONFIG_MEMCG
diff --git a/kernel/cgroup/cgroup-internal.h b/kernel/cgroup/cgroup-internal.h
index 520b90dd97ecae50a3244286459be895185e944e..c964dd7ff967a030b5e1aed3c543e82efc20aa0f 100644
--- a/kernel/cgroup/cgroup-internal.h
+++ b/kernel/cgroup/cgroup-internal.h
@@ -81,6 +81,8 @@ struct cgroup_file_ctx {
 	struct {
 		struct cgroup_pidlist	*pidlist;
 	} procs1;
+
+	struct cgroup_of_peak peak;
 };
 
 /*
diff --git a/kernel/cgroup/cgroup.c b/kernel/cgroup/cgroup.c
index c8e4b62b436a48ae74d6242236b4f29a7f43c159..0a97cb2ef12459789a710af2f2fc48e8b3e2172a 100644
--- a/kernel/cgroup/cgroup.c
+++ b/kernel/cgroup/cgroup.c
@@ -1972,6 +1972,13 @@ static int cgroup2_parse_param(struct fs_context *fc, struct fs_parameter *param
 	return -EINVAL;
 }
 
+struct cgroup_of_peak *of_peak(struct kernfs_open_file *of)
+{
+	struct cgroup_file_ctx *ctx = of->priv;
+
+	return &ctx->peak;
+}
+
 static void apply_cgroup_root_flags(unsigned int root_flags)
 {
 	if (current->nsproxy->cgroup_ns == &init_cgroup_ns) {
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index f92578f13b2e87144530a04ccdd4959c1722fed1..8971d3473a7bfc4e100d50088f1395616a935c58 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -25,6 +25,7 @@
  * Copyright (C) 2020 Alibaba, Inc, Alex Shi
  */
 
+#include <linux/cgroup-defs.h>
 #include <linux/page_counter.h>
 #include <linux/memcontrol.h>
 #include <linux/cgroup.h>
@@ -41,6 +42,7 @@
 #include <linux/rcupdate.h>
 #include <linux/limits.h>
 #include <linux/export.h>
+#include <linux/list.h>
 #include <linux/mutex.h>
 #include <linux/rbtree.h>
 #include <linux/slab.h>
@@ -3550,6 +3552,9 @@ static struct mem_cgroup *mem_cgroup_alloc(struct mem_cgroup *parent)
 
 	INIT_WORK(&memcg->high_work, high_work_func);
 	vmpressure_init(&memcg->vmpressure);
+	INIT_LIST_HEAD(&memcg->memory_peaks);
+	INIT_LIST_HEAD(&memcg->swap_peaks);
+	spin_lock_init(&memcg->peaks_lock);
 	memcg->socket_pressure = jiffies;
 	memcg1_memcg_init(memcg);
 	memcg->kmemcg_id = -1;
@@ -3944,14 +3949,91 @@ static u64 memory_current_read(struct cgroup_subsys_state *css,
 	return (u64)page_counter_read(&memcg->memory) * PAGE_SIZE;
 }
 
-static u64 memory_peak_read(struct cgroup_subsys_state *css,
-			    struct cftype *cft)
+#define OFP_PEAK_UNSET (((-1UL)))
+
+static int peak_show(struct seq_file *sf, void *v, struct page_counter *pc)
 {
-	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+	struct cgroup_of_peak *ofp = of_peak(sf->private);
+	u64 fd_peak = READ_ONCE(ofp->value), peak;
+
+	/* User wants global or local peak? */
+	if (fd_peak == OFP_PEAK_UNSET)
+		peak = pc->watermark;
+	else
+		peak = max(fd_peak, READ_ONCE(pc->local_watermark));
+
+	seq_printf(sf, "%llu\n", peak * PAGE_SIZE);
+	return 0;
+}
+
+static int memory_peak_show(struct seq_file *sf, void *v)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(sf));
+
+	return peak_show(sf, v, &memcg->memory);
+}
+
+static int peak_open(struct kernfs_open_file *of)
+{
+	struct cgroup_of_peak *ofp = of_peak(of);
+
+	ofp->value = OFP_PEAK_UNSET;
+	return 0;
+}
+
+static void peak_release(struct kernfs_open_file *of)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+	struct cgroup_of_peak *ofp = of_peak(of);
+
+	if (ofp->value == OFP_PEAK_UNSET) {
+		/* fast path (no writes on this fd) */
+		return;
+	}
+	spin_lock(&memcg->peaks_lock);
+	list_del(&ofp->list);
+	spin_unlock(&memcg->peaks_lock);
+}
+
+static ssize_t peak_write(struct kernfs_open_file *of, char *buf, size_t nbytes,
+			  loff_t off, struct page_counter *pc,
+			  struct list_head *watchers)
+{
+	unsigned long usage;
+	struct cgroup_of_peak *peer_ctx;
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+	struct cgroup_of_peak *ofp = of_peak(of);
+
+	spin_lock(&memcg->peaks_lock);
+
+	usage = page_counter_read(pc);
+	WRITE_ONCE(pc->local_watermark, usage);
+
+	list_for_each_entry(peer_ctx, watchers, list)
+		if (usage > peer_ctx->value)
+			WRITE_ONCE(peer_ctx->value, usage);
+
+	/* initial write, register watcher */
+	if (ofp->value == -1)
+		list_add(&ofp->list, watchers);
+
+	WRITE_ONCE(ofp->value, usage);
+	spin_unlock(&memcg->peaks_lock);
+
+	return nbytes;
+}
+
+static ssize_t memory_peak_write(struct kernfs_open_file *of, char *buf,
+				 size_t nbytes, loff_t off)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
 
-	return (u64)memcg->memory.watermark * PAGE_SIZE;
+	return peak_write(of, buf, nbytes, off, &memcg->memory,
+			  &memcg->memory_peaks);
 }
 
+#undef OFP_PEAK_UNSET
+
 static int memory_min_show(struct seq_file *m, void *v)
 {
 	return seq_puts_memcg_tunable(m,
@@ -4301,7 +4383,10 @@ static struct cftype memory_files[] = {
 	{
 		.name = "peak",
 		.flags = CFTYPE_NOT_ON_ROOT,
-		.read_u64 = memory_peak_read,
+		.open = peak_open,
+		.release = peak_release,
+		.seq_show = memory_peak_show,
+		.write = memory_peak_write,
 	},
 	{
 		.name = "min",
@@ -5093,12 +5178,20 @@ static u64 swap_current_read(struct cgroup_subsys_state *css,
 	return (u64)page_counter_read(&memcg->swap) * PAGE_SIZE;
 }
 
-static u64 swap_peak_read(struct cgroup_subsys_state *css,
-			  struct cftype *cft)
+static int swap_peak_show(struct seq_file *sf, void *v)
 {
-	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+	struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(sf));
+
+	return peak_show(sf, v, &memcg->swap);
+}
+
+static ssize_t swap_peak_write(struct kernfs_open_file *of, char *buf,
+			       size_t nbytes, loff_t off)
+{
+	struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
 
-	return (u64)memcg->swap.watermark * PAGE_SIZE;
+	return peak_write(of, buf, nbytes, off, &memcg->swap,
+			  &memcg->swap_peaks);
 }
 
 static int swap_high_show(struct seq_file *m, void *v)
@@ -5182,7 +5275,10 @@ static struct cftype swap_files[] = {
 	{
 		.name = "swap.peak",
 		.flags = CFTYPE_NOT_ON_ROOT,
-		.read_u64 = swap_peak_read,
+		.open = peak_open,
+		.release = peak_release,
+		.seq_show = swap_peak_show,
+		.write = swap_peak_write,
 	},
 	{
 		.name = "swap.events",
diff --git a/mm/page_counter.c b/mm/page_counter.c
index 3887bd152b756c8bddd8291c1144fdd04f9a168e..b249d15af9dd80f8a7ba8ff8521e745e1e949b70 100644
--- a/mm/page_counter.c
+++ b/mm/page_counter.c
@@ -87,9 +87,22 @@ void page_counter_charge(struct page_counter *counter, unsigned long nr_pages)
 		/*
 		 * This is indeed racy, but we can live with some
 		 * inaccuracy in the watermark.
+		 *
+		 * Notably, we have two watermarks to allow for both a globally
+		 * visible peak and one that can be reset at a smaller scope.
+		 *
+		 * Since we reset both watermarks when the global reset occurs,
+		 * we can guarantee that watermark >= local_watermark, so we
+		 * don't need to do both comparisons every time.
+		 *
+		 * On systems with branch predictors, the inner condition should
+		 * be almost free.
 		 */
-		if (new > READ_ONCE(c->watermark))
-			WRITE_ONCE(c->watermark, new);
+		if (new > READ_ONCE(c->local_watermark)) {
+			WRITE_ONCE(c->local_watermark, new);
+			if (new > READ_ONCE(c->watermark))
+				WRITE_ONCE(c->watermark, new);
+		}
 	}
 }
 
@@ -140,12 +153,12 @@ bool page_counter_try_charge(struct page_counter *counter,
 		if (protection)
 			propagate_protected_usage(c, new);
 
-		/*
-		 * Just like with failcnt, we can live with some
-		 * inaccuracy in the watermark.
-		 */
-		if (new > READ_ONCE(c->watermark))
-			WRITE_ONCE(c->watermark, new);
+		/* see comment on page_counter_charge */
+		if (new > READ_ONCE(c->local_watermark)) {
+			WRITE_ONCE(c->local_watermark, new);
+			if (new > READ_ONCE(c->watermark))
+				WRITE_ONCE(c->watermark, new);
+		}
 	}
 	return true;