diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index a263918043cebdb91041519bf1d873b28495e51a..796506dcfc42e86484a8eb27b69ce8a14c2776c0 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -281,7 +281,8 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
 	EMIT1(0x55);             /* push rbp */
 	EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
 	/* sub rsp, rounded_stack_depth */
-	EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
 	if (tail_call_reachable)
 		EMIT1(0x50);         /* push rax */
 	*pprog = prog;
@@ -407,9 +408,9 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
 	int tcc_off = -4 - round_up(stack_depth, 8);
 	u8 *prog = *pprog;
 	int pop_bytes = 0;
-	int off1 = 49;
-	int off2 = 38;
-	int off3 = 16;
+	int off1 = 42;
+	int off2 = 31;
+	int off3 = 9;
 	int cnt = 0;
 
 	/* count the additional bytes used for popping callee regs from stack
@@ -421,6 +422,12 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
 	off2 += pop_bytes;
 	off3 += pop_bytes;
 
+	if (stack_depth) {
+		off1 += 7;
+		off2 += 7;
+		off3 += 7;
+	}
+
 	/*
 	 * rdi - pointer to ctx
 	 * rsi - pointer to bpf_array
@@ -465,8 +472,9 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
 	prog = *pprog;
 
 	EMIT1(0x58);                              /* pop rax */
-	EMIT3_off32(0x48, 0x81, 0xC4,             /* add rsp, sd */
-		    round_up(stack_depth, 8));
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xC4,     /* add rsp, sd */
+			    round_up(stack_depth, 8));
 
 	/* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */
 	EMIT4(0x48, 0x8B, 0x49,                   /* mov rcx, qword ptr [rcx + 32] */
@@ -491,7 +499,7 @@ static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
 	int tcc_off = -4 - round_up(stack_depth, 8);
 	u8 *prog = *pprog;
 	int pop_bytes = 0;
-	int off1 = 27;
+	int off1 = 20;
 	int poke_off;
 	int cnt = 0;
 
@@ -506,10 +514,14 @@ static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
 	 * total bytes for:
 	 * - nop5/ jmpq $off
 	 * - pop callee regs
-	 * - sub rsp, $val
+	 * - sub rsp, $val if depth > 0
 	 * - pop rax
 	 */
-	poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1;
+	poke_off = X86_PATCH_SIZE + pop_bytes + 1;
+	if (stack_depth) {
+		poke_off += 7;
+		off1 += 7;
+	}
 
 	/*
 	 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
@@ -533,7 +545,8 @@ static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
 	pop_callee_regs(pprog, callee_regs_used);
 	prog = *pprog;
 	EMIT1(0x58);                                  /* pop rax */
-	EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
+	if (stack_depth)
+		EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
 
 	memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
 	prog += X86_PATCH_SIZE;