diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index aeb59551446193e715c76c80bcf7b701d2451514..a910a5b4c274db7d66108137d29ad6984fc842eb 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -5914,12 +5914,15 @@ static void prev_balance(struct rq *rq, struct task_struct *prev,
 
 #ifdef CONFIG_SCHED_CLASS_EXT
 	/*
-	 * SCX requires a balance() call before every pick_next_task() including
-	 * when waking up from SCHED_IDLE. If @start_class is below SCX, start
-	 * from SCX instead.
+	 * SCX requires a balance() call before every pick_task() including when
+	 * waking up from SCHED_IDLE. If @start_class is below SCX, start from
+	 * SCX instead. Also, set a flag to detect missing balance() call.
 	 */
-	if (scx_enabled() && sched_class_above(&ext_sched_class, start_class))
-		start_class = &ext_sched_class;
+	if (scx_enabled()) {
+		rq->scx.flags |= SCX_RQ_BAL_PENDING;
+		if (sched_class_above(&ext_sched_class, start_class))
+			start_class = &ext_sched_class;
+	}
 #endif
 
 	/*
diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c
index 3bdb08fc2056aa7341442831b4d7fd0f0055a5dc..19f9cb3a4190a1864faa6b897d5f0ab8cdef4134 100644
--- a/kernel/sched/ext.c
+++ b/kernel/sched/ext.c
@@ -2634,7 +2634,7 @@ static int balance_one(struct rq *rq, struct task_struct *prev)
 
 	lockdep_assert_rq_held(rq);
 	rq->scx.flags |= SCX_RQ_IN_BALANCE;
-	rq->scx.flags &= ~SCX_RQ_BAL_KEEP;
+	rq->scx.flags &= ~(SCX_RQ_BAL_PENDING | SCX_RQ_BAL_KEEP);
 
 	if (static_branch_unlikely(&scx_ops_cpu_preempt) &&
 	    unlikely(rq->scx.cpu_released)) {
@@ -2948,12 +2948,11 @@ static struct task_struct *pick_task_scx(struct rq *rq)
 {
 	struct task_struct *prev = rq->curr;
 	struct task_struct *p;
+	bool prev_on_scx = prev->sched_class == &ext_sched_class;
+	bool keep_prev = rq->scx.flags & SCX_RQ_BAL_KEEP;
+	bool kick_idle = false;
 
 	/*
-	 * If balance_scx() is telling us to keep running @prev, replenish slice
-	 * if necessary and keep running @prev. Otherwise, pop the first one
-	 * from the local DSQ.
-	 *
 	 * WORKAROUND:
 	 *
 	 * %SCX_RQ_BAL_KEEP should be set iff $prev is on SCX as it must just
@@ -2962,22 +2961,41 @@ static struct task_struct *pick_task_scx(struct rq *rq)
 	 * which then ends up calling pick_task_scx() without preceding
 	 * balance_scx().
 	 *
-	 * For now, ignore cases where $prev is not on SCX. This isn't great and
-	 * can theoretically lead to stalls. However, for switch_all cases, this
-	 * happens only while a BPF scheduler is being loaded or unloaded, and,
-	 * for partial cases, fair will likely keep triggering this CPU.
+	 * Keep running @prev if possible and avoid stalling from entering idle
+	 * without balancing.
 	 *
-	 * Once fair is fixed, restore WARN_ON_ONCE().
+	 * Once fair is fixed, remove the workaround and trigger WARN_ON_ONCE()
+	 * if pick_task_scx() is called without preceding balance_scx().
 	 */
-	if ((rq->scx.flags & SCX_RQ_BAL_KEEP) &&
-	    prev->sched_class == &ext_sched_class) {
+	if (unlikely(rq->scx.flags & SCX_RQ_BAL_PENDING)) {
+		if (prev_on_scx) {
+			keep_prev = true;
+		} else {
+			keep_prev = false;
+			kick_idle = true;
+		}
+	} else if (unlikely(keep_prev && !prev_on_scx)) {
+		/* only allowed during transitions */
+		WARN_ON_ONCE(scx_ops_enable_state() == SCX_OPS_ENABLED);
+		keep_prev = false;
+	}
+
+	/*
+	 * If balance_scx() is telling us to keep running @prev, replenish slice
+	 * if necessary and keep running @prev. Otherwise, pop the first one
+	 * from the local DSQ.
+	 */
+	if (keep_prev) {
 		p = prev;
 		if (!p->scx.slice)
 			p->scx.slice = SCX_SLICE_DFL;
 	} else {
 		p = first_local_task(rq);
-		if (!p)
+		if (!p) {
+			if (kick_idle)
+				scx_bpf_kick_cpu(cpu_of(rq), SCX_KICK_IDLE);
 			return NULL;
+		}
 
 		if (unlikely(!p->scx.slice)) {
 			if (!scx_rq_bypassing(rq) && !scx_warned_zero_slice) {
diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index 6085ef50febfafaa4db8587fab05f64dba5e4299..4d79804631e4978e7ad1b374d18d142345e8665c 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -751,8 +751,9 @@ enum scx_rq_flags {
 	 */
 	SCX_RQ_ONLINE		= 1 << 0,
 	SCX_RQ_CAN_STOP_TICK	= 1 << 1,
-	SCX_RQ_BAL_KEEP		= 1 << 2, /* balance decided to keep current */
-	SCX_RQ_BYPASSING	= 1 << 3,
+	SCX_RQ_BAL_PENDING	= 1 << 2, /* balance hasn't run yet */
+	SCX_RQ_BAL_KEEP		= 1 << 3, /* balance decided to keep current */
+	SCX_RQ_BYPASSING	= 1 << 4,
 
 	SCX_RQ_IN_WAKEUP	= 1 << 16,
 	SCX_RQ_IN_BALANCE	= 1 << 17,