aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid S. Miller <davem@davemloft.net>2023-06-12 10:45:50 +0100
committerDavid S. Miller <davem@davemloft.net>2023-06-12 10:45:50 +0100
commitba47545c756b55f4b114c45fea7d52dd1577e181 (patch)
treef7a3daf73864e0b55ef2e44e85671706a5494c89
parent55d7c91406b4b486ea8c50e2fb31f1e1a0ef5143 (diff)
parent97154bcf4d1b7cabefec8a72cff5fbb91d5afb7b (diff)
Merge branch 'SCM_PIDFD-SCM_PEERPIDFD'
Alexander Mikhalitsyn says: ==================== Add SCM_PIDFD and SO_PEERPIDFD 1. Implement SCM_PIDFD, a new type of CMSG type analogical to SCM_CREDENTIALS, but it contains pidfd instead of plain pid, which allows programmers not to care about PID reuse problem. 2. Add SO_PEERPIDFD which allows to get pidfd of peer socket holder pidfd. This thing is direct analog of SO_PEERCRED which allows to get plain PID. 3. Add SCM_PIDFD / SO_PEERPIDFD kselftest Idea comes from UAPI kernel group: https://uapi-group.org/kernel-features/ Big thanks to Christian Brauner and Lennart Poettering for productive discussions about this and Luca Boccassi for testing and reviewing this. === Motivation behind this patchset Eric Dumazet raised a question: > It seems that we already can use pidfd_open() (since linux-5.3), and > pass the resulting fd in af_unix SCM_RIGHTS message ? Yes, it's possible, but it means that from the receiver side we need to trust the sent pidfd (in SCM_RIGHTS), or always use combination of SCM_RIGHTS+SCM_CREDENTIALS, then we can extract pidfd from SCM_RIGHTS, then acquire plain pid from pidfd and after compare it with the pid from SCM_CREDENTIALS. A few comments from other folks regarding this. Christian Brauner wrote: >Let me try and provide some of the missing background. >There are a range of use-cases where we would like to authenticate a >client through sockets without being susceptible to PID recycling >attacks. Currently, we can't do this as the race isn't fully fixable. >We can only apply mitigations. >What this patchset will allows us to do is to get a pidfd without the >client having to send us an fd explicitly via SCM_RIGHTS. As that's >already possibly as you correctly point out. >But for protocols like polkit this is quite important. Every message is >standalone and we would need to force a complete protocol change where >we would need to require that every client allocate and send a pidfd via >SCM_RIGHTS. That would also mean patching through all polkit users. >For something like systemd-journald where we provide logging facilities >and want to add metadata to the log we would also immensely benefit from >being able to get a receiver-side controlled pidfd. >With the message type we envisioned we don't need to change the sender >at all and can be safe against pid recycling. >Link: https://gitlab.freedesktop.org/polkit/polkit/-/merge_requests/154 >Link: https://uapi-group.org/kernel-features Lennart Poettering wrote: >So yes, this is of course possible, but it would mean the pidfd would >have to be transported as part of the user protocol, explicitly sent >by the sender. (Moreover, the receiver after receiving the pidfd would >then still have to somehow be able to prove that the pidfd it just >received actually refers to the peer's process and not some random >process. – this part is actually solvable in userspace, but ugly) >The big thing is simply that we want that the pidfd is associated >*implicity* with each AF_UNIX connection, not explicitly. A lot of >userspace already relies on this, both in the authentication area >(polkit) as well as in the logging area (systemd-journald). Right now >using the PID field from SO_PEERCREDS/SCM_CREDENTIALS is racy though >and very hard to get right. Making this available as pidfd too, would >solve this raciness, without otherwise changing semantics of it all: >receivers can still enable the creds stuff as they wish, and the data >is then implicitly appended to the connections/datagrams the sender >initiates. >Or to turn this around: things like polkit are typically used to >authenticate arbitrary dbus methods calls: some service implements a >dbus method call, and when an unprivileged client then issues that >call, it will take the client's info, go to polkit and ask it if this >is ok. If we wanted to send the pidfd as part of the protocol we >basically would have to extend every single method call to contain the >client's pidfd along with it as an additional argument, which would be >a massive undertaking: it would change the prototypes of basically >*all* methods a service defines… And that's just ugly. >Note that Alex' patch set doesn't expose anything that wasn't exposed >before, or attach, propagate what wasn't before. All it does, is make >the field already available anyway (the struct ucred .pid field) >available also in a better way (as a pidfd), to solve a variety of >races, with no effect on the protocol actually spoken within the >AF_UNIX transport. It's a seamless improvement of the status quo. ==================== Signed-off-by: David S. Miller <davem@davemloft.net>
-rw-r--r--arch/alpha/include/uapi/asm/socket.h3
-rw-r--r--arch/mips/include/uapi/asm/socket.h3
-rw-r--r--arch/parisc/include/uapi/asm/socket.h3
-rw-r--r--arch/sparc/include/uapi/asm/socket.h3
-rw-r--r--include/linux/net.h1
-rw-r--r--include/linux/socket.h1
-rw-r--r--include/net/scm.h39
-rw-r--r--include/uapi/asm-generic/socket.h3
-rw-r--r--net/core/sock.c44
-rw-r--r--net/mptcp/sockopt.c1
-rw-r--r--net/unix/Kconfig6
-rw-r--r--net/unix/af_unix.c34
-rw-r--r--tools/include/uapi/asm-generic/socket.h3
-rw-r--r--tools/testing/selftests/net/.gitignore1
-rw-r--r--tools/testing/selftests/net/af_unix/Makefile3
-rw-r--r--tools/testing/selftests/net/af_unix/scm_pidfd.c430
16 files changed, 565 insertions, 13 deletions
diff --git a/arch/alpha/include/uapi/asm/socket.h b/arch/alpha/include/uapi/asm/socket.h
index 739891b94136..e94f621903fe 100644
--- a/arch/alpha/include/uapi/asm/socket.h
+++ b/arch/alpha/include/uapi/asm/socket.h
@@ -137,6 +137,9 @@
#define SO_RCVMARK 75
+#define SO_PASSPIDFD 76
+#define SO_PEERPIDFD 77
+
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64
diff --git a/arch/mips/include/uapi/asm/socket.h b/arch/mips/include/uapi/asm/socket.h
index 18f3d95ecfec..60ebaed28a4c 100644
--- a/arch/mips/include/uapi/asm/socket.h
+++ b/arch/mips/include/uapi/asm/socket.h
@@ -148,6 +148,9 @@
#define SO_RCVMARK 75
+#define SO_PASSPIDFD 76
+#define SO_PEERPIDFD 77
+
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64
diff --git a/arch/parisc/include/uapi/asm/socket.h b/arch/parisc/include/uapi/asm/socket.h
index f486d3dfb6bb..be264c2b1a11 100644
--- a/arch/parisc/include/uapi/asm/socket.h
+++ b/arch/parisc/include/uapi/asm/socket.h
@@ -129,6 +129,9 @@
#define SO_RCVMARK 0x4049
+#define SO_PASSPIDFD 0x404A
+#define SO_PEERPIDFD 0x404B
+
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64
diff --git a/arch/sparc/include/uapi/asm/socket.h b/arch/sparc/include/uapi/asm/socket.h
index 2fda57a3ea86..682da3714686 100644
--- a/arch/sparc/include/uapi/asm/socket.h
+++ b/arch/sparc/include/uapi/asm/socket.h
@@ -130,6 +130,9 @@
#define SO_RCVMARK 0x0054
+#define SO_PASSPIDFD 0x0055
+#define SO_PEERPIDFD 0x0056
+
#if !defined(__KERNEL__)
diff --git a/include/linux/net.h b/include/linux/net.h
index 8defc8f1d82e..23324e9a2b3d 100644
--- a/include/linux/net.h
+++ b/include/linux/net.h
@@ -43,6 +43,7 @@ struct net;
#define SOCK_PASSSEC 4
#define SOCK_SUPPORT_ZC 5
#define SOCK_CUSTOM_SOCKOPT 6
+#define SOCK_PASSPIDFD 7
#ifndef ARCH_HAS_SOCKET_TYPES
/**
diff --git a/include/linux/socket.h b/include/linux/socket.h
index 3fd3436bc09f..58204700018a 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -177,6 +177,7 @@ static inline size_t msg_data_left(struct msghdr *msg)
#define SCM_RIGHTS 0x01 /* rw: access rights (array of int) */
#define SCM_CREDENTIALS 0x02 /* rw: struct ucred */
#define SCM_SECURITY 0x03 /* rw: security label */
+#define SCM_PIDFD 0x04 /* ro: pidfd (int) */
struct ucred {
__u32 pid;
diff --git a/include/net/scm.h b/include/net/scm.h
index 585adc1346bd..c67f765a165b 100644
--- a/include/net/scm.h
+++ b/include/net/scm.h
@@ -120,12 +120,44 @@ static inline bool scm_has_secdata(struct socket *sock)
}
#endif /* CONFIG_SECURITY_NETWORK */
+static __inline__ void scm_pidfd_recv(struct msghdr *msg, struct scm_cookie *scm)
+{
+ struct file *pidfd_file = NULL;
+ int pidfd;
+
+ /*
+ * put_cmsg() doesn't return an error if CMSG is truncated,
+ * that's why we need to opencode these checks here.
+ */
+ if ((msg->msg_controllen <= sizeof(struct cmsghdr)) ||
+ (msg->msg_controllen - sizeof(struct cmsghdr)) < sizeof(int)) {
+ msg->msg_flags |= MSG_CTRUNC;
+ return;
+ }
+
+ WARN_ON_ONCE(!scm->pid);
+ pidfd = pidfd_prepare(scm->pid, 0, &pidfd_file);
+
+ if (put_cmsg(msg, SOL_SOCKET, SCM_PIDFD, sizeof(int), &pidfd)) {
+ if (pidfd_file) {
+ put_unused_fd(pidfd);
+ fput(pidfd_file);
+ }
+
+ return;
+ }
+
+ if (pidfd_file)
+ fd_install(pidfd, pidfd_file);
+}
+
static __inline__ void scm_recv(struct socket *sock, struct msghdr *msg,
struct scm_cookie *scm, int flags)
{
if (!msg->msg_control) {
- if (test_bit(SOCK_PASSCRED, &sock->flags) || scm->fp ||
- scm_has_secdata(sock))
+ if (test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags) ||
+ scm->fp || scm_has_secdata(sock))
msg->msg_flags |= MSG_CTRUNC;
scm_destroy(scm);
return;
@@ -141,6 +173,9 @@ static __inline__ void scm_recv(struct socket *sock, struct msghdr *msg,
put_cmsg(msg, SOL_SOCKET, SCM_CREDENTIALS, sizeof(ucreds), &ucreds);
}
+ if (test_bit(SOCK_PASSPIDFD, &sock->flags))
+ scm_pidfd_recv(msg, scm);
+
scm_destroy_cred(scm);
scm_passec(sock, msg, scm);
diff --git a/include/uapi/asm-generic/socket.h b/include/uapi/asm-generic/socket.h
index 638230899e98..8ce8a39a1e5f 100644
--- a/include/uapi/asm-generic/socket.h
+++ b/include/uapi/asm-generic/socket.h
@@ -132,6 +132,9 @@
#define SO_RCVMARK 75
+#define SO_PASSPIDFD 76
+#define SO_PEERPIDFD 77
+
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
diff --git a/net/core/sock.c b/net/core/sock.c
index 24f2761bdb1d..ea66b1afadd0 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1246,6 +1246,13 @@ set_sndbuf:
clear_bit(SOCK_PASSCRED, &sock->flags);
break;
+ case SO_PASSPIDFD:
+ if (valbool)
+ set_bit(SOCK_PASSPIDFD, &sock->flags);
+ else
+ clear_bit(SOCK_PASSPIDFD, &sock->flags);
+ break;
+
case SO_TIMESTAMP_OLD:
case SO_TIMESTAMP_NEW:
case SO_TIMESTAMPNS_OLD:
@@ -1732,6 +1739,10 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
v.val = !!test_bit(SOCK_PASSCRED, &sock->flags);
break;
+ case SO_PASSPIDFD:
+ v.val = !!test_bit(SOCK_PASSPIDFD, &sock->flags);
+ break;
+
case SO_PEERCRED:
{
struct ucred peercred;
@@ -1747,6 +1758,39 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
goto lenout;
}
+ case SO_PEERPIDFD:
+ {
+ struct pid *peer_pid;
+ struct file *pidfd_file = NULL;
+ int pidfd;
+
+ if (len > sizeof(pidfd))
+ len = sizeof(pidfd);
+
+ spin_lock(&sk->sk_peer_lock);
+ peer_pid = get_pid(sk->sk_peer_pid);
+ spin_unlock(&sk->sk_peer_lock);
+
+ if (!peer_pid)
+ return -ESRCH;
+
+ pidfd = pidfd_prepare(peer_pid, 0, &pidfd_file);
+ put_pid(peer_pid);
+ if (pidfd < 0)
+ return pidfd;
+
+ if (copy_to_sockptr(optval, &pidfd, len) ||
+ copy_to_sockptr(optlen, &len, sizeof(int))) {
+ put_unused_fd(pidfd);
+ fput(pidfd_file);
+
+ return -EFAULT;
+ }
+
+ fd_install(pidfd, pidfd_file);
+ return 0;
+ }
+
case SO_PEERGROUPS:
{
const struct cred *cred;
diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c
index d4258869ac48..e172a5848b0d 100644
--- a/net/mptcp/sockopt.c
+++ b/net/mptcp/sockopt.c
@@ -355,6 +355,7 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
case SO_BROADCAST:
case SO_BSDCOMPAT:
case SO_PASSCRED:
+ case SO_PASSPIDFD:
case SO_PASSSEC:
case SO_RXQ_OVFL:
case SO_WIFI_STATUS:
diff --git a/net/unix/Kconfig b/net/unix/Kconfig
index b7f811216820..28b232f281ab 100644
--- a/net/unix/Kconfig
+++ b/net/unix/Kconfig
@@ -4,7 +4,7 @@
#
config UNIX
- tristate "Unix domain sockets"
+ bool "Unix domain sockets"
help
If you say Y here, you will include support for Unix domain sockets;
sockets are the standard Unix mechanism for establishing and
@@ -14,10 +14,6 @@ config UNIX
an embedded system or something similar, you therefore definitely
want to say Y here.
- To compile this driver as a module, choose M here: the module will be
- called unix. Note that several important services won't work
- correctly if you say M here and then neglect to load the module.
-
Say Y unless you know what you are doing.
config UNIX_SCM
diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c
index 653136d68b32..73c61a010b01 100644
--- a/net/unix/af_unix.c
+++ b/net/unix/af_unix.c
@@ -921,11 +921,26 @@ static void unix_unhash(struct sock *sk)
*/
}
+static bool unix_bpf_bypass_getsockopt(int level, int optname)
+{
+ if (level == SOL_SOCKET) {
+ switch (optname) {
+ case SO_PEERPIDFD:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ return false;
+}
+
struct proto unix_dgram_proto = {
.name = "UNIX",
.owner = THIS_MODULE,
.obj_size = sizeof(struct unix_sock),
.close = unix_close,
+ .bpf_bypass_getsockopt = unix_bpf_bypass_getsockopt,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = unix_dgram_bpf_update_proto,
#endif
@@ -937,6 +952,7 @@ struct proto unix_stream_proto = {
.obj_size = sizeof(struct unix_sock),
.close = unix_close,
.unhash = unix_unhash,
+ .bpf_bypass_getsockopt = unix_bpf_bypass_getsockopt,
#ifdef CONFIG_BPF_SYSCALL
.psock_update_sk_prot = unix_stream_bpf_update_proto,
#endif
@@ -1361,7 +1377,8 @@ static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr,
if (err)
goto out;
- if (test_bit(SOCK_PASSCRED, &sock->flags) &&
+ if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags)) &&
!unix_sk(sk)->addr) {
err = unix_autobind(sk);
if (err)
@@ -1469,7 +1486,8 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
if (err)
goto out;
- if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
+ if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags)) && !u->addr) {
err = unix_autobind(sk);
if (err)
goto out;
@@ -1670,6 +1688,8 @@ static void unix_sock_inherit_flags(const struct socket *old,
{
if (test_bit(SOCK_PASSCRED, &old->flags))
set_bit(SOCK_PASSCRED, &new->flags);
+ if (test_bit(SOCK_PASSPIDFD, &old->flags))
+ set_bit(SOCK_PASSPIDFD, &new->flags);
if (test_bit(SOCK_PASSSEC, &old->flags))
set_bit(SOCK_PASSSEC, &new->flags);
}
@@ -1819,8 +1839,10 @@ static bool unix_passcred_enabled(const struct socket *sock,
const struct sock *other)
{
return test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags) ||
!other->sk_socket ||
- test_bit(SOCK_PASSCRED, &other->sk_socket->flags);
+ test_bit(SOCK_PASSCRED, &other->sk_socket->flags) ||
+ test_bit(SOCK_PASSPIDFD, &other->sk_socket->flags);
}
/*
@@ -1904,7 +1926,8 @@ static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) {
+ if ((test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags)) && !u->addr) {
err = unix_autobind(sk);
if (err)
goto out;
@@ -2718,7 +2741,8 @@ unlock:
/* Never glue messages from different writers */
if (!unix_skb_scm_eq(skb, &scm))
break;
- } else if (test_bit(SOCK_PASSCRED, &sock->flags)) {
+ } else if (test_bit(SOCK_PASSCRED, &sock->flags) ||
+ test_bit(SOCK_PASSPIDFD, &sock->flags)) {
/* Copy credentials */
scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid);
unix_set_secdata(&scm, skb);
diff --git a/tools/include/uapi/asm-generic/socket.h b/tools/include/uapi/asm-generic/socket.h
index 8756df13be50..54d9c8bf7c55 100644
--- a/tools/include/uapi/asm-generic/socket.h
+++ b/tools/include/uapi/asm-generic/socket.h
@@ -121,6 +121,9 @@
#define SO_RCVMARK 75
+#define SO_PASSPIDFD 76
+#define SO_PEERPIDFD 77
+
#if !defined(__KERNEL__)
#if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
diff --git a/tools/testing/selftests/net/.gitignore b/tools/testing/selftests/net/.gitignore
index f27a7338b60e..501854a89cc0 100644
--- a/tools/testing/selftests/net/.gitignore
+++ b/tools/testing/selftests/net/.gitignore
@@ -29,6 +29,7 @@ reuseport_bpf_numa
reuseport_dualstack
rxtimestamp
sctp_hello
+scm_pidfd
sk_bind_sendto_listen
sk_connect_zero_addr
socket
diff --git a/tools/testing/selftests/net/af_unix/Makefile b/tools/testing/selftests/net/af_unix/Makefile
index 1e4b397cece6..221c387a7d7f 100644
--- a/tools/testing/selftests/net/af_unix/Makefile
+++ b/tools/testing/selftests/net/af_unix/Makefile
@@ -1,3 +1,4 @@
-TEST_GEN_PROGS := diag_uid test_unix_oob unix_connect
+CFLAGS += $(KHDR_INCLUDES)
+TEST_GEN_PROGS := diag_uid test_unix_oob unix_connect scm_pidfd
include ../../lib.mk
diff --git a/tools/testing/selftests/net/af_unix/scm_pidfd.c b/tools/testing/selftests/net/af_unix/scm_pidfd.c
new file mode 100644
index 000000000000..a86222143d79
--- /dev/null
+++ b/tools/testing/selftests/net/af_unix/scm_pidfd.c
@@ -0,0 +1,430 @@
+// SPDX-License-Identifier: GPL-2.0 OR MIT
+#define _GNU_SOURCE
+#include <error.h>
+#include <limits.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <linux/socket.h>
+#include <unistd.h>
+#include <string.h>
+#include <errno.h>
+#include <sys/un.h>
+#include <sys/signal.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+
+#include "../../kselftest_harness.h"
+
+#define clean_errno() (errno == 0 ? "None" : strerror(errno))
+#define log_err(MSG, ...) \
+ fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", __FILE__, __LINE__, \
+ clean_errno(), ##__VA_ARGS__)
+
+#ifndef SCM_PIDFD
+#define SCM_PIDFD 0x04
+#endif
+
+static void child_die()
+{
+ exit(1);
+}
+
+static int safe_int(const char *numstr, int *converted)
+{
+ char *err = NULL;
+ long sli;
+
+ errno = 0;
+ sli = strtol(numstr, &err, 0);
+ if (errno == ERANGE && (sli == LONG_MAX || sli == LONG_MIN))
+ return -ERANGE;
+
+ if (errno != 0 && sli == 0)
+ return -EINVAL;
+
+ if (err == numstr || *err != '\0')
+ return -EINVAL;
+
+ if (sli > INT_MAX || sli < INT_MIN)
+ return -ERANGE;
+
+ *converted = (int)sli;
+ return 0;
+}
+
+static int char_left_gc(const char *buffer, size_t len)
+{
+ size_t i;
+
+ for (i = 0; i < len; i++) {
+ if (buffer[i] == ' ' || buffer[i] == '\t')
+ continue;
+
+ return i;
+ }
+
+ return 0;
+}
+
+static int char_right_gc(const char *buffer, size_t len)
+{
+ int i;
+
+ for (i = len - 1; i >= 0; i--) {
+ if (buffer[i] == ' ' || buffer[i] == '\t' ||
+ buffer[i] == '\n' || buffer[i] == '\0')
+ continue;
+
+ return i + 1;
+ }
+
+ return 0;
+}
+
+static char *trim_whitespace_in_place(char *buffer)
+{
+ buffer += char_left_gc(buffer, strlen(buffer));
+ buffer[char_right_gc(buffer, strlen(buffer))] = '\0';
+ return buffer;
+}
+
+/* borrowed (with all helpers) from pidfd/pidfd_open_test.c */
+static pid_t get_pid_from_fdinfo_file(int pidfd, const char *key, size_t keylen)
+{
+ int ret;
+ char path[512];
+ FILE *f;
+ size_t n = 0;
+ pid_t result = -1;
+ char *line = NULL;
+
+ snprintf(path, sizeof(path), "/proc/self/fdinfo/%d", pidfd);
+
+ f = fopen(path, "re");
+ if (!f)
+ return -1;
+
+ while (getline(&line, &n, f) != -1) {
+ char *numstr;
+
+ if (strncmp(line, key, keylen))
+ continue;
+
+ numstr = trim_whitespace_in_place(line + 4);
+ ret = safe_int(numstr, &result);
+ if (ret < 0)
+ goto out;
+
+ break;
+ }
+
+out:
+ free(line);
+ fclose(f);
+ return result;
+}
+
+static int cmsg_check(int fd)
+{
+ struct msghdr msg = { 0 };
+ struct cmsghdr *cmsg;
+ struct iovec iov;
+ struct ucred *ucred = NULL;
+ int data = 0;
+ char control[CMSG_SPACE(sizeof(struct ucred)) +
+ CMSG_SPACE(sizeof(int))] = { 0 };
+ int *pidfd = NULL;
+ pid_t parent_pid;
+ int err;
+
+ iov.iov_base = &data;
+ iov.iov_len = sizeof(data);
+
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ err = recvmsg(fd, &msg, 0);
+ if (err < 0) {
+ log_err("recvmsg");
+ return 1;
+ }
+
+ if (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
+ log_err("recvmsg: truncated");
+ return 1;
+ }
+
+ for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
+ cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_PIDFD) {
+ if (cmsg->cmsg_len < sizeof(*pidfd)) {
+ log_err("CMSG parse: SCM_PIDFD wrong len");
+ return 1;
+ }
+
+ pidfd = (void *)CMSG_DATA(cmsg);
+ }
+
+ if (cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_CREDENTIALS) {
+ if (cmsg->cmsg_len < sizeof(*ucred)) {
+ log_err("CMSG parse: SCM_CREDENTIALS wrong len");
+ return 1;
+ }
+
+ ucred = (void *)CMSG_DATA(cmsg);
+ }
+ }
+
+ /* send(pfd, "x", sizeof(char), 0) */
+ if (data != 'x') {
+ log_err("recvmsg: data corruption");
+ return 1;
+ }
+
+ if (!pidfd) {
+ log_err("CMSG parse: SCM_PIDFD not found");
+ return 1;
+ }
+
+ if (!ucred) {
+ log_err("CMSG parse: SCM_CREDENTIALS not found");
+ return 1;
+ }
+
+ /* pidfd from SCM_PIDFD should point to the parent process PID */
+ parent_pid =
+ get_pid_from_fdinfo_file(*pidfd, "Pid:", sizeof("Pid:") - 1);
+ if (parent_pid != getppid()) {
+ log_err("wrong SCM_PIDFD %d != %d", parent_pid, getppid());
+ return 1;
+ }
+
+ return 0;
+}
+
+struct sock_addr {
+ char sock_name[32];
+ struct sockaddr_un listen_addr;
+ socklen_t addrlen;
+};
+
+FIXTURE(scm_pidfd)
+{
+ int server;
+ pid_t client_pid;
+ int startup_pipe[2];
+ struct sock_addr server_addr;
+ struct sock_addr *client_addr;
+};
+
+FIXTURE_VARIANT(scm_pidfd)
+{
+ int type;
+ bool abstract;
+};
+
+FIXTURE_VARIANT_ADD(scm_pidfd, stream_pathname)
+{
+ .type = SOCK_STREAM,
+ .abstract = 0,
+};
+
+FIXTURE_VARIANT_ADD(scm_pidfd, stream_abstract)
+{
+ .type = SOCK_STREAM,
+ .abstract = 1,
+};
+
+FIXTURE_VARIANT_ADD(scm_pidfd, dgram_pathname)
+{
+ .type = SOCK_DGRAM,
+ .abstract = 0,
+};
+
+FIXTURE_VARIANT_ADD(scm_pidfd, dgram_abstract)
+{
+ .type = SOCK_DGRAM,
+ .abstract = 1,
+};
+
+FIXTURE_SETUP(scm_pidfd)
+{
+ self->client_addr = mmap(NULL, sizeof(*self->client_addr), PROT_READ | PROT_WRITE,
+ MAP_SHARED | MAP_ANONYMOUS, -1, 0);
+ ASSERT_NE(MAP_FAILED, self->client_addr);
+}
+
+FIXTURE_TEARDOWN(scm_pidfd)
+{
+ close(self->server);
+
+ kill(self->client_pid, SIGKILL);
+ waitpid(self->client_pid, NULL, 0);
+
+ if (!variant->abstract) {
+ unlink(self->server_addr.sock_name);
+ unlink(self->client_addr->sock_name);
+ }
+}
+
+static void fill_sockaddr(struct sock_addr *addr, bool abstract)
+{
+ char *sun_path_buf = (char *)&addr->listen_addr.sun_path;
+
+ addr->listen_addr.sun_family = AF_UNIX;
+ addr->addrlen = offsetof(struct sockaddr_un, sun_path);
+ snprintf(addr->sock_name, sizeof(addr->sock_name), "scm_pidfd_%d", getpid());
+ addr->addrlen += strlen(addr->sock_name);
+ if (abstract) {
+ *sun_path_buf = '\0';
+ addr->addrlen++;
+ sun_path_buf++;
+ } else {
+ unlink(addr->sock_name);
+ }
+ memcpy(sun_path_buf, addr->sock_name, strlen(addr->sock_name));
+}
+
+static void client(FIXTURE_DATA(scm_pidfd) *self,
+ const FIXTURE_VARIANT(scm_pidfd) *variant)
+{
+ int err;
+ int cfd;
+ socklen_t len;
+ struct ucred peer_cred;
+ int peer_pidfd;
+ pid_t peer_pid;
+ int on = 0;
+
+ cfd = socket(AF_UNIX, variant->type, 0);
+ if (cfd < 0) {
+ log_err("socket");
+ child_die();
+ }
+
+ if (variant->type == SOCK_DGRAM) {
+ fill_sockaddr(self->client_addr, variant->abstract);
+
+ if (bind(cfd, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen)) {
+ log_err("bind");
+ child_die();
+ }
+ }
+
+ if (connect(cfd, (struct sockaddr *)&self->server_addr.listen_addr,
+ self->server_addr.addrlen) != 0) {
+ log_err("connect");
+ child_die();
+ }
+
+ on = 1;
+ if (setsockopt(cfd, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
+ log_err("Failed to set SO_PASSCRED");
+ child_die();
+ }
+
+ if (setsockopt(cfd, SOL_SOCKET, SO_PASSPIDFD, &on, sizeof(on))) {
+ log_err("Failed to set SO_PASSPIDFD");
+ child_die();
+ }
+
+ close(self->startup_pipe[1]);
+
+ if (cmsg_check(cfd)) {
+ log_err("cmsg_check failed");
+ child_die();
+ }
+
+ /* skip further for SOCK_DGRAM as it's not applicable */
+ if (variant->type == SOCK_DGRAM)
+ return;
+
+ len = sizeof(peer_cred);
+ if (getsockopt(cfd, SOL_SOCKET, SO_PEERCRED, &peer_cred, &len)) {
+ log_err("Failed to get SO_PEERCRED");
+ child_die();
+ }
+
+ len = sizeof(peer_pidfd);
+ if (getsockopt(cfd, SOL_SOCKET, SO_PEERPIDFD, &peer_pidfd, &len)) {
+ log_err("Failed to get SO_PEERPIDFD");
+ child_die();
+ }
+
+ /* pid from SO_PEERCRED should point to the parent process PID */
+ if (peer_cred.pid != getppid()) {
+ log_err("peer_cred.pid != getppid(): %d != %d", peer_cred.pid, getppid());
+ child_die();
+ }
+
+ peer_pid = get_pid_from_fdinfo_file(peer_pidfd,
+ "Pid:", sizeof("Pid:") - 1);
+ if (peer_pid != peer_cred.pid) {
+ log_err("peer_pid != peer_cred.pid: %d != %d", peer_pid, peer_cred.pid);
+ child_die();
+ }
+}
+
+TEST_F(scm_pidfd, test)
+{
+ int err;
+ int pfd;
+ int child_status = 0;
+
+ self->server = socket(AF_UNIX, variant->type, 0);
+ ASSERT_NE(-1, self->server);
+
+ fill_sockaddr(&self->server_addr, variant->abstract);
+
+ err = bind(self->server, (struct sockaddr *)&self->server_addr.listen_addr, self->server_addr.addrlen);
+ ASSERT_EQ(0, err);
+
+ if (variant->type == SOCK_STREAM) {
+ err = listen(self->server, 1);
+ ASSERT_EQ(0, err);
+ }
+
+ err = pipe(self->startup_pipe);
+ ASSERT_NE(-1, err);
+
+ self->client_pid = fork();
+ ASSERT_NE(-1, self->client_pid);
+ if (self->client_pid == 0) {
+ close(self->server);
+ close(self->startup_pipe[0]);
+ client(self, variant);
+ exit(0);
+ }
+ close(self->startup_pipe[1]);
+
+ if (variant->type == SOCK_STREAM) {
+ pfd = accept(self->server, NULL, NULL);
+ ASSERT_NE(-1, pfd);
+ } else {
+ pfd = self->server;
+ }
+
+ /* wait until the child arrives at checkpoint */
+ read(self->startup_pipe[0], &err, sizeof(int));
+ close(self->startup_pipe[0]);
+
+ if (variant->type == SOCK_DGRAM) {
+ err = sendto(pfd, "x", sizeof(char), 0, (struct sockaddr *)&self->client_addr->listen_addr, self->client_addr->addrlen);
+ ASSERT_NE(-1, err);
+ } else {
+ err = send(pfd, "x", sizeof(char), 0);
+ ASSERT_NE(-1, err);
+ }
+
+ close(pfd);
+ waitpid(self->client_pid, &child_status, 0);
+ ASSERT_EQ(0, WIFEXITED(child_status) ? WEXITSTATUS(child_status) : 1);
+}
+
+TEST_HARNESS_MAIN