diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 19df5d857411a7cc567fa7571985bfdcd65c28b0..6b75640ef5ab553a4e0292940bacc466db1a6437 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -139,48 +139,23 @@ static inline bool mem_cgroup_disabled(void)
 	return false;
 }
 
-void __mem_cgroup_begin_update_page_stat(struct page *page, bool *locked,
-					 unsigned long *flags);
-
-extern atomic_t memcg_moving;
-
-static inline void mem_cgroup_begin_update_page_stat(struct page *page,
-					bool *locked, unsigned long *flags)
-{
-	if (mem_cgroup_disabled())
-		return;
-	rcu_read_lock();
-	*locked = false;
-	if (atomic_read(&memcg_moving))
-		__mem_cgroup_begin_update_page_stat(page, locked, flags);
-}
-
-void __mem_cgroup_end_update_page_stat(struct page *page,
-				unsigned long *flags);
-static inline void mem_cgroup_end_update_page_stat(struct page *page,
-					bool *locked, unsigned long *flags)
-{
-	if (mem_cgroup_disabled())
-		return;
-	if (*locked)
-		__mem_cgroup_end_update_page_stat(page, flags);
-	rcu_read_unlock();
-}
-
-void mem_cgroup_update_page_stat(struct page *page,
-				 enum mem_cgroup_stat_index idx,
-				 int val);
-
-static inline void mem_cgroup_inc_page_stat(struct page *page,
+struct mem_cgroup *mem_cgroup_begin_page_stat(struct page *page, bool *locked,
+					      unsigned long *flags);
+void mem_cgroup_end_page_stat(struct mem_cgroup *memcg, bool locked,
+			      unsigned long flags);
+void mem_cgroup_update_page_stat(struct mem_cgroup *memcg,
+				 enum mem_cgroup_stat_index idx, int val);
+
+static inline void mem_cgroup_inc_page_stat(struct mem_cgroup *memcg,
 					    enum mem_cgroup_stat_index idx)
 {
-	mem_cgroup_update_page_stat(page, idx, 1);
+	mem_cgroup_update_page_stat(memcg, idx, 1);
 }
 
-static inline void mem_cgroup_dec_page_stat(struct page *page,
+static inline void mem_cgroup_dec_page_stat(struct mem_cgroup *memcg,
 					    enum mem_cgroup_stat_index idx)
 {
-	mem_cgroup_update_page_stat(page, idx, -1);
+	mem_cgroup_update_page_stat(memcg, idx, -1);
 }
 
 unsigned long mem_cgroup_soft_limit_reclaim(struct zone *zone, int order,
@@ -315,13 +290,14 @@ mem_cgroup_print_oom_info(struct mem_cgroup *memcg, struct task_struct *p)
 {
 }
 
-static inline void mem_cgroup_begin_update_page_stat(struct page *page,
+static inline struct mem_cgroup *mem_cgroup_begin_page_stat(struct page *page,
 					bool *locked, unsigned long *flags)
 {
+	return NULL;
 }
 
-static inline void mem_cgroup_end_update_page_stat(struct page *page,
-					bool *locked, unsigned long *flags)
+static inline void mem_cgroup_end_page_stat(struct mem_cgroup *memcg,
+					bool locked, unsigned long flags)
 {
 }
 
@@ -343,12 +319,12 @@ static inline bool mem_cgroup_oom_synchronize(bool wait)
 	return false;
 }
 
-static inline void mem_cgroup_inc_page_stat(struct page *page,
+static inline void mem_cgroup_inc_page_stat(struct mem_cgroup *memcg,
 					    enum mem_cgroup_stat_index idx)
 {
 }
 
-static inline void mem_cgroup_dec_page_stat(struct page *page,
+static inline void mem_cgroup_dec_page_stat(struct mem_cgroup *memcg,
 					    enum mem_cgroup_stat_index idx)
 {
 }
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 23976fd885fd588a7687bf5631f98ea38f2235b4..d6ac0e33e150339cba37ecb3db98be3fecd3702a 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -1536,12 +1536,8 @@ int mem_cgroup_swappiness(struct mem_cgroup *memcg)
  *         start move here.
  */
 
-/* for quick checking without looking up memcg */
-atomic_t memcg_moving __read_mostly;
-
 static void mem_cgroup_start_move(struct mem_cgroup *memcg)
 {
-	atomic_inc(&memcg_moving);
 	atomic_inc(&memcg->moving_account);
 	synchronize_rcu();
 }
@@ -1552,10 +1548,8 @@ static void mem_cgroup_end_move(struct mem_cgroup *memcg)
 	 * Now, mem_cgroup_clear_mc() may call this function with NULL.
 	 * We check NULL in callee rather than caller.
 	 */
-	if (memcg) {
-		atomic_dec(&memcg_moving);
+	if (memcg)
 		atomic_dec(&memcg->moving_account);
-	}
 }
 
 /*
@@ -2204,41 +2198,52 @@ bool mem_cgroup_oom_synchronize(bool handle)
 	return true;
 }
 
-/*
- * Used to update mapped file or writeback or other statistics.
+/**
+ * mem_cgroup_begin_page_stat - begin a page state statistics transaction
+ * @page: page that is going to change accounted state
+ * @locked: &memcg->move_lock slowpath was taken
+ * @flags: IRQ-state flags for &memcg->move_lock
  *
- * Notes: Race condition
+ * This function must mark the beginning of an accounted page state
+ * change to prevent double accounting when the page is concurrently
+ * being moved to another memcg:
  *
- * Charging occurs during page instantiation, while the page is
- * unmapped and locked in page migration, or while the page table is
- * locked in THP migration.  No race is possible.
+ *   memcg = mem_cgroup_begin_page_stat(page, &locked, &flags);
+ *   if (TestClearPageState(page))
+ *     mem_cgroup_update_page_stat(memcg, state, -1);
+ *   mem_cgroup_end_page_stat(memcg, locked, flags);
  *
- * Uncharge happens to pages with zero references, no race possible.
+ * The RCU lock is held throughout the transaction.  The fast path can
+ * get away without acquiring the memcg->move_lock (@locked is false)
+ * because page moving starts with an RCU grace period.
  *
- * Charge moving between groups is protected by checking mm->moving
- * account and taking the move_lock in the slowpath.
+ * The RCU lock also protects the memcg from being freed when the page
+ * state that is going to change is the only thing preventing the page
+ * from being uncharged.  E.g. end-writeback clearing PageWriteback(),
+ * which allows migration to go ahead and uncharge the page before the
+ * account transaction might be complete.
  */
-
-void __mem_cgroup_begin_update_page_stat(struct page *page,
-				bool *locked, unsigned long *flags)
+struct mem_cgroup *mem_cgroup_begin_page_stat(struct page *page,
+					      bool *locked,
+					      unsigned long *flags)
 {
 	struct mem_cgroup *memcg;
 	struct page_cgroup *pc;
 
+	rcu_read_lock();
+
+	if (mem_cgroup_disabled())
+		return NULL;
+
 	pc = lookup_page_cgroup(page);
 again:
 	memcg = pc->mem_cgroup;
 	if (unlikely(!memcg || !PageCgroupUsed(pc)))
-		return;
-	/*
-	 * If this memory cgroup is not under account moving, we don't
-	 * need to take move_lock_mem_cgroup(). Because we already hold
-	 * rcu_read_lock(), any calls to move_account will be delayed until
-	 * rcu_read_unlock().
-	 */
-	VM_BUG_ON(!rcu_read_lock_held());
+		return NULL;
+
+	*locked = false;
 	if (atomic_read(&memcg->moving_account) <= 0)
-		return;
+		return memcg;
 
 	move_lock_mem_cgroup(memcg, flags);
 	if (memcg != pc->mem_cgroup || !PageCgroupUsed(pc)) {
@@ -2246,36 +2251,40 @@ void __mem_cgroup_begin_update_page_stat(struct page *page,
 		goto again;
 	}
 	*locked = true;
+
+	return memcg;
 }
 
-void __mem_cgroup_end_update_page_stat(struct page *page, unsigned long *flags)
+/**
+ * mem_cgroup_end_page_stat - finish a page state statistics transaction
+ * @memcg: the memcg that was accounted against
+ * @locked: value received from mem_cgroup_begin_page_stat()
+ * @flags: value received from mem_cgroup_begin_page_stat()
+ */
+void mem_cgroup_end_page_stat(struct mem_cgroup *memcg, bool locked,
+			      unsigned long flags)
 {
-	struct page_cgroup *pc = lookup_page_cgroup(page);
+	if (memcg && locked)
+		move_unlock_mem_cgroup(memcg, &flags);
 
-	/*
-	 * It's guaranteed that pc->mem_cgroup never changes while
-	 * lock is held because a routine modifies pc->mem_cgroup
-	 * should take move_lock_mem_cgroup().
-	 */
-	move_unlock_mem_cgroup(pc->mem_cgroup, flags);
+	rcu_read_unlock();
 }
 
-void mem_cgroup_update_page_stat(struct page *page,
+/**
+ * mem_cgroup_update_page_stat - update page state statistics
+ * @memcg: memcg to account against
+ * @idx: page state item to account
+ * @val: number of pages (positive or negative)
+ *
+ * See mem_cgroup_begin_page_stat() for locking requirements.
+ */
+void mem_cgroup_update_page_stat(struct mem_cgroup *memcg,
 				 enum mem_cgroup_stat_index idx, int val)
 {
-	struct mem_cgroup *memcg;
-	struct page_cgroup *pc = lookup_page_cgroup(page);
-	unsigned long uninitialized_var(flags);
-
-	if (mem_cgroup_disabled())
-		return;
-
 	VM_BUG_ON(!rcu_read_lock_held());
-	memcg = pc->mem_cgroup;
-	if (unlikely(!memcg || !PageCgroupUsed(pc)))
-		return;
 
-	this_cpu_add(memcg->stat->count[idx], val);
+	if (memcg)
+		this_cpu_add(memcg->stat->count[idx], val);
 }
 
 /*
diff --git a/mm/page-writeback.c b/mm/page-writeback.c
index ff6a5b07211e625735c4806c4c7a7455e6872d3b..19ceae87522d9a87ea23457e4ae9eec50882b194 100644
--- a/mm/page-writeback.c
+++ b/mm/page-writeback.c
@@ -2327,11 +2327,12 @@ EXPORT_SYMBOL(clear_page_dirty_for_io);
 int test_clear_page_writeback(struct page *page)
 {
 	struct address_space *mapping = page_mapping(page);
-	int ret;
-	bool locked;
 	unsigned long memcg_flags;
+	struct mem_cgroup *memcg;
+	bool locked;
+	int ret;
 
-	mem_cgroup_begin_update_page_stat(page, &locked, &memcg_flags);
+	memcg = mem_cgroup_begin_page_stat(page, &locked, &memcg_flags);
 	if (mapping) {
 		struct backing_dev_info *bdi = mapping->backing_dev_info;
 		unsigned long flags;
@@ -2352,22 +2353,23 @@ int test_clear_page_writeback(struct page *page)
 		ret = TestClearPageWriteback(page);
 	}
 	if (ret) {
-		mem_cgroup_dec_page_stat(page, MEM_CGROUP_STAT_WRITEBACK);
+		mem_cgroup_dec_page_stat(memcg, MEM_CGROUP_STAT_WRITEBACK);
 		dec_zone_page_state(page, NR_WRITEBACK);
 		inc_zone_page_state(page, NR_WRITTEN);
 	}
-	mem_cgroup_end_update_page_stat(page, &locked, &memcg_flags);
+	mem_cgroup_end_page_stat(memcg, locked, memcg_flags);
 	return ret;
 }
 
 int __test_set_page_writeback(struct page *page, bool keep_write)
 {
 	struct address_space *mapping = page_mapping(page);
-	int ret;
-	bool locked;
 	unsigned long memcg_flags;
+	struct mem_cgroup *memcg;
+	bool locked;
+	int ret;
 
-	mem_cgroup_begin_update_page_stat(page, &locked, &memcg_flags);
+	memcg = mem_cgroup_begin_page_stat(page, &locked, &memcg_flags);
 	if (mapping) {
 		struct backing_dev_info *bdi = mapping->backing_dev_info;
 		unsigned long flags;
@@ -2394,10 +2396,10 @@ int __test_set_page_writeback(struct page *page, bool keep_write)
 		ret = TestSetPageWriteback(page);
 	}
 	if (!ret) {
-		mem_cgroup_inc_page_stat(page, MEM_CGROUP_STAT_WRITEBACK);
+		mem_cgroup_inc_page_stat(memcg, MEM_CGROUP_STAT_WRITEBACK);
 		inc_zone_page_state(page, NR_WRITEBACK);
 	}
-	mem_cgroup_end_update_page_stat(page, &locked, &memcg_flags);
+	mem_cgroup_end_page_stat(memcg, locked, memcg_flags);
 	return ret;
 
 }
diff --git a/mm/rmap.c b/mm/rmap.c
index 116a5053415b8dbc41c39622316e6b8b417a3727..f574046f77d461685722c2ef027126f7689c1f37 100644
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -1042,15 +1042,16 @@ void page_add_new_anon_rmap(struct page *page,
  */
 void page_add_file_rmap(struct page *page)
 {
-	bool locked;
+	struct mem_cgroup *memcg;
 	unsigned long flags;
+	bool locked;
 
-	mem_cgroup_begin_update_page_stat(page, &locked, &flags);
+	memcg = mem_cgroup_begin_page_stat(page, &locked, &flags);
 	if (atomic_inc_and_test(&page->_mapcount)) {
 		__inc_zone_page_state(page, NR_FILE_MAPPED);
-		mem_cgroup_inc_page_stat(page, MEM_CGROUP_STAT_FILE_MAPPED);
+		mem_cgroup_inc_page_stat(memcg, MEM_CGROUP_STAT_FILE_MAPPED);
 	}
-	mem_cgroup_end_update_page_stat(page, &locked, &flags);
+	mem_cgroup_end_page_stat(memcg, locked, flags);
 }
 
 /**
@@ -1061,9 +1062,10 @@ void page_add_file_rmap(struct page *page)
  */
 void page_remove_rmap(struct page *page)
 {
+	struct mem_cgroup *uninitialized_var(memcg);
 	bool anon = PageAnon(page);
-	bool locked;
 	unsigned long flags;
+	bool locked;
 
 	/*
 	 * The anon case has no mem_cgroup page_stat to update; but may
@@ -1071,7 +1073,7 @@ void page_remove_rmap(struct page *page)
 	 * we hold the lock against page_stat move: so avoid it on anon.
 	 */
 	if (!anon)
-		mem_cgroup_begin_update_page_stat(page, &locked, &flags);
+		memcg = mem_cgroup_begin_page_stat(page, &locked, &flags);
 
 	/* page still mapped by someone else? */
 	if (!atomic_add_negative(-1, &page->_mapcount))
@@ -1096,8 +1098,7 @@ void page_remove_rmap(struct page *page)
 				-hpage_nr_pages(page));
 	} else {
 		__dec_zone_page_state(page, NR_FILE_MAPPED);
-		mem_cgroup_dec_page_stat(page, MEM_CGROUP_STAT_FILE_MAPPED);
-		mem_cgroup_end_update_page_stat(page, &locked, &flags);
+		mem_cgroup_dec_page_stat(memcg, MEM_CGROUP_STAT_FILE_MAPPED);
 	}
 	if (unlikely(PageMlocked(page)))
 		clear_page_mlock(page);
@@ -1110,10 +1111,9 @@ void page_remove_rmap(struct page *page)
 	 * Leaving it set also helps swapoff to reinstate ptes
 	 * faster for those pages still in swapcache.
 	 */
-	return;
 out:
 	if (!anon)
-		mem_cgroup_end_update_page_stat(page, &locked, &flags);
+		mem_cgroup_end_page_stat(memcg, locked, flags);
 }
 
 /*