diff options
Diffstat (limited to 'net/core/sock.c')
| -rw-r--r-- | net/core/sock.c | 134 | 
1 files changed, 94 insertions, 40 deletions
diff --git a/net/core/sock.c b/net/core/sock.c index 788c1372663c..eeb6cbac6f49 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -703,15 +703,17 @@ static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen)  			goto out;  	} -	return sock_bindtoindex(sk, index, true); +	sockopt_lock_sock(sk); +	ret = sock_bindtoindex_locked(sk, index); +	sockopt_release_sock(sk);  out:  #endif  	return ret;  } -static int sock_getbindtodevice(struct sock *sk, char __user *optval, -				int __user *optlen, int len) +static int sock_getbindtodevice(struct sock *sk, sockptr_t optval, +				sockptr_t optlen, int len)  {  	int ret = -ENOPROTOOPT;  #ifdef CONFIG_NETDEVICES @@ -735,12 +737,12 @@ static int sock_getbindtodevice(struct sock *sk, char __user *optval,  	len = strlen(devname) + 1;  	ret = -EFAULT; -	if (copy_to_user(optval, devname, len)) +	if (copy_to_sockptr(optval, devname, len))  		goto out;  zero:  	ret = -EFAULT; -	if (put_user(len, optlen)) +	if (copy_to_sockptr(optlen, &len, sizeof(int)))  		goto out;  	ret = 0; @@ -1036,17 +1038,51 @@ static int sock_reserve_memory(struct sock *sk, int bytes)  	return 0;  } +void sockopt_lock_sock(struct sock *sk) +{ +	/* When current->bpf_ctx is set, the setsockopt is called from +	 * a bpf prog.  bpf has ensured the sk lock has been +	 * acquired before calling setsockopt(). +	 */ +	if (has_current_bpf_ctx()) +		return; + +	lock_sock(sk); +} +EXPORT_SYMBOL(sockopt_lock_sock); + +void sockopt_release_sock(struct sock *sk) +{ +	if (has_current_bpf_ctx()) +		return; + +	release_sock(sk); +} +EXPORT_SYMBOL(sockopt_release_sock); + +bool sockopt_ns_capable(struct user_namespace *ns, int cap) +{ +	return has_current_bpf_ctx() || ns_capable(ns, cap); +} +EXPORT_SYMBOL(sockopt_ns_capable); + +bool sockopt_capable(int cap) +{ +	return has_current_bpf_ctx() || capable(cap); +} +EXPORT_SYMBOL(sockopt_capable); +  /*   *	This is meant for all protocols to use and covers goings on   *	at the socket level. Everything here is generic.   */ -int sock_setsockopt(struct socket *sock, int level, int optname, -		    sockptr_t optval, unsigned int optlen) +int sk_setsockopt(struct sock *sk, int level, int optname, +		  sockptr_t optval, unsigned int optlen)  {  	struct so_timestamping timestamping; +	struct socket *sock = sk->sk_socket;  	struct sock_txtime sk_txtime; -	struct sock *sk = sock->sk;  	int val;  	int valbool;  	struct linger ling; @@ -1067,11 +1103,11 @@ int sock_setsockopt(struct socket *sock, int level, int optname,  	valbool = val ? 1 : 0; -	lock_sock(sk); +	sockopt_lock_sock(sk);  	switch (optname) {  	case SO_DEBUG: -		if (val && !capable(CAP_NET_ADMIN)) +		if (val && !sockopt_capable(CAP_NET_ADMIN))  			ret = -EACCES;  		else  			sock_valbool_flag(sk, SOCK_DBG, valbool); @@ -1115,7 +1151,7 @@ set_sndbuf:  		break;  	case SO_SNDBUFFORCE: -		if (!capable(CAP_NET_ADMIN)) { +		if (!sockopt_capable(CAP_NET_ADMIN)) {  			ret = -EPERM;  			break;  		} @@ -1137,7 +1173,7 @@ set_sndbuf:  		break;  	case SO_RCVBUFFORCE: -		if (!capable(CAP_NET_ADMIN)) { +		if (!sockopt_capable(CAP_NET_ADMIN)) {  			ret = -EPERM;  			break;  		} @@ -1164,8 +1200,8 @@ set_sndbuf:  	case SO_PRIORITY:  		if ((val >= 0 && val <= 6) || -		    ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) || -		    ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) +		    sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) || +		    sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN))  			sk->sk_priority = val;  		else  			ret = -EPERM; @@ -1228,7 +1264,7 @@ set_sndbuf:  	case SO_RCVLOWAT:  		if (val < 0)  			val = INT_MAX; -		if (sock->ops->set_rcvlowat) +		if (sock && sock->ops->set_rcvlowat)  			ret = sock->ops->set_rcvlowat(sk, val);  		else  			WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); @@ -1310,8 +1346,8 @@ set_sndbuf:  			clear_bit(SOCK_PASSSEC, &sock->flags);  		break;  	case SO_MARK: -		if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && -		    !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { +		if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && +		    !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {  			ret = -EPERM;  			break;  		} @@ -1319,8 +1355,8 @@ set_sndbuf:  		__sock_set_mark(sk, val);  		break;  	case SO_RCVMARK: -		if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && -		    !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { +		if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && +		    !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {  			ret = -EPERM;  			break;  		} @@ -1354,7 +1390,7 @@ set_sndbuf:  #ifdef CONFIG_NET_RX_BUSY_POLL  	case SO_BUSY_POLL:  		/* allow unprivileged users to decrease the value */ -		if ((val > sk->sk_ll_usec) && !capable(CAP_NET_ADMIN)) +		if ((val > sk->sk_ll_usec) && !sockopt_capable(CAP_NET_ADMIN))  			ret = -EPERM;  		else {  			if (val < 0) @@ -1364,13 +1400,13 @@ set_sndbuf:  		}  		break;  	case SO_PREFER_BUSY_POLL: -		if (valbool && !capable(CAP_NET_ADMIN)) +		if (valbool && !sockopt_capable(CAP_NET_ADMIN))  			ret = -EPERM;  		else  			WRITE_ONCE(sk->sk_prefer_busy_poll, valbool);  		break;  	case SO_BUSY_POLL_BUDGET: -		if (val > READ_ONCE(sk->sk_busy_poll_budget) && !capable(CAP_NET_ADMIN)) { +		if (val > READ_ONCE(sk->sk_busy_poll_budget) && !sockopt_capable(CAP_NET_ADMIN)) {  			ret = -EPERM;  		} else {  			if (val < 0 || val > U16_MAX) @@ -1441,7 +1477,7 @@ set_sndbuf:  		 * scheduler has enough safe guards.  		 */  		if (sk_txtime.clockid != CLOCK_MONOTONIC && -		    !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { +		    !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) {  			ret = -EPERM;  			break;  		} @@ -1496,9 +1532,16 @@ set_sndbuf:  		ret = -ENOPROTOOPT;  		break;  	} -	release_sock(sk); +	sockopt_release_sock(sk);  	return ret;  } + +int sock_setsockopt(struct socket *sock, int level, int optname, +		    sockptr_t optval, unsigned int optlen) +{ +	return sk_setsockopt(sock->sk, level, optname, +			     optval, optlen); +}  EXPORT_SYMBOL(sock_setsockopt);  static const struct cred *sk_get_peer_cred(struct sock *sk) @@ -1525,22 +1568,25 @@ static void cred_to_ucred(struct pid *pid, const struct cred *cred,  	}  } -static int groups_to_user(gid_t __user *dst, const struct group_info *src) +static int groups_to_user(sockptr_t dst, const struct group_info *src)  {  	struct user_namespace *user_ns = current_user_ns();  	int i; -	for (i = 0; i < src->ngroups; i++) -		if (put_user(from_kgid_munged(user_ns, src->gid[i]), dst + i)) +	for (i = 0; i < src->ngroups; i++) { +		gid_t gid = from_kgid_munged(user_ns, src->gid[i]); + +		if (copy_to_sockptr_offset(dst, i * sizeof(gid), &gid, sizeof(gid)))  			return -EFAULT; +	}  	return 0;  } -int sock_getsockopt(struct socket *sock, int level, int optname, -		    char __user *optval, int __user *optlen) +int sk_getsockopt(struct sock *sk, int level, int optname, +		  sockptr_t optval, sockptr_t optlen)  { -	struct sock *sk = sock->sk; +	struct socket *sock = sk->sk_socket;  	union {  		int val; @@ -1557,7 +1603,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  	int lv = sizeof(int);  	int len; -	if (get_user(len, optlen)) +	if (copy_from_sockptr(&len, optlen, sizeof(int)))  		return -EFAULT;  	if (len < 0)  		return -EINVAL; @@ -1692,7 +1738,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  		cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred);  		spin_unlock(&sk->sk_peer_lock); -		if (copy_to_user(optval, &peercred, len)) +		if (copy_to_sockptr(optval, &peercred, len))  			return -EFAULT;  		goto lenout;  	} @@ -1710,11 +1756,11 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  		if (len < n * sizeof(gid_t)) {  			len = n * sizeof(gid_t);  			put_cred(cred); -			return put_user(len, optlen) ? -EFAULT : -ERANGE; +			return copy_to_sockptr(optlen, &len, sizeof(int)) ? -EFAULT : -ERANGE;  		}  		len = n * sizeof(gid_t); -		ret = groups_to_user((gid_t __user *)optval, cred->group_info); +		ret = groups_to_user(optval, cred->group_info);  		put_cred(cred);  		if (ret)  			return ret; @@ -1730,7 +1776,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  			return -ENOTCONN;  		if (lv < len)  			return -EINVAL; -		if (copy_to_user(optval, address, len)) +		if (copy_to_sockptr(optval, address, len))  			return -EFAULT;  		goto lenout;  	} @@ -1747,7 +1793,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  		break;  	case SO_PEERSEC: -		return security_socket_getpeersec_stream(sock, optval, optlen, len); +		return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len);  	case SO_MARK:  		v.val = sk->sk_mark; @@ -1779,7 +1825,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  		return sock_getbindtodevice(sk, optval, optlen, len);  	case SO_GET_FILTER: -		len = sk_get_filter(sk, (struct sock_filter __user *)optval, len); +		len = sk_get_filter(sk, optval, len);  		if (len < 0)  			return len; @@ -1827,7 +1873,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  		sk_get_meminfo(sk, meminfo);  		len = min_t(unsigned int, len, sizeof(meminfo)); -		if (copy_to_user(optval, &meminfo, len)) +		if (copy_to_sockptr(optval, &meminfo, len))  			return -EFAULT;  		goto lenout; @@ -1896,14 +1942,22 @@ int sock_getsockopt(struct socket *sock, int level, int optname,  	if (len > lv)  		len = lv; -	if (copy_to_user(optval, &v, len)) +	if (copy_to_sockptr(optval, &v, len))  		return -EFAULT;  lenout: -	if (put_user(len, optlen)) +	if (copy_to_sockptr(optlen, &len, sizeof(int)))  		return -EFAULT;  	return 0;  } +int sock_getsockopt(struct socket *sock, int level, int optname, +		    char __user *optval, int __user *optlen) +{ +	return sk_getsockopt(sock->sk, level, optname, +			     USER_SOCKPTR(optval), +			     USER_SOCKPTR(optlen)); +} +  /*   * Initialize an sk_lock.   *  |