diff options
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 237 | ||||
-rw-r--r-- | include/linux/bpf.h | 3 | ||||
-rw-r--r-- | include/linux/bpf_verifier.h | 1 | ||||
-rw-r--r-- | kernel/bpf/arraymap.c | 40 | ||||
-rw-r--r-- | kernel/bpf/core.c | 2 | ||||
-rw-r--r-- | kernel/bpf/verifier.c | 16 |
6 files changed, 244 insertions, 55 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 7b0ff169c9a0..26f43279b78b 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -221,14 +221,48 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 +/* Number of bytes that will be skipped on tailcall */ +#define X86_TAIL_CALL_OFFSET 11 -#define PROLOGUE_SIZE 25 +static void push_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + int cnt = 0; + + if (callee_regs_used[0]) + EMIT1(0x53); /* push rbx */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x55); /* push r13 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x56); /* push r14 */ + if (callee_regs_used[3]) + EMIT2(0x41, 0x57); /* push r15 */ + *pprog = prog; +} + +static void pop_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + int cnt = 0; + + if (callee_regs_used[3]) + EMIT2(0x41, 0x5F); /* pop r15 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x5E); /* pop r14 */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x5D); /* pop r13 */ + if (callee_regs_used[0]) + EMIT1(0x5B); /* pop rbx */ + *pprog = prog; +} /* - * Emit x86-64 prologue code for BPF program and check its size. - * bpf_tail_call helper will skip it while jumping into another program + * Emit x86-64 prologue code for BPF program. + * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes + * while jumping to another program */ -static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) +static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, + bool tail_call_reachable, bool is_subprog) { u8 *prog = *pprog; int cnt = X86_PATCH_SIZE; @@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) */ memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt); prog += cnt; + if (!ebpf_from_cbpf) { + if (tail_call_reachable && !is_subprog) + EMIT2(0x31, 0xC0); /* xor eax, eax */ + else + EMIT2(0x66, 0x90); /* nop2 */ + } EMIT1(0x55); /* push rbp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ /* sub rsp, rounded_stack_depth */ EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - EMIT1(0x53); /* push rbx */ - EMIT2(0x41, 0x55); /* push r13 */ - EMIT2(0x41, 0x56); /* push r14 */ - EMIT2(0x41, 0x57); /* push r15 */ - if (!ebpf_from_cbpf) { - /* zero init tail_call_cnt */ - EMIT2(0x6a, 0x00); - BUILD_BUG_ON(cnt != PROLOGUE_SIZE); - } + if (tail_call_reachable) + EMIT1(0x50); /* push rax */ *pprog = prog; } @@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, mutex_lock(&text_mutex); if (memcmp(ip, old_insn, X86_PATCH_SIZE)) goto out; + ret = 1; if (memcmp(ip, new_insn, X86_PATCH_SIZE)) { if (text_live) text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); else memcpy(ip, new_insn, X86_PATCH_SIZE); + ret = 0; } - ret = 0; out: mutex_unlock(&text_mutex); return ret; @@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true); } +static int get_pop_bytes(bool *callee_regs_used) +{ + int bytes = 0; + + if (callee_regs_used[3]) + bytes += 2; + if (callee_regs_used[2]) + bytes += 2; + if (callee_regs_used[1]) + bytes += 2; + if (callee_regs_used[0]) + bytes += 1; + + return bytes; +} + /* * Generate the following code: * @@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, * goto *(prog->bpf_func + prologue_size); * out: */ -static void emit_bpf_tail_call_indirect(u8 **pprog) +static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used, + u32 stack_depth) { + int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog; - int label1, label2, label3; + int pop_bytes = 0; + int off1 = 49; + int off2 = 38; + int off3 = 16; int cnt = 0; + /* count the additional bytes used for popping callee regs from stack + * that need to be taken into account for each of the offsets that + * are used for bailing out of the tail call + */ + pop_bytes = get_pop_bytes(callee_regs_used); + off1 += pop_bytes; + off2 += pop_bytes; + off3 += pop_bytes; + /* * rdi - pointer to ctx * rsi - pointer to bpf_array @@ -370,21 +434,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) EMIT2(0x89, 0xD2); /* mov edx, edx */ EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ offsetof(struct bpf_array, map.max_entries)); -#define OFFSET1 (41 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */ +#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */ EMIT2(X86_JBE, OFFSET1); /* jbe out */ - label1 = cnt; /* * if (tail_call_cnt > MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ -#define OFFSET2 (30 + RETPOLINE_RCX_BPF_JIT_SIZE) +#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE) EMIT2(X86_JA, OFFSET2); /* ja out */ - label2 = cnt; EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -394,48 +456,84 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) * if (prog == NULL) * goto out; */ - EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */ -#define OFFSET3 (8 + RETPOLINE_RCX_BPF_JIT_SIZE) + EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */ +#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE) EMIT2(X86_JE, OFFSET3); /* je out */ - label3 = cnt; - /* goto *(prog->bpf_func + prologue_size); */ + *pprog = prog; + pop_callee_regs(pprog, callee_regs_used); + prog = *pprog; + + EMIT1(0x58); /* pop rax */ + EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ + round_up(stack_depth, 8)); + + /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */ EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */ offsetof(struct bpf_prog, bpf_func)); - EMIT4(0x48, 0x83, 0xC1, PROLOGUE_SIZE); /* add rcx, prologue_size */ - + EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */ + X86_TAIL_CALL_OFFSET); /* * Now we're ready to jump into next BPF program * rdi == ctx (1st arg) - * rcx == prog->bpf_func + prologue_size + * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET */ RETPOLINE_RCX_BPF_JIT(); /* out: */ - BUILD_BUG_ON(cnt - label1 != OFFSET1); - BUILD_BUG_ON(cnt - label2 != OFFSET2); - BUILD_BUG_ON(cnt - label3 != OFFSET3); *pprog = prog; } static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke, - u8 **pprog, int addr, u8 *image) + u8 **pprog, int addr, u8 *image, + bool *callee_regs_used, u32 stack_depth) { + int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog; + int pop_bytes = 0; + int off1 = 27; + int poke_off; int cnt = 0; + /* count the additional bytes used for popping callee regs to stack + * that need to be taken into account for jump offset that is used for + * bailing out from of the tail call when limit is reached + */ + pop_bytes = get_pop_bytes(callee_regs_used); + off1 += pop_bytes; + + /* + * total bytes for: + * - nop5/ jmpq $off + * - pop callee regs + * - sub rsp, $val + * - pop rax + */ + poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1; + /* * if (tail_call_cnt > MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ - EMIT2(X86_JA, 14); /* ja out */ + EMIT2(X86_JA, off1); /* ja out */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE); + poke->adj_off = X86_TAIL_CALL_OFFSET; poke->tailcall_target = image + (addr - X86_PATCH_SIZE); - poke->adj_off = PROLOGUE_SIZE; + poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE; + + emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE, + poke->tailcall_bypass); + + *pprog = prog; + pop_callee_regs(pprog, callee_regs_used); + prog = *pprog; + EMIT1(0x58); /* pop rax */ + EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE); prog += X86_PATCH_SIZE; @@ -476,6 +574,11 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) (u8 *)target->bpf_func + poke->adj_off, false); BUG_ON(ret < 0); + ret = __bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + (u8 *)poke->tailcall_target + + X86_PATCH_SIZE, NULL, false); + BUG_ON(ret < 0); } WRITE_ONCE(poke->tailcall_target_stable, true); mutex_unlock(&array->aux->poke_mutex); @@ -654,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x, return true; } +static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt, + bool *regs_used, bool *tail_call_seen) +{ + int i; + + for (i = 1; i <= insn_cnt; i++, insn++) { + if (insn->code == (BPF_JMP | BPF_TAIL_CALL)) + *tail_call_seen = true; + if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6) + regs_used[0] = true; + if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7) + regs_used[1] = true; + if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8) + regs_used[2] = true; + if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9) + regs_used[3] = true; + } +} + static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, int oldproglen, struct jit_context *ctx) { + bool tail_call_reachable = bpf_prog->aux->tail_call_reachable; struct bpf_insn *insn = bpf_prog->insnsi; + bool callee_regs_used[4] = {}; int insn_cnt = bpf_prog->len; + bool tail_call_seen = false; bool seen_exit = false; u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; int i, cnt = 0, excnt = 0; int proglen = 0; u8 *prog = temp; + detect_reg_usage(insn, insn_cnt, callee_regs_used, + &tail_call_seen); + + /* tail call's presence in current prog implies it is reachable */ + tail_call_reachable |= tail_call_seen; + emit_prologue(&prog, bpf_prog->aux->stack_depth, - bpf_prog_was_classic(bpf_prog)); + bpf_prog_was_classic(bpf_prog), tail_call_reachable, + bpf_prog->aux->func_idx != 0); + push_callee_regs(&prog, callee_regs_used); addrs[0] = prog - temp; for (i = 1; i <= insn_cnt; i++, insn++) { @@ -1104,16 +1237,27 @@ xadd: if (is_imm8(insn->off)) /* call */ case BPF_JMP | BPF_CALL: func = (u8 *) __bpf_call_base + imm32; - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) - return -EINVAL; + if (tail_call_reachable) { + EMIT3_off32(0x48, 0x8B, 0x85, + -(bpf_prog->aux->stack_depth + 8)); + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7)) + return -EINVAL; + } else { + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) + return -EINVAL; + } break; case BPF_JMP | BPF_TAIL_CALL: if (imm32) emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1], - &prog, addrs[i], image); + &prog, addrs[i], image, + callee_regs_used, + bpf_prog->aux->stack_depth); else - emit_bpf_tail_call_indirect(&prog); + emit_bpf_tail_call_indirect(&prog, + callee_regs_used, + bpf_prog->aux->stack_depth); break; /* cond jump */ @@ -1296,12 +1440,9 @@ emit_jmp: seen_exit = true; /* Update cleanup_addr */ ctx->cleanup_addr = proglen; - if (!bpf_prog_was_classic(bpf_prog)) - EMIT1(0x5B); /* get rid of tail_call_cnt */ - EMIT2(0x41, 0x5F); /* pop r15 */ - EMIT2(0x41, 0x5E); /* pop r14 */ - EMIT2(0x41, 0x5D); /* pop r13 */ - EMIT1(0x5B); /* pop rbx */ + pop_callee_regs(&prog, callee_regs_used); + if (tail_call_reachable) + EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */ EMIT1(0xC9); /* leave */ EMIT1(0xC3); /* ret */ break; diff --git a/include/linux/bpf.h b/include/linux/bpf.h index f3790c9cf542..d7c5a6ed87e3 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -698,6 +698,8 @@ enum bpf_jit_poke_reason { /* Descriptor of pokes pointing /into/ the JITed image. */ struct bpf_jit_poke_descriptor { void *tailcall_target; + void *tailcall_bypass; + void *bypass_addr; union { struct { struct bpf_map *map; @@ -738,6 +740,7 @@ struct bpf_prog_aux { bool attach_btf_trace; /* true if attaching to BTF-enabled raw tp */ bool func_proto_unreliable; bool sleepable; + bool tail_call_reachable; enum bpf_tramp_prog_type trampoline_prog_type; struct bpf_trampoline *trampoline; struct hlist_node tramp_hlist; diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h index 5026b75db972..fbc964526ba3 100644 --- a/include/linux/bpf_verifier.h +++ b/include/linux/bpf_verifier.h @@ -359,6 +359,7 @@ struct bpf_subprog_info { u32 linfo_idx; /* The idx to the main_prog->aux->linfo */ u16 stack_depth; /* max. stack depth used by this function */ bool has_tail_call; + bool tail_call_reachable; }; /* single container for all structs diff --git a/kernel/bpf/arraymap.c b/kernel/bpf/arraymap.c index 60abf7fe12de..e5fd31268ae0 100644 --- a/kernel/bpf/arraymap.c +++ b/kernel/bpf/arraymap.c @@ -898,6 +898,7 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key, struct bpf_prog *old, struct bpf_prog *new) { + u8 *old_addr, *new_addr, *old_bypass_addr; struct prog_poke_elem *elem; struct bpf_array_aux *aux; @@ -949,12 +950,39 @@ static void prog_array_map_poke_run(struct bpf_map *map, u32 key, poke->tail_call.key != key) continue; - ret = bpf_arch_text_poke(poke->tailcall_target, BPF_MOD_JUMP, - old ? (u8 *)old->bpf_func + - poke->adj_off : NULL, - new ? (u8 *)new->bpf_func + - poke->adj_off : NULL); - BUG_ON(ret < 0 && ret != -EINVAL); + old_bypass_addr = old ? NULL : poke->bypass_addr; + old_addr = old ? (u8 *)old->bpf_func + poke->adj_off : NULL; + new_addr = new ? (u8 *)new->bpf_func + poke->adj_off : NULL; + + if (new) { + ret = bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, + old_addr, new_addr); + BUG_ON(ret < 0 && ret != -EINVAL); + if (!old) { + ret = bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + poke->bypass_addr, + NULL); + BUG_ON(ret < 0 && ret != -EINVAL); + } + } else { + ret = bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + old_bypass_addr, + poke->bypass_addr); + BUG_ON(ret < 0 && ret != -EINVAL); + /* let other CPUs finish the execution of program + * so that it will not possible to expose them + * to invalid nop, stack unwind, nop state + */ + if (!ret) + synchronize_rcu(); + ret = bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, + old_addr, NULL); + BUG_ON(ret < 0 && ret != -EINVAL); + } } } } diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index 2e00ac028d38..c4811b139caa 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -776,7 +776,7 @@ int bpf_jit_add_poke_descriptor(struct bpf_prog *prog, if (size > poke_tab_max) return -ENOSPC; if (poke->tailcall_target || poke->tailcall_target_stable || - poke->adj_off) + poke->tailcall_bypass || poke->adj_off || poke->bypass_addr) return -EINVAL; switch (poke->reason) { diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 0958fba48d59..172e12df9eaa 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -2983,8 +2983,10 @@ static int check_max_stack_depth(struct bpf_verifier_env *env) int depth = 0, frame = 0, idx = 0, i = 0, subprog_end; struct bpf_subprog_info *subprog = env->subprog_info; struct bpf_insn *insn = env->prog->insnsi; + bool tail_call_reachable = false; int ret_insn[MAX_CALL_FRAMES]; int ret_prog[MAX_CALL_FRAMES]; + int j; process_func: /* protect against potential stack overflow that might happen when @@ -3040,6 +3042,10 @@ continue_func: i); return -EFAULT; } + + if (subprog[idx].has_tail_call) + tail_call_reachable = true; + frame++; if (frame >= MAX_CALL_FRAMES) { verbose(env, "the call stack of %d frames is too deep !\n", @@ -3048,6 +3054,15 @@ continue_func: } goto process_func; } + /* if tail call got detected across bpf2bpf calls then mark each of the + * currently present subprog frames as tail call reachable subprogs; + * this info will be utilized by JIT so that we will be preserving the + * tail call counter throughout bpf2bpf calls combined with tailcalls + */ + if (tail_call_reachable) + for (j = 0; j < frame; j++) + subprog[ret_prog[j]].tail_call_reachable = true; + /* end of for() loop means the last insn of the 'subprog' * was reached. Doesn't matter whether it was JA or EXIT */ @@ -10322,6 +10337,7 @@ static int jit_subprogs(struct bpf_verifier_env *env) num_exentries++; } func[i]->aux->num_exentries = num_exentries; + func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable; func[i] = bpf_int_jit_compile(func[i]); if (!func[i]->jited) { err = -ENOTSUPP; |