diff options
Diffstat (limited to 'net/vmw_vsock/virtio_transport_common.c')
| -rw-r--r-- | net/vmw_vsock/virtio_transport_common.c | 48 | 
1 files changed, 31 insertions, 17 deletions
| diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index a1581c77cf84..ee78b4082ef9 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -94,6 +94,11 @@ virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,  					 info->op,  					 info->flags); +	if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) { +		WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n"); +		goto out; +	} +  	return skb;  out: @@ -241,21 +246,18 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,  }  static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, -					struct sk_buff *skb) +					u32 len)  { -	if (vvs->rx_bytes + skb->len > vvs->buf_alloc) +	if (vvs->rx_bytes + len > vvs->buf_alloc)  		return false; -	vvs->rx_bytes += skb->len; +	vvs->rx_bytes += len;  	return true;  }  static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, -					struct sk_buff *skb) +					u32 len)  { -	int len; - -	len = skb_headroom(skb) - sizeof(struct virtio_vsock_hdr) - skb->len;  	vvs->rx_bytes -= len;  	vvs->fwd_cnt += len;  } @@ -366,8 +368,15 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,  	u32 free_space;  	spin_lock_bh(&vvs->rx_lock); + +	if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes, +		      "rx_queue is empty, but rx_bytes is non-zero\n")) { +		spin_unlock_bh(&vvs->rx_lock); +		return err; +	} +  	while (total < len && !skb_queue_empty(&vvs->rx_queue)) { -		skb = __skb_dequeue(&vvs->rx_queue); +		skb = skb_peek(&vvs->rx_queue);  		bytes = len - total;  		if (bytes > skb->len) @@ -388,10 +397,11 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,  		skb_pull(skb, bytes);  		if (skb->len == 0) { -			virtio_transport_dec_rx_pkt(vvs, skb); +			u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len); + +			virtio_transport_dec_rx_pkt(vvs, pkt_len); +			__skb_unlink(skb, &vvs->rx_queue);  			consume_skb(skb); -		} else { -			__skb_queue_head(&vvs->rx_queue, skb);  		}  	} @@ -437,17 +447,17 @@ static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,  	while (!msg_ready) {  		struct virtio_vsock_hdr *hdr; +		size_t pkt_len;  		skb = __skb_dequeue(&vvs->rx_queue);  		if (!skb)  			break;  		hdr = virtio_vsock_hdr(skb); +		pkt_len = (size_t)le32_to_cpu(hdr->len);  		if (dequeued_len >= 0) { -			size_t pkt_len;  			size_t bytes_to_copy; -			pkt_len = (size_t)le32_to_cpu(hdr->len);  			bytes_to_copy = min(user_buf_len, pkt_len);  			if (bytes_to_copy) { @@ -466,7 +476,6 @@ static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,  					dequeued_len = err;  				} else {  					user_buf_len -= bytes_to_copy; -					skb_pull(skb, bytes_to_copy);  				}  				spin_lock_bh(&vvs->rx_lock); @@ -484,7 +493,7 @@ static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,  				msg->msg_flags |= MSG_EOR;  		} -		virtio_transport_dec_rx_pkt(vvs, skb); +		virtio_transport_dec_rx_pkt(vvs, pkt_len);  		kfree_skb(skb);  	} @@ -1040,7 +1049,7 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,  	spin_lock_bh(&vvs->rx_lock); -	can_enqueue = virtio_transport_inc_rx_pkt(vvs, skb); +	can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);  	if (!can_enqueue) {  		free_pkt = true;  		goto out; @@ -1071,7 +1080,7 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,  			memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);  			free_pkt = true;  			last_hdr->flags |= hdr->flags; -			last_hdr->len = cpu_to_le32(last_skb->len); +			le32_add_cpu(&last_hdr->len, len);  			goto out;  		}  	} @@ -1299,6 +1308,11 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,  		goto free_pkt;  	} +	if (!skb_set_owner_sk_safe(skb, sk)) { +		WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n"); +		goto free_pkt; +	} +  	vsk = vsock_sk(sk);  	lock_sock(sk); |