diff options
56 files changed, 3691 insertions, 708 deletions
diff --git a/arch/powerpc/platforms/pseries/Kconfig b/arch/powerpc/platforms/pseries/Kconfig index 4ebf2ef2845d..afc0f6a61337 100644 --- a/arch/powerpc/platforms/pseries/Kconfig +++ b/arch/powerpc/platforms/pseries/Kconfig @@ -164,6 +164,12 @@ config PSERIES_PLPKS # This option is selected by in-kernel consumers that require # access to the PKS. +config PSERIES_PLPKS_SED + depends on PPC_PSERIES + bool + # This option is selected by in-kernel consumers that require + # access to the SED PKS keystore. + config PAPR_SCM depends on PPC_PSERIES && MEMORY_HOTPLUG && LIBNVDIMM tristate "Support for the PAPR Storage Class Memory interface" diff --git a/arch/powerpc/platforms/pseries/Makefile b/arch/powerpc/platforms/pseries/Makefile index 53c3b91af2f7..1476c5e4433c 100644 --- a/arch/powerpc/platforms/pseries/Makefile +++ b/arch/powerpc/platforms/pseries/Makefile @@ -29,6 +29,7 @@ obj-$(CONFIG_PPC_SVM) += svm.o obj-$(CONFIG_FA_DUMP) += rtas-fadump.o obj-$(CONFIG_PSERIES_PLPKS) += plpks.o obj-$(CONFIG_PPC_SECURE_BOOT) += plpks-secvar.o +obj-$(CONFIG_PSERIES_PLPKS_SED) += plpks_sed_ops.o obj-$(CONFIG_SUSPEND) += suspend.o obj-$(CONFIG_PPC_VAS) += vas.o vas-sysfs.o diff --git a/arch/powerpc/platforms/pseries/plpks_sed_ops.c b/arch/powerpc/platforms/pseries/plpks_sed_ops.c new file mode 100644 index 000000000000..7c873c9589ef --- /dev/null +++ b/arch/powerpc/platforms/pseries/plpks_sed_ops.c @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * POWER Platform specific code for non-volatile SED key access + * Copyright (C) 2022 IBM Corporation + * + * Define operations for SED Opal to read/write keys + * from POWER LPAR Platform KeyStore(PLPKS). + * + * Self Encrypting Drives(SED) key storage using PLPKS + */ + +#include <linux/kernel.h> +#include <linux/slab.h> +#include <linux/string.h> +#include <linux/ioctl.h> +#include <linux/sed-opal-key.h> +#include <asm/plpks.h> + +static bool plpks_sed_initialized = false; +static bool plpks_sed_available = false; + +/* + * structure that contains all SED data + */ +struct plpks_sed_object_data { + u_char version; + u_char pad1[7]; + u_long authority; + u_long range; + u_int key_len; + u_char key[32]; +}; + +#define PLPKS_SED_OBJECT_DATA_V0 0 +#define PLPKS_SED_MANGLED_LABEL "/default/pri" +#define PLPKS_SED_COMPONENT "sed-opal" +#define PLPKS_SED_KEY "opal-boot-pin" + +/* + * authority is admin1 and range is global + */ +#define PLPKS_SED_AUTHORITY 0x0000000900010001 +#define PLPKS_SED_RANGE 0x0000080200000001 + +static void plpks_init_var(struct plpks_var *var, char *keyname) +{ + if (!plpks_sed_initialized) { + plpks_sed_initialized = true; + plpks_sed_available = plpks_is_available(); + if (!plpks_sed_available) + pr_err("SED: plpks not available\n"); + } + + var->name = keyname; + var->namelen = strlen(keyname); + if (strcmp(PLPKS_SED_KEY, keyname) == 0) { + var->name = PLPKS_SED_MANGLED_LABEL; + var->namelen = strlen(keyname); + } + var->policy = PLPKS_WORLDREADABLE; + var->os = PLPKS_VAR_COMMON; + var->data = NULL; + var->datalen = 0; + var->component = PLPKS_SED_COMPONENT; +} + +/* + * Read the SED Opal key from PLPKS given the label + */ +int sed_read_key(char *keyname, char *key, u_int *keylen) +{ + struct plpks_var var; + struct plpks_sed_object_data data; + int ret; + u_int len; + + plpks_init_var(&var, keyname); + + if (!plpks_sed_available) + return -EOPNOTSUPP; + + var.data = (u8 *)&data; + var.datalen = sizeof(data); + + ret = plpks_read_os_var(&var); + if (ret != 0) + return ret; + + len = min_t(u16, be32_to_cpu(data.key_len), var.datalen); + memcpy(key, data.key, len); + key[len] = '\0'; + *keylen = len; + + return 0; +} + +/* + * Write the SED Opal key to PLPKS given the label + */ +int sed_write_key(char *keyname, char *key, u_int keylen) +{ + struct plpks_var var; + struct plpks_sed_object_data data; + struct plpks_var_name vname; + + plpks_init_var(&var, keyname); + + if (!plpks_sed_available) + return -EOPNOTSUPP; + + var.datalen = sizeof(struct plpks_sed_object_data); + var.data = (u8 *)&data; + + /* initialize SED object */ + data.version = PLPKS_SED_OBJECT_DATA_V0; + data.authority = cpu_to_be64(PLPKS_SED_AUTHORITY); + data.range = cpu_to_be64(PLPKS_SED_RANGE); + memset(&data.pad1, '\0', sizeof(data.pad1)); + data.key_len = cpu_to_be32(keylen); + memcpy(data.key, (char *)key, keylen); + + /* + * Key update requires remove first. The return value + * is ignored since it's okay if the key doesn't exist. + */ + vname.namelen = var.namelen; + vname.name = var.name; + plpks_remove_var(var.component, var.os, vname); + + return plpks_write_var(var); +} diff --git a/block/Kconfig b/block/Kconfig index f1364d1c0d93..55ae2286a4de 100644 --- a/block/Kconfig +++ b/block/Kconfig @@ -186,6 +186,7 @@ config BLK_SED_OPAL bool "Logic for interfacing with Opal enabled SEDs" depends on KEYS select PSERIES_PLPKS if PPC_PSERIES + select PSERIES_PLPKS_SED if PPC_PSERIES help Builds Logic for interfacing with Opal enabled controllers. Enabling this option enables users to setup/unlock/lock diff --git a/block/badblocks.c b/block/badblocks.c index 3afb550c0f7b..fc92d4e18aa3 100644 --- a/block/badblocks.c +++ b/block/badblocks.c @@ -16,119 +16,830 @@ #include <linux/types.h> #include <linux/slab.h> -/** - * badblocks_check() - check a given range for bad sectors - * @bb: the badblocks structure that holds all badblock information - * @s: sector (start) at which to check for badblocks - * @sectors: number of sectors to check for badblocks - * @first_bad: pointer to store location of the first badblock - * @bad_sectors: pointer to store number of badblocks after @first_bad +/* + * The purpose of badblocks set/clear is to manage bad blocks ranges which are + * identified by LBA addresses. * - * We can record which blocks on each device are 'bad' and so just - * fail those blocks, or that stripe, rather than the whole device. - * Entries in the bad-block table are 64bits wide. This comprises: - * Length of bad-range, in sectors: 0-511 for lengths 1-512 - * Start of bad-range, sector offset, 54 bits (allows 8 exbibytes) - * A 'shift' can be set so that larger blocks are tracked and - * consequently larger devices can be covered. - * 'Acknowledged' flag - 1 bit. - the most significant bit. + * When the caller of badblocks_set() wants to set a range of bad blocks, the + * setting range can be acked or unacked. And the setting range may merge, + * overwrite, skip the overlapped already set range, depends on who they are + * overlapped or adjacent, and the acknowledgment type of the ranges. It can be + * more complicated when the setting range covers multiple already set bad block + * ranges, with restrictions of maximum length of each bad range and the bad + * table space limitation. * - * Locking of the bad-block table uses a seqlock so badblocks_check - * might need to retry if it is very unlucky. - * We will sometimes want to check for bad blocks in a bi_end_io function, - * so we use the write_seqlock_irq variant. + * It is difficult and unnecessary to take care of all the possible situations, + * for setting a large range of bad blocks, we can handle it by dividing the + * large range into smaller ones when encounter overlap, max range length or + * bad table full conditions. Every time only a smaller piece of the bad range + * is handled with a limited number of conditions how it is interacted with + * possible overlapped or adjacent already set bad block ranges. Then the hard + * complicated problem can be much simpler to handle in proper way. * - * When looking for a bad block we specify a range and want to - * know if any block in the range is bad. So we binary-search - * to the last range that starts at-or-before the given endpoint, - * (or "before the sector after the target range") - * then see if it ends after the given start. + * When setting a range of bad blocks to the bad table, the simplified situations + * to be considered are, (The already set bad blocks ranges are naming with + * prefix E, and the setting bad blocks range is naming with prefix S) * - * Return: - * 0: there are no known bad blocks in the range - * 1: there are known bad block which are all acknowledged - * -1: there are bad blocks which have not yet been acknowledged in metadata. - * plus the start/length of the first bad section we overlap. + * 1) A setting range is not overlapped or adjacent to any other already set bad + * block range. + * +--------+ + * | S | + * +--------+ + * +-------------+ +-------------+ + * | E1 | | E2 | + * +-------------+ +-------------+ + * For this situation if the bad blocks table is not full, just allocate a + * free slot from the bad blocks table to mark the setting range S. The + * result is, + * +-------------+ +--------+ +-------------+ + * | E1 | | S | | E2 | + * +-------------+ +--------+ +-------------+ + * 2) A setting range starts exactly at a start LBA of an already set bad blocks + * range. + * 2.1) The setting range size < already set range size + * +--------+ + * | S | + * +--------+ + * +-------------+ + * | E | + * +-------------+ + * 2.1.1) If S and E are both acked or unacked range, the setting range S can + * be merged into existing bad range E. The result is, + * +-------------+ + * | S | + * +-------------+ + * 2.1.2) If S is unacked setting and E is acked, the setting will be denied, and + * the result is, + * +-------------+ + * | E | + * +-------------+ + * 2.1.3) If S is acked setting and E is unacked, range S can overwrite on E. + * An extra slot from the bad blocks table will be allocated for S, and head + * of E will move to end of the inserted range S. The result is, + * +--------+----+ + * | S | E | + * +--------+----+ + * 2.2) The setting range size == already set range size + * 2.2.1) If S and E are both acked or unacked range, the setting range S can + * be merged into existing bad range E. The result is, + * +-------------+ + * | S | + * +-------------+ + * 2.2.2) If S is unacked setting and E is acked, the setting will be denied, and + * the result is, + * +-------------+ + * | E | + * +-------------+ + * 2.2.3) If S is acked setting and E is unacked, range S can overwrite all of + bad blocks range E. The result is, + * +-------------+ + * | S | + * +-------------+ + * 2.3) The setting range size > already set range size + * +-------------------+ + * | S | + * +-------------------+ + * +-------------+ + * | E | + * +-------------+ + * For such situation, the setting range S can be treated as two parts, the + * first part (S1) is as same size as the already set range E, the second + * part (S2) is the rest of setting range. + * +-------------+-----+ +-------------+ +-----+ + * | S1 | S2 | | S1 | | S2 | + * +-------------+-----+ ===> +-------------+ +-----+ + * +-------------+ +-------------+ + * | E | | E | + * +-------------+ +-------------+ + * Now we only focus on how to handle the setting range S1 and already set + * range E, which are already explained in 2.2), for the rest S2 it will be + * handled later in next loop. + * 3) A setting range starts before the start LBA of an already set bad blocks + * range. + * +-------------+ + * | S | + * +-------------+ + * +-------------+ + * | E | + * +-------------+ + * For this situation, the setting range S can be divided into two parts, the + * first (S1) ends at the start LBA of already set range E, the second part + * (S2) starts exactly at a start LBA of the already set range E. + * +----+---------+ +----+ +---------+ + * | S1 | S2 | | S1 | | S2 | + * +----+---------+ ===> +----+ +---------+ + * +-------------+ +-------------+ + * | E | | E | + * +-------------+ +-------------+ + * Now only the first part S1 should be handled in this loop, which is in + * similar condition as 1). The rest part S2 has exact same start LBA address + * of the already set range E, they will be handled in next loop in one of + * situations in 2). + * 4) A setting range starts after the start LBA of an already set bad blocks + * range. + * 4.1) If the setting range S exactly matches the tail part of already set bad + * blocks range E, like the following chart shows, + * +---------+ + * | S | + * +---------+ + * +-------------+ + * | E | + * +-------------+ + * 4.1.1) If range S and E have same acknowledge value (both acked or unacked), + * they will be merged into one, the result is, + * +-------------+ + * | S | + * +-------------+ + * 4.1.2) If range E is acked and the setting range S is unacked, the setting + * request of S will be rejected, the result is, + * +-------------+ + * | E | + * +-------------+ + * 4.1.3) If range E is unacked, and the setting range S is acked, then S may + * overwrite the overlapped range of E, the result is, + * +---+---------+ + * | E | S | + * +---+---------+ + * 4.2) If the setting range S stays in middle of an already set range E, like + * the following chart shows, + * +----+ + * | S | + * +----+ + * +--------------+ + * | E | + * +--------------+ + * 4.2.1) If range S and E have same acknowledge value (both acked or unacked), + * they will be merged into one, the result is, + * +--------------+ + * | S | + * +--------------+ + * 4.2.2) If range E is acked and the setting range S is unacked, the setting + * request of S will be rejected, the result is also, + * +--------------+ + * | E | + * +--------------+ + * 4.2.3) If range E is unacked, and the setting range S is acked, then S will + * inserted into middle of E and split previous range E into two parts (E1 + * and E2), the result is, + * +----+----+----+ + * | E1 | S | E2 | + * +----+----+----+ + * 4.3) If the setting bad blocks range S is overlapped with an already set bad + * blocks range E. The range S starts after the start LBA of range E, and + * ends after the end LBA of range E, as the following chart shows, + * +-------------------+ + * | S | + * +-------------------+ + * +-------------+ + * | E | + * +-------------+ + * For this situation the range S can be divided into two parts, the first + * part (S1) ends at end range E, and the second part (S2) has rest range of + * origin S. + * +---------+---------+ +---------+ +---------+ + * | S1 | S2 | | S1 | | S2 | + * +---------+---------+ ===> +---------+ +---------+ + * +-------------+ +-------------+ + * | E | | E | + * +-------------+ +-------------+ + * Now in this loop the setting range S1 and already set range E can be + * handled as the situations 4.1), the rest range S2 will be handled in next + * loop and ignored in this loop. + * 5) A setting bad blocks range S is adjacent to one or more already set bad + * blocks range(s), and they are all acked or unacked range. + * 5.1) Front merge: If the already set bad blocks range E is before setting + * range S and they are adjacent, + * +------+ + * | S | + * +------+ + * +-------+ + * | E | + * +-------+ + * 5.1.1) When total size of range S and E <= BB_MAX_LEN, and their acknowledge + * values are same, the setting range S can front merges into range E. The + * result is, + * +--------------+ + * | S | + * +--------------+ + * 5.1.2) Otherwise these two ranges cannot merge, just insert the setting + * range S right after already set range E into the bad blocks table. The + * result is, + * +--------+------+ + * | E | S | + * +--------+------+ + * 6) Special cases which above conditions cannot handle + * 6.1) Multiple already set ranges may merge into less ones in a full bad table + * +-------------------------------------------------------+ + * | S | + * +-------------------------------------------------------+ + * |<----- BB_MAX_LEN ----->| + * +-----+ +-----+ +-----+ + * | E1 | | E2 | | E3 | + * +-----+ +-----+ +-----+ + * In the above example, when the bad blocks table is full, inserting the + * first part of setting range S will fail because no more available slot + * can be allocated from bad blocks table. In this situation a proper + * setting method should be go though all the setting bad blocks range and + * look for chance to merge already set ranges into less ones. When there + * is available slot from bad blocks table, re-try again to handle more + * setting bad blocks ranges as many as possible. + * +------------------------+ + * | S3 | + * +------------------------+ + * |<----- BB_MAX_LEN ----->| + * +-----+-----+-----+---+-----+--+ + * | S1 | S2 | + * +-----+-----+-----+---+-----+--+ + * The above chart shows although the first part (S3) cannot be inserted due + * to no-space in bad blocks table, but the following E1, E2 and E3 ranges + * can be merged with rest part of S into less range S1 and S2. Now there is + * 1 free slot in bad blocks table. + * +------------------------+-----+-----+-----+---+-----+--+ + * | S3 | S1 | S2 | + * +------------------------+-----+-----+-----+---+-----+--+ + * Since the bad blocks table is not full anymore, re-try again for the + * origin setting range S. Now the setting range S3 can be inserted into the + * bad blocks table with previous freed slot from multiple ranges merge. + * 6.2) Front merge after overwrite + * In the following example, in bad blocks table, E1 is an acked bad blocks + * range and E2 is an unacked bad blocks range, therefore they are not able + * to merge into a larger range. The setting bad blocks range S is acked, + * therefore part of E2 can be overwritten by S. + * +--------+ + * | S | acknowledged + * +--------+ S: 1 + * +-------+-------------+ E1: 1 + * | E1 | E2 | E2: 0 + * +-------+-------------+ + * With previous simplified routines, after overwriting part of E2 with S, + * the bad blocks table should be (E3 is remaining part of E2 which is not + * overwritten by S), + * acknowledged + * +-------+--------+----+ S: 1 + * | E1 | S | E3 | E1: 1 + * +-------+--------+----+ E3: 0 + * The above result is correct but not perfect. Range E1 and S in the bad + * blocks table are all acked, merging them into a larger one range may + * occupy less bad blocks table space and make badblocks_check() faster. + * Therefore in such situation, after overwriting range S, the previous range + * E1 should be checked for possible front combination. Then the ideal + * result can be, + * +----------------+----+ acknowledged + * | E1 | E3 | E1: 1 + * +----------------+----+ E3: 0 + * 6.3) Behind merge: If the already set bad blocks range E is behind the setting + * range S and they are adjacent. Normally we don't need to care about this + * because front merge handles this while going though range S from head to + * tail, except for the tail part of range S. When the setting range S are + * fully handled, all the above simplified routine doesn't check whether the + * tail LBA of range S is adjacent to the next already set range and not + * merge them even it is possible. + * +------+ + * | S | + * +------+ + * +-------+ + * | E | + * +-------+ + * For the above special situation, when the setting range S are all handled + * and the loop ends, an extra check is necessary for whether next already + * set range E is right after S and mergeable. + * 6.3.1) When total size of range E and S <= BB_MAX_LEN, and their acknowledge + * values are same, the setting range S can behind merges into range E. The + * result is, + * +--------------+ + * | S | + * +--------------+ + * 6.3.2) Otherwise these two ranges cannot merge, just insert the setting range + * S in front of the already set range E in the bad blocks table. The result + * is, + * +------+-------+ + * | S | E | + * +------+-------+ + * + * All the above 5 simplified situations and 3 special cases may cover 99%+ of + * the bad block range setting conditions. Maybe there is some rare corner case + * is not considered and optimized, it won't hurt if badblocks_set() fails due + * to no space, or some ranges are not merged to save bad blocks table space. + * + * Inside badblocks_set() each loop starts by jumping to re_insert label, every + * time for the new loop prev_badblocks() is called to find an already set range + * which starts before or at current setting range. Since the setting bad blocks + * range is handled from head to tail, most of the cases it is unnecessary to do + * the binary search inside prev_badblocks(), it is possible to provide a hint + * to prev_badblocks() for a fast path, then the expensive binary search can be + * avoided. In my test with the hint to prev_badblocks(), except for the first + * loop, all rested calls to prev_badblocks() can go into the fast path and + * return correct bad blocks table index immediately. + * + * + * Clearing a bad blocks range from the bad block table has similar idea as + * setting does, but much more simpler. The only thing needs to be noticed is + * when the clearing range hits middle of a bad block range, the existing bad + * block range will split into two, and one more item should be added into the + * bad block table. The simplified situations to be considered are, (The already + * set bad blocks ranges in bad block table are naming with prefix E, and the + * clearing bad blocks range is naming with prefix C) + * + * 1) A clearing range is not overlapped to any already set ranges in bad block + * table. + * +-----+ | +-----+ | +-----+ + * | C | | | C | | | C | + * +-----+ or +-----+ or +-----+ + * +---+ | +----+ +----+ | +---+ + * | E | | | E1 | | E2 | | | E | + * +---+ | +----+ +----+ | +---+ + * For the above situations, no bad block to be cleared and no failure + * happens, simply returns 0. + * 2) The clearing range hits middle of an already setting bad blocks range in + * the bad block table. + * +---+ + * | C | + * +---+ + * +-----------------+ + * | E | + * +-----------------+ + * In this situation if the bad block table is not full, the range E will be + * split into two ranges E1 and E2. The result is, + * +------+ +------+ + * | E1 | | E2 | + * +------+ +------+ + * 3) The clearing range starts exactly at same LBA as an already set bad block range + * from the bad block table. + * 3.1) Partially covered at head part + * +------------+ + * | C | + * +------------+ + * +-----------------+ + * | E | + * +-----------------+ + * For this situation, the overlapped already set range will update the + * start LBA to end of C and shrink the range to BB_LEN(E) - BB_LEN(C). No + * item deleted from bad block table. The result is, + * +----+ + * | E1 | + * +----+ + * 3.2) Exact fully covered + * +-----------------+ + * | C | + * +-----------------+ + * +-----------------+ + * | E | + * +-----------------+ + * For this situation the whole bad blocks range E will be cleared and its + * corresponded item is deleted from the bad block table. + * 4) The clearing range exactly ends at same LBA as an already set bad block + * range. + * +-------+ + * | C | + * +-------+ + * +-----------------+ + * | E | + * +-----------------+ + * For the above situation, the already set range E is updated to shrink its + * end to the start of C, and reduce its length to BB_LEN(E) - BB_LEN(C). + * The result is, + * +---------+ + * | E | + * +---------+ + * 5) The clearing range is partially overlapped with an already set bad block + * range from the bad block table. + * 5.1) The already set bad block range is front overlapped with the clearing + * range. + * +----------+ + * | C | + * +----------+ + * +------------+ + * | E | + * +------------+ + * For such situation, the clearing range C can be treated as two parts. The + * first part ends at the start LBA of range E, and the second part starts at + * same LBA of range E. + * +----+-----+ +----+ +-----+ + * | C1 | C2 | | C1 | | C2 | + * +----+-----+ ===> +----+ +-----+ + * +------------+ +------------+ + * | E | | E | + * +------------+ +------------+ + * Now the first part C1 can be handled as condition 1), and the second part C2 can be + * handled as condition 3.1) in next loop. + * 5.2) The already set bad block range is behind overlaopped with the clearing + * range. + * +----------+ + * | C | + * +----------+ + * +------------+ + * | E | + * +------------+ + * For such situation, the clearing range C can be treated as two parts. The + * first part C1 ends at same end LBA of range E, and the second part starts + * at end LBA of range E. + * +----+-----+ +----+ +-----+ + * | C1 | C2 | | C1 | | C2 | + * +----+-----+ ===> +----+ +-----+ + * +------------+ +------------+ + * | E | | E | + * +------------+ +------------+ + * Now the first part clearing range C1 can be handled as condition 4), and + * the second part clearing range C2 can be handled as condition 1) in next + * loop. + * + * All bad blocks range clearing can be simplified into the above 5 situations + * by only handling the head part of the clearing range in each run of the + * while-loop. The idea is similar to bad blocks range setting but much + * simpler. */ -int badblocks_check(struct badblocks *bb, sector_t s, int sectors, - sector_t *first_bad, int *bad_sectors) + +/* + * Find the range starts at-or-before 's' from bad table. The search + * starts from index 'hint' and stops at index 'hint_end' from the bad + * table. + */ +static int prev_by_hint(struct badblocks *bb, sector_t s, int hint) { - int hi; - int lo; + int hint_end = hint + 2; u64 *p = bb->page; - int rv; - sector_t target = s + sectors; - unsigned seq; + int ret = -1; - if (bb->shift > 0) { - /* round the start down, and the end up */ - s >>= bb->shift; - target += (1<<bb->shift) - 1; - target >>= bb->shift; + while ((hint < hint_end) && ((hint + 1) <= bb->count) && + (BB_OFFSET(p[hint]) <= s)) { + if ((hint + 1) == bb->count || BB_OFFSET(p[hint + 1]) > s) { + ret = hint; + break; + } + hint++; + } + + return ret; +} + +/* + * Find the range starts at-or-before bad->start. If 'hint' is provided + * (hint >= 0) then search in the bad table from hint firstly. It is + * very probably the wanted bad range can be found from the hint index, + * then the unnecessary while-loop iteration can be avoided. + */ +static int prev_badblocks(struct badblocks *bb, struct badblocks_context *bad, + int hint) +{ + sector_t s = bad->start; + int ret = -1; + int lo, hi; + u64 *p; + + if (!bb->count) + goto out; + + if (hint >= 0) { + ret = prev_by_hint(bb, s, hint); + if (ret >= 0) + goto out; } - /* 'target' is now the first block after the bad range */ -retry: - seq = read_seqbegin(&bb->lock); lo = 0; - rv = 0; hi = bb->count; + p = bb->page; - /* Binary search between lo and hi for 'target' - * i.e. for the last range that starts before 'target' - */ - /* INVARIANT: ranges before 'lo' and at-or-after 'hi' - * are known not to be the last range before target. - * VARIANT: hi-lo is the number of possible - * ranges, and decreases until it reaches 1 - */ + /* The following bisect search might be unnecessary */ + if (BB_OFFSET(p[lo]) > s) + return -1; + if (BB_OFFSET(p[hi - 1]) <= s) + return hi - 1; + + /* Do bisect search in bad table */ while (hi - lo > 1) { - int mid = (lo + hi) / 2; + int mid = (lo + hi)/2; sector_t a = BB_OFFSET(p[mid]); - if (a < target) - /* This could still be the one, earlier ranges - * could not. - */ + if (a == s) { + ret = mid; + goto out; + } + + if (a < s) lo = mid; else - /* This and later ranges are definitely out. */ hi = mid; } - /* 'lo' might be the last that started before target, but 'hi' isn't */ - if (hi > lo) { - /* need to check all range that end after 's' to see if - * any are unacknowledged. + + if (BB_OFFSET(p[lo]) <= s) + ret = lo; +out: + return ret; +} + +/* + * Return 'true' if the range indicated by 'bad' can be backward merged + * with the bad range (from the bad table) index by 'behind'. + */ +static bool can_merge_behind(struct badblocks *bb, + struct badblocks_context *bad, int behind) +{ + sector_t sectors = bad->len; + sector_t s = bad->start; + u64 *p = bb->page; + + if ((s < BB_OFFSET(p[behind])) && + ((s + sectors) >= BB_OFFSET(p[behind])) && + ((BB_END(p[behind]) - s) <= BB_MAX_LEN) && + BB_ACK(p[behind]) == bad->ack) + return true; + return false; +} + +/* + * Do backward merge for range indicated by 'bad' and the bad range + * (from the bad table) indexed by 'behind'. The return value is merged + * sectors from bad->len. + */ +static int behind_merge(struct badblocks *bb, struct badblocks_context *bad, + int behind) +{ + sector_t sectors = bad->len; + sector_t s = bad->start; + u64 *p = bb->page; + int merged = 0; + + WARN_ON(s >= BB_OFFSET(p[behind])); + WARN_ON((s + sectors) < BB_OFFSET(p[behind])); + + if (s < BB_OFFSET(p[behind])) { + merged = BB_OFFSET(p[behind]) - s; + p[behind] = BB_MAKE(s, BB_LEN(p[behind]) + merged, bad->ack); + + WARN_ON((BB_LEN(p[behind]) + merged) >= BB_MAX_LEN); + } + + return merged; +} + +/* + * Return 'true' if the range indicated by 'bad' can be forward + * merged with the bad range (from the bad table) indexed by 'prev'. + */ +static bool can_merge_front(struct badblocks *bb, int prev, + struct badblocks_context *bad) +{ + sector_t s = bad->start; + u64 *p = bb->page; + + if (BB_ACK(p[prev]) == bad->ack && + (s < BB_END(p[prev]) || + (s == BB_END(p[prev]) && (BB_LEN(p[prev]) < BB_MAX_LEN)))) + return true; + return false; +} + +/* + * Do forward merge for range indicated by 'bad' and the bad range + * (from bad table) indexed by 'prev'. The return value is sectors + * merged from bad->len. + */ +static int front_merge(struct badblocks *bb, int prev, struct badblocks_context *bad) +{ + sector_t sectors = bad->len; + sector_t s = bad->start; + u64 *p = bb->page; + int merged = 0; + + WARN_ON(s > BB_END(p[prev])); + + if (s < BB_END(p[prev])) { + merged = min_t(sector_t, sectors, BB_END(p[prev]) - s); + } else { + merged = min_t(sector_t, sectors, BB_MAX_LEN - BB_LEN(p[prev])); + if ((prev + 1) < bb->count && + merged > (BB_OFFSET(p[prev + 1]) - BB_END(p[prev]))) { + merged = BB_OFFSET(p[prev + 1]) - BB_END(p[prev]); + } + + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + BB_LEN(p[prev]) + merged, bad->ack); + } + + return merged; +} + +/* + * 'Combine' is a special case which can_merge_front() is not able to + * handle: If a bad range (indexed by 'prev' from bad table) exactly + * starts as bad->start, and the bad range ahead of 'prev' (indexed by + * 'prev - 1' from bad table) exactly ends at where 'prev' starts, and + * the sum of their lengths does not exceed BB_MAX_LEN limitation, then + * these two bad range (from bad table) can be combined. + * + * Return 'true' if bad ranges indexed by 'prev' and 'prev - 1' from bad + * table can be combined. + */ +static bool can_combine_front(struct badblocks *bb, int prev, + struct badblocks_context *bad) +{ + u64 *p = bb->page; + + if ((prev > 0) && + (BB_OFFSET(p[prev]) == bad->start) && + (BB_END(p[prev - 1]) == BB_OFFSET(p[prev])) && + (BB_LEN(p[prev - 1]) + BB_LEN(p[prev]) <= BB_MAX_LEN) && + (BB_ACK(p[prev - 1]) == BB_ACK(p[prev]))) + return true; + return false; +} + +/* + * Combine the bad ranges indexed by 'prev' and 'prev - 1' (from bad + * table) into one larger bad range, and the new range is indexed by + * 'prev - 1'. + * The caller of front_combine() will decrease bb->count, therefore + * it is unnecessary to clear p[perv] after front merge. + */ +static void front_combine(struct badblocks *bb, int prev) +{ + u64 *p = bb->page; + + p[prev - 1] = BB_MAKE(BB_OFFSET(p[prev - 1]), + BB_LEN(p[prev - 1]) + BB_LEN(p[prev]), + BB_ACK(p[prev])); + if ((prev + 1) < bb->count) + memmove(p + prev, p + prev + 1, (bb->count - prev - 1) * 8); +} + +/* + * Return 'true' if the range indicated by 'bad' is exactly forward + * overlapped with the bad range (from bad table) indexed by 'front'. + * Exactly forward overlap means the bad range (from bad table) indexed + * by 'prev' does not cover the whole range indicated by 'bad'. + */ +static bool overlap_front(struct badblocks *bb, int front, + struct badblocks_context *bad) +{ + u64 *p = bb->page; + + if (bad->start >= BB_OFFSET(p[front]) && + bad->start < BB_END(p[front])) + return true; + return false; +} + +/* + * Return 'true' if the range indicated by 'bad' is exactly backward + * overlapped with the bad range (from bad table) indexed by 'behind'. + */ +static bool overlap_behind(struct badblocks *bb, struct badblocks_context *bad, + int behind) +{ + u64 *p = bb->page; + + if (bad->start < BB_OFFSET(p[behind]) && + (bad->start + bad->len) > BB_OFFSET(p[behind])) + return true; + return false; +} + +/* + * Return 'true' if the range indicated by 'bad' can overwrite the bad + * range (from bad table) indexed by 'prev'. + * + * The range indicated by 'bad' can overwrite the bad range indexed by + * 'prev' when, + * 1) The whole range indicated by 'bad' can cover partial or whole bad + * range (from bad table) indexed by 'prev'. + * 2) The ack value of 'bad' is larger or equal to the ack value of bad + * range 'prev'. + * + * If the overwriting doesn't cover the whole bad range (from bad table) + * indexed by 'prev', new range might be split from existing bad range, + * 1) The overwrite covers head or tail part of existing bad range, 1 + * extra bad range will be split and added into the bad table. + * 2) The overwrite covers middle of existing bad range, 2 extra bad + * ranges will be split (ahead and after the overwritten range) and + * added into the bad table. + * The number of extra split ranges of the overwriting is stored in + * 'extra' and returned for the caller. + */ +static bool can_front_overwrite(struct badblocks *bb, int prev, + struct badblocks_context *bad, int *extra) +{ + u64 *p = bb->page; + int len; + + WARN_ON(!overlap_front(bb, prev, bad)); + + if (BB_ACK(p[prev]) >= bad->ack) + return false; + + if (BB_END(p[prev]) <= (bad->start + bad->len)) { + len = BB_END(p[prev]) - bad->start; + if (BB_OFFSET(p[prev]) == bad->start) + *extra = 0; + else + *extra = 1; + + bad->len = len; + } else { + if (BB_OFFSET(p[prev]) == bad->start) + *extra = 1; + else + /* + * prev range will be split into two, beside the overwritten + * one, an extra slot needed from bad table. */ - while (lo >= 0 && - BB_OFFSET(p[lo]) + BB_LEN(p[lo]) > s) { - if (BB_OFFSET(p[lo]) < target) { - /* starts before the end, and finishes after - * the start, so they must overlap - */ - if (rv != -1 && BB_ACK(p[lo])) - rv = 1; - else - rv = -1; - *first_bad = BB_OFFSET(p[lo]); - *bad_sectors = BB_LEN(p[lo]); - } - lo--; + *extra = 2; + } + + if ((bb->count + (*extra)) >= MAX_BADBLOCKS) + return false; + + return true; +} + +/* + * Do the overwrite from the range indicated by 'bad' to the bad range + * (from bad table) indexed by 'prev'. + * The previously called can_front_overwrite() will provide how many + * extra bad range(s) might be split and added into the bad table. All + * the splitting cases in the bad table will be handled here. + */ +static int front_overwrite(struct badblocks *bb, int prev, + struct badblocks_context *bad, int extra) +{ + u64 *p = bb->page; + sector_t orig_end = BB_END(p[prev]); + int orig_ack = BB_ACK(p[prev]); + + switch (extra) { + case 0: + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), BB_LEN(p[prev]), + bad->ack); + break; + case 1: + if (BB_OFFSET(p[prev]) == bad->start) { + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + bad->len, bad->ack); + memmove(p + prev + 2, p + prev + 1, + (bb->count - prev - 1) * 8); + p[prev + 1] = BB_MAKE(bad->start + bad->len, + orig_end - BB_END(p[prev]), + orig_ack); + } else { + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + bad->start - BB_OFFSET(p[prev]), + orig_ack); + /* + * prev +2 -> prev + 1 + 1, which is for, + * 1) prev + 1: the slot index of the previous one + * 2) + 1: one more slot for extra being 1. + */ + memmove(p + prev + 2, p + prev + 1, + (bb->count - prev - 1) * 8); + p[prev + 1] = BB_MAKE(bad->start, bad->len, bad->ack); } + break; + case 2: + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + bad->start - BB_OFFSET(p[prev]), + orig_ack); + /* + * prev + 3 -> prev + 1 + 2, which is for, + * 1) prev + 1: the slot index of the previous one + * 2) + 2: two more slots for extra being 2. + */ + memmove(p + prev + 3, p + prev + 1, + (bb->count - prev - 1) * 8); + p[prev + 1] = BB_MAKE(bad->start, bad->len, bad->ack); + p[prev + 2] = BB_MAKE(BB_END(p[prev + 1]), + orig_end - BB_END(p[prev + 1]), + orig_ack); + break; + default: + break; } - if (read_seqretry(&bb->lock, seq)) - goto retry; + return bad->len; +} - return rv; +/* + * Explicitly insert a range indicated by 'bad' to the bad table, where + * the location is indexed by 'at'. + */ +static int insert_at(struct badblocks *bb, int at, struct badblocks_context *bad) +{ + u64 *p = bb->page; + int len; + + WARN_ON(badblocks_full(bb)); + + len = min_t(sector_t, bad->len, BB_MAX_LEN); + if (at < bb->count) + memmove(p + at + 1, p + at, (bb->count - at) * 8); + p[at] = BB_MAKE(bad->start, len, bad->ack); + + return len; } -EXPORT_SYMBOL_GPL(badblocks_check); static void badblocks_update_acked(struct badblocks *bb) { + bool unacked = false; u64 *p = bb->page; int i; - bool unacked = false; if (!bb->unacked_exist) return; @@ -144,281 +855,600 @@ static void badblocks_update_acked(struct badblocks *bb) bb->unacked_exist = 0; } -/** - * badblocks_set() - Add a range of bad blocks to the table. - * @bb: the badblocks structure that holds all badblock information - * @s: first sector to mark as bad - * @sectors: number of sectors to mark as bad - * @acknowledged: weather to mark the bad sectors as acknowledged - * - * This might extend the table, or might contract it if two adjacent ranges - * can be merged. We binary-search to find the 'insertion' point, then - * decide how best to handle it. - * - * Return: - * 0: success - * 1: failed to set badblocks (out of space) - */ -int badblocks_set(struct badblocks *bb, sector_t s, int sectors, - int acknowledged) +/* Do exact work to set bad block range into the bad block table */ +static int _badblocks_set(struct badblocks *bb, sector_t s, int sectors, + int acknowledged) { - u64 *p; - int lo, hi; - int rv = 0; + int retried = 0, space_desired = 0; + int orig_len, len = 0, added = 0; + struct badblocks_context bad; + int prev = -1, hint = -1; + sector_t orig_start; unsigned long flags; + int rv = 0; + u64 *p; if (bb->shift < 0) /* badblocks are disabled */ return 1; + if (sectors == 0) + /* Invalid sectors number */ + return 1; + if (bb->shift) { /* round the start down, and the end up */ sector_t next = s + sectors; - s >>= bb->shift; - next += (1<<bb->shift) - 1; - next >>= bb->shift; + rounddown(s, bb->shift); + roundup(next, bb->shift); sectors = next - s; } write_seqlock_irqsave(&bb->lock, flags); + orig_start = s; + orig_len = sectors; + bad.ack = acknowledged; p = bb->page; - lo = 0; - hi = bb->count; - /* Find the last range that starts at-or-before 's' */ - while (hi - lo > 1) { - int mid = (lo + hi) / 2; - sector_t a = BB_OFFSET(p[mid]); - if (a <= s) - lo = mid; - else - hi = mid; +re_insert: + bad.start = s; + bad.len = sectors; + len = 0; + + if (badblocks_empty(bb)) { + len = insert_at(bb, 0, &bad); + bb->count++; + added++; + goto update_sectors; } - if (hi > lo && BB_OFFSET(p[lo]) > s) - hi = lo; - if (hi > lo) { - /* we found a range that might merge with the start - * of our new range - */ - sector_t a = BB_OFFSET(p[lo]); - sector_t e = a + BB_LEN(p[lo]); - int ack = BB_ACK(p[lo]); - - if (e >= s) { - /* Yes, we can merge with a previous range */ - if (s == a && s + sectors >= e) - /* new range covers old */ - ack = acknowledged; - else - ack = ack && acknowledged; - - if (e < s + sectors) - e = s + sectors; - if (e - a <= BB_MAX_LEN) { - p[lo] = BB_MAKE(a, e-a, ack); - s = e; + prev = prev_badblocks(bb, &bad, hint); + + /* start before all badblocks */ + if (prev < 0) { + if (!badblocks_full(bb)) { + /* insert on the first */ + if (bad.len > (BB_OFFSET(p[0]) - bad.start)) + bad.len = BB_OFFSET(p[0]) - bad.start; + len = insert_at(bb, 0, &bad); + bb->count++; + added++; + hint = 0; + goto update_sectors; + } + + /* No sapce, try to merge */ + if (overlap_behind(bb, &bad, 0)) { + if (can_merge_behind(bb, &bad, 0)) { + len = behind_merge(bb, &bad, 0); + added++; } else { - /* does not all fit in one range, - * make p[lo] maximal - */ - if (BB_LEN(p[lo]) != BB_MAX_LEN) - p[lo] = BB_MAKE(a, BB_MAX_LEN, ack); - s = a + BB_MAX_LEN; + len = BB_OFFSET(p[0]) - s; + space_desired = 1; } - sectors = e - s; + hint = 0; + goto update_sectors; } + + /* no table space and give up */ + goto out; } - if (sectors && hi < bb->count) { - /* 'hi' points to the first range that starts after 's'. - * Maybe we can merge with the start of that range - */ - sector_t a = BB_OFFSET(p[hi]); - sector_t e = a + BB_LEN(p[hi]); - int ack = BB_ACK(p[hi]); - - if (a <= s + sectors) { - /* merging is possible */ - if (e <= s + sectors) { - /* full overlap */ - e = s + sectors; - ack = acknowledged; - } else - ack = ack && acknowledged; - - a = s; - if (e - a <= BB_MAX_LEN) { - p[hi] = BB_MAKE(a, e-a, ack); - s = e; - } else { - p[hi] = BB_MAKE(a, BB_MAX_LEN, ack); - s = a + BB_MAX_LEN; + + /* in case p[prev-1] can be merged with p[prev] */ + if (can_combine_front(bb, prev, &bad)) { + front_combine(bb, prev); + bb->count--; + added++; + hint = prev; + goto update_sectors; + } + + if (overlap_front(bb, prev, &bad)) { + if (can_merge_front(bb, prev, &bad)) { + len = front_merge(bb, prev, &bad); + added++; + } else { + int extra = 0; + + if (!can_front_overwrite(bb, prev, &bad, &extra)) { + len = min_t(sector_t, + BB_END(p[prev]) - s, sectors); + hint = prev; + goto update_sectors; + } + + len = front_overwrite(bb, prev, &bad, extra); + added++; + bb->count += extra; + + if (can_combine_front(bb, prev, &bad)) { + front_combine(bb, prev); + bb->count--; } - sectors = e - s; - lo = hi; - hi++; } + hint = prev; + goto update_sectors; + } + + if (can_merge_front(bb, prev, &bad)) { + len = front_merge(bb, prev, &bad); + added++; + hint = prev; + goto update_sectors; } - if (sectors == 0 && hi < bb->count) { - /* we might be able to combine lo and hi */ - /* Note: 's' is at the end of 'lo' */ - sector_t a = BB_OFFSET(p[hi]); - int lolen = BB_LEN(p[lo]); - int hilen = BB_LEN(p[hi]); - int newlen = lolen + hilen - (s - a); - - if (s >= a && newlen < BB_MAX_LEN) { - /* yes, we can combine them */ - int ack = BB_ACK(p[lo]) && BB_ACK(p[hi]); - - p[lo] = BB_MAKE(BB_OFFSET(p[lo]), newlen, ack); - memmove(p + hi, p + hi + 1, - (bb->count - hi - 1) * 8); - bb->count--; + + /* if no space in table, still try to merge in the covered range */ + if (badblocks_full(bb)) { + /* skip the cannot-merge range */ + if (((prev + 1) < bb->count) && + overlap_behind(bb, &bad, prev + 1) && + ((s + sectors) >= BB_END(p[prev + 1]))) { + len = BB_END(p[prev + 1]) - s; + hint = prev + 1; + goto update_sectors; } + + /* no retry any more */ + len = sectors; + space_desired = 1; + hint = -1; + goto update_sectors; } - while (sectors) { - /* didn't merge (it all). - * Need to add a range just before 'hi' - */ - if (bb->count >= MAX_BADBLOCKS) { - /* No room for more */ - rv = 1; - break; - } else { - int this_sectors = sectors; - memmove(p + hi + 1, p + hi, - (bb->count - hi) * 8); - bb->count++; + /* cannot merge and there is space in bad table */ + if ((prev + 1) < bb->count && + overlap_behind(bb, &bad, prev + 1)) + bad.len = min_t(sector_t, + bad.len, BB_OFFSET(p[prev + 1]) - bad.start); - if (this_sectors > BB_MAX_LEN) - this_sectors = BB_MAX_LEN; - p[hi] = BB_MAKE(s, this_sectors, acknowledged); - sectors -= this_sectors; - s += this_sectors; - } + len = insert_at(bb, prev + 1, &bad); + bb->count++; + added++; + hint = prev + 1; + +update_sectors: + s += len; + sectors -= len; + + if (sectors > 0) + goto re_insert; + + WARN_ON(sectors < 0); + + /* + * Check whether the following already set range can be + * merged. (prev < 0) condition is not handled here, + * because it's already complicated enough. + */ + if (prev >= 0 && + (prev + 1) < bb->count && + BB_END(p[prev]) == BB_OFFSET(p[prev + 1]) && + (BB_LEN(p[prev]) + BB_LEN(p[prev + 1])) <= BB_MAX_LEN && + BB_ACK(p[prev]) == BB_ACK(p[prev + 1])) { + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + BB_LEN(p[prev]) + BB_LEN(p[prev + 1]), + BB_ACK(p[prev])); + + if ((prev + 2) < bb->count) + memmove(p + prev + 1, p + prev + 2, + (bb->count - (prev + 2)) * 8); + bb->count--; + } + + if (space_desired && !badblocks_full(bb)) { + s = orig_start; + sectors = orig_len; + space_desired = 0; + if (retried++ < 3) + goto re_insert; + } + +out: + if (added) { + set_changed(bb); + + if (!acknowledged) + bb->unacked_exist = 1; + else + badblocks_update_acked(bb); } - bb->changed = 1; - if (!acknowledged) - bb->unacked_exist = 1; - else - badblocks_update_acked(bb); write_sequnlock_irqrestore(&bb->lock, flags); + if (!added) + rv = 1; + return rv; } -EXPORT_SYMBOL_GPL(badblocks_set); -/** - * badblocks_clear() - Remove a range of bad blocks to the table. - * @bb: the badblocks structure that holds all badblock information - * @s: first sector to mark as bad - * @sectors: number of sectors to mark as bad - * - * This may involve extending the table if we spilt a region, - * but it must not fail. So if the table becomes full, we just - * drop the remove request. - * - * Return: - * 0: success - * 1: failed to clear badblocks +/* + * Clear the bad block range from bad block table which is front overlapped + * with the clearing range. The return value is how many sectors from an + * already set bad block range are cleared. If the whole bad block range is + * covered by the clearing range and fully cleared, 'delete' is set as 1 for + * the caller to reduce bb->count. */ -int badblocks_clear(struct badblocks *bb, sector_t s, int sectors) +static int front_clear(struct badblocks *bb, int prev, + struct badblocks_context *bad, int *deleted) { - u64 *p; - int lo, hi; - sector_t target = s + sectors; + sector_t sectors = bad->len; + sector_t s = bad->start; + u64 *p = bb->page; + int cleared = 0; + + *deleted = 0; + if (s == BB_OFFSET(p[prev])) { + if (BB_LEN(p[prev]) > sectors) { + p[prev] = BB_MAKE(BB_OFFSET(p[prev]) + sectors, + BB_LEN(p[prev]) - sectors, + BB_ACK(p[prev])); + cleared = sectors; + } else { + /* BB_LEN(p[prev]) <= sectors */ + cleared = BB_LEN(p[prev]); + if ((prev + 1) < bb->count) + memmove(p + prev, p + prev + 1, + (bb->count - prev - 1) * 8); + *deleted = 1; + } + } else if (s > BB_OFFSET(p[prev])) { + if (BB_END(p[prev]) <= (s + sectors)) { + cleared = BB_END(p[prev]) - s; + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + s - BB_OFFSET(p[prev]), + BB_ACK(p[prev])); + } else { + /* Splitting is handled in front_splitting_clear() */ + BUG(); + } + } + + return cleared; +} + +/* + * Handle the condition that the clearing range hits middle of an already set + * bad block range from bad block table. In this condition the existing bad + * block range is split into two after the middle part is cleared. + */ +static int front_splitting_clear(struct badblocks *bb, int prev, + struct badblocks_context *bad) +{ + u64 *p = bb->page; + u64 end = BB_END(p[prev]); + int ack = BB_ACK(p[prev]); + sector_t sectors = bad->len; + sector_t s = bad->start; + + p[prev] = BB_MAKE(BB_OFFSET(p[prev]), + s - BB_OFFSET(p[prev]), + ack); + memmove(p + prev + 2, p + prev + 1, (bb->count - prev - 1) * 8); + p[prev + 1] = BB_MAKE(s + sectors, end - s - sectors, ack); + return sectors; +} + +/* Do the exact work to clear bad block range from the bad block table */ +static int _badblocks_clear(struct badblocks *bb, sector_t s, int sectors) +{ + struct badblocks_context bad; + int prev = -1, hint = -1; + int len = 0, cleared = 0; int rv = 0; + u64 *p; + + if (bb->shift < 0) + /* badblocks are disabled */ + return 1; + + if (sectors == 0) + /* Invalid sectors number */ + return 1; + + if (bb->shift) { + sector_t target; - if (bb->shift > 0) { /* When clearing we round the start up and the end down. * This should not matter as the shift should align with * the block size and no rounding should ever be needed. * However it is better the think a block is bad when it * isn't than to think a block is not bad when it is. */ - s += (1<<bb->shift) - 1; - s >>= bb->shift; - target >>= bb->shift; + target = s + sectors; + roundup(s, bb->shift); + rounddown(target, bb->shift); + sectors = target - s; } write_seqlock_irq(&bb->lock); + bad.ack = true; p = bb->page; - lo = 0; - hi = bb->count; - /* Find the last range that starts before 'target' */ - while (hi - lo > 1) { - int mid = (lo + hi) / 2; - sector_t a = BB_OFFSET(p[mid]); - if (a < target) - lo = mid; - else - hi = mid; +re_clear: + bad.start = s; + bad.len = sectors; + + if (badblocks_empty(bb)) { + len = sectors; + cleared++; + goto update_sectors; } - if (hi > lo) { - /* p[lo] is the last range that could overlap the - * current range. Earlier ranges could also overlap, - * but only this one can overlap the end of the range. - */ - if ((BB_OFFSET(p[lo]) + BB_LEN(p[lo]) > target) && - (BB_OFFSET(p[lo]) < target)) { - /* Partial overlap, leave the tail of this range */ - int ack = BB_ACK(p[lo]); - sector_t a = BB_OFFSET(p[lo]); - sector_t end = a + BB_LEN(p[lo]); - - if (a < s) { - /* we need to split this range */ - if (bb->count >= MAX_BADBLOCKS) { - rv = -ENOSPC; - goto out; - } - memmove(p+lo+1, p+lo, (bb->count - lo) * 8); - bb->count++; - p[lo] = BB_MAKE(a, s-a, ack); - lo++; - } - p[lo] = BB_MAKE(target, end - target, ack); - /* there is no longer an overlap */ - hi = lo; - lo--; - } - while (lo >= 0 && - (BB_OFFSET(p[lo]) + BB_LEN(p[lo]) > s) && - (BB_OFFSET(p[lo]) < target)) { - /* This range does overlap */ - if (BB_OFFSET(p[lo]) < s) { - /* Keep the early parts of this range. */ - int ack = BB_ACK(p[lo]); - sector_t start = BB_OFFSET(p[lo]); - - p[lo] = BB_MAKE(start, s - start, ack); - /* now low doesn't overlap, so.. */ - break; - } - lo--; + + + prev = prev_badblocks(bb, &bad, hint); + + /* Start before all badblocks */ + if (prev < 0) { + if (overlap_behind(bb, &bad, 0)) { + len = BB_OFFSET(p[0]) - s; + hint = 0; + } else { + len = sectors; } - /* 'lo' is strictly before, 'hi' is strictly after, - * anything between needs to be discarded + /* + * Both situations are to clear non-bad range, + * should be treated as successful */ - if (hi - lo > 1) { - memmove(p+lo+1, p+hi, (bb->count - hi) * 8); - bb->count -= (hi - lo - 1); + cleared++; + goto update_sectors; + } + + /* Start after all badblocks */ + if ((prev + 1) >= bb->count && !overlap_front(bb, prev, &bad)) { + len = sectors; + cleared++; + goto update_sectors; + } + + /* Clear will split a bad record but the table is full */ + if (badblocks_full(bb) && (BB_OFFSET(p[prev]) < bad.start) && + (BB_END(p[prev]) > (bad.start + sectors))) { + len = sectors; + goto update_sectors; + } + + if (overlap_front(bb, prev, &bad)) { + if ((BB_OFFSET(p[prev]) < bad.start) && + (BB_END(p[prev]) > (bad.start + bad.len))) { + /* Splitting */ + if ((bb->count + 1) < MAX_BADBLOCKS) { + len = front_splitting_clear(bb, prev, &bad); + bb->count += 1; + cleared++; + } else { + /* No space to split, give up */ + len = sectors; + } + } else { + int deleted = 0; + + len = front_clear(bb, prev, &bad, &deleted); + bb->count -= deleted; + cleared++; + hint = prev; } + + goto update_sectors; + } + + /* Not front overlap, but behind overlap */ + if ((prev + 1) < bb->count && overlap_behind(bb, &bad, prev + 1)) { + len = BB_OFFSET(p[prev + 1]) - bad.start; + hint = prev + 1; + /* Clear non-bad range should be treated as successful */ + cleared++; + goto update_sectors; + } + + /* Not cover any badblocks range in the table */ + len = sectors; + /* Clear non-bad range should be treated as successful */ + cleared++; + +update_sectors: + s += len; + sectors -= len; + + if (sectors > 0) + goto re_clear; + + WARN_ON(sectors < 0); + + if (cleared) { + badblocks_update_acked(bb); + set_changed(bb); } - badblocks_update_acked(bb); - bb->changed = 1; -out: write_sequnlock_irq(&bb->lock); + + if (!cleared) + rv = 1; + return rv; } + +/* Do the exact work to check bad blocks range from the bad block table */ +static int _badblocks_check(struct badblocks *bb, sector_t s, int sectors, + sector_t *first_bad, int *bad_sectors) +{ + int unacked_badblocks, acked_badblocks; + int prev = -1, hint = -1, set = 0; + struct badblocks_context bad; + unsigned int seq; + int len, rv; + u64 *p; + + WARN_ON(bb->shift < 0 || sectors == 0); + + if (bb->shift > 0) { + sector_t target; + + /* round the start down, and the end up */ + target = s + sectors; + rounddown(s, bb->shift); + roundup(target, bb->shift); + sectors = target - s; + } + +retry: + seq = read_seqbegin(&bb->lock); + + p = bb->page; + unacked_badblocks = 0; + acked_badblocks = 0; + +re_check: + bad.start = s; + bad.len = sectors; + + if (badblocks_empty(bb)) { + len = sectors; + goto update_sectors; + } + + prev = prev_badblocks(bb, &bad, hint); + + /* start after all badblocks */ + if ((prev + 1) >= bb->count && !overlap_front(bb, prev, &bad)) { + len = sectors; + goto update_sectors; + } + + if (overlap_front(bb, prev, &bad)) { + if (BB_ACK(p[prev])) + acked_badblocks++; + else + unacked_badblocks++; + + if (BB_END(p[prev]) >= (s + sectors)) + len = sectors; + else + len = BB_END(p[prev]) - s; + + if (set == 0) { + *first_bad = BB_OFFSET(p[prev]); + *bad_sectors = BB_LEN(p[prev]); + set = 1; + } + goto update_sectors; + } + + /* Not front overlap, but behind overlap */ + if ((prev + 1) < bb->count && overlap_behind(bb, &bad, prev + 1)) { + len = BB_OFFSET(p[prev + 1]) - bad.start; + hint = prev + 1; + goto update_sectors; + } + + /* not cover any badblocks range in the table */ + len = sectors; + +update_sectors: + s += len; + sectors -= len; + + if (sectors > 0) + goto re_check; + + WARN_ON(sectors < 0); + + if (unacked_badblocks > 0) + rv = -1; + else if (acked_badblocks > 0) + rv = 1; + else + rv = 0; + + if (read_seqretry(&bb->lock, seq)) + goto retry; + + return rv; +} + +/** + * badblocks_check() - check a given range for bad sectors + * @bb: the badblocks structure that holds all badblock information + * @s: sector (start) at which to check for badblocks + * @sectors: number of sectors to check for badblocks + * @first_bad: pointer to store location of the first badblock + * @bad_sectors: pointer to store number of badblocks after @first_bad + * + * We can record which blocks on each device are 'bad' and so just + * fail those blocks, or that stripe, rather than the whole device. + * Entries in the bad-block table are 64bits wide. This comprises: + * Length of bad-range, in sectors: 0-511 for lengths 1-512 + * Start of bad-range, sector offset, 54 bits (allows 8 exbibytes) + * A 'shift' can be set so that larger blocks are tracked and + * consequently larger devices can be covered. + * 'Acknowledged' flag - 1 bit. - the most significant bit. + * + * Locking of the bad-block table uses a seqlock so badblocks_check + * might need to retry if it is very unlucky. + * We will sometimes want to check for bad blocks in a bi_end_io function, + * so we use the write_seqlock_irq variant. + * + * When looking for a bad block we specify a range and want to + * know if any block in the range is bad. So we binary-search + * to the last range that starts at-or-before the given endpoint, + * (or "before the sector after the target range") + * then see if it ends after the given start. + * + * Return: + * 0: there are no known bad blocks in the range + * 1: there are known bad block which are all acknowledged + * -1: there are bad blocks which have not yet been acknowledged in metadata. + * plus the start/length of the first bad section we overlap. + */ +int badblocks_check(struct badblocks *bb, sector_t s, int sectors, + sector_t *first_bad, int *bad_sectors) +{ + return _badblocks_check(bb, s, sectors, first_bad, bad_sectors); +} +EXPORT_SYMBOL_GPL(badblocks_check); + +/** + * badblocks_set() - Add a range of bad blocks to the table. + * @bb: the badblocks structure that holds all badblock information + * @s: first sector to mark as bad + * @sectors: number of sectors to mark as bad + * @acknowledged: weather to mark the bad sectors as acknowledged + * + * This might extend the table, or might contract it if two adjacent ranges + * can be merged. We binary-search to find the 'insertion' point, then + * decide how best to handle it. + * + * Return: + * 0: success + * 1: failed to set badblocks (out of space) + */ +int badblocks_set(struct badblocks *bb, sector_t s, int sectors, + int acknowledged) +{ + return _badblocks_set(bb, s, sectors, acknowledged); +} +EXPORT_SYMBOL_GPL(badblocks_set); + +/** + * badblocks_clear() - Remove a range of bad blocks to the table. + * @bb: the badblocks structure that holds all badblock information + * @s: first sector to mark as bad + * @sectors: number of sectors to mark as bad + * + * This may involve extending the table if we spilt a region, + * but it must not fail. So if the table becomes full, we just + * drop the remove request. + * + * Return: + * 0: success + * 1: failed to clear badblocks + */ +int badblocks_clear(struct badblocks *bb, sector_t s, int sectors) +{ + return _badblocks_clear(bb, s, sectors); +} EXPORT_SYMBOL_GPL(badblocks_clear); /** diff --git a/block/partitions/ibm.c b/block/partitions/ibm.c index 403756dbd50d..82d9c4c3fb41 100644 --- a/block/partitions/ibm.c +++ b/block/partitions/ibm.c @@ -61,6 +61,47 @@ static sector_t cchhb2blk(struct vtoc_cchhb *ptr, struct hd_geometry *geo) ptr->b; } +/* Volume Label Type/ID Length */ +#define DASD_VOL_TYPE_LEN 4 +#define DASD_VOL_ID_LEN 6 + +/* Volume Label Types */ +#define DASD_VOLLBL_TYPE_VOL1 0 +#define DASD_VOLLBL_TYPE_LNX1 1 +#define DASD_VOLLBL_TYPE_CMS1 2 + +struct dasd_vollabel { + char *type; + int idx; +}; + +static struct dasd_vollabel dasd_vollabels[] = { + [DASD_VOLLBL_TYPE_VOL1] = { + .type = "VOL1", + .idx = DASD_VOLLBL_TYPE_VOL1, + }, + [DASD_VOLLBL_TYPE_LNX1] = { + .type = "LNX1", + .idx = DASD_VOLLBL_TYPE_LNX1, + }, + [DASD_VOLLBL_TYPE_CMS1] = { + .type = "CMS1", + .idx = DASD_VOLLBL_TYPE_CMS1, + }, +}; + +static int get_label_by_type(const char *type) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(dasd_vollabels); i++) { + if (!memcmp(type, dasd_vollabels[i].type, DASD_VOL_TYPE_LEN)) + return dasd_vollabels[i].idx; + } + + return -1; +} + static int find_label(struct parsed_partitions *state, dasd_information2_t *info, struct hd_geometry *geo, @@ -70,12 +111,10 @@ static int find_label(struct parsed_partitions *state, char type[], union label_t *label) { - Sector sect; - unsigned char *data; sector_t testsect[3]; - unsigned char temp[5]; - int found = 0; int i, testcount; + Sector sect; + void *data; /* There a three places where we may find a valid label: * - on an ECKD disk it's block 2 @@ -103,31 +142,27 @@ static int find_label(struct parsed_partitions *state, if (data == NULL) continue; memcpy(label, data, sizeof(*label)); - memcpy(temp, data, 4); - temp[4] = 0; - EBCASC(temp, 4); + memcpy(type, data, DASD_VOL_TYPE_LEN); + EBCASC(type, DASD_VOL_TYPE_LEN); put_dev_sector(sect); - if (!strcmp(temp, "VOL1") || - !strcmp(temp, "LNX1") || - !strcmp(temp, "CMS1")) { - if (!strcmp(temp, "VOL1")) { - strncpy(type, label->vol.vollbl, 4); - strncpy(name, label->vol.volid, 6); - } else { - strncpy(type, label->lnx.vollbl, 4); - strncpy(name, label->lnx.volid, 6); - } - EBCASC(type, 4); - EBCASC(name, 6); + switch (get_label_by_type(type)) { + case DASD_VOLLBL_TYPE_VOL1: + memcpy(name, label->vol.volid, DASD_VOL_ID_LEN); + EBCASC(name, DASD_VOL_ID_LEN); + *labelsect = testsect[i]; + return 1; + case DASD_VOLLBL_TYPE_LNX1: + case DASD_VOLLBL_TYPE_CMS1: + memcpy(name, label->lnx.volid, DASD_VOL_ID_LEN); + EBCASC(name, DASD_VOL_ID_LEN); *labelsect = testsect[i]; - found = 1; + return 1; + default: break; } } - if (!found) - memset(label, 0, sizeof(*label)); - return found; + return 0; } static int find_vol1_partitions(struct parsed_partitions *state, @@ -297,8 +332,8 @@ int ibm_partition(struct parsed_partitions *state) sector_t nr_sectors; dasd_information2_t *info; struct hd_geometry *geo; - char type[5] = {0,}; - char name[7] = {0,}; + char type[DASD_VOL_TYPE_LEN + 1] = ""; + char name[DASD_VOL_ID_LEN + 1] = ""; sector_t labelsect; union label_t *label; @@ -330,18 +365,21 @@ int ibm_partition(struct parsed_partitions *state) info = NULL; } - if (find_label(state, info, geo, blocksize, &labelsect, name, type, - label)) { - if (!strncmp(type, "VOL1", 4)) { + if (find_label(state, info, geo, blocksize, &labelsect, name, type, label)) { + switch (get_label_by_type(type)) { + case DASD_VOLLBL_TYPE_VOL1: res = find_vol1_partitions(state, geo, blocksize, name, label); - } else if (!strncmp(type, "LNX1", 4)) { + break; + case DASD_VOLLBL_TYPE_LNX1: res = find_lnx1_partitions(state, geo, blocksize, name, label, labelsect, nr_sectors, info); - } else if (!strncmp(type, "CMS1", 4)) { + break; + case DASD_VOLLBL_TYPE_CMS1: res = find_cms1_partitions(state, geo, blocksize, name, label, labelsect); + break; } } else if (info) { /* diff --git a/block/sed-opal.c b/block/sed-opal.c index 6d7f25d1711b..fa23a6a60485 100644 --- a/block/sed-opal.c +++ b/block/sed-opal.c @@ -18,6 +18,7 @@ #include <linux/uaccess.h> #include <uapi/linux/sed-opal.h> #include <linux/sed-opal.h> +#include <linux/sed-opal-key.h> #include <linux/string.h> #include <linux/kdev_t.h> #include <linux/key.h> @@ -3019,7 +3020,13 @@ static int opal_set_new_pw(struct opal_dev *dev, struct opal_new_pw *opal_pw) if (ret) return ret; - /* update keyring with new password */ + /* update keyring and key store with new password */ + ret = sed_write_key(OPAL_AUTH_KEY, + opal_pw->new_user_pw.opal_key.key, + opal_pw->new_user_pw.opal_key.key_len); + if (ret != -EOPNOTSUPP) + pr_warn("error updating SED key: %d\n", ret); + ret = update_sed_opal_key(OPAL_AUTH_KEY, opal_pw->new_user_pw.opal_key.key, opal_pw->new_user_pw.opal_key.key_len); @@ -3292,6 +3299,8 @@ EXPORT_SYMBOL_GPL(sed_ioctl); static int __init sed_opal_init(void) { struct key *kr; + char init_sed_key[OPAL_KEY_MAX]; + int keylen = OPAL_KEY_MAX - 1; kr = keyring_alloc(".sed_opal", GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, current_cred(), @@ -3304,6 +3313,11 @@ static int __init sed_opal_init(void) sed_opal_keyring = kr; - return 0; + if (sed_read_key(OPAL_AUTH_KEY, init_sed_key, &keylen) < 0) { + memset(init_sed_key, '\0', sizeof(init_sed_key)); + keylen = OPAL_KEY_MAX - 1; + } + + return update_sed_opal_key(OPAL_AUTH_KEY, init_sed_key, keylen); } late_initcall(sed_opal_init); diff --git a/drivers/block/aoe/aoenet.c b/drivers/block/aoe/aoenet.c index 63773a90581d..c51ea95bc2ce 100644 --- a/drivers/block/aoe/aoenet.c +++ b/drivers/block/aoe/aoenet.c @@ -39,8 +39,7 @@ static struct ktstate kts; #ifndef MODULE static int __init aoe_iflist_setup(char *str) { - strncpy(aoe_iflist, str, IFLISTSZ); - aoe_iflist[IFLISTSZ - 1] = '\0'; + strscpy(aoe_iflist, str, IFLISTSZ); return 1; } diff --git a/drivers/block/null_blk/main.c b/drivers/block/null_blk/main.c index 79d6cd3c3d41..22a3cf7f32e2 100644 --- a/drivers/block/null_blk/main.c +++ b/drivers/block/null_blk/main.c @@ -1966,7 +1966,7 @@ static int null_gendisk_register(struct nullb *nullb) else disk->fops = &null_bio_ops; disk->private_data = nullb; - strncpy(disk->disk_name, nullb->disk_name, DISK_NAME_LEN); + strscpy_pad(disk->disk_name, nullb->disk_name, DISK_NAME_LEN); if (nullb->dev->zoned) { int ret = null_register_zoned_dev(nullb); diff --git a/drivers/block/ublk_drv.c b/drivers/block/ublk_drv.c index 630ddfe6657b..83600b45e12a 100644 --- a/drivers/block/ublk_drv.c +++ b/drivers/block/ublk_drv.c @@ -75,6 +75,7 @@ struct ublk_rq_data { struct ublk_uring_cmd_pdu { struct ublk_queue *ubq; + u16 tag; }; /* @@ -115,6 +116,9 @@ struct ublk_uring_cmd_pdu { */ #define UBLK_IO_FLAG_NEED_GET_DATA 0x08 +/* atomic RW with ubq->cancel_lock */ +#define UBLK_IO_FLAG_CANCELED 0x80000000 + struct ublk_io { /* userspace buffer address from io cmd */ __u64 addr; @@ -138,13 +142,13 @@ struct ublk_queue { unsigned int max_io_sz; bool force_abort; bool timeout; + bool canceling; unsigned short nr_io_ready; /* how many ios setup */ + spinlock_t cancel_lock; struct ublk_device *dev; struct ublk_io ios[]; }; -#define UBLK_DAEMON_MONITOR_PERIOD (5 * HZ) - struct ublk_device { struct gendisk *ub_disk; @@ -166,7 +170,7 @@ struct ublk_device { struct mutex mutex; - spinlock_t mm_lock; + spinlock_t lock; struct mm_struct *mm; struct ublk_params params; @@ -175,11 +179,6 @@ struct ublk_device { unsigned int nr_queues_ready; unsigned int nr_privileged_daemon; - /* - * Our ubq->daemon may be killed without any notification, so - * monitor each queue's daemon periodically - */ - struct delayed_work monitor_work; struct work_struct quiesce_work; struct work_struct stop_work; }; @@ -190,10 +189,11 @@ struct ublk_params_header { __u32 types; }; +static bool ublk_abort_requests(struct ublk_device *ub, struct ublk_queue *ubq); + static inline unsigned int ublk_req_build_flags(struct request *req); static inline struct ublksrv_io_desc *ublk_get_iod(struct ublk_queue *ubq, int tag); - static inline bool ublk_dev_is_user_copy(const struct ublk_device *ub) { return ub->dev_info.flags & UBLK_F_USER_COPY; @@ -470,6 +470,7 @@ static DEFINE_MUTEX(ublk_ctl_mutex); * It can be extended to one per-user limit in future or even controlled * by cgroup. */ +#define UBLK_MAX_UBLKS UBLK_MINORS static unsigned int ublks_max = 64; static unsigned int ublks_added; /* protected by ublk_ctl_mutex */ @@ -1083,13 +1084,10 @@ static void __ublk_fail_req(struct ublk_queue *ubq, struct ublk_io *io, { WARN_ON_ONCE(io->flags & UBLK_IO_FLAG_ACTIVE); - if (!(io->flags & UBLK_IO_FLAG_ABORTED)) { - io->flags |= UBLK_IO_FLAG_ABORTED; - if (ublk_queue_can_use_recovery_reissue(ubq)) - blk_mq_requeue_request(req, false); - else - ublk_put_req_ref(ubq, req); - } + if (ublk_queue_can_use_recovery_reissue(ubq)) + blk_mq_requeue_request(req, false); + else + ublk_put_req_ref(ubq, req); } static void ubq_complete_io_cmd(struct ublk_io *io, int res, @@ -1118,8 +1116,6 @@ static inline void __ublk_abort_rq(struct ublk_queue *ubq, blk_mq_requeue_request(rq, false); else blk_mq_end_request(rq, BLK_STS_IOERR); - - mod_delayed_work(system_wq, &ubq->dev->monitor_work, 0); } static inline void __ublk_rq_task_work(struct request *req, @@ -1212,15 +1208,6 @@ static inline void ublk_forward_io_cmds(struct ublk_queue *ubq, __ublk_rq_task_work(blk_mq_rq_from_pdu(data), issue_flags); } -static inline void ublk_abort_io_cmds(struct ublk_queue *ubq) -{ - struct llist_node *io_cmds = llist_del_all(&ubq->io_cmds); - struct ublk_rq_data *data, *tmp; - - llist_for_each_entry_safe(data, tmp, io_cmds, node) - __ublk_abort_rq(ubq, blk_mq_rq_from_pdu(data)); -} - static void ublk_rq_task_work_cb(struct io_uring_cmd *cmd, unsigned issue_flags) { struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(cmd); @@ -1232,38 +1219,19 @@ static void ublk_rq_task_work_cb(struct io_uring_cmd *cmd, unsigned issue_flags) static void ublk_queue_cmd(struct ublk_queue *ubq, struct request *rq) { struct ublk_rq_data *data = blk_mq_rq_to_pdu(rq); - struct ublk_io *io; - if (!llist_add(&data->node, &ubq->io_cmds)) - return; + if (llist_add(&data->node, &ubq->io_cmds)) { + struct ublk_io *io = &ubq->ios[rq->tag]; - io = &ubq->ios[rq->tag]; - /* - * If the check pass, we know that this is a re-issued request aborted - * previously in monitor_work because the ubq_daemon(cmd's task) is - * PF_EXITING. We cannot call io_uring_cmd_complete_in_task() anymore - * because this ioucmd's io_uring context may be freed now if no inflight - * ioucmd exists. Otherwise we may cause null-deref in ctx->fallback_work. - * - * Note: monitor_work sets UBLK_IO_FLAG_ABORTED and ends this request(releasing - * the tag). Then the request is re-started(allocating the tag) and we are here. - * Since releasing/allocating a tag implies smp_mb(), finding UBLK_IO_FLAG_ABORTED - * guarantees that here is a re-issued request aborted previously. - */ - if (unlikely(io->flags & UBLK_IO_FLAG_ABORTED)) { - ublk_abort_io_cmds(ubq); - } else { - struct io_uring_cmd *cmd = io->cmd; - struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(cmd); - - pdu->ubq = ubq; - io_uring_cmd_complete_in_task(cmd, ublk_rq_task_work_cb); + io_uring_cmd_complete_in_task(io->cmd, ublk_rq_task_work_cb); } } static enum blk_eh_timer_return ublk_timeout(struct request *rq) { struct ublk_queue *ubq = rq->mq_hctx->driver_data; + unsigned int nr_inflight = 0; + int i; if (ubq->flags & UBLK_F_UNPRIVILEGED_DEV) { if (!ubq->timeout) { @@ -1274,6 +1242,29 @@ static enum blk_eh_timer_return ublk_timeout(struct request *rq) return BLK_EH_DONE; } + if (!ubq_daemon_is_dying(ubq)) + return BLK_EH_RESET_TIMER; + + for (i = 0; i < ubq->q_depth; i++) { + struct ublk_io *io = &ubq->ios[i]; + + if (!(io->flags & UBLK_IO_FLAG_ACTIVE)) + nr_inflight++; + } + + /* cancelable uring_cmd can't help us if all commands are in-flight */ + if (nr_inflight == ubq->q_depth) { + struct ublk_device *ub = ubq->dev; + + if (ublk_abort_requests(ub, ubq)) { + if (ublk_can_use_recovery(ub)) + schedule_work(&ub->quiesce_work); + else + schedule_work(&ub->stop_work); + } + return BLK_EH_DONE; + } + return BLK_EH_RESET_TIMER; } @@ -1301,13 +1292,12 @@ static blk_status_t ublk_queue_rq(struct blk_mq_hw_ctx *hctx, if (ublk_queue_can_use_recovery(ubq) && unlikely(ubq->force_abort)) return BLK_STS_IOERR; - blk_mq_start_request(bd->rq); - - if (unlikely(ubq_daemon_is_dying(ubq))) { + if (unlikely(ubq->canceling)) { __ublk_abort_rq(ubq, rq); return BLK_STS_OK; } + blk_mq_start_request(bd->rq); ublk_queue_cmd(ubq, rq); return BLK_STS_OK; @@ -1357,12 +1347,12 @@ static int ublk_ch_mmap(struct file *filp, struct vm_area_struct *vma) unsigned long pfn, end, phys_off = vma->vm_pgoff << PAGE_SHIFT; int q_id, ret = 0; - spin_lock(&ub->mm_lock); + spin_lock(&ub->lock); if (!ub->mm) ub->mm = current->mm; if (current->mm != ub->mm) ret = -EINVAL; - spin_unlock(&ub->mm_lock); + spin_unlock(&ub->lock); if (ret) return ret; @@ -1411,17 +1401,14 @@ static void ublk_commit_completion(struct ublk_device *ub, } /* - * When ->ubq_daemon is exiting, either new request is ended immediately, - * or any queued io command is drained, so it is safe to abort queue - * lockless + * Called from ubq_daemon context via cancel fn, meantime quiesce ublk + * blk-mq queue, so we are called exclusively with blk-mq and ubq_daemon + * context, so everything is serialized. */ static void ublk_abort_queue(struct ublk_device *ub, struct ublk_queue *ubq) { int i; - if (!ublk_get_device(ub)) - return; - for (i = 0; i < ubq->q_depth; i++) { struct ublk_io *io = &ubq->ios[i]; @@ -1433,72 +1420,114 @@ static void ublk_abort_queue(struct ublk_device *ub, struct ublk_queue *ubq) * will do it */ rq = blk_mq_tag_to_rq(ub->tag_set.tags[ubq->q_id], i); - if (rq) + if (rq && blk_mq_request_started(rq)) { + io->flags |= UBLK_IO_FLAG_ABORTED; __ublk_fail_req(ubq, io, rq); + } } } - ublk_put_device(ub); } -static void ublk_daemon_monitor_work(struct work_struct *work) +static bool ublk_abort_requests(struct ublk_device *ub, struct ublk_queue *ubq) { - struct ublk_device *ub = - container_of(work, struct ublk_device, monitor_work.work); - int i; + struct gendisk *disk; - for (i = 0; i < ub->dev_info.nr_hw_queues; i++) { - struct ublk_queue *ubq = ublk_get_queue(ub, i); + spin_lock(&ubq->cancel_lock); + if (ubq->canceling) { + spin_unlock(&ubq->cancel_lock); + return false; + } + ubq->canceling = true; + spin_unlock(&ubq->cancel_lock); - if (ubq_daemon_is_dying(ubq)) { - if (ublk_queue_can_use_recovery(ubq)) - schedule_work(&ub->quiesce_work); - else - schedule_work(&ub->stop_work); + spin_lock(&ub->lock); + disk = ub->ub_disk; + if (disk) + get_device(disk_to_dev(disk)); + spin_unlock(&ub->lock); - /* abort queue is for making forward progress */ - ublk_abort_queue(ub, ubq); - } - } + /* Our disk has been dead */ + if (!disk) + return false; - /* - * We can't schedule monitor work after ub's state is not UBLK_S_DEV_LIVE. - * after ublk_remove() or __ublk_quiesce_dev() is started. - * - * No need ub->mutex, monitor work are canceled after state is marked - * as not LIVE, so new state is observed reliably. - */ - if (ub->dev_info.state == UBLK_S_DEV_LIVE) - schedule_delayed_work(&ub->monitor_work, - UBLK_DAEMON_MONITOR_PERIOD); -} + /* Now we are serialized with ublk_queue_rq() */ + blk_mq_quiesce_queue(disk->queue); + /* abort queue is for making forward progress */ + ublk_abort_queue(ub, ubq); + blk_mq_unquiesce_queue(disk->queue); + put_device(disk_to_dev(disk)); -static inline bool ublk_queue_ready(struct ublk_queue *ubq) -{ - return ubq->nr_io_ready == ubq->q_depth; + return true; } -static void ublk_cmd_cancel_cb(struct io_uring_cmd *cmd, unsigned issue_flags) +static void ublk_cancel_cmd(struct ublk_queue *ubq, struct ublk_io *io, + unsigned int issue_flags) { - io_uring_cmd_done(cmd, UBLK_IO_RES_ABORT, 0, issue_flags); + bool done; + + if (!(io->flags & UBLK_IO_FLAG_ACTIVE)) + return; + + spin_lock(&ubq->cancel_lock); + done = !!(io->flags & UBLK_IO_FLAG_CANCELED); + if (!done) + io->flags |= UBLK_IO_FLAG_CANCELED; + spin_unlock(&ubq->cancel_lock); + + if (!done) + io_uring_cmd_done(io->cmd, UBLK_IO_RES_ABORT, 0, issue_flags); } -static void ublk_cancel_queue(struct ublk_queue *ubq) +/* + * The ublk char device won't be closed when calling cancel fn, so both + * ublk device and queue are guaranteed to be live + */ +static void ublk_uring_cmd_cancel_fn(struct io_uring_cmd *cmd, + unsigned int issue_flags) { - int i; + struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(cmd); + struct ublk_queue *ubq = pdu->ubq; + struct task_struct *task; + struct ublk_device *ub; + bool need_schedule; + struct ublk_io *io; - if (!ublk_queue_ready(ubq)) + if (WARN_ON_ONCE(!ubq)) return; - for (i = 0; i < ubq->q_depth; i++) { - struct ublk_io *io = &ubq->ios[i]; + if (WARN_ON_ONCE(pdu->tag >= ubq->q_depth)) + return; + + task = io_uring_cmd_get_task(cmd); + if (WARN_ON_ONCE(task && task != ubq->ubq_daemon)) + return; + + ub = ubq->dev; + need_schedule = ublk_abort_requests(ub, ubq); + + io = &ubq->ios[pdu->tag]; + WARN_ON_ONCE(io->cmd != cmd); + ublk_cancel_cmd(ubq, io, issue_flags); - if (io->flags & UBLK_IO_FLAG_ACTIVE) - io_uring_cmd_complete_in_task(io->cmd, - ublk_cmd_cancel_cb); + if (need_schedule) { + if (ublk_can_use_recovery(ub)) + schedule_work(&ub->quiesce_work); + else + schedule_work(&ub->stop_work); } +} - /* all io commands are canceled */ - ubq->nr_io_ready = 0; +static inline bool ublk_queue_ready(struct ublk_queue *ubq) +{ + return ubq->nr_io_ready == ubq->q_depth; +} + +static void ublk_cancel_queue(struct ublk_queue *ubq) +{ + int i; + + for (i = 0; i < ubq->q_depth; i++) + ublk_cancel_cmd(ubq, &ubq->ios[i], IO_URING_F_UNLOCKED); } /* Cancel all pending commands, must be called after del_gendisk() returns */ @@ -1545,16 +1574,6 @@ static void __ublk_quiesce_dev(struct ublk_device *ub) blk_mq_quiesce_queue(ub->ub_disk->queue); ublk_wait_tagset_rqs_idle(ub); ub->dev_info.state = UBLK_S_DEV_QUIESCED; - ublk_cancel_dev(ub); - /* we are going to release task_struct of ubq_daemon and resets - * ->ubq_daemon to NULL. So in monitor_work, check on ubq_daemon causes UAF. - * Besides, monitor_work is not necessary in QUIESCED state since we have - * already scheduled quiesce_work and quiesced all ubqs. - * - * Do not let monitor_work schedule itself if state it QUIESCED. And we cancel - * it here and re-schedule it in END_USER_RECOVERY to avoid UAF. - */ - cancel_delayed_work_sync(&ub->monitor_work); } static void ublk_quiesce_work_fn(struct work_struct *work) @@ -1568,6 +1587,7 @@ static void ublk_quiesce_work_fn(struct work_struct *work) __ublk_quiesce_dev(ub); unlock: mutex_unlock(&ub->mutex); + ublk_cancel_dev(ub); } static void ublk_unquiesce_dev(struct ublk_device *ub) @@ -1593,6 +1613,8 @@ static void ublk_unquiesce_dev(struct ublk_device *ub) static void ublk_stop_dev(struct ublk_device *ub) { + struct gendisk *disk; + mutex_lock(&ub->mutex); if (ub->dev_info.state == UBLK_S_DEV_DEAD) goto unlock; @@ -1602,14 +1624,18 @@ static void ublk_stop_dev(struct ublk_device *ub) ublk_unquiesce_dev(ub); } del_gendisk(ub->ub_disk); + + /* Sync with ublk_abort_queue() by holding the lock */ + spin_lock(&ub->lock); + disk = ub->ub_disk; ub->dev_info.state = UBLK_S_DEV_DEAD; ub->dev_info.ublksrv_pid = -1; - put_disk(ub->ub_disk); ub->ub_disk = NULL; + spin_unlock(&ub->lock); + put_disk(disk); unlock: - ublk_cancel_dev(ub); mutex_unlock(&ub->mutex); - cancel_delayed_work_sync(&ub->monitor_work); + ublk_cancel_dev(ub); } /* device can only be started after all IOs are ready */ @@ -1660,6 +1686,21 @@ static inline void ublk_fill_io_cmd(struct ublk_io *io, io->addr = buf_addr; } +static inline void ublk_prep_cancel(struct io_uring_cmd *cmd, + unsigned int issue_flags, + struct ublk_queue *ubq, unsigned int tag) +{ + struct ublk_uring_cmd_pdu *pdu = ublk_get_uring_cmd_pdu(cmd); + + /* + * Safe to refer to @ubq since ublk_queue won't be died until its + * commands are completed + */ + pdu->ubq = ubq; + pdu->tag = tag; + io_uring_cmd_mark_cancelable(cmd, issue_flags); +} + static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd, unsigned int issue_flags, const struct ublksrv_io_cmd *ub_cmd) @@ -1775,6 +1816,7 @@ static int __ublk_ch_uring_cmd(struct io_uring_cmd *cmd, default: goto out; } + ublk_prep_cancel(cmd, issue_flags, ubq, tag); return -EIOCBQUEUED; out: @@ -1814,7 +1856,8 @@ fail_put: return NULL; } -static int ublk_ch_uring_cmd(struct io_uring_cmd *cmd, unsigned int issue_flags) +static inline int ublk_ch_uring_cmd_local(struct io_uring_cmd *cmd, + unsigned int issue_flags) { /* * Not necessary for async retry, but let's keep it simple and always @@ -1828,9 +1871,33 @@ static int ublk_ch_uring_cmd(struct io_uring_cmd *cmd, unsigned int issue_flags) .addr = READ_ONCE(ub_src->addr) }; + WARN_ON_ONCE(issue_flags & IO_URING_F_UNLOCKED); + return __ublk_ch_uring_cmd(cmd, issue_flags, &ub_cmd); } +static void ublk_ch_uring_cmd_cb(struct io_uring_cmd *cmd, + unsigned int issue_flags) +{ + ublk_ch_uring_cmd_local(cmd, issue_flags); +} + +static int ublk_ch_uring_cmd(struct io_uring_cmd *cmd, unsigned int issue_flags) +{ + if (unlikely(issue_flags & IO_URING_F_CANCEL)) { + ublk_uring_cmd_cancel_fn(cmd, issue_flags); + return 0; + } + + /* well-implemented server won't run into unlocked */ + if (unlikely(issue_flags & IO_URING_F_UNLOCKED)) { + io_uring_cmd_complete_in_task(cmd, ublk_ch_uring_cmd_cb); + return -EIOCBQUEUED; + } + + return ublk_ch_uring_cmd_local(cmd, issue_flags); +} + static inline bool ublk_check_ubuf_dir(const struct request *req, int ubuf_dir) { @@ -1962,6 +2029,7 @@ static int ublk_init_queue(struct ublk_device *ub, int q_id) void *ptr; int size; + spin_lock_init(&ubq->cancel_lock); ubq->flags = ub->dev_info.flags; ubq->q_id = q_id; ubq->q_depth = ub->dev_info.queue_depth; @@ -2026,7 +2094,8 @@ static int ublk_alloc_dev_number(struct ublk_device *ub, int idx) if (err == -ENOSPC) err = -EEXIST; } else { - err = idr_alloc(&ublk_index_idr, ub, 0, 0, GFP_NOWAIT); + err = idr_alloc(&ublk_index_idr, ub, 0, UBLK_MAX_UBLKS, + GFP_NOWAIT); } spin_unlock(&ublk_idr_lock); @@ -2151,8 +2220,6 @@ static int ublk_ctrl_start_dev(struct ublk_device *ub, struct io_uring_cmd *cmd) if (wait_for_completion_interruptible(&ub->completion) != 0) return -EINTR; - schedule_delayed_work(&ub->monitor_work, UBLK_DAEMON_MONITOR_PERIOD); - mutex_lock(&ub->mutex); if (ub->dev_info.state == UBLK_S_DEV_LIVE || test_bit(UB_STATE_USED, &ub->state)) { @@ -2305,6 +2372,12 @@ static int ublk_ctrl_add_dev(struct io_uring_cmd *cmd) return -EINVAL; } + if (header->dev_id != U32_MAX && header->dev_id >= UBLK_MAX_UBLKS) { + pr_warn("%s: dev id is too large. Max supported is %d\n", + __func__, UBLK_MAX_UBLKS - 1); + return -EINVAL; + } + ublk_dump_dev_info(&info); ret = mutex_lock_killable(&ublk_ctl_mutex); @@ -2320,10 +2393,9 @@ static int ublk_ctrl_add_dev(struct io_uring_cmd *cmd) if (!ub) goto out_unlock; mutex_init(&ub->mutex); - spin_lock_init(&ub->mm_lock); + spin_lock_init(&ub->lock); INIT_WORK(&ub->quiesce_work, ublk_quiesce_work_fn); INIT_WORK(&ub->stop_work, ublk_stop_work_fn); - INIT_DELAYED_WORK(&ub->monitor_work, ublk_daemon_monitor_work); ret = ublk_alloc_dev_number(ub, header->dev_id); if (ret < 0) @@ -2569,13 +2641,15 @@ static void ublk_queue_reinit(struct ublk_device *ub, struct ublk_queue *ubq) int i; WARN_ON_ONCE(!(ubq->ubq_daemon && ubq_daemon_is_dying(ubq))); + /* All old ioucmds have to be completed */ - WARN_ON_ONCE(ubq->nr_io_ready); + ubq->nr_io_ready = 0; /* old daemon is PF_EXITING, put it now */ put_task_struct(ubq->ubq_daemon); /* We have to reset it to NULL, otherwise ub won't accept new FETCH_REQ */ ubq->ubq_daemon = NULL; ubq->timeout = false; + ubq->canceling = false; for (i = 0; i < ubq->q_depth; i++) { struct ublk_io *io = &ubq->ios[i]; @@ -2661,7 +2735,6 @@ static int ublk_ctrl_end_recovery(struct ublk_device *ub, __func__, header->dev_id); blk_mq_kick_requeue_list(ub->ub_disk->queue); ub->dev_info.state = UBLK_S_DEV_LIVE; - schedule_delayed_work(&ub->monitor_work, UBLK_DAEMON_MONITOR_PERIOD); ret = 0; out_unlock: mutex_unlock(&ub->mutex); @@ -2932,7 +3005,22 @@ static void __exit ublk_exit(void) module_init(ublk_init); module_exit(ublk_exit); -module_param(ublks_max, int, 0444); +static int ublk_set_max_ublks(const char *buf, const struct kernel_param *kp) +{ + return param_set_uint_minmax(buf, kp, 0, UBLK_MAX_UBLKS); +} + +static int ublk_get_max_ublks(char *buf, const struct kernel_param *kp) +{ + return sysfs_emit(buf, "%u\n", ublks_max); +} + +static const struct kernel_param_ops ublk_max_ublks_ops = { + .set = ublk_set_max_ublks, + .get = ublk_get_max_ublks, +}; + +module_param_cb(ublks_max, &ublk_max_ublks_ops, &ublks_max, 0644); MODULE_PARM_DESC(ublks_max, "max number of ublk devices allowed to add(default: 64)"); MODULE_AUTHOR("Ming Lei <[email protected]>"); diff --git a/drivers/cdrom/cdrom.c b/drivers/cdrom/cdrom.c index cc2839805983..a5e07270e0d4 100644 --- a/drivers/cdrom/cdrom.c +++ b/drivers/cdrom/cdrom.c @@ -3655,7 +3655,6 @@ static struct ctl_table cdrom_table[] = { .mode = 0644, .proc_handler = cdrom_sysctl_handler }, - { } }; static struct ctl_table_header *cdrom_sysctl_header; diff --git a/drivers/nvme/common/Kconfig b/drivers/nvme/common/Kconfig index 4514f44362dd..06c8df00d1e2 100644 --- a/drivers/nvme/common/Kconfig +++ b/drivers/nvme/common/Kconfig @@ -2,3 +2,16 @@ config NVME_COMMON tristate + +config NVME_KEYRING + bool + select KEYS + +config NVME_AUTH + bool + select CRYPTO + select CRYPTO_HMAC + select CRYPTO_SHA256 + select CRYPTO_SHA512 + select CRYPTO_DH + select CRYPTO_DH_RFC7919_GROUPS diff --git a/drivers/nvme/common/Makefile b/drivers/nvme/common/Makefile index 720c625b8a52..0cbd0b0b8d49 100644 --- a/drivers/nvme/common/Makefile +++ b/drivers/nvme/common/Makefile @@ -4,4 +4,5 @@ ccflags-y += -I$(src) obj-$(CONFIG_NVME_COMMON) += nvme-common.o -nvme-common-y += auth.o +nvme-common-$(CONFIG_NVME_AUTH) += auth.o +nvme-common-$(CONFIG_NVME_KEYRING) += keyring.o diff --git a/drivers/nvme/common/auth.c b/drivers/nvme/common/auth.c index d90e4f0c08b7..a8e87dfbeab2 100644 --- a/drivers/nvme/common/auth.c +++ b/drivers/nvme/common/auth.c @@ -150,6 +150,14 @@ size_t nvme_auth_hmac_hash_len(u8 hmac_id) } EXPORT_SYMBOL_GPL(nvme_auth_hmac_hash_len); +u32 nvme_auth_key_struct_size(u32 key_len) +{ + struct nvme_dhchap_key key; + + return struct_size(&key, key, key_len); +} +EXPORT_SYMBOL_GPL(nvme_auth_key_struct_size); + struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret, u8 key_hash) { @@ -163,14 +171,9 @@ struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret, p = strrchr(secret, ':'); if (p) allocated_len = p - secret; - key = kzalloc(sizeof(*key), GFP_KERNEL); + key = nvme_auth_alloc_key(allocated_len, 0); if (!key) return ERR_PTR(-ENOMEM); - key->key = kzalloc(allocated_len, GFP_KERNEL); - if (!key->key) { - ret = -ENOMEM; - goto out_free_key; - } key_len = base64_decode(secret, allocated_len, key->key); if (key_len < 0) { @@ -187,14 +190,6 @@ struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret, goto out_free_secret; } - if (key_hash > 0 && - (key_len - 4) != nvme_auth_hmac_hash_len(key_hash)) { - pr_err("Mismatched key len %d for %s\n", key_len, - nvme_auth_hmac_name(key_hash)); - ret = -EINVAL; - goto out_free_secret; - } - /* The last four bytes is the CRC in little-endian format */ key_len -= 4; /* @@ -213,37 +208,51 @@ struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret, key->hash = key_hash; return key; out_free_secret: - kfree_sensitive(key->key); -out_free_key: - kfree(key); + nvme_auth_free_key(key); return ERR_PTR(ret); } EXPORT_SYMBOL_GPL(nvme_auth_extract_key); +struct nvme_dhchap_key *nvme_auth_alloc_key(u32 len, u8 hash) +{ + u32 num_bytes = nvme_auth_key_struct_size(len); + struct nvme_dhchap_key *key = kzalloc(num_bytes, GFP_KERNEL); + + if (key) { + key->len = len; + key->hash = hash; + } + return key; +} +EXPORT_SYMBOL_GPL(nvme_auth_alloc_key); + void nvme_auth_free_key(struct nvme_dhchap_key *key) { if (!key) return; - kfree_sensitive(key->key); - kfree(key); + kfree_sensitive(key); } EXPORT_SYMBOL_GPL(nvme_auth_free_key); -u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn) +struct nvme_dhchap_key *nvme_auth_transform_key( + struct nvme_dhchap_key *key, char *nqn) { const char *hmac_name; struct crypto_shash *key_tfm; struct shash_desc *shash; - u8 *transformed_key; - int ret; + struct nvme_dhchap_key *transformed_key; + int ret, key_len; - if (!key || !key->key) { + if (!key) { pr_warn("No key specified\n"); return ERR_PTR(-ENOKEY); } if (key->hash == 0) { - transformed_key = kmemdup(key->key, key->len, GFP_KERNEL); - return transformed_key ? transformed_key : ERR_PTR(-ENOMEM); + key_len = nvme_auth_key_struct_size(key->len); + transformed_key = kmemdup(key, key_len, GFP_KERNEL); + if (!transformed_key) + return ERR_PTR(-ENOMEM); + return transformed_key; } hmac_name = nvme_auth_hmac_name(key->hash); if (!hmac_name) { @@ -253,7 +262,7 @@ u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn) key_tfm = crypto_alloc_shash(hmac_name, 0, 0); if (IS_ERR(key_tfm)) - return (u8 *)key_tfm; + return ERR_CAST(key_tfm); shash = kmalloc(sizeof(struct shash_desc) + crypto_shash_descsize(key_tfm), @@ -263,7 +272,8 @@ u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn) goto out_free_key; } - transformed_key = kzalloc(crypto_shash_digestsize(key_tfm), GFP_KERNEL); + key_len = crypto_shash_digestsize(key_tfm); + transformed_key = nvme_auth_alloc_key(key_len, key->hash); if (!transformed_key) { ret = -ENOMEM; goto out_free_shash; @@ -282,7 +292,7 @@ u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn) ret = crypto_shash_update(shash, "NVMe-over-Fabrics", 17); if (ret < 0) goto out_free_transformed_key; - ret = crypto_shash_final(shash, transformed_key); + ret = crypto_shash_final(shash, transformed_key->key); if (ret < 0) goto out_free_transformed_key; @@ -292,7 +302,7 @@ u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn) return transformed_key; out_free_transformed_key: - kfree_sensitive(transformed_key); + nvme_auth_free_key(transformed_key); out_free_shash: kfree(shash); out_free_key: diff --git a/drivers/nvme/common/keyring.c b/drivers/nvme/common/keyring.c new file mode 100644 index 000000000000..f8d9a208397b --- /dev/null +++ b/drivers/nvme/common/keyring.c @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2023 Hannes Reinecke, SUSE Labs + */ + +#include <linux/module.h> +#include <linux/seq_file.h> +#include <linux/key.h> +#include <linux/key-type.h> +#include <keys/user-type.h> +#include <linux/nvme.h> +#include <linux/nvme-tcp.h> +#include <linux/nvme-keyring.h> + +static struct key *nvme_keyring; + +key_serial_t nvme_keyring_id(void) +{ + return nvme_keyring->serial; +} +EXPORT_SYMBOL_GPL(nvme_keyring_id); + +static void nvme_tls_psk_describe(const struct key *key, struct seq_file *m) +{ + seq_puts(m, key->description); + seq_printf(m, ": %u", key->datalen); +} + +static bool nvme_tls_psk_match(const struct key *key, + const struct key_match_data *match_data) +{ + const char *match_id; + size_t match_len; + + if (!key->description) { + pr_debug("%s: no key description\n", __func__); + return false; + } + match_len = strlen(key->description); + pr_debug("%s: id %s len %zd\n", __func__, key->description, match_len); + + if (!match_data->raw_data) { + pr_debug("%s: no match data\n", __func__); + return false; + } + match_id = match_data->raw_data; + pr_debug("%s: match '%s' '%s' len %zd\n", + __func__, match_id, key->description, match_len); + return !memcmp(key->description, match_id, match_len); +} + +static int nvme_tls_psk_match_preparse(struct key_match_data *match_data) +{ + match_data->lookup_type = KEYRING_SEARCH_LOOKUP_ITERATE; + match_data->cmp = nvme_tls_psk_match; + return 0; +} + +static struct key_type nvme_tls_psk_key_type = { + .name = "psk", + .flags = KEY_TYPE_NET_DOMAIN, + .preparse = user_preparse, + .free_preparse = user_free_preparse, + .match_preparse = nvme_tls_psk_match_preparse, + .instantiate = generic_key_instantiate, + .revoke = user_revoke, + .destroy = user_destroy, + .describe = nvme_tls_psk_describe, + .read = user_read, +}; + +static struct key *nvme_tls_psk_lookup(struct key *keyring, + const char *hostnqn, const char *subnqn, + int hmac, bool generated) +{ + char *identity; + size_t identity_len = (NVMF_NQN_SIZE) * 2 + 11; + key_ref_t keyref; + key_serial_t keyring_id; + + identity = kzalloc(identity_len, GFP_KERNEL); + if (!identity) + return ERR_PTR(-ENOMEM); + + snprintf(identity, identity_len, "NVMe0%c%02d %s %s", + generated ? 'G' : 'R', hmac, hostnqn, subnqn); + + if (!keyring) + keyring = nvme_keyring; + keyring_id = key_serial(keyring); + pr_debug("keyring %x lookup tls psk '%s'\n", + keyring_id, identity); + keyref = keyring_search(make_key_ref(keyring, true), + &nvme_tls_psk_key_type, + identity, false); + if (IS_ERR(keyref)) { + pr_debug("lookup tls psk '%s' failed, error %ld\n", + identity, PTR_ERR(keyref)); + kfree(identity); + return ERR_PTR(-ENOKEY); + } + kfree(identity); + + return key_ref_to_ptr(keyref); +} + +/* + * NVMe PSK priority list + * + * 'Retained' PSKs (ie 'generated == false') + * should be preferred to 'generated' PSKs, + * and SHA-384 should be preferred to SHA-256. + */ +struct nvme_tls_psk_priority_list { + bool generated; + enum nvme_tcp_tls_cipher cipher; +} nvme_tls_psk_prio[] = { + { .generated = false, + .cipher = NVME_TCP_TLS_CIPHER_SHA384, }, + { .generated = false, + .cipher = NVME_TCP_TLS_CIPHER_SHA256, }, + { .generated = true, + .cipher = NVME_TCP_TLS_CIPHER_SHA384, }, + { .generated = true, + .cipher = NVME_TCP_TLS_CIPHER_SHA256, }, +}; + +/* + * nvme_tls_psk_default - Return the preferred PSK to use for TLS ClientHello + */ +key_serial_t nvme_tls_psk_default(struct key *keyring, + const char *hostnqn, const char *subnqn) +{ + struct key *tls_key; + key_serial_t tls_key_id; + int prio; + + for (prio = 0; prio < ARRAY_SIZE(nvme_tls_psk_prio); prio++) { + bool generated = nvme_tls_psk_prio[prio].generated; + enum nvme_tcp_tls_cipher cipher = nvme_tls_psk_prio[prio].cipher; + + tls_key = nvme_tls_psk_lookup(keyring, hostnqn, subnqn, + cipher, generated); + if (!IS_ERR(tls_key)) { + tls_key_id = tls_key->serial; + key_put(tls_key); + return tls_key_id; + } + } + return 0; +} +EXPORT_SYMBOL_GPL(nvme_tls_psk_default); + +int nvme_keyring_init(void) +{ + int err; + + nvme_keyring = keyring_alloc(".nvme", + GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, + current_cred(), + (KEY_POS_ALL & ~KEY_POS_SETATTR) | + (KEY_USR_ALL & ~KEY_USR_SETATTR), + KEY_ALLOC_NOT_IN_QUOTA, NULL, NULL); + if (IS_ERR(nvme_keyring)) + return PTR_ERR(nvme_keyring); + + err = register_key_type(&nvme_tls_psk_key_type); + if (err) { + key_put(nvme_keyring); + return err; + } + return 0; +} +EXPORT_SYMBOL_GPL(nvme_keyring_init); + +void nvme_keyring_exit(void) +{ + unregister_key_type(&nvme_tls_psk_key_type); + key_revoke(nvme_keyring); + key_put(nvme_keyring); +} +EXPORT_SYMBOL_GPL(nvme_keyring_exit); diff --git a/drivers/nvme/host/Kconfig b/drivers/nvme/host/Kconfig index 2f6a7f8c94e8..48f7d72de5e9 100644 --- a/drivers/nvme/host/Kconfig +++ b/drivers/nvme/host/Kconfig @@ -92,16 +92,26 @@ config NVME_TCP If unsure, say N. -config NVME_AUTH +config NVME_TCP_TLS + bool "NVMe over Fabrics TCP TLS encryption support" + depends on NVME_TCP + select NVME_COMMON + select NVME_KEYRING + select NET_HANDSHAKE + select KEYS + help + Enables TLS encryption for NVMe TCP using the netlink handshake API. + + The TLS handshake daemon is availble at + https://github.com/oracle/ktls-utils. + + If unsure, say N. + +config NVME_HOST_AUTH bool "NVM Express over Fabrics In-Band Authentication" depends on NVME_CORE select NVME_COMMON - select CRYPTO - select CRYPTO_HMAC - select CRYPTO_SHA256 - select CRYPTO_SHA512 - select CRYPTO_DH - select CRYPTO_DH_RFC7919_GROUPS + select NVME_AUTH help This provides support for NVMe over Fabrics In-Band Authentication. diff --git a/drivers/nvme/host/Makefile b/drivers/nvme/host/Makefile index c7c3cf202d12..6414ec968f99 100644 --- a/drivers/nvme/host/Makefile +++ b/drivers/nvme/host/Makefile @@ -17,7 +17,7 @@ nvme-core-$(CONFIG_NVME_MULTIPATH) += multipath.o nvme-core-$(CONFIG_BLK_DEV_ZONED) += zns.o nvme-core-$(CONFIG_FAULT_INJECTION_DEBUG_FS) += fault_inject.o nvme-core-$(CONFIG_NVME_HWMON) += hwmon.o -nvme-core-$(CONFIG_NVME_AUTH) += auth.o +nvme-core-$(CONFIG_NVME_HOST_AUTH) += auth.o nvme-y += pci.o diff --git a/drivers/nvme/host/auth.c b/drivers/nvme/host/auth.c index daf5d144a8ea..de1390d705dc 100644 --- a/drivers/nvme/host/auth.c +++ b/drivers/nvme/host/auth.c @@ -23,6 +23,7 @@ struct nvme_dhchap_queue_context { struct nvme_ctrl *ctrl; struct crypto_shash *shash_tfm; struct crypto_kpp *dh_tfm; + struct nvme_dhchap_key *transformed_key; void *buf; int qid; int error; @@ -36,7 +37,6 @@ struct nvme_dhchap_queue_context { u8 c1[64]; u8 c2[64]; u8 response[64]; - u8 *host_response; u8 *ctrl_key; u8 *host_key; u8 *sess_key; @@ -428,12 +428,12 @@ static int nvme_auth_dhchap_setup_host_response(struct nvme_ctrl *ctrl, dev_dbg(ctrl->device, "%s: qid %d host response seq %u transaction %d\n", __func__, chap->qid, chap->s1, chap->transaction); - if (!chap->host_response) { - chap->host_response = nvme_auth_transform_key(ctrl->host_key, + if (!chap->transformed_key) { + chap->transformed_key = nvme_auth_transform_key(ctrl->host_key, ctrl->opts->host->nqn); - if (IS_ERR(chap->host_response)) { - ret = PTR_ERR(chap->host_response); - chap->host_response = NULL; + if (IS_ERR(chap->transformed_key)) { + ret = PTR_ERR(chap->transformed_key); + chap->transformed_key = NULL; return ret; } } else { @@ -442,7 +442,7 @@ static int nvme_auth_dhchap_setup_host_response(struct nvme_ctrl *ctrl, } ret = crypto_shash_setkey(chap->shash_tfm, - chap->host_response, ctrl->host_key->len); + chap->transformed_key->key, chap->transformed_key->len); if (ret) { dev_warn(ctrl->device, "qid %d: failed to set key, error %d\n", chap->qid, ret); @@ -508,19 +508,19 @@ static int nvme_auth_dhchap_setup_ctrl_response(struct nvme_ctrl *ctrl, struct nvme_dhchap_queue_context *chap) { SHASH_DESC_ON_STACK(shash, chap->shash_tfm); - u8 *ctrl_response; + struct nvme_dhchap_key *transformed_key; u8 buf[4], *challenge = chap->c2; int ret; - ctrl_response = nvme_auth_transform_key(ctrl->ctrl_key, + transformed_key = nvme_auth_transform_key(ctrl->ctrl_key, ctrl->opts->subsysnqn); - if (IS_ERR(ctrl_response)) { - ret = PTR_ERR(ctrl_response); + if (IS_ERR(transformed_key)) { + ret = PTR_ERR(transformed_key); return ret; } ret = crypto_shash_setkey(chap->shash_tfm, - ctrl_response, ctrl->ctrl_key->len); + transformed_key->key, transformed_key->len); if (ret) { dev_warn(ctrl->device, "qid %d: failed to set key, error %d\n", chap->qid, ret); @@ -586,7 +586,7 @@ static int nvme_auth_dhchap_setup_ctrl_response(struct nvme_ctrl *ctrl, out: if (challenge != chap->c2) kfree(challenge); - kfree(ctrl_response); + nvme_auth_free_key(transformed_key); return ret; } @@ -648,8 +648,8 @@ gen_sesskey: static void nvme_auth_reset_dhchap(struct nvme_dhchap_queue_context *chap) { - kfree_sensitive(chap->host_response); - chap->host_response = NULL; + nvme_auth_free_key(chap->transformed_key); + chap->transformed_key = NULL; kfree_sensitive(chap->host_key); chap->host_key = NULL; chap->host_key_len = 0; diff --git a/drivers/nvme/host/core.c b/drivers/nvme/host/core.c index 21783aa2ee8e..62612f87aafa 100644 --- a/drivers/nvme/host/core.c +++ b/drivers/nvme/host/core.c @@ -25,6 +25,7 @@ #include "nvme.h" #include "fabrics.h" #include <linux/nvme-auth.h> +#include <linux/nvme-keyring.h> #define CREATE_TRACE_POINTS #include "trace.h" @@ -420,7 +421,7 @@ void nvme_complete_rq(struct request *req) nvme_failover_req(req); return; case AUTHENTICATE: -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH queue_work(nvme_wq, &ctrl->dhchap_auth_work); nvme_retry_req(req); #else @@ -4399,7 +4400,7 @@ static void nvme_free_ctrl(struct device *dev) if (!subsys || ctrl->instance != subsys->instance) ida_free(&nvme_instance_ida, ctrl->instance); - + key_put(ctrl->tls_key); nvme_free_cels(ctrl); nvme_mpath_uninit(ctrl); nvme_auth_stop(ctrl); @@ -4723,12 +4724,16 @@ static int __init nvme_core_init(void) result = PTR_ERR(nvme_ns_chr_class); goto unregister_generic_ns; } - - result = nvme_init_auth(); + result = nvme_keyring_init(); if (result) goto destroy_ns_chr; + result = nvme_init_auth(); + if (result) + goto keyring_exit; return 0; +keyring_exit: + nvme_keyring_exit(); destroy_ns_chr: class_destroy(nvme_ns_chr_class); unregister_generic_ns: @@ -4752,6 +4757,7 @@ out: static void __exit nvme_core_exit(void) { nvme_exit_auth(); + nvme_keyring_exit(); class_destroy(nvme_ns_chr_class); class_destroy(nvme_subsys_class); class_destroy(nvme_class); diff --git a/drivers/nvme/host/fabrics.c b/drivers/nvme/host/fabrics.c index 8175d49f2909..4673ead69c5f 100644 --- a/drivers/nvme/host/fabrics.c +++ b/drivers/nvme/host/fabrics.c @@ -12,6 +12,7 @@ #include <linux/seq_file.h> #include "nvme.h" #include "fabrics.h" +#include <linux/nvme-keyring.h> static LIST_HEAD(nvmf_transports); static DECLARE_RWSEM(nvmf_transports_rwsem); @@ -622,6 +623,23 @@ static struct nvmf_transport_ops *nvmf_lookup_transport( return NULL; } +static struct key *nvmf_parse_key(int key_id) +{ + struct key *key; + + if (!IS_ENABLED(CONFIG_NVME_TCP_TLS)) { + pr_err("TLS is not supported\n"); + return ERR_PTR(-EINVAL); + } + + key = key_lookup(key_id); + if (!IS_ERR(key)) + pr_err("key id %08x not found\n", key_id); + else + pr_debug("Using key id %08x\n", key_id); + return key; +} + static const match_table_t opt_tokens = { { NVMF_OPT_TRANSPORT, "transport=%s" }, { NVMF_OPT_TRADDR, "traddr=%s" }, @@ -643,10 +661,17 @@ static const match_table_t opt_tokens = { { NVMF_OPT_NR_WRITE_QUEUES, "nr_write_queues=%d" }, { NVMF_OPT_NR_POLL_QUEUES, "nr_poll_queues=%d" }, { NVMF_OPT_TOS, "tos=%d" }, +#ifdef CONFIG_NVME_TCP_TLS + { NVMF_OPT_KEYRING, "keyring=%d" }, + { NVMF_OPT_TLS_KEY, "tls_key=%d" }, +#endif { NVMF_OPT_FAIL_FAST_TMO, "fast_io_fail_tmo=%d" }, { NVMF_OPT_DISCOVERY, "discovery" }, { NVMF_OPT_DHCHAP_SECRET, "dhchap_secret=%s" }, { NVMF_OPT_DHCHAP_CTRL_SECRET, "dhchap_ctrl_secret=%s" }, +#ifdef CONFIG_NVME_TCP_TLS + { NVMF_OPT_TLS, "tls" }, +#endif { NVMF_OPT_ERR, NULL } }; @@ -657,9 +682,10 @@ static int nvmf_parse_options(struct nvmf_ctrl_options *opts, char *options, *o, *p; int token, ret = 0; size_t nqnlen = 0; - int ctrl_loss_tmo = NVMF_DEF_CTRL_LOSS_TMO; + int ctrl_loss_tmo = NVMF_DEF_CTRL_LOSS_TMO, key_id; uuid_t hostid; char hostnqn[NVMF_NQN_SIZE]; + struct key *key; /* Set defaults */ opts->queue_size = NVMF_DEF_QUEUE_SIZE; @@ -671,6 +697,9 @@ static int nvmf_parse_options(struct nvmf_ctrl_options *opts, opts->hdr_digest = false; opts->data_digest = false; opts->tos = -1; /* < 0 == use transport default */ + opts->tls = false; + opts->tls_key = NULL; + opts->keyring = NULL; options = o = kstrdup(buf, GFP_KERNEL); if (!options) @@ -924,6 +953,32 @@ static int nvmf_parse_options(struct nvmf_ctrl_options *opts, } opts->tos = token; break; + case NVMF_OPT_KEYRING: + if (match_int(args, &key_id) || key_id <= 0) { + ret = -EINVAL; + goto out; + } + key = nvmf_parse_key(key_id); + if (IS_ERR(key)) { + ret = PTR_ERR(key); + goto out; + } + key_put(opts->keyring); + opts->keyring = key; + break; + case NVMF_OPT_TLS_KEY: + if (match_int(args, &key_id) || key_id <= 0) { + ret = -EINVAL; + goto out; + } + key = nvmf_parse_key(key_id); + if (IS_ERR(key)) { + ret = PTR_ERR(key); + goto out; + } + key_put(opts->tls_key); + opts->tls_key = key; + break; case NVMF_OPT_DISCOVERY: opts->discovery_nqn = true; break; @@ -955,6 +1010,14 @@ static int nvmf_parse_options(struct nvmf_ctrl_options *opts, kfree(opts->dhchap_ctrl_secret); opts->dhchap_ctrl_secret = p; break; + case NVMF_OPT_TLS: + if (!IS_ENABLED(CONFIG_NVME_TCP_TLS)) { + pr_err("TLS is not supported\n"); + ret = -EINVAL; + goto out; + } + opts->tls = true; + break; default: pr_warn("unknown parameter or missing value '%s' in ctrl creation request\n", p); @@ -1156,6 +1219,8 @@ static int nvmf_check_allowed_opts(struct nvmf_ctrl_options *opts, void nvmf_free_options(struct nvmf_ctrl_options *opts) { nvmf_host_put(opts->host); + key_put(opts->keyring); + key_put(opts->tls_key); kfree(opts->transport); kfree(opts->traddr); kfree(opts->trsvcid); diff --git a/drivers/nvme/host/fabrics.h b/drivers/nvme/host/fabrics.h index 82e7a27ffbde..fbaee5a7be19 100644 --- a/drivers/nvme/host/fabrics.h +++ b/drivers/nvme/host/fabrics.h @@ -70,6 +70,9 @@ enum { NVMF_OPT_DISCOVERY = 1 << 22, NVMF_OPT_DHCHAP_SECRET = 1 << 23, NVMF_OPT_DHCHAP_CTRL_SECRET = 1 << 24, + NVMF_OPT_TLS = 1 << 25, + NVMF_OPT_KEYRING = 1 << 26, + NVMF_OPT_TLS_KEY = 1 << 27, }; /** @@ -102,6 +105,9 @@ enum { * @dhchap_secret: DH-HMAC-CHAP secret * @dhchap_ctrl_secret: DH-HMAC-CHAP controller secret for bi-directional * authentication + * @keyring: Keyring to use for key lookups + * @tls_key: TLS key for encrypted connections (TCP) + * @tls: Start TLS encrypted connections (TCP) * @disable_sqflow: disable controller sq flow control * @hdr_digest: generate/verify header digest (TCP) * @data_digest: generate/verify data digest (TCP) @@ -128,6 +134,9 @@ struct nvmf_ctrl_options { struct nvmf_host *host; char *dhchap_secret; char *dhchap_ctrl_secret; + struct key *keyring; + struct key *tls_key; + bool tls; bool disable_sqflow; bool hdr_digest; bool data_digest; diff --git a/drivers/nvme/host/nvme.h b/drivers/nvme/host/nvme.h index f35647c470af..39a90b7cb125 100644 --- a/drivers/nvme/host/nvme.h +++ b/drivers/nvme/host/nvme.h @@ -349,7 +349,7 @@ struct nvme_ctrl { struct work_struct ana_work; #endif -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH struct work_struct dhchap_auth_work; struct mutex dhchap_auth_mutex; struct nvme_dhchap_queue_context *dhchap_ctxs; @@ -357,6 +357,7 @@ struct nvme_ctrl { struct nvme_dhchap_key *ctrl_key; u16 transaction; #endif + struct key *tls_key; /* Power saving configuration */ u64 ps_max_latency_us; @@ -1048,7 +1049,7 @@ static inline bool nvme_ctrl_sgl_supported(struct nvme_ctrl *ctrl) return ctrl->sgls & ((1 << 0) | (1 << 1)); } -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH int __init nvme_init_auth(void); void __exit nvme_exit_auth(void); int nvme_auth_init_ctrl(struct nvme_ctrl *ctrl); diff --git a/drivers/nvme/host/sysfs.c b/drivers/nvme/host/sysfs.c index 212e1b05d298..c6b7fbd4d34d 100644 --- a/drivers/nvme/host/sysfs.c +++ b/drivers/nvme/host/sysfs.c @@ -409,7 +409,7 @@ static ssize_t dctype_show(struct device *dev, } static DEVICE_ATTR_RO(dctype); -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH static ssize_t nvme_ctrl_dhchap_secret_show(struct device *dev, struct device_attribute *attr, char *buf) { @@ -527,6 +527,19 @@ static DEVICE_ATTR(dhchap_ctrl_secret, S_IRUGO | S_IWUSR, nvme_ctrl_dhchap_ctrl_secret_show, nvme_ctrl_dhchap_ctrl_secret_store); #endif +#ifdef CONFIG_NVME_TCP_TLS +static ssize_t tls_key_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct nvme_ctrl *ctrl = dev_get_drvdata(dev); + + if (!ctrl->tls_key) + return 0; + return sysfs_emit(buf, "%08x", key_serial(ctrl->tls_key)); +} +static DEVICE_ATTR_RO(tls_key); +#endif + static struct attribute *nvme_dev_attrs[] = { &dev_attr_reset_controller.attr, &dev_attr_rescan_controller.attr, @@ -550,10 +563,13 @@ static struct attribute *nvme_dev_attrs[] = { &dev_attr_kato.attr, &dev_attr_cntrltype.attr, &dev_attr_dctype.attr, -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH &dev_attr_dhchap_secret.attr, &dev_attr_dhchap_ctrl_secret.attr, #endif +#ifdef CONFIG_NVME_TCP_TLS + &dev_attr_tls_key.attr, +#endif NULL }; @@ -577,12 +593,17 @@ static umode_t nvme_dev_attrs_are_visible(struct kobject *kobj, return 0; if (a == &dev_attr_fast_io_fail_tmo.attr && !ctrl->opts) return 0; -#ifdef CONFIG_NVME_AUTH +#ifdef CONFIG_NVME_HOST_AUTH if (a == &dev_attr_dhchap_secret.attr && !ctrl->opts) return 0; if (a == &dev_attr_dhchap_ctrl_secret.attr && !ctrl->opts) return 0; #endif +#ifdef CONFIG_NVME_TCP_TLS + if (a == &dev_attr_tls_key.attr && + (!ctrl->opts || strcmp(ctrl->opts->transport, "tcp"))) + return 0; +#endif return a->mode; } diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c index 5b332d9f87fc..4714a902f4ca 100644 --- a/drivers/nvme/host/tcp.c +++ b/drivers/nvme/host/tcp.c @@ -8,9 +8,14 @@ #include <linux/init.h> #include <linux/slab.h> #include <linux/err.h> +#include <linux/key.h> #include <linux/nvme-tcp.h> +#include <linux/nvme-keyring.h> #include <net/sock.h> #include <net/tcp.h> +#include <net/tls.h> +#include <net/tls_prot.h> +#include <net/handshake.h> #include <linux/blk-mq.h> #include <crypto/hash.h> #include <net/busy_poll.h> @@ -31,6 +36,16 @@ static int so_priority; module_param(so_priority, int, 0644); MODULE_PARM_DESC(so_priority, "nvme tcp socket optimize priority"); +#ifdef CONFIG_NVME_TCP_TLS +/* + * TLS handshake timeout + */ +static int tls_handshake_timeout = 10; +module_param(tls_handshake_timeout, int, 0644); +MODULE_PARM_DESC(tls_handshake_timeout, + "nvme TLS handshake timeout in seconds (default 10)"); +#endif + #ifdef CONFIG_DEBUG_LOCK_ALLOC /* lockdep can detect a circular dependency of the form * sk_lock -> mmap_lock (page fault) -> fs locks -> sk_lock @@ -146,7 +161,10 @@ struct nvme_tcp_queue { struct ahash_request *snd_hash; __le32 exp_ddgst; __le32 recv_ddgst; - +#ifdef CONFIG_NVME_TCP_TLS + struct completion tls_complete; + int tls_err; +#endif struct page_frag_cache pf_cache; void (*state_change)(struct sock *); @@ -1338,7 +1356,9 @@ static void nvme_tcp_free_queue(struct nvme_ctrl *nctrl, int qid) } noreclaim_flag = memalloc_noreclaim_save(); - sock_release(queue->sock); + /* ->sock will be released by fput() */ + fput(queue->sock->file); + queue->sock = NULL; memalloc_noreclaim_restore(noreclaim_flag); kfree(queue->pdu); @@ -1350,6 +1370,8 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue) { struct nvme_tcp_icreq_pdu *icreq; struct nvme_tcp_icresp_pdu *icresp; + char cbuf[CMSG_LEN(sizeof(char))] = {}; + u8 ctype; struct msghdr msg = {}; struct kvec iov; bool ctrl_hdgst, ctrl_ddgst; @@ -1381,17 +1403,35 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue) iov.iov_base = icreq; iov.iov_len = sizeof(*icreq); ret = kernel_sendmsg(queue->sock, &msg, &iov, 1, iov.iov_len); - if (ret < 0) + if (ret < 0) { + pr_warn("queue %d: failed to send icreq, error %d\n", + nvme_tcp_queue_id(queue), ret); goto free_icresp; + } memset(&msg, 0, sizeof(msg)); iov.iov_base = icresp; iov.iov_len = sizeof(*icresp); + if (queue->ctrl->ctrl.opts->tls) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } ret = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); - if (ret < 0) + if (ret < 0) { + pr_warn("queue %d: failed to receive icresp, error %d\n", + nvme_tcp_queue_id(queue), ret); goto free_icresp; - + } + if (queue->ctrl->ctrl.opts->tls) { + ctype = tls_get_record_type(queue->sock->sk, + (struct cmsghdr *)cbuf); + if (ctype != TLS_RECORD_TYPE_DATA) { + pr_err("queue %d: unhandled TLS record %d\n", + nvme_tcp_queue_id(queue), ctype); + return -ENOTCONN; + } + } ret = -EINVAL; if (icresp->hdr.type != nvme_tcp_icresp) { pr_err("queue %d: bad type returned %d\n", @@ -1507,11 +1547,99 @@ static void nvme_tcp_set_queue_io_cpu(struct nvme_tcp_queue *queue) queue->io_cpu = cpumask_next_wrap(n - 1, cpu_online_mask, -1, false); } -static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) +#ifdef CONFIG_NVME_TCP_TLS +static void nvme_tcp_tls_done(void *data, int status, key_serial_t pskid) +{ + struct nvme_tcp_queue *queue = data; + struct nvme_tcp_ctrl *ctrl = queue->ctrl; + int qid = nvme_tcp_queue_id(queue); + struct key *tls_key; + + dev_dbg(ctrl->ctrl.device, "queue %d: TLS handshake done, key %x, status %d\n", + qid, pskid, status); + + if (status) { + queue->tls_err = -status; + goto out_complete; + } + + tls_key = key_lookup(pskid); + if (IS_ERR(tls_key)) { + dev_warn(ctrl->ctrl.device, "queue %d: Invalid key %x\n", + qid, pskid); + queue->tls_err = -ENOKEY; + } else { + ctrl->ctrl.tls_key = tls_key; + queue->tls_err = 0; + } + +out_complete: + complete(&queue->tls_complete); +} + +static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl, + struct nvme_tcp_queue *queue, + key_serial_t pskid) +{ + int qid = nvme_tcp_queue_id(queue); + int ret; + struct tls_handshake_args args; + unsigned long tmo = tls_handshake_timeout * HZ; + key_serial_t keyring = nvme_keyring_id(); + + dev_dbg(nctrl->device, "queue %d: start TLS with key %x\n", + qid, pskid); + memset(&args, 0, sizeof(args)); + args.ta_sock = queue->sock; + args.ta_done = nvme_tcp_tls_done; + args.ta_data = queue; + args.ta_my_peerids[0] = pskid; + args.ta_num_peerids = 1; + if (nctrl->opts->keyring) + keyring = key_serial(nctrl->opts->keyring); + args.ta_keyring = keyring; + args.ta_timeout_ms = tls_handshake_timeout * 1000; + queue->tls_err = -EOPNOTSUPP; + init_completion(&queue->tls_complete); + ret = tls_client_hello_psk(&args, GFP_KERNEL); + if (ret) { + dev_err(nctrl->device, "queue %d: failed to start TLS: %d\n", + qid, ret); + return ret; + } + ret = wait_for_completion_interruptible_timeout(&queue->tls_complete, tmo); + if (ret <= 0) { + if (ret == 0) + ret = -ETIMEDOUT; + + dev_err(nctrl->device, + "queue %d: TLS handshake failed, error %d\n", + qid, ret); + tls_handshake_cancel(queue->sock->sk); + } else { + dev_dbg(nctrl->device, + "queue %d: TLS handshake complete, error %d\n", + qid, queue->tls_err); + ret = queue->tls_err; + } + return ret; +} +#else +static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl, + struct nvme_tcp_queue *queue, + key_serial_t pskid) +{ + return -EPROTONOSUPPORT; +} +#endif + +static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid, + key_serial_t pskid) { struct nvme_tcp_ctrl *ctrl = to_tcp_ctrl(nctrl); struct nvme_tcp_queue *queue = &ctrl->queues[qid]; int ret, rcv_pdu_size; + struct file *sock_file; mutex_init(&queue->queue_lock); queue->ctrl = ctrl; @@ -1534,6 +1662,11 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) goto err_destroy_mutex; } + sock_file = sock_alloc_file(queue->sock, O_CLOEXEC, NULL); + if (IS_ERR(sock_file)) { + ret = PTR_ERR(sock_file); + goto err_destroy_mutex; + } nvme_tcp_reclassify_socket(queue->sock); /* Single syn retry */ @@ -1624,6 +1757,13 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) goto err_rcv_pdu; } + /* If PSKs are configured try to start TLS */ + if (pskid) { + ret = nvme_tcp_start_tls(nctrl, queue, pskid); + if (ret) + goto err_init_connect; + } + ret = nvme_tcp_init_connection(queue); if (ret) goto err_init_connect; @@ -1640,7 +1780,8 @@ err_crypto: if (queue->hdr_digest || queue->data_digest) nvme_tcp_free_crypto(queue); err_sock: - sock_release(queue->sock); + /* ->sock will be released by fput() */ + fput(queue->sock->file); queue->sock = NULL; err_destroy_mutex: mutex_destroy(&queue->send_mutex); @@ -1772,10 +1913,25 @@ out_stop_queues: static int nvme_tcp_alloc_admin_queue(struct nvme_ctrl *ctrl) { int ret; + key_serial_t pskid = 0; - ret = nvme_tcp_alloc_queue(ctrl, 0); + if (ctrl->opts->tls) { + if (ctrl->opts->tls_key) + pskid = key_serial(ctrl->opts->tls_key); + else + pskid = nvme_tls_psk_default(ctrl->opts->keyring, + ctrl->opts->host->nqn, + ctrl->opts->subsysnqn); + if (!pskid) { + dev_err(ctrl->device, "no valid PSK found\n"); + ret = -ENOKEY; + goto out_free_queue; + } + } + + ret = nvme_tcp_alloc_queue(ctrl, 0, pskid); if (ret) - return ret; + goto out_free_queue; ret = nvme_tcp_alloc_async_req(to_tcp_ctrl(ctrl)); if (ret) @@ -1792,8 +1948,13 @@ static int __nvme_tcp_alloc_io_queues(struct nvme_ctrl *ctrl) { int i, ret; + if (ctrl->opts->tls && !ctrl->tls_key) { + dev_err(ctrl->device, "no PSK negotiated\n"); + return -ENOKEY; + } for (i = 1; i < ctrl->queue_count; i++) { - ret = nvme_tcp_alloc_queue(ctrl, i); + ret = nvme_tcp_alloc_queue(ctrl, i, + key_serial(ctrl->tls_key)); if (ret) goto out_free_queues; } @@ -2621,7 +2782,8 @@ static struct nvmf_transport_ops nvme_tcp_transport = { NVMF_OPT_HOST_TRADDR | NVMF_OPT_CTRL_LOSS_TMO | NVMF_OPT_HDR_DIGEST | NVMF_OPT_DATA_DIGEST | NVMF_OPT_NR_WRITE_QUEUES | NVMF_OPT_NR_POLL_QUEUES | - NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE, + NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE | NVMF_OPT_TLS | + NVMF_OPT_KEYRING | NVMF_OPT_TLS_KEY, .create_ctrl = nvme_tcp_create_ctrl, }; diff --git a/drivers/nvme/target/Kconfig b/drivers/nvme/target/Kconfig index 79fc64035ee3..fa479c9f5c3d 100644 --- a/drivers/nvme/target/Kconfig +++ b/drivers/nvme/target/Kconfig @@ -84,16 +84,26 @@ config NVME_TARGET_TCP If unsure, say N. +config NVME_TARGET_TCP_TLS + bool "NVMe over Fabrics TCP target TLS encryption support" + depends on NVME_TARGET_TCP + select NVME_COMMON + select NVME_KEYRING + select NET_HANDSHAKE + select KEYS + help + Enables TLS encryption for the NVMe TCP target using the netlink handshake API. + + The TLS handshake daemon is available at + https://github.com/oracle/ktls-utils. + + If unsure, say N. + config NVME_TARGET_AUTH bool "NVMe over Fabrics In-band Authentication support" depends on NVME_TARGET select NVME_COMMON - select CRYPTO - select CRYPTO_HMAC - select CRYPTO_SHA256 - select CRYPTO_SHA512 - select CRYPTO_DH - select CRYPTO_DH_RFC7919_GROUPS + select NVME_AUTH help This enables support for NVMe over Fabrics In-band Authentication diff --git a/drivers/nvme/target/auth.c b/drivers/nvme/target/auth.c index 4dcddcf95279..3ddbc3880cac 100644 --- a/drivers/nvme/target/auth.c +++ b/drivers/nvme/target/auth.c @@ -267,7 +267,8 @@ int nvmet_auth_host_hash(struct nvmet_req *req, u8 *response, struct shash_desc *shash; struct nvmet_ctrl *ctrl = req->sq->ctrl; const char *hash_name; - u8 *challenge = req->sq->dhchap_c1, *host_response; + u8 *challenge = req->sq->dhchap_c1; + struct nvme_dhchap_key *transformed_key; u8 buf[4]; int ret; @@ -291,14 +292,15 @@ int nvmet_auth_host_hash(struct nvmet_req *req, u8 *response, goto out_free_tfm; } - host_response = nvme_auth_transform_key(ctrl->host_key, ctrl->hostnqn); - if (IS_ERR(host_response)) { - ret = PTR_ERR(host_response); + transformed_key = nvme_auth_transform_key(ctrl->host_key, + ctrl->hostnqn); + if (IS_ERR(transformed_key)) { + ret = PTR_ERR(transformed_key); goto out_free_tfm; } - ret = crypto_shash_setkey(shash_tfm, host_response, - ctrl->host_key->len); + ret = crypto_shash_setkey(shash_tfm, transformed_key->key, + transformed_key->len); if (ret) goto out_free_response; @@ -365,7 +367,7 @@ out: kfree(challenge); kfree(shash); out_free_response: - kfree_sensitive(host_response); + nvme_auth_free_key(transformed_key); out_free_tfm: crypto_free_shash(shash_tfm); return 0; @@ -378,7 +380,8 @@ int nvmet_auth_ctrl_hash(struct nvmet_req *req, u8 *response, struct shash_desc *shash; struct nvmet_ctrl *ctrl = req->sq->ctrl; const char *hash_name; - u8 *challenge = req->sq->dhchap_c2, *ctrl_response; + u8 *challenge = req->sq->dhchap_c2; + struct nvme_dhchap_key *transformed_key; u8 buf[4]; int ret; @@ -402,15 +405,15 @@ int nvmet_auth_ctrl_hash(struct nvmet_req *req, u8 *response, goto out_free_tfm; } - ctrl_response = nvme_auth_transform_key(ctrl->ctrl_key, + transformed_key = nvme_auth_transform_key(ctrl->ctrl_key, ctrl->subsysnqn); - if (IS_ERR(ctrl_response)) { - ret = PTR_ERR(ctrl_response); + if (IS_ERR(transformed_key)) { + ret = PTR_ERR(transformed_key); goto out_free_tfm; } - ret = crypto_shash_setkey(shash_tfm, ctrl_response, - ctrl->ctrl_key->len); + ret = crypto_shash_setkey(shash_tfm, transformed_key->key, + transformed_key->len); if (ret) goto out_free_response; @@ -474,7 +477,7 @@ out: kfree(challenge); kfree(shash); out_free_response: - kfree_sensitive(ctrl_response); + nvme_auth_free_key(transformed_key); out_free_tfm: crypto_free_shash(shash_tfm); return 0; diff --git a/drivers/nvme/target/configfs.c b/drivers/nvme/target/configfs.c index 907143870da5..9eed6e6765ea 100644 --- a/drivers/nvme/target/configfs.c +++ b/drivers/nvme/target/configfs.c @@ -15,6 +15,7 @@ #ifdef CONFIG_NVME_TARGET_AUTH #include <linux/nvme-auth.h> #endif +#include <linux/nvme-keyring.h> #include <crypto/hash.h> #include <crypto/kpp.h> @@ -159,10 +160,14 @@ static const struct nvmet_type_name_map nvmet_addr_treq[] = { { NVMF_TREQ_NOT_REQUIRED, "not required" }, }; +static inline u8 nvmet_port_disc_addr_treq_mask(struct nvmet_port *port) +{ + return (port->disc_addr.treq & ~NVME_TREQ_SECURE_CHANNEL_MASK); +} + static ssize_t nvmet_addr_treq_show(struct config_item *item, char *page) { - u8 treq = to_nvmet_port(item)->disc_addr.treq & - NVME_TREQ_SECURE_CHANNEL_MASK; + u8 treq = nvmet_port_disc_addr_treq_secure_channel(to_nvmet_port(item)); int i; for (i = 0; i < ARRAY_SIZE(nvmet_addr_treq); i++) { @@ -178,7 +183,7 @@ static ssize_t nvmet_addr_treq_store(struct config_item *item, const char *page, size_t count) { struct nvmet_port *port = to_nvmet_port(item); - u8 treq = port->disc_addr.treq & ~NVME_TREQ_SECURE_CHANNEL_MASK; + u8 treq = nvmet_port_disc_addr_treq_mask(port); int i; if (nvmet_is_port_enabled(port, __func__)) @@ -193,6 +198,20 @@ static ssize_t nvmet_addr_treq_store(struct config_item *item, return -EINVAL; found: + if (port->disc_addr.trtype == NVMF_TRTYPE_TCP && + port->disc_addr.tsas.tcp.sectype == NVMF_TCP_SECTYPE_TLS13) { + switch (nvmet_addr_treq[i].type) { + case NVMF_TREQ_NOT_SPECIFIED: + pr_debug("treq '%s' not allowed for TLS1.3\n", + nvmet_addr_treq[i].name); + return -EINVAL; + case NVMF_TREQ_NOT_REQUIRED: + pr_warn("Allow non-TLS connections while TLS1.3 is enabled\n"); + break; + default: + break; + } + } treq |= nvmet_addr_treq[i].type; port->disc_addr.treq = treq; return count; @@ -303,6 +322,11 @@ static void nvmet_port_init_tsas_rdma(struct nvmet_port *port) port->disc_addr.tsas.rdma.cms = NVMF_RDMA_CMS_RDMA_CM; } +static void nvmet_port_init_tsas_tcp(struct nvmet_port *port, int sectype) +{ + port->disc_addr.tsas.tcp.sectype = sectype; +} + static ssize_t nvmet_addr_trtype_store(struct config_item *item, const char *page, size_t count) { @@ -325,11 +349,99 @@ found: port->disc_addr.trtype = nvmet_transport[i].type; if (port->disc_addr.trtype == NVMF_TRTYPE_RDMA) nvmet_port_init_tsas_rdma(port); + else if (port->disc_addr.trtype == NVMF_TRTYPE_TCP) + nvmet_port_init_tsas_tcp(port, NVMF_TCP_SECTYPE_NONE); return count; } CONFIGFS_ATTR(nvmet_, addr_trtype); +static const struct nvmet_type_name_map nvmet_addr_tsas_tcp[] = { + { NVMF_TCP_SECTYPE_NONE, "none" }, + { NVMF_TCP_SECTYPE_TLS13, "tls1.3" }, +}; + +static const struct nvmet_type_name_map nvmet_addr_tsas_rdma[] = { + { NVMF_RDMA_QPTYPE_CONNECTED, "connected" }, + { NVMF_RDMA_QPTYPE_DATAGRAM, "datagram" }, +}; + +static ssize_t nvmet_addr_tsas_show(struct config_item *item, + char *page) +{ + struct nvmet_port *port = to_nvmet_port(item); + int i; + + if (port->disc_addr.trtype == NVMF_TRTYPE_TCP) { + for (i = 0; i < ARRAY_SIZE(nvmet_addr_tsas_tcp); i++) { + if (port->disc_addr.tsas.tcp.sectype == nvmet_addr_tsas_tcp[i].type) + return sprintf(page, "%s\n", nvmet_addr_tsas_tcp[i].name); + } + } else if (port->disc_addr.trtype == NVMF_TRTYPE_RDMA) { + for (i = 0; i < ARRAY_SIZE(nvmet_addr_tsas_rdma); i++) { + if (port->disc_addr.tsas.rdma.qptype == nvmet_addr_tsas_rdma[i].type) + return sprintf(page, "%s\n", nvmet_addr_tsas_rdma[i].name); + } + } + return sprintf(page, "reserved\n"); +} + +static ssize_t nvmet_addr_tsas_store(struct config_item *item, + const char *page, size_t count) +{ + struct nvmet_port *port = to_nvmet_port(item); + u8 treq = nvmet_port_disc_addr_treq_mask(port); + u8 sectype; + int i; + + if (nvmet_is_port_enabled(port, __func__)) + return -EACCES; + + if (port->disc_addr.trtype != NVMF_TRTYPE_TCP) + return -EINVAL; + + for (i = 0; i < ARRAY_SIZE(nvmet_addr_tsas_tcp); i++) { + if (sysfs_streq(page, nvmet_addr_tsas_tcp[i].name)) { + sectype = nvmet_addr_tsas_tcp[i].type; + goto found; + } + } + + pr_err("Invalid value '%s' for tsas\n", page); + return -EINVAL; + +found: + if (sectype == NVMF_TCP_SECTYPE_TLS13) { + if (!IS_ENABLED(CONFIG_NVME_TARGET_TCP_TLS)) { + pr_err("TLS is not supported\n"); + return -EINVAL; + } + if (!port->keyring) { + pr_err("TLS keyring not configured\n"); + return -EINVAL; + } + } + + nvmet_port_init_tsas_tcp(port, sectype); + /* + * If TLS is enabled TREQ should be set to 'required' per default + */ + if (sectype == NVMF_TCP_SECTYPE_TLS13) { + u8 sc = nvmet_port_disc_addr_treq_secure_channel(port); + + if (sc == NVMF_TREQ_NOT_SPECIFIED) + treq |= NVMF_TREQ_REQUIRED; + else + treq |= sc; + } else { + treq |= NVMF_TREQ_NOT_SPECIFIED; + } + port->disc_addr.treq = treq; + return count; +} + +CONFIGFS_ATTR(nvmet_, addr_tsas); + /* * Namespace structures & file operation functions below */ @@ -1731,6 +1843,7 @@ static void nvmet_port_release(struct config_item *item) flush_workqueue(nvmet_wq); list_del(&port->global_entry); + key_put(port->keyring); kfree(port->ana_state); kfree(port); } @@ -1741,6 +1854,7 @@ static struct configfs_attribute *nvmet_port_attrs[] = { &nvmet_attr_addr_traddr, &nvmet_attr_addr_trsvcid, &nvmet_attr_addr_trtype, + &nvmet_attr_addr_tsas, &nvmet_attr_param_inline_data_size, #ifdef CONFIG_BLK_DEV_INTEGRITY &nvmet_attr_param_pi_enable, @@ -1779,6 +1893,14 @@ static struct config_group *nvmet_ports_make(struct config_group *group, return ERR_PTR(-ENOMEM); } + if (nvme_keyring_id()) { + port->keyring = key_lookup(nvme_keyring_id()); + if (IS_ERR(port->keyring)) { + pr_warn("NVMe keyring not available, disabling TLS\n"); + port->keyring = NULL; + } + } + for (i = 1; i <= NVMET_MAX_ANAGRPS; i++) { if (i == NVMET_DEFAULT_ANA_GRPID) port->ana_state[1] = NVME_ANA_OPTIMIZED; diff --git a/drivers/nvme/target/fc.c b/drivers/nvme/target/fc.c index 1ab6601fdd5c..bd59990b5250 100644 --- a/drivers/nvme/target/fc.c +++ b/drivers/nvme/target/fc.c @@ -146,7 +146,8 @@ struct nvmet_fc_tgt_queue { struct workqueue_struct *work_q; struct kref ref; struct rcu_head rcu; - struct nvmet_fc_fcp_iod fod[]; /* array of fcp_iods */ + /* array of fcp_iods */ + struct nvmet_fc_fcp_iod fod[] __counted_by(sqsize); } __aligned(sizeof(unsigned long long)); struct nvmet_fc_hostport { diff --git a/drivers/nvme/target/nvmet.h b/drivers/nvme/target/nvmet.h index 8cfd60f3b564..3e179019ca7c 100644 --- a/drivers/nvme/target/nvmet.h +++ b/drivers/nvme/target/nvmet.h @@ -158,6 +158,7 @@ struct nvmet_port { struct config_group ana_groups_group; struct nvmet_ana_group ana_default_group; enum nvme_ana_state *ana_state; + struct key *keyring; void *priv; bool enabled; int inline_data_size; @@ -178,6 +179,16 @@ static inline struct nvmet_port *ana_groups_to_port( ana_groups_group); } +static inline u8 nvmet_port_disc_addr_treq_secure_channel(struct nvmet_port *port) +{ + return (port->disc_addr.treq & NVME_TREQ_SECURE_CHANNEL_MASK); +} + +static inline bool nvmet_port_secure_channel_required(struct nvmet_port *port) +{ + return nvmet_port_disc_addr_treq_secure_channel(port) == NVMF_TREQ_REQUIRED; +} + struct nvmet_ctrl { struct nvmet_subsys *subsys; struct nvmet_sq **sqs; diff --git a/drivers/nvme/target/tcp.c b/drivers/nvme/target/tcp.c index cd92d7ddf5ed..4336fe048e43 100644 --- a/drivers/nvme/target/tcp.c +++ b/drivers/nvme/target/tcp.c @@ -8,9 +8,14 @@ #include <linux/init.h> #include <linux/slab.h> #include <linux/err.h> +#include <linux/key.h> #include <linux/nvme-tcp.h> +#include <linux/nvme-keyring.h> #include <net/sock.h> #include <net/tcp.h> +#include <net/tls.h> +#include <net/tls_prot.h> +#include <net/handshake.h> #include <linux/inet.h> #include <linux/llist.h> #include <crypto/hash.h> @@ -66,6 +71,16 @@ device_param_cb(idle_poll_period_usecs, &set_param_ops, MODULE_PARM_DESC(idle_poll_period_usecs, "nvmet tcp io_work poll till idle time period in usecs: Default 0"); +#ifdef CONFIG_NVME_TARGET_TCP_TLS +/* + * TLS handshake timeout + */ +static int tls_handshake_timeout = 10; +module_param(tls_handshake_timeout, int, 0644); +MODULE_PARM_DESC(tls_handshake_timeout, + "nvme TLS handshake timeout in seconds (default 10)"); +#endif + #define NVMET_TCP_RECV_BUDGET 8 #define NVMET_TCP_SEND_BUDGET 8 #define NVMET_TCP_IO_WORK_BUDGET 64 @@ -104,6 +119,7 @@ struct nvmet_tcp_cmd { u32 pdu_len; u32 pdu_recv; int sg_idx; + char recv_cbuf[CMSG_LEN(sizeof(char))]; struct msghdr recv_msg; struct bio_vec *iov; u32 flags; @@ -122,8 +138,10 @@ struct nvmet_tcp_cmd { enum nvmet_tcp_queue_state { NVMET_TCP_Q_CONNECTING, + NVMET_TCP_Q_TLS_HANDSHAKE, NVMET_TCP_Q_LIVE, NVMET_TCP_Q_DISCONNECTING, + NVMET_TCP_Q_FAILED, }; struct nvmet_tcp_queue { @@ -132,6 +150,7 @@ struct nvmet_tcp_queue { struct work_struct io_work; struct nvmet_cq nvme_cq; struct nvmet_sq nvme_sq; + struct kref kref; /* send state */ struct nvmet_tcp_cmd *cmds; @@ -155,6 +174,10 @@ struct nvmet_tcp_queue { struct ahash_request *snd_hash; struct ahash_request *rcv_hash; + /* TLS state */ + key_serial_t tls_pskid; + struct delayed_work tls_handshake_tmo_work; + unsigned long poll_end; spinlock_t state_lock; @@ -918,6 +941,7 @@ static int nvmet_tcp_handle_icreq(struct nvmet_tcp_queue *queue) free_crypto: if (queue->hdr_digest || queue->data_digest) nvmet_tcp_free_crypto(queue); + queue->state = NVMET_TCP_Q_FAILED; return ret; } @@ -1099,20 +1123,65 @@ static inline bool nvmet_tcp_pdu_valid(u8 type) return false; } +static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue, + struct msghdr *msg, char *cbuf) +{ + struct cmsghdr *cmsg = (struct cmsghdr *)cbuf; + u8 ctype, level, description; + int ret = 0; + + ctype = tls_get_record_type(queue->sock->sk, cmsg); + switch (ctype) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(queue->sock->sk, msg, &level, &description); + if (level == TLS_ALERT_LEVEL_FATAL) { + pr_err("queue %d: TLS Alert desc %u\n", + queue->idx, description); + ret = -ENOTCONN; + } else { + pr_warn("queue %d: TLS Alert desc %u\n", + queue->idx, description); + ret = -EAGAIN; + } + break; + default: + /* discard this record type */ + pr_err("queue %d: TLS record %d unhandled\n", + queue->idx, ctype); + ret = -EAGAIN; + break; + } + return ret; +} + static int nvmet_tcp_try_recv_pdu(struct nvmet_tcp_queue *queue) { struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr; - int len; + int len, ret; struct kvec iov; + char cbuf[CMSG_LEN(sizeof(char))] = {}; struct msghdr msg = { .msg_flags = MSG_DONTWAIT }; recv: iov.iov_base = (void *)&queue->pdu + queue->offset; iov.iov_len = queue->left; + if (queue->tls_pskid) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } len = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); if (unlikely(len < 0)) return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + if (ret < 0) + return ret; + } queue->offset += len; queue->left -= len; @@ -1165,16 +1234,22 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd) static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue) { struct nvmet_tcp_cmd *cmd = queue->cmd; - int ret; + int len, ret; while (msg_data_left(&cmd->recv_msg)) { - ret = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg, + len = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg, cmd->recv_msg.msg_flags); - if (ret <= 0) - return ret; + if (len <= 0) + return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(cmd->queue, + &cmd->recv_msg, cmd->recv_cbuf); + if (ret < 0) + return ret; + } - cmd->pdu_recv += ret; - cmd->rbytes_done += ret; + cmd->pdu_recv += len; + cmd->rbytes_done += len; } if (queue->data_digest) { @@ -1192,20 +1267,30 @@ static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue) static int nvmet_tcp_try_recv_ddgst(struct nvmet_tcp_queue *queue) { struct nvmet_tcp_cmd *cmd = queue->cmd; - int ret; + int ret, len; + char cbuf[CMSG_LEN(sizeof(char))] = {}; struct msghdr msg = { .msg_flags = MSG_DONTWAIT }; struct kvec iov = { .iov_base = (void *)&cmd->recv_ddgst + queue->offset, .iov_len = queue->left }; - ret = kernel_recvmsg(queue->sock, &msg, &iov, 1, + if (queue->tls_pskid) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } + len = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); - if (unlikely(ret < 0)) - return ret; + if (unlikely(len < 0)) + return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + if (ret < 0) + return ret; + } - queue->offset += ret; - queue->left -= ret; + queue->offset += len; + queue->left -= len; if (queue->left) return -EAGAIN; @@ -1283,14 +1368,27 @@ done: return ret; } +static void nvmet_tcp_release_queue(struct kref *kref) +{ + struct nvmet_tcp_queue *queue = + container_of(kref, struct nvmet_tcp_queue, kref); + + WARN_ON(queue->state != NVMET_TCP_Q_DISCONNECTING); + queue_work(nvmet_wq, &queue->release_work); +} + static void nvmet_tcp_schedule_release_queue(struct nvmet_tcp_queue *queue) { - spin_lock(&queue->state_lock); + spin_lock_bh(&queue->state_lock); + if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) { + /* Socket closed during handshake */ + tls_handshake_cancel(queue->sock->sk); + } if (queue->state != NVMET_TCP_Q_DISCONNECTING) { queue->state = NVMET_TCP_Q_DISCONNECTING; - queue_work(nvmet_wq, &queue->release_work); + kref_put(&queue->kref, nvmet_tcp_release_queue); } - spin_unlock(&queue->state_lock); + spin_unlock_bh(&queue->state_lock); } static inline void nvmet_tcp_arm_queue_deadline(struct nvmet_tcp_queue *queue) @@ -1372,6 +1470,10 @@ static int nvmet_tcp_alloc_cmd(struct nvmet_tcp_queue *queue, if (!c->r2t_pdu) goto out_free_data; + if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) { + c->recv_msg.msg_control = c->recv_cbuf; + c->recv_msg.msg_controllen = sizeof(c->recv_cbuf); + } c->recv_msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL; list_add_tail(&c->entry, &queue->free_list); @@ -1485,6 +1587,7 @@ static void nvmet_tcp_release_queue_work(struct work_struct *w) mutex_unlock(&nvmet_tcp_queue_mutex); nvmet_tcp_restore_socket_callbacks(queue); + cancel_delayed_work_sync(&queue->tls_handshake_tmo_work); cancel_work_sync(&queue->io_work); /* stop accepting incoming data */ queue->rcv_state = NVMET_TCP_RECV_ERR; @@ -1493,12 +1596,12 @@ static void nvmet_tcp_release_queue_work(struct work_struct *w) nvmet_sq_destroy(&queue->nvme_sq); cancel_work_sync(&queue->io_work); nvmet_tcp_free_cmd_data_in_buffers(queue); - sock_release(queue->sock); + /* ->sock will be released by fput() */ + fput(queue->sock->file); nvmet_tcp_free_cmds(queue); if (queue->hdr_digest || queue->data_digest) nvmet_tcp_free_crypto(queue); ida_free(&nvmet_tcp_queue_ida, queue->idx); - page = virt_to_head_page(queue->pf_cache.va); __page_frag_cache_drain(page, queue->pf_cache.pagecnt_bias); kfree(queue); @@ -1512,8 +1615,13 @@ static void nvmet_tcp_data_ready(struct sock *sk) read_lock_bh(&sk->sk_callback_lock); queue = sk->sk_user_data; - if (likely(queue)) - queue_work_on(queue_cpu(queue), nvmet_tcp_wq, &queue->io_work); + if (likely(queue)) { + if (queue->data_ready) + queue->data_ready(sk); + if (queue->state != NVMET_TCP_Q_TLS_HANDSHAKE) + queue_work_on(queue_cpu(queue), nvmet_tcp_wq, + &queue->io_work); + } read_unlock_bh(&sk->sk_callback_lock); } @@ -1621,31 +1729,174 @@ static int nvmet_tcp_set_queue_sock(struct nvmet_tcp_queue *queue) return ret; } -static int nvmet_tcp_alloc_queue(struct nvmet_tcp_port *port, +#ifdef CONFIG_NVME_TARGET_TCP_TLS +static int nvmet_tcp_try_peek_pdu(struct nvmet_tcp_queue *queue) +{ + struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr; + int len, ret; + struct kvec iov = { + .iov_base = (u8 *)&queue->pdu + queue->offset, + .iov_len = sizeof(struct nvme_tcp_hdr), + }; + char cbuf[CMSG_LEN(sizeof(char))] = {}; + struct msghdr msg = { + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf), + .msg_flags = MSG_PEEK, + }; + + if (nvmet_port_secure_channel_required(queue->port->nport)) + return 0; + + len = kernel_recvmsg(queue->sock, &msg, &iov, 1, + iov.iov_len, msg.msg_flags); + if (unlikely(len < 0)) { + pr_debug("queue %d: peek error %d\n", + queue->idx, len); + return len; + } + + ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + if (ret < 0) + return ret; + + if (len < sizeof(struct nvme_tcp_hdr)) { + pr_debug("queue %d: short read, %d bytes missing\n", + queue->idx, (int)iov.iov_len - len); + return -EAGAIN; + } + pr_debug("queue %d: hdr type %d hlen %d plen %d size %d\n", + queue->idx, hdr->type, hdr->hlen, hdr->plen, + (int)sizeof(struct nvme_tcp_icreq_pdu)); + if (hdr->type == nvme_tcp_icreq && + hdr->hlen == sizeof(struct nvme_tcp_icreq_pdu) && + hdr->plen == (__le32)sizeof(struct nvme_tcp_icreq_pdu)) { + pr_debug("queue %d: icreq detected\n", + queue->idx); + return len; + } + return 0; +} + +static void nvmet_tcp_tls_handshake_done(void *data, int status, + key_serial_t peerid) +{ + struct nvmet_tcp_queue *queue = data; + + pr_debug("queue %d: TLS handshake done, key %x, status %d\n", + queue->idx, peerid, status); + spin_lock_bh(&queue->state_lock); + if (WARN_ON(queue->state != NVMET_TCP_Q_TLS_HANDSHAKE)) { + spin_unlock_bh(&queue->state_lock); + return; + } + if (!status) { + queue->tls_pskid = peerid; + queue->state = NVMET_TCP_Q_CONNECTING; + } else + queue->state = NVMET_TCP_Q_FAILED; + spin_unlock_bh(&queue->state_lock); + + cancel_delayed_work_sync(&queue->tls_handshake_tmo_work); + if (status) + nvmet_tcp_schedule_release_queue(queue); + else + nvmet_tcp_set_queue_sock(queue); + kref_put(&queue->kref, nvmet_tcp_release_queue); +} + +static void nvmet_tcp_tls_handshake_timeout(struct work_struct *w) +{ + struct nvmet_tcp_queue *queue = container_of(to_delayed_work(w), + struct nvmet_tcp_queue, tls_handshake_tmo_work); + + pr_warn("queue %d: TLS handshake timeout\n", queue->idx); + /* + * If tls_handshake_cancel() fails we've lost the race with + * nvmet_tcp_tls_handshake_done() */ + if (!tls_handshake_cancel(queue->sock->sk)) + return; + spin_lock_bh(&queue->state_lock); + if (WARN_ON(queue->state != NVMET_TCP_Q_TLS_HANDSHAKE)) { + spin_unlock_bh(&queue->state_lock); + return; + } + queue->state = NVMET_TCP_Q_FAILED; + spin_unlock_bh(&queue->state_lock); + nvmet_tcp_schedule_release_queue(queue); + kref_put(&queue->kref, nvmet_tcp_release_queue); +} + +static int nvmet_tcp_tls_handshake(struct nvmet_tcp_queue *queue) +{ + int ret = -EOPNOTSUPP; + struct tls_handshake_args args; + + if (queue->state != NVMET_TCP_Q_TLS_HANDSHAKE) { + pr_warn("cannot start TLS in state %d\n", queue->state); + return -EINVAL; + } + + kref_get(&queue->kref); + pr_debug("queue %d: TLS ServerHello\n", queue->idx); + memset(&args, 0, sizeof(args)); + args.ta_sock = queue->sock; + args.ta_done = nvmet_tcp_tls_handshake_done; + args.ta_data = queue; + args.ta_keyring = key_serial(queue->port->nport->keyring); + args.ta_timeout_ms = tls_handshake_timeout * 1000; + + ret = tls_server_hello_psk(&args, GFP_KERNEL); + if (ret) { + kref_put(&queue->kref, nvmet_tcp_release_queue); + pr_err("failed to start TLS, err=%d\n", ret); + } else { + queue_delayed_work(nvmet_wq, &queue->tls_handshake_tmo_work, + tls_handshake_timeout * HZ); + } + return ret; +} +#endif + +static void nvmet_tcp_alloc_queue(struct nvmet_tcp_port *port, struct socket *newsock) { struct nvmet_tcp_queue *queue; + struct file *sock_file = NULL; int ret; queue = kzalloc(sizeof(*queue), GFP_KERNEL); - if (!queue) - return -ENOMEM; + if (!queue) { + ret = -ENOMEM; + goto out_release; + } INIT_WORK(&queue->release_work, nvmet_tcp_release_queue_work); INIT_WORK(&queue->io_work, nvmet_tcp_io_work); + kref_init(&queue->kref); queue->sock = newsock; queue->port = port; queue->nr_cmds = 0; spin_lock_init(&queue->state_lock); - queue->state = NVMET_TCP_Q_CONNECTING; + if (queue->port->nport->disc_addr.tsas.tcp.sectype == + NVMF_TCP_SECTYPE_TLS13) + queue->state = NVMET_TCP_Q_TLS_HANDSHAKE; + else + queue->state = NVMET_TCP_Q_CONNECTING; INIT_LIST_HEAD(&queue->free_list); init_llist_head(&queue->resp_list); INIT_LIST_HEAD(&queue->resp_send_list); + sock_file = sock_alloc_file(queue->sock, O_CLOEXEC, NULL); + if (IS_ERR(sock_file)) { + ret = PTR_ERR(sock_file); + goto out_free_queue; + } + queue->idx = ida_alloc(&nvmet_tcp_queue_ida, GFP_KERNEL); if (queue->idx < 0) { ret = queue->idx; - goto out_free_queue; + goto out_sock; } ret = nvmet_tcp_alloc_cmd(queue, &queue->connect); @@ -1662,11 +1913,33 @@ static int nvmet_tcp_alloc_queue(struct nvmet_tcp_port *port, list_add_tail(&queue->queue_list, &nvmet_tcp_queue_list); mutex_unlock(&nvmet_tcp_queue_mutex); +#ifdef CONFIG_NVME_TARGET_TCP_TLS + INIT_DELAYED_WORK(&queue->tls_handshake_tmo_work, + nvmet_tcp_tls_handshake_timeout); + if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) { + struct sock *sk = queue->sock->sk; + + /* Restore the default callbacks before starting upcall */ + read_lock_bh(&sk->sk_callback_lock); + sk->sk_user_data = NULL; + sk->sk_data_ready = port->data_ready; + read_unlock_bh(&sk->sk_callback_lock); + if (!nvmet_tcp_try_peek_pdu(queue)) { + if (!nvmet_tcp_tls_handshake(queue)) + return; + /* TLS handshake failed, terminate the connection */ + goto out_destroy_sq; + } + /* Not a TLS connection, continue with normal processing */ + queue->state = NVMET_TCP_Q_CONNECTING; + } +#endif + ret = nvmet_tcp_set_queue_sock(queue); if (ret) goto out_destroy_sq; - return 0; + return; out_destroy_sq: mutex_lock(&nvmet_tcp_queue_mutex); list_del_init(&queue->queue_list); @@ -1676,9 +1949,14 @@ out_free_connect: nvmet_tcp_free_cmd(&queue->connect); out_ida_remove: ida_free(&nvmet_tcp_queue_ida, queue->idx); +out_sock: + fput(queue->sock->file); out_free_queue: kfree(queue); - return ret; +out_release: + pr_err("failed to allocate queue, error %d\n", ret); + if (!sock_file) + sock_release(newsock); } static void nvmet_tcp_accept_work(struct work_struct *w) @@ -1695,11 +1973,7 @@ static void nvmet_tcp_accept_work(struct work_struct *w) pr_warn("failed to accept err=%d\n", ret); return; } - ret = nvmet_tcp_alloc_queue(port, newsock); - if (ret) { - pr_err("failed to allocate queue\n"); - sock_release(newsock); - } + nvmet_tcp_alloc_queue(port, newsock); } } diff --git a/include/linux/badblocks.h b/include/linux/badblocks.h index 2426276b9bd3..670f2dae692f 100644 --- a/include/linux/badblocks.h +++ b/include/linux/badblocks.h @@ -15,6 +15,7 @@ #define BB_OFFSET(x) (((x) & BB_OFFSET_MASK) >> 9) #define BB_LEN(x) (((x) & BB_LEN_MASK) + 1) #define BB_ACK(x) (!!((x) & BB_ACK_MASK)) +#define BB_END(x) (BB_OFFSET(x) + BB_LEN(x)) #define BB_MAKE(a, l, ack) (((a)<<9) | ((l)-1) | ((u64)(!!(ack)) << 63)) /* Bad block numbers are stored sorted in a single page. @@ -41,6 +42,12 @@ struct badblocks { sector_t size; /* in sectors */ }; +struct badblocks_context { + sector_t start; + sector_t len; + int ack; +}; + int badblocks_check(struct badblocks *bb, sector_t s, int sectors, sector_t *first_bad, int *bad_sectors); int badblocks_set(struct badblocks *bb, sector_t s, int sectors, @@ -63,4 +70,27 @@ static inline void devm_exit_badblocks(struct device *dev, struct badblocks *bb) } badblocks_exit(bb); } + +static inline int badblocks_full(struct badblocks *bb) +{ + return (bb->count >= MAX_BADBLOCKS); +} + +static inline int badblocks_empty(struct badblocks *bb) +{ + return (bb->count == 0); +} + +static inline void set_changed(struct badblocks *bb) +{ + if (bb->changed != 1) + bb->changed = 1; +} + +static inline void clear_changed(struct badblocks *bb) +{ + if (bb->changed != 0) + bb->changed = 0; +} + #endif diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h index 106cdc55ff3b..b4391e0a9bc8 100644 --- a/include/linux/io_uring.h +++ b/include/linux/io_uring.h @@ -20,8 +20,15 @@ enum io_uring_cmd_flags { IO_URING_F_SQE128 = (1 << 8), IO_URING_F_CQE32 = (1 << 9), IO_URING_F_IOPOLL = (1 << 10), + + /* set when uring wants to cancel a previously issued command */ + IO_URING_F_CANCEL = (1 << 11), }; +/* only top 8 bits of sqe->uring_cmd_flags for kernel internal use */ +#define IORING_URING_CMD_CANCELABLE (1U << 30) +#define IORING_URING_CMD_POLLED (1U << 31) + struct io_uring_cmd { struct file *file; const struct io_uring_sqe *sqe; @@ -82,6 +89,9 @@ static inline void io_uring_free(struct task_struct *tsk) __io_uring_free(tsk); } int io_uring_cmd_sock(struct io_uring_cmd *cmd, unsigned int issue_flags); +void io_uring_cmd_mark_cancelable(struct io_uring_cmd *cmd, + unsigned int issue_flags); +struct task_struct *io_uring_cmd_get_task(struct io_uring_cmd *cmd); #else static inline int io_uring_cmd_import_fixed(u64 ubuf, unsigned long len, int rw, struct iov_iter *iter, void *ioucmd) @@ -122,6 +132,14 @@ static inline int io_uring_cmd_sock(struct io_uring_cmd *cmd, { return -EOPNOTSUPP; } +static inline void io_uring_cmd_mark_cancelable(struct io_uring_cmd *cmd, + unsigned int issue_flags) +{ +} +static inline struct task_struct *io_uring_cmd_get_task(struct io_uring_cmd *cmd) +{ + return NULL; +} #endif #endif diff --git a/include/linux/io_uring_types.h b/include/linux/io_uring_types.h index 13d19b9be9f4..e4e67899b134 100644 --- a/include/linux/io_uring_types.h +++ b/include/linux/io_uring_types.h @@ -265,6 +265,12 @@ struct io_ring_ctx { */ struct io_wq_work_list iopoll_list; bool poll_multi_queue; + + /* + * Any cancelable uring_cmd is added to this list in + * ->uring_cmd() by io_uring_cmd_insert_cancelable() + */ + struct hlist_head cancelable_uring_cmd; } ____cacheline_aligned_in_smp; struct { @@ -313,6 +319,8 @@ struct io_ring_ctx { struct list_head cq_overflow_list; struct io_hash_table cancel_table; + struct hlist_head waitid_list; + const struct cred *sq_creds; /* cred used for __io_sq_thread() */ struct io_sq_data *sq_data; /* if using sq thread polling */ @@ -342,8 +350,6 @@ struct io_ring_ctx { struct wait_queue_head rsrc_quiesce_wq; unsigned rsrc_quiesce; - struct list_head io_buffers_pages; - #if defined(CONFIG_UNIX) struct socket *ring_sock; #endif diff --git a/include/linux/key.h b/include/linux/key.h index 938d7ecfb495..943a432da3ae 100644 --- a/include/linux/key.h +++ b/include/linux/key.h @@ -515,6 +515,7 @@ extern void key_init(void); #define key_init() do { } while(0) #define key_free_user_ns(ns) do { } while(0) #define key_remove_domain(d) do { } while(0) +#define key_lookup(k) NULL #endif /* CONFIG_KEYS */ #endif /* __KERNEL__ */ diff --git a/include/linux/nvme-auth.h b/include/linux/nvme-auth.h index dcb8030062dd..c1d0bc5d9624 100644 --- a/include/linux/nvme-auth.h +++ b/include/linux/nvme-auth.h @@ -9,9 +9,9 @@ #include <crypto/kpp.h> struct nvme_dhchap_key { - u8 *key; size_t len; u8 hash; + u8 key[]; }; u32 nvme_auth_get_seqnum(void); @@ -24,10 +24,13 @@ const char *nvme_auth_digest_name(u8 hmac_id); size_t nvme_auth_hmac_hash_len(u8 hmac_id); u8 nvme_auth_hmac_id(const char *hmac_name); +u32 nvme_auth_key_struct_size(u32 key_len); struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret, u8 key_hash); void nvme_auth_free_key(struct nvme_dhchap_key *key); -u8 *nvme_auth_transform_key(struct nvme_dhchap_key *key, char *nqn); +struct nvme_dhchap_key *nvme_auth_alloc_key(u32 len, u8 hash); +struct nvme_dhchap_key *nvme_auth_transform_key( + struct nvme_dhchap_key *key, char *nqn); int nvme_auth_generate_key(u8 *secret, struct nvme_dhchap_key **ret_key); int nvme_auth_augmented_challenge(u8 hmac_id, u8 *skey, size_t skey_len, u8 *challenge, u8 *aug, size_t hlen); diff --git a/include/linux/nvme-keyring.h b/include/linux/nvme-keyring.h new file mode 100644 index 000000000000..4efea9dd967c --- /dev/null +++ b/include/linux/nvme-keyring.h @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (c) 2023 Hannes Reinecke, SUSE Labs + */ + +#ifndef _NVME_KEYRING_H +#define _NVME_KEYRING_H + +#ifdef CONFIG_NVME_KEYRING + +key_serial_t nvme_tls_psk_default(struct key *keyring, + const char *hostnqn, const char *subnqn); + +key_serial_t nvme_keyring_id(void); +int nvme_keyring_init(void); +void nvme_keyring_exit(void); + +#else + +static inline key_serial_t nvme_tls_psk_default(struct key *keyring, + const char *hostnqn, const char *subnqn) +{ + return 0; +} +static inline key_serial_t nvme_keyring_id(void) +{ + return 0; +} +static inline int nvme_keyring_init(void) +{ + return 0; +} +static inline void nvme_keyring_exit(void) {} + +#endif /* !CONFIG_NVME_KEYRING */ +#endif /* _NVME_KEYRING_H */ diff --git a/include/linux/nvme-tcp.h b/include/linux/nvme-tcp.h index 57ebe1267f7f..e07e8978d691 100644 --- a/include/linux/nvme-tcp.h +++ b/include/linux/nvme-tcp.h @@ -18,6 +18,12 @@ enum nvme_tcp_pfv { NVME_TCP_PFV_1_0 = 0x0, }; +enum nvme_tcp_tls_cipher { + NVME_TCP_TLS_CIPHER_INVALID = 0, + NVME_TCP_TLS_CIPHER_SHA256 = 1, + NVME_TCP_TLS_CIPHER_SHA384 = 2, +}; + enum nvme_tcp_fatal_error_status { NVME_TCP_FES_INVALID_PDU_HDR = 0x01, NVME_TCP_FES_PDU_SEQ_ERR = 0x02, diff --git a/include/linux/nvme.h b/include/linux/nvme.h index 26dd3f859d9d..a7ba74babad7 100644 --- a/include/linux/nvme.h +++ b/include/linux/nvme.h @@ -108,6 +108,13 @@ enum { NVMF_RDMA_CMS_RDMA_CM = 1, /* Sockets based endpoint addressing */ }; +/* TSAS SECTYPE for TCP transport */ +enum { + NVMF_TCP_SECTYPE_NONE = 0, /* No Security */ + NVMF_TCP_SECTYPE_TLS12 = 1, /* TLSv1.2, NVMe-oF 1.1 and NVMe-TCP 3.6.1.1 */ + NVMF_TCP_SECTYPE_TLS13 = 2, /* TLSv1.3, NVMe-oF 1.1 and NVMe-TCP 3.6.1.1 */ +}; + #define NVME_AQ_DEPTH 32 #define NVME_NR_AEN_COMMANDS 1 #define NVME_AQ_BLK_MQ_DEPTH (NVME_AQ_DEPTH - NVME_NR_AEN_COMMANDS) @@ -1493,6 +1500,9 @@ struct nvmf_disc_rsp_page_entry { __u16 pkey; __u8 resv10[246]; } rdma; + struct tcp { + __u8 sectype; + } tcp; } tsas; }; diff --git a/include/linux/sed-opal-key.h b/include/linux/sed-opal-key.h new file mode 100644 index 000000000000..0ca03054e8f6 --- /dev/null +++ b/include/linux/sed-opal-key.h @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * SED key operations. + * + * Copyright (C) 2023 IBM Corporation + * + * These are the accessor functions (read/write) for SED Opal + * keys. Specific keystores can provide overrides. + * + */ + +#include <linux/kernel.h> + +#ifdef CONFIG_PSERIES_PLPKS_SED +int sed_read_key(char *keyname, char *key, u_int *keylen); +int sed_write_key(char *keyname, char *key, u_int keylen); +#else +static inline +int sed_read_key(char *keyname, char *key, u_int *keylen) { + return -EOPNOTSUPP; +} +static inline +int sed_write_key(char *keyname, char *key, u_int keylen) { + return -EOPNOTSUPP; +} +#endif diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h index 8e61f8b7c2ce..425f64eee44e 100644 --- a/include/uapi/linux/io_uring.h +++ b/include/uapi/linux/io_uring.h @@ -65,6 +65,7 @@ struct io_uring_sqe { __u32 xattr_flags; __u32 msg_ring_flags; __u32 uring_cmd_flags; + __u32 waitid_flags; }; __u64 user_data; /* data to be passed back at completion time */ /* pack this to avoid bogus arm OABI complaints */ @@ -240,19 +241,20 @@ enum io_uring_op { IORING_OP_URING_CMD, IORING_OP_SEND_ZC, IORING_OP_SENDMSG_ZC, + IORING_OP_READ_MULTISHOT, + IORING_OP_WAITID, /* this goes last, obviously */ IORING_OP_LAST, }; /* - * sqe->uring_cmd_flags + * sqe->uring_cmd_flags top 8bits aren't available for userspace * IORING_URING_CMD_FIXED use registered buffer; pass this flag * along with setting sqe->buf_index. - * IORING_URING_CMD_POLLED driver use only */ #define IORING_URING_CMD_FIXED (1U << 0) -#define IORING_URING_CMD_POLLED (1U << 31) +#define IORING_URING_CMD_MASK IORING_URING_CMD_FIXED /* diff --git a/io_uring/Makefile b/io_uring/Makefile index 8cc8e5387a75..7bd64e442567 100644 --- a/io_uring/Makefile +++ b/io_uring/Makefile @@ -7,5 +7,6 @@ obj-$(CONFIG_IO_URING) += io_uring.o xattr.o nop.o fs.o splice.o \ openclose.o uring_cmd.o epoll.o \ statx.o net.o msg_ring.o timeout.o \ sqpoll.o fdinfo.o tctx.o poll.o \ - cancel.o kbuf.o rsrc.o rw.o opdef.o notif.o + cancel.o kbuf.o rsrc.o rw.o opdef.o \ + notif.o waitid.o obj-$(CONFIG_IO_WQ) += io-wq.o diff --git a/io_uring/cancel.c b/io_uring/cancel.c index 7b23607cf4af..eb77a51c5a79 100644 --- a/io_uring/cancel.c +++ b/io_uring/cancel.c @@ -15,6 +15,7 @@ #include "tctx.h" #include "poll.h" #include "timeout.h" +#include "waitid.h" #include "cancel.h" struct io_cancel { @@ -119,6 +120,10 @@ int io_try_cancel(struct io_uring_task *tctx, struct io_cancel_data *cd, if (ret != -ENOENT) return ret; + ret = io_waitid_cancel(ctx, cd, issue_flags); + if (ret != -ENOENT) + return ret; + spin_lock(&ctx->completion_lock); if (!(cd->flags & IORING_ASYNC_CANCEL_FD)) ret = io_timeout_cancel(ctx, cd); diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c index 783ed0fff71b..b9e1af5772f3 100644 --- a/io_uring/io_uring.c +++ b/io_uring/io_uring.c @@ -92,6 +92,7 @@ #include "cancel.h" #include "net.h" #include "notif.h" +#include "waitid.h" #include "timeout.h" #include "poll.h" @@ -338,7 +339,6 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p) spin_lock_init(&ctx->completion_lock); spin_lock_init(&ctx->timeout_lock); INIT_WQ_LIST(&ctx->iopoll_list); - INIT_LIST_HEAD(&ctx->io_buffers_pages); INIT_LIST_HEAD(&ctx->io_buffers_comp); INIT_LIST_HEAD(&ctx->defer_list); INIT_LIST_HEAD(&ctx->timeout_list); @@ -348,8 +348,10 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p) INIT_LIST_HEAD(&ctx->tctx_list); ctx->submit_state.free_list.next = NULL; INIT_WQ_LIST(&ctx->locked_free_list); + INIT_HLIST_HEAD(&ctx->waitid_list); INIT_DELAYED_WORK(&ctx->fallback_work, io_fallback_req_func); INIT_WQ_LIST(&ctx->submit_state.compl_reqs); + INIT_HLIST_HEAD(&ctx->cancelable_uring_cmd); return ctx; err: kfree(ctx->cancel_table.hbs); @@ -3256,6 +3258,37 @@ static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx) return ret; } +static bool io_uring_try_cancel_uring_cmd(struct io_ring_ctx *ctx, + struct task_struct *task, bool cancel_all) +{ + struct hlist_node *tmp; + struct io_kiocb *req; + bool ret = false; + + lockdep_assert_held(&ctx->uring_lock); + + hlist_for_each_entry_safe(req, tmp, &ctx->cancelable_uring_cmd, + hash_node) { + struct io_uring_cmd *cmd = io_kiocb_to_cmd(req, + struct io_uring_cmd); + struct file *file = req->file; + + if (!cancel_all && req->task != task) + continue; + + if (cmd->flags & IORING_URING_CMD_CANCELABLE) { + /* ->sqe isn't available if no async data */ + if (!req_has_async_data(req)) + cmd->sqe = NULL; + file->f_op->uring_cmd(cmd, IO_URING_F_CANCEL); + ret = true; + } + } + io_submit_flush_completions(ctx); + + return ret; +} + static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx, struct task_struct *task, bool cancel_all) @@ -3303,6 +3336,8 @@ static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx, ret |= io_cancel_defer_files(ctx, task, cancel_all); mutex_lock(&ctx->uring_lock); ret |= io_poll_remove_all(ctx, task, cancel_all); + ret |= io_waitid_remove_all(ctx, task, cancel_all); + ret |= io_uring_try_cancel_uring_cmd(ctx, task, cancel_all); mutex_unlock(&ctx->uring_lock); ret |= io_kill_timeouts(ctx, task, cancel_all); if (task) @@ -4666,6 +4701,9 @@ static int __init io_uring_init(void) BUILD_BUG_ON(sizeof(atomic_t) != sizeof(u32)); + /* top 8bits are for internal use */ + BUILD_BUG_ON((IORING_URING_CMD_MASK & 0xff000000) != 0); + io_uring_optable_init(); /* @@ -4681,6 +4719,9 @@ static int __init io_uring_init(void) SLAB_ACCOUNT | SLAB_TYPESAFE_BY_RCU, offsetof(struct io_kiocb, cmd.data), sizeof_field(struct io_kiocb, cmd.data), NULL); + io_buf_cachep = kmem_cache_create("io_buffer", sizeof(struct io_buffer), 0, + SLAB_HWCACHE_ALIGN | SLAB_PANIC | SLAB_ACCOUNT, + NULL); #ifdef CONFIG_SYSCTL register_sysctl_init("kernel", kernel_io_uring_disabled_table); diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h index 547c30582fb8..2ff719ae1b57 100644 --- a/io_uring/io_uring.h +++ b/io_uring/io_uring.h @@ -330,6 +330,7 @@ static inline bool io_req_cache_empty(struct io_ring_ctx *ctx) } extern struct kmem_cache *req_cachep; +extern struct kmem_cache *io_buf_cachep; static inline struct io_kiocb *io_extract_req(struct io_ring_ctx *ctx) { diff --git a/io_uring/kbuf.c b/io_uring/kbuf.c index 556f4df25b0f..d5a04467666f 100644 --- a/io_uring/kbuf.c +++ b/io_uring/kbuf.c @@ -19,12 +19,17 @@ #define BGID_ARRAY 64 +/* BIDs are addressed by a 16-bit field in a CQE */ +#define MAX_BIDS_PER_BGID (1 << 16) + +struct kmem_cache *io_buf_cachep; + struct io_provide_buf { struct file *file; __u64 addr; __u32 len; __u32 bgid; - __u16 nbufs; + __u32 nbufs; __u16 bid; }; @@ -255,6 +260,8 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx, void io_destroy_buffers(struct io_ring_ctx *ctx) { struct io_buffer_list *bl; + struct list_head *item, *tmp; + struct io_buffer *buf; unsigned long index; int i; @@ -270,12 +277,9 @@ void io_destroy_buffers(struct io_ring_ctx *ctx) kfree(bl); } - while (!list_empty(&ctx->io_buffers_pages)) { - struct page *page; - - page = list_first_entry(&ctx->io_buffers_pages, struct page, lru); - list_del_init(&page->lru); - __free_page(page); + list_for_each_safe(item, tmp, &ctx->io_buffers_cache) { + buf = list_entry(item, struct io_buffer, list); + kmem_cache_free(io_buf_cachep, buf); } } @@ -289,7 +293,7 @@ int io_remove_buffers_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) return -EINVAL; tmp = READ_ONCE(sqe->fd); - if (!tmp || tmp > USHRT_MAX) + if (!tmp || tmp > MAX_BIDS_PER_BGID) return -EINVAL; memset(p, 0, sizeof(*p)); @@ -332,7 +336,7 @@ int io_provide_buffers_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe return -EINVAL; tmp = READ_ONCE(sqe->fd); - if (!tmp || tmp > USHRT_MAX) + if (!tmp || tmp > MAX_BIDS_PER_BGID) return -E2BIG; p->nbufs = tmp; p->addr = READ_ONCE(sqe->addr); @@ -352,17 +356,18 @@ int io_provide_buffers_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe tmp = READ_ONCE(sqe->off); if (tmp > USHRT_MAX) return -E2BIG; - if (tmp + p->nbufs >= USHRT_MAX) + if (tmp + p->nbufs > MAX_BIDS_PER_BGID) return -EINVAL; p->bid = tmp; return 0; } +#define IO_BUFFER_ALLOC_BATCH 64 + static int io_refill_buffer_cache(struct io_ring_ctx *ctx) { - struct io_buffer *buf; - struct page *page; - int bufs_in_page; + struct io_buffer *bufs[IO_BUFFER_ALLOC_BATCH]; + int allocated; /* * Completions that don't happen inline (eg not under uring_lock) will @@ -382,22 +387,25 @@ static int io_refill_buffer_cache(struct io_ring_ctx *ctx) /* * No free buffers and no completion entries either. Allocate a new - * page worth of buffer entries and add those to our freelist. + * batch of buffer entries and add those to our freelist. */ - page = alloc_page(GFP_KERNEL_ACCOUNT); - if (!page) - return -ENOMEM; - - list_add(&page->lru, &ctx->io_buffers_pages); - buf = page_address(page); - bufs_in_page = PAGE_SIZE / sizeof(*buf); - while (bufs_in_page) { - list_add_tail(&buf->list, &ctx->io_buffers_cache); - buf++; - bufs_in_page--; + allocated = kmem_cache_alloc_bulk(io_buf_cachep, GFP_KERNEL_ACCOUNT, + ARRAY_SIZE(bufs), (void **) bufs); + if (unlikely(!allocated)) { + /* + * Bulk alloc is all-or-nothing. If we fail to get a batch, + * retry single alloc to be on the safe side. + */ + bufs[0] = kmem_cache_alloc(io_buf_cachep, GFP_KERNEL); + if (!bufs[0]) + return -ENOMEM; + allocated = 1; } + while (allocated) + list_add_tail(&bufs[--allocated]->list, &ctx->io_buffers_cache); + return 0; } diff --git a/io_uring/opdef.c b/io_uring/opdef.c index 3b9c6489b8b6..aadcbf7136b0 100644 --- a/io_uring/opdef.c +++ b/io_uring/opdef.c @@ -33,6 +33,7 @@ #include "poll.h" #include "cancel.h" #include "rw.h" +#include "waitid.h" static int io_no_issue(struct io_kiocb *req, unsigned int issue_flags) { @@ -63,6 +64,7 @@ const struct io_issue_def io_issue_defs[] = { .ioprio = 1, .iopoll = 1, .iopoll_queue = 1, + .vectored = 1, .prep = io_prep_rw, .issue = io_read, }, @@ -76,6 +78,7 @@ const struct io_issue_def io_issue_defs[] = { .ioprio = 1, .iopoll = 1, .iopoll_queue = 1, + .vectored = 1, .prep = io_prep_rw, .issue = io_write, }, @@ -428,9 +431,21 @@ const struct io_issue_def io_issue_defs[] = { .prep = io_eopnotsupp_prep, #endif }, + [IORING_OP_READ_MULTISHOT] = { + .needs_file = 1, + .unbound_nonreg_file = 1, + .pollin = 1, + .buffer_select = 1, + .audit_skip = 1, + .prep = io_read_mshot_prep, + .issue = io_read_mshot, + }, + [IORING_OP_WAITID] = { + .prep = io_waitid_prep, + .issue = io_waitid, + }, }; - const struct io_cold_def io_cold_defs[] = { [IORING_OP_NOP] = { .name = "NOP", @@ -648,6 +663,13 @@ const struct io_cold_def io_cold_defs[] = { .fail = io_sendrecv_fail, #endif }, + [IORING_OP_READ_MULTISHOT] = { + .name = "READ_MULTISHOT", + }, + [IORING_OP_WAITID] = { + .name = "WAITID", + .async_size = sizeof(struct io_waitid_async), + }, }; const char *io_uring_get_opcode(u8 opcode) diff --git a/io_uring/opdef.h b/io_uring/opdef.h index c22c8696e749..9e5435ec27d0 100644 --- a/io_uring/opdef.h +++ b/io_uring/opdef.h @@ -29,6 +29,8 @@ struct io_issue_def { unsigned iopoll_queue : 1; /* opcode specific path will handle ->async_data allocation if needed */ unsigned manual_alloc : 1; + /* vectored opcode, set if 1) vectored, and 2) handler needs to know */ + unsigned vectored : 1; int (*issue)(struct io_kiocb *, unsigned int); int (*prep)(struct io_kiocb *, const struct io_uring_sqe *); diff --git a/io_uring/rsrc.c b/io_uring/rsrc.c index d9c853d10587..7034be555334 100644 --- a/io_uring/rsrc.c +++ b/io_uring/rsrc.c @@ -1037,39 +1037,36 @@ struct page **io_pin_pages(unsigned long ubuf, unsigned long len, int *npages) { unsigned long start, end, nr_pages; struct page **pages = NULL; - int pret, ret = -ENOMEM; + int ret; end = (ubuf + len + PAGE_SIZE - 1) >> PAGE_SHIFT; start = ubuf >> PAGE_SHIFT; nr_pages = end - start; + WARN_ON(!nr_pages); pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL); if (!pages) - goto done; + return ERR_PTR(-ENOMEM); - ret = 0; mmap_read_lock(current->mm); - pret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM, - pages); - if (pret == nr_pages) + ret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM, pages); + mmap_read_unlock(current->mm); + + /* success, mapped all pages */ + if (ret == nr_pages) { *npages = nr_pages; - else - ret = pret < 0 ? pret : -EFAULT; + return pages; + } - mmap_read_unlock(current->mm); - if (ret) { + /* partial map, or didn't map anything */ + if (ret >= 0) { /* if we did partial map, release any pages we did get */ - if (pret > 0) - unpin_user_pages(pages, pret); - goto done; - } - ret = 0; -done: - if (ret < 0) { - kvfree(pages); - pages = ERR_PTR(ret); + if (ret) + unpin_user_pages(pages, ret); + ret = -EFAULT; } - return pages; + kvfree(pages); + return ERR_PTR(ret); } static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov, diff --git a/io_uring/rw.c b/io_uring/rw.c index c8c822fa7980..ec0cc38ea682 100644 --- a/io_uring/rw.c +++ b/io_uring/rw.c @@ -123,6 +123,22 @@ int io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe) return 0; } +/* + * Multishot read is prepared just like a normal read/write request, only + * difference is that we set the MULTISHOT flag. + */ +int io_read_mshot_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) +{ + int ret; + + ret = io_prep_rw(req, sqe); + if (unlikely(ret)) + return ret; + + req->flags |= REQ_F_APOLL_MULTISHOT; + return 0; +} + void io_readv_writev_cleanup(struct io_kiocb *req) { struct io_async_rw *io = req->async_data; @@ -388,8 +404,7 @@ static struct iovec *__io_import_iovec(int ddir, struct io_kiocb *req, buf = u64_to_user_ptr(rw->addr); sqe_len = rw->len; - if (opcode == IORING_OP_READ || opcode == IORING_OP_WRITE || - (req->flags & REQ_F_BUFFER_SELECT)) { + if (!io_issue_defs[opcode].vectored || req->flags & REQ_F_BUFFER_SELECT) { if (io_do_buffer_select(req)) { buf = io_buffer_select(req, &sqe_len, issue_flags); if (!buf) @@ -708,7 +723,7 @@ static int io_rw_init_file(struct io_kiocb *req, fmode_t mode) return 0; } -int io_read(struct io_kiocb *req, unsigned int issue_flags) +static int __io_read(struct io_kiocb *req, unsigned int issue_flags) { struct io_rw *rw = io_kiocb_to_cmd(req, struct io_rw); struct io_rw_state __s, *s = &__s; @@ -776,8 +791,11 @@ int io_read(struct io_kiocb *req, unsigned int issue_flags) if (ret == -EAGAIN || (req->flags & REQ_F_REISSUE)) { req->flags &= ~REQ_F_REISSUE; - /* if we can poll, just do that */ - if (req->opcode == IORING_OP_READ && file_can_poll(req->file)) + /* + * If we can poll, just do that. For a vectored read, we'll + * need to copy state first. + */ + if (file_can_poll(req->file) && !io_issue_defs[req->opcode].vectored) return -EAGAIN; /* IOPOLL retry should happen for io-wq threads */ if (!force_nonblock && !(req->ctx->flags & IORING_SETUP_IOPOLL)) @@ -853,7 +871,69 @@ done: /* it's faster to check here then delegate to kfree */ if (iovec) kfree(iovec); - return kiocb_done(req, ret, issue_flags); + return ret; +} + +int io_read(struct io_kiocb *req, unsigned int issue_flags) +{ + int ret; + + ret = __io_read(req, issue_flags); + if (ret >= 0) + return kiocb_done(req, ret, issue_flags); + + return ret; +} + +int io_read_mshot(struct io_kiocb *req, unsigned int issue_flags) +{ + unsigned int cflags = 0; + int ret; + + /* + * Multishot MUST be used on a pollable file + */ + if (!file_can_poll(req->file)) + return -EBADFD; + + ret = __io_read(req, issue_flags); + + /* + * If we get -EAGAIN, recycle our buffer and just let normal poll + * handling arm it. + */ + if (ret == -EAGAIN) { + io_kbuf_recycle(req, issue_flags); + return -EAGAIN; + } + + /* + * Any successful return value will keep the multishot read armed. + */ + if (ret > 0) { + /* + * Put our buffer and post a CQE. If we fail to post a CQE, then + * jump to the termination path. This request is then done. + */ + cflags = io_put_kbuf(req, issue_flags); + + if (io_fill_cqe_req_aux(req, + issue_flags & IO_URING_F_COMPLETE_DEFER, + ret, cflags | IORING_CQE_F_MORE)) { + if (issue_flags & IO_URING_F_MULTISHOT) + return IOU_ISSUE_SKIP_COMPLETE; + return -EAGAIN; + } + } + + /* + * Either an error, or we've hit overflow posting the CQE. For any + * multishot request, hitting overflow will terminate it. + */ + io_req_set_res(req, ret, cflags); + if (issue_flags & IO_URING_F_MULTISHOT) + return IOU_STOP_MULTISHOT; + return IOU_OK; } int io_write(struct io_kiocb *req, unsigned int issue_flags) diff --git a/io_uring/rw.h b/io_uring/rw.h index 4b89f9659366..c5aed03d42a4 100644 --- a/io_uring/rw.h +++ b/io_uring/rw.h @@ -23,3 +23,5 @@ int io_writev_prep_async(struct io_kiocb *req); void io_readv_writev_cleanup(struct io_kiocb *req); void io_rw_fail(struct io_kiocb *req); void io_req_rw_complete(struct io_kiocb *req, struct io_tw_state *ts); +int io_read_mshot_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe); +int io_read_mshot(struct io_kiocb *req, unsigned int issue_flags); diff --git a/io_uring/uring_cmd.c b/io_uring/uring_cmd.c index 537795fddc87..00a5e5621a28 100644 --- a/io_uring/uring_cmd.c +++ b/io_uring/uring_cmd.c @@ -13,6 +13,51 @@ #include "rsrc.h" #include "uring_cmd.h" +static void io_uring_cmd_del_cancelable(struct io_uring_cmd *cmd, + unsigned int issue_flags) +{ + struct io_kiocb *req = cmd_to_io_kiocb(cmd); + struct io_ring_ctx *ctx = req->ctx; + + if (!(cmd->flags & IORING_URING_CMD_CANCELABLE)) + return; + + cmd->flags &= ~IORING_URING_CMD_CANCELABLE; + io_ring_submit_lock(ctx, issue_flags); + hlist_del(&req->hash_node); + io_ring_submit_unlock(ctx, issue_flags); +} + +/* + * Mark this command as concelable, then io_uring_try_cancel_uring_cmd() + * will try to cancel this issued command by sending ->uring_cmd() with + * issue_flags of IO_URING_F_CANCEL. + * + * The command is guaranteed to not be done when calling ->uring_cmd() + * with IO_URING_F_CANCEL, but it is driver's responsibility to deal + * with race between io_uring canceling and normal completion. + */ +void io_uring_cmd_mark_cancelable(struct io_uring_cmd *cmd, + unsigned int issue_flags) +{ + struct io_kiocb *req = cmd_to_io_kiocb(cmd); + struct io_ring_ctx *ctx = req->ctx; + + if (!(cmd->flags & IORING_URING_CMD_CANCELABLE)) { + cmd->flags |= IORING_URING_CMD_CANCELABLE; + io_ring_submit_lock(ctx, issue_flags); + hlist_add_head(&req->hash_node, &ctx->cancelable_uring_cmd); + io_ring_submit_unlock(ctx, issue_flags); + } +} +EXPORT_SYMBOL_GPL(io_uring_cmd_mark_cancelable); + +struct task_struct *io_uring_cmd_get_task(struct io_uring_cmd *cmd) +{ + return cmd_to_io_kiocb(cmd)->task; +} +EXPORT_SYMBOL_GPL(io_uring_cmd_get_task); + static void io_uring_cmd_work(struct io_kiocb *req, struct io_tw_state *ts) { struct io_uring_cmd *ioucmd = io_kiocb_to_cmd(req, struct io_uring_cmd); @@ -56,6 +101,8 @@ void io_uring_cmd_done(struct io_uring_cmd *ioucmd, ssize_t ret, ssize_t res2, { struct io_kiocb *req = cmd_to_io_kiocb(ioucmd); + io_uring_cmd_del_cancelable(ioucmd, issue_flags); + if (ret < 0) req_set_fail(req); @@ -91,7 +138,7 @@ int io_uring_cmd_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) return -EINVAL; ioucmd->flags = READ_ONCE(sqe->uring_cmd_flags); - if (ioucmd->flags & ~IORING_URING_CMD_FIXED) + if (ioucmd->flags & ~IORING_URING_CMD_MASK) return -EINVAL; if (ioucmd->flags & IORING_URING_CMD_FIXED) { diff --git a/io_uring/waitid.c b/io_uring/waitid.c new file mode 100644 index 000000000000..6f851978606d --- /dev/null +++ b/io_uring/waitid.c @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Support for async notification of waitid + */ +#include <linux/kernel.h> +#include <linux/errno.h> +#include <linux/fs.h> +#include <linux/file.h> +#include <linux/compat.h> +#include <linux/io_uring.h> + +#include <uapi/linux/io_uring.h> + +#include "io_uring.h" +#include "cancel.h" +#include "waitid.h" +#include "../kernel/exit.h" + +static void io_waitid_cb(struct io_kiocb *req, struct io_tw_state *ts); + +#define IO_WAITID_CANCEL_FLAG BIT(31) +#define IO_WAITID_REF_MASK GENMASK(30, 0) + +struct io_waitid { + struct file *file; + int which; + pid_t upid; + int options; + atomic_t refs; + struct wait_queue_head *head; + struct siginfo __user *infop; + struct waitid_info info; +}; + +static void io_waitid_free(struct io_kiocb *req) +{ + struct io_waitid_async *iwa = req->async_data; + + put_pid(iwa->wo.wo_pid); + kfree(req->async_data); + req->async_data = NULL; + req->flags &= ~REQ_F_ASYNC_DATA; +} + +#ifdef CONFIG_COMPAT +static bool io_waitid_compat_copy_si(struct io_waitid *iw, int signo) +{ + struct compat_siginfo __user *infop; + bool ret; + + infop = (struct compat_siginfo __user *) iw->infop; + + if (!user_write_access_begin(infop, sizeof(*infop))) + return false; + + unsafe_put_user(signo, &infop->si_signo, Efault); + unsafe_put_user(0, &infop->si_errno, Efault); + unsafe_put_user(iw->info.cause, &infop->si_code, Efault); + unsafe_put_user(iw->info.pid, &infop->si_pid, Efault); + unsafe_put_user(iw->info.uid, &infop->si_uid, Efault); + unsafe_put_user(iw->info.status, &infop->si_status, Efault); + ret = true; +done: + user_write_access_end(); + return ret; +Efault: + ret = false; + goto done; +} +#endif + +static bool io_waitid_copy_si(struct io_kiocb *req, int signo) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + bool ret; + + if (!iw->infop) + return true; + +#ifdef CONFIG_COMPAT + if (req->ctx->compat) + return io_waitid_compat_copy_si(iw, signo); +#endif + + if (!user_write_access_begin(iw->infop, sizeof(*iw->infop))) + return false; + + unsafe_put_user(signo, &iw->infop->si_signo, Efault); + unsafe_put_user(0, &iw->infop->si_errno, Efault); + unsafe_put_user(iw->info.cause, &iw->infop->si_code, Efault); + unsafe_put_user(iw->info.pid, &iw->infop->si_pid, Efault); + unsafe_put_user(iw->info.uid, &iw->infop->si_uid, Efault); + unsafe_put_user(iw->info.status, &iw->infop->si_status, Efault); + ret = true; +done: + user_write_access_end(); + return ret; +Efault: + ret = false; + goto done; +} + +static int io_waitid_finish(struct io_kiocb *req, int ret) +{ + int signo = 0; + + if (ret > 0) { + signo = SIGCHLD; + ret = 0; + } + + if (!io_waitid_copy_si(req, signo)) + ret = -EFAULT; + io_waitid_free(req); + return ret; +} + +static void io_waitid_complete(struct io_kiocb *req, int ret) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + struct io_tw_state ts = { .locked = true }; + + /* anyone completing better be holding a reference */ + WARN_ON_ONCE(!(atomic_read(&iw->refs) & IO_WAITID_REF_MASK)); + + lockdep_assert_held(&req->ctx->uring_lock); + + /* + * Did cancel find it meanwhile? + */ + if (hlist_unhashed(&req->hash_node)) + return; + + hlist_del_init(&req->hash_node); + + ret = io_waitid_finish(req, ret); + if (ret < 0) + req_set_fail(req); + io_req_set_res(req, ret, 0); + io_req_task_complete(req, &ts); +} + +static bool __io_waitid_cancel(struct io_ring_ctx *ctx, struct io_kiocb *req) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + struct io_waitid_async *iwa = req->async_data; + + /* + * Mark us canceled regardless of ownership. This will prevent a + * potential retry from a spurious wakeup. + */ + atomic_or(IO_WAITID_CANCEL_FLAG, &iw->refs); + + /* claim ownership */ + if (atomic_fetch_inc(&iw->refs) & IO_WAITID_REF_MASK) + return false; + + spin_lock_irq(&iw->head->lock); + list_del_init(&iwa->wo.child_wait.entry); + spin_unlock_irq(&iw->head->lock); + io_waitid_complete(req, -ECANCELED); + return true; +} + +int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd, + unsigned int issue_flags) +{ + struct hlist_node *tmp; + struct io_kiocb *req; + int nr = 0; + + if (cd->flags & (IORING_ASYNC_CANCEL_FD|IORING_ASYNC_CANCEL_FD_FIXED)) + return -ENOENT; + + io_ring_submit_lock(ctx, issue_flags); + hlist_for_each_entry_safe(req, tmp, &ctx->waitid_list, hash_node) { + if (req->cqe.user_data != cd->data && + !(cd->flags & IORING_ASYNC_CANCEL_ANY)) + continue; + if (__io_waitid_cancel(ctx, req)) + nr++; + if (!(cd->flags & IORING_ASYNC_CANCEL_ALL)) + break; + } + io_ring_submit_unlock(ctx, issue_flags); + + if (nr) + return nr; + + return -ENOENT; +} + +bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task, + bool cancel_all) +{ + struct hlist_node *tmp; + struct io_kiocb *req; + bool found = false; + + lockdep_assert_held(&ctx->uring_lock); + + hlist_for_each_entry_safe(req, tmp, &ctx->waitid_list, hash_node) { + if (!io_match_task_safe(req, task, cancel_all)) + continue; + __io_waitid_cancel(ctx, req); + found = true; + } + + return found; +} + +static inline bool io_waitid_drop_issue_ref(struct io_kiocb *req) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + struct io_waitid_async *iwa = req->async_data; + + if (!atomic_sub_return(1, &iw->refs)) + return false; + + /* + * Wakeup triggered, racing with us. It was prevented from + * completing because of that, queue up the tw to do that. + */ + req->io_task_work.func = io_waitid_cb; + io_req_task_work_add(req); + remove_wait_queue(iw->head, &iwa->wo.child_wait); + return true; +} + +static void io_waitid_cb(struct io_kiocb *req, struct io_tw_state *ts) +{ + struct io_waitid_async *iwa = req->async_data; + struct io_ring_ctx *ctx = req->ctx; + int ret; + + io_tw_lock(ctx, ts); + + ret = __do_wait(&iwa->wo); + + /* + * If we get -ERESTARTSYS here, we need to re-arm and check again + * to ensure we get another callback. If the retry works, then we can + * just remove ourselves from the waitqueue again and finish the + * request. + */ + if (unlikely(ret == -ERESTARTSYS)) { + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + + /* Don't retry if cancel found it meanwhile */ + ret = -ECANCELED; + if (!(atomic_read(&iw->refs) & IO_WAITID_CANCEL_FLAG)) { + iw->head = ¤t->signal->wait_chldexit; + add_wait_queue(iw->head, &iwa->wo.child_wait); + ret = __do_wait(&iwa->wo); + if (ret == -ERESTARTSYS) { + /* retry armed, drop our ref */ + io_waitid_drop_issue_ref(req); + return; + } + + remove_wait_queue(iw->head, &iwa->wo.child_wait); + } + } + + io_waitid_complete(req, ret); +} + +static int io_waitid_wait(struct wait_queue_entry *wait, unsigned mode, + int sync, void *key) +{ + struct wait_opts *wo = container_of(wait, struct wait_opts, child_wait); + struct io_waitid_async *iwa = container_of(wo, struct io_waitid_async, wo); + struct io_kiocb *req = iwa->req; + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + struct task_struct *p = key; + + if (!pid_child_should_wake(wo, p)) + return 0; + + /* cancel is in progress */ + if (atomic_fetch_inc(&iw->refs) & IO_WAITID_REF_MASK) + return 1; + + req->io_task_work.func = io_waitid_cb; + io_req_task_work_add(req); + list_del_init(&wait->entry); + return 1; +} + +int io_waitid_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + + if (sqe->addr || sqe->buf_index || sqe->addr3 || sqe->waitid_flags) + return -EINVAL; + + iw->which = READ_ONCE(sqe->len); + iw->upid = READ_ONCE(sqe->fd); + iw->options = READ_ONCE(sqe->file_index); + iw->infop = u64_to_user_ptr(READ_ONCE(sqe->addr2)); + return 0; +} + +int io_waitid(struct io_kiocb *req, unsigned int issue_flags) +{ + struct io_waitid *iw = io_kiocb_to_cmd(req, struct io_waitid); + struct io_ring_ctx *ctx = req->ctx; + struct io_waitid_async *iwa; + int ret; + + if (io_alloc_async_data(req)) + return -ENOMEM; + + iwa = req->async_data; + iwa->req = req; + + ret = kernel_waitid_prepare(&iwa->wo, iw->which, iw->upid, &iw->info, + iw->options, NULL); + if (ret) + goto done; + + /* + * Mark the request as busy upfront, in case we're racing with the + * wakeup. If we are, then we'll notice when we drop this initial + * reference again after arming. + */ + atomic_set(&iw->refs, 1); + + /* + * Cancel must hold the ctx lock, so there's no risk of cancelation + * finding us until a) we remain on the list, and b) the lock is + * dropped. We only need to worry about racing with the wakeup + * callback. + */ + io_ring_submit_lock(ctx, issue_flags); + hlist_add_head(&req->hash_node, &ctx->waitid_list); + + init_waitqueue_func_entry(&iwa->wo.child_wait, io_waitid_wait); + iwa->wo.child_wait.private = req->task; + iw->head = ¤t->signal->wait_chldexit; + add_wait_queue(iw->head, &iwa->wo.child_wait); + + ret = __do_wait(&iwa->wo); + if (ret == -ERESTARTSYS) { + /* + * Nobody else grabbed a reference, it'll complete when we get + * a waitqueue callback, or if someone cancels it. + */ + if (!io_waitid_drop_issue_ref(req)) { + io_ring_submit_unlock(ctx, issue_flags); + return IOU_ISSUE_SKIP_COMPLETE; + } + + /* + * Wakeup triggered, racing with us. It was prevented from + * completing because of that, queue up the tw to do that. + */ + io_ring_submit_unlock(ctx, issue_flags); + return IOU_ISSUE_SKIP_COMPLETE; + } + + hlist_del_init(&req->hash_node); + remove_wait_queue(iw->head, &iwa->wo.child_wait); + ret = io_waitid_finish(req, ret); + + io_ring_submit_unlock(ctx, issue_flags); +done: + if (ret < 0) + req_set_fail(req); + io_req_set_res(req, ret, 0); + return IOU_OK; +} diff --git a/io_uring/waitid.h b/io_uring/waitid.h new file mode 100644 index 000000000000..956a8adafe8c --- /dev/null +++ b/io_uring/waitid.h @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include "../kernel/exit.h" + +struct io_waitid_async { + struct io_kiocb *req; + struct wait_opts wo; +}; + +int io_waitid_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe); +int io_waitid(struct io_kiocb *req, unsigned int issue_flags); +int io_waitid_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd, + unsigned int issue_flags); +bool io_waitid_remove_all(struct io_ring_ctx *ctx, struct task_struct *task, + bool cancel_all); diff --git a/kernel/exit.c b/kernel/exit.c index edb50b4c9972..2b4a232f2f68 100644 --- a/kernel/exit.c +++ b/kernel/exit.c @@ -74,6 +74,8 @@ #include <asm/unistd.h> #include <asm/mmu_context.h> +#include "exit.h" + /* * The default value should be high enough to not crash a system that randomly * crashes its kernel from time to time, but low enough to at least not permit @@ -1037,26 +1039,6 @@ SYSCALL_DEFINE1(exit_group, int, error_code) return 0; } -struct waitid_info { - pid_t pid; - uid_t uid; - int status; - int cause; -}; - -struct wait_opts { - enum pid_type wo_type; - int wo_flags; - struct pid *wo_pid; - - struct waitid_info *wo_info; - int wo_stat; - struct rusage *wo_rusage; - - wait_queue_entry_t child_wait; - int notask_error; -}; - static int eligible_pid(struct wait_opts *wo, struct task_struct *p) { return wo->wo_type == PIDTYPE_MAX || @@ -1520,6 +1502,17 @@ static int ptrace_do_wait(struct wait_opts *wo, struct task_struct *tsk) return 0; } +bool pid_child_should_wake(struct wait_opts *wo, struct task_struct *p) +{ + if (!eligible_pid(wo, p)) + return false; + + if ((wo->wo_flags & __WNOTHREAD) && wo->child_wait.private != p->parent) + return false; + + return true; +} + static int child_wait_callback(wait_queue_entry_t *wait, unsigned mode, int sync, void *key) { @@ -1527,13 +1520,10 @@ static int child_wait_callback(wait_queue_entry_t *wait, unsigned mode, child_wait); struct task_struct *p = key; - if (!eligible_pid(wo, p)) - return 0; + if (pid_child_should_wake(wo, p)) + return default_wake_function(wait, mode, sync, key); - if ((wo->wo_flags & __WNOTHREAD) && wait->private != p->parent) - return 0; - - return default_wake_function(wait, mode, sync, key); + return 0; } void __wake_up_parent(struct task_struct *p, struct task_struct *parent) @@ -1582,16 +1572,10 @@ static int do_wait_pid(struct wait_opts *wo) return 0; } -static long do_wait(struct wait_opts *wo) +long __do_wait(struct wait_opts *wo) { - int retval; + long retval; - trace_sched_process_wait(wo->wo_pid); - - init_waitqueue_func_entry(&wo->child_wait, child_wait_callback); - wo->child_wait.private = current; - add_wait_queue(¤t->signal->wait_chldexit, &wo->child_wait); -repeat: /* * If there is nothing that can match our criteria, just get out. * We will clear ->notask_error to zero if we see any child that @@ -1603,24 +1587,23 @@ repeat: (!wo->wo_pid || !pid_has_task(wo->wo_pid, wo->wo_type))) goto notask; - set_current_state(TASK_INTERRUPTIBLE); read_lock(&tasklist_lock); if (wo->wo_type == PIDTYPE_PID) { retval = do_wait_pid(wo); if (retval) - goto end; + return retval; } else { struct task_struct *tsk = current; do { retval = do_wait_thread(wo, tsk); if (retval) - goto end; + return retval; retval = ptrace_do_wait(wo, tsk); if (retval) - goto end; + return retval; if (wo->wo_flags & __WNOTHREAD) break; @@ -1630,27 +1613,44 @@ repeat: notask: retval = wo->notask_error; - if (!retval && !(wo->wo_flags & WNOHANG)) { - retval = -ERESTARTSYS; - if (!signal_pending(current)) { - schedule(); - goto repeat; - } - } -end: + if (!retval && !(wo->wo_flags & WNOHANG)) + return -ERESTARTSYS; + + return retval; +} + +static long do_wait(struct wait_opts *wo) +{ + int retval; + + trace_sched_process_wait(wo->wo_pid); + + init_waitqueue_func_entry(&wo->child_wait, child_wait_callback); + wo->child_wait.private = current; + add_wait_queue(¤t->signal->wait_chldexit, &wo->child_wait); + + do { + set_current_state(TASK_INTERRUPTIBLE); + retval = __do_wait(wo); + if (retval != -ERESTARTSYS) + break; + if (signal_pending(current)) + break; + schedule(); + } while (1); + __set_current_state(TASK_RUNNING); remove_wait_queue(¤t->signal->wait_chldexit, &wo->child_wait); return retval; } -static long kernel_waitid(int which, pid_t upid, struct waitid_info *infop, - int options, struct rusage *ru) +int kernel_waitid_prepare(struct wait_opts *wo, int which, pid_t upid, + struct waitid_info *infop, int options, + struct rusage *ru) { - struct wait_opts wo; + unsigned int f_flags = 0; struct pid *pid = NULL; enum pid_type type; - long ret; - unsigned int f_flags = 0; if (options & ~(WNOHANG|WNOWAIT|WEXITED|WSTOPPED|WCONTINUED| __WNOTHREAD|__WCLONE|__WALL)) @@ -1693,19 +1693,32 @@ static long kernel_waitid(int which, pid_t upid, struct waitid_info *infop, return -EINVAL; } - wo.wo_type = type; - wo.wo_pid = pid; - wo.wo_flags = options; - wo.wo_info = infop; - wo.wo_rusage = ru; + wo->wo_type = type; + wo->wo_pid = pid; + wo->wo_flags = options; + wo->wo_info = infop; + wo->wo_rusage = ru; if (f_flags & O_NONBLOCK) - wo.wo_flags |= WNOHANG; + wo->wo_flags |= WNOHANG; + + return 0; +} + +static long kernel_waitid(int which, pid_t upid, struct waitid_info *infop, + int options, struct rusage *ru) +{ + struct wait_opts wo; + long ret; + + ret = kernel_waitid_prepare(&wo, which, upid, infop, options, ru); + if (ret) + return ret; ret = do_wait(&wo); - if (!ret && !(options & WNOHANG) && (f_flags & O_NONBLOCK)) + if (!ret && !(options & WNOHANG) && (wo.wo_flags & WNOHANG)) ret = -EAGAIN; - put_pid(pid); + put_pid(wo.wo_pid); return ret; } diff --git a/kernel/exit.h b/kernel/exit.h new file mode 100644 index 000000000000..278faa26a653 --- /dev/null +++ b/kernel/exit.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: GPL-2.0-only +#ifndef LINUX_WAITID_H +#define LINUX_WAITID_H + +struct waitid_info { + pid_t pid; + uid_t uid; + int status; + int cause; +}; + +struct wait_opts { + enum pid_type wo_type; + int wo_flags; + struct pid *wo_pid; + + struct waitid_info *wo_info; + int wo_stat; + struct rusage *wo_rusage; + + wait_queue_entry_t child_wait; + int notask_error; +}; + +bool pid_child_should_wake(struct wait_opts *wo, struct task_struct *p); +long __do_wait(struct wait_opts *wo); +int kernel_waitid_prepare(struct wait_opts *wo, int which, pid_t upid, + struct waitid_info *infop, int options, + struct rusage *ru); +#endif diff --git a/security/keys/key.c b/security/keys/key.c index 5c0c7df833f8..0260a1902922 100644 --- a/security/keys/key.c +++ b/security/keys/key.c @@ -693,6 +693,7 @@ error: spin_unlock(&key_serial_lock); return key; } +EXPORT_SYMBOL(key_lookup); /* * Find and lock the specified key type against removal. |