diff --git a/kernel/cpuset.c b/kernel/cpuset.c
index c7fd2778ed50edc52b7c0e5f1f651dcb141c9495..c27e53326befe9f33ffc605a75a48cf8e9ae6624 100644
--- a/kernel/cpuset.c
+++ b/kernel/cpuset.c
@@ -2069,6 +2069,20 @@ static void cpuset_bind(struct cgroup_subsys_state *root_css)
 	mutex_unlock(&cpuset_mutex);
 }
 
+/*
+ * Make sure the new task conform to the current state of its parent,
+ * which could have been changed by cpuset just after it inherits the
+ * state from the parent and before it sits on the cgroup's task list.
+ */
+void cpuset_fork(struct task_struct *task)
+{
+	if (task_css_is_root(task, cpuset_cgrp_id))
+		return;
+
+	set_cpus_allowed_ptr(task, &current->cpus_allowed);
+	task->mems_allowed = current->mems_allowed;
+}
+
 struct cgroup_subsys cpuset_cgrp_subsys = {
 	.css_alloc	= cpuset_css_alloc,
 	.css_online	= cpuset_css_online,
@@ -2079,6 +2093,7 @@ struct cgroup_subsys cpuset_cgrp_subsys = {
 	.attach		= cpuset_attach,
 	.post_attach	= cpuset_post_attach,
 	.bind		= cpuset_bind,
+	.fork		= cpuset_fork,
 	.legacy_cftypes	= files,
 	.early_init	= true,
 };
diff --git a/kernel/fork.c b/kernel/fork.c
index 52e725d4a866b4ac30f16b4db075b9b197e4e53d..aaf782327bf330e22c4ef2ca1881c35f4970baf4 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1404,7 +1404,6 @@ static struct task_struct *copy_process(unsigned long clone_flags,
 	p->real_start_time = ktime_get_boot_ns();
 	p->io_context = NULL;
 	p->audit_context = NULL;
-	threadgroup_change_begin(current);
 	cgroup_fork(p);
 #ifdef CONFIG_NUMA
 	p->mempolicy = mpol_dup(p->mempolicy);
@@ -1556,6 +1555,7 @@ static struct task_struct *copy_process(unsigned long clone_flags,
 	INIT_LIST_HEAD(&p->thread_group);
 	p->task_works = NULL;
 
+	threadgroup_change_begin(current);
 	/*
 	 * Ensure that the cgroup subsystem policies allow the new process to be
 	 * forked. It should be noted the the new process's css_set can be changed
@@ -1656,6 +1656,7 @@ static struct task_struct *copy_process(unsigned long clone_flags,
 bad_fork_cancel_cgroup:
 	cgroup_cancel_fork(p);
 bad_fork_free_pid:
+	threadgroup_change_end(current);
 	if (pid != &init_struct_pid)
 		free_pid(pid);
 bad_fork_cleanup_thread:
@@ -1688,7 +1689,6 @@ static struct task_struct *copy_process(unsigned long clone_flags,
 	mpol_put(p->mempolicy);
 bad_fork_cleanup_threadgroup_lock:
 #endif
-	threadgroup_change_end(current);
 	delayacct_tsk_free(p);
 bad_fork_cleanup_count:
 	atomic_dec(&p->cred->user->processes);