diff options
Diffstat (limited to 'tools/testing/selftests/net/lib')
-rw-r--r-- | tools/testing/selftests/net/lib/csum.c | 16 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/ksft.py | 60 |
2 files changed, 72 insertions, 4 deletions
diff --git a/tools/testing/selftests/net/lib/csum.c b/tools/testing/selftests/net/lib/csum.c index b9f3fc3c3426..e0a34e5e8dd5 100644 --- a/tools/testing/selftests/net/lib/csum.c +++ b/tools/testing/selftests/net/lib/csum.c @@ -654,10 +654,16 @@ static int recv_verify_packet_ipv4(void *nh, int len) { struct iphdr *iph = nh; uint16_t proto = cfg_encap ? IPPROTO_UDP : cfg_proto; + uint16_t ip_len; if (len < sizeof(*iph) || iph->protocol != proto) return -1; + ip_len = ntohs(iph->tot_len); + if (ip_len > len || ip_len < sizeof(*iph)) + return -1; + + len = ip_len; iph_addr_p = &iph->saddr; if (proto == IPPROTO_TCP) return recv_verify_packet_tcp(iph + 1, len - sizeof(*iph)); @@ -669,16 +675,22 @@ static int recv_verify_packet_ipv6(void *nh, int len) { struct ipv6hdr *ip6h = nh; uint16_t proto = cfg_encap ? IPPROTO_UDP : cfg_proto; + uint16_t ip_len; if (len < sizeof(*ip6h) || ip6h->nexthdr != proto) return -1; + ip_len = ntohs(ip6h->payload_len); + if (ip_len > len - sizeof(*ip6h)) + return -1; + + len = ip_len; iph_addr_p = &ip6h->saddr; if (proto == IPPROTO_TCP) - return recv_verify_packet_tcp(ip6h + 1, len - sizeof(*ip6h)); + return recv_verify_packet_tcp(ip6h + 1, len); else - return recv_verify_packet_udp(ip6h + 1, len - sizeof(*ip6h)); + return recv_verify_packet_udp(ip6h + 1, len); } /* return whether auxdata includes TP_STATUS_CSUM_VALID */ diff --git a/tools/testing/selftests/net/lib/py/ksft.py b/tools/testing/selftests/net/lib/py/ksft.py index f26c20df9db4..477ae76de93d 100644 --- a/tools/testing/selftests/net/lib/py/ksft.py +++ b/tools/testing/selftests/net/lib/py/ksft.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: GPL-2.0 import builtins +import functools import inspect import sys import time @@ -10,6 +11,7 @@ from .utils import global_defer_queue KSFT_RESULT = None KSFT_RESULT_ALL = True +KSFT_DISRUPTIVE = True class KsftFailEx(Exception): @@ -32,8 +34,18 @@ def _fail(*args): global KSFT_RESULT KSFT_RESULT = False - frame = inspect.stack()[2] - ksft_pr("At " + frame.filename + " line " + str(frame.lineno) + ":") + stack = inspect.stack() + started = False + for frame in reversed(stack[2:]): + # Start printing from the test case function + if not started: + if frame.function == 'ksft_run': + started = True + continue + + ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + + ", in " + frame.function + ":") + ksft_pr("Check| " + frame.code_context[0].strip()) ksft_pr(*args) @@ -43,6 +55,12 @@ def ksft_eq(a, b, comment=""): _fail("Check failed", a, "!=", b, comment) +def ksft_ne(a, b, comment=""): + global KSFT_RESULT + if a == b: + _fail("Check failed", a, "==", b, comment) + + def ksft_true(a, comment=""): if not a: _fail("Check failed", a, "does not eval to True", comment) @@ -127,6 +145,44 @@ def ksft_flush_defer(): KSFT_RESULT = False +def ksft_disruptive(func): + """ + Decorator that marks the test as disruptive (e.g. the test + that can down the interface). Disruptive tests can be skipped + by passing DISRUPTIVE=False environment variable. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not KSFT_DISRUPTIVE: + raise KsftSkipEx(f"marked as disruptive") + return func(*args, **kwargs) + return wrapper + + +def ksft_setup(env): + """ + Setup test framework global state from the environment. + """ + + def get_bool(env, name): + value = env.get(name, "").lower() + if value in ["yes", "true"]: + return True + if value in ["no", "false"]: + return False + try: + return bool(int(value)) + except: + raise Exception(f"failed to parse {name}") + + if "DISRUPTIVE" in env: + global KSFT_DISRUPTIVE + KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") + + return env + + def ksft_run(cases=None, globs=None, case_pfx=None, args=()): cases = cases or [] |