diff options
Diffstat (limited to 'drivers/vhost/vsock.c')
| -rw-r--r-- | drivers/vhost/vsock.c | 227 | 
1 files changed, 99 insertions, 128 deletions
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 5703775af129..1f3b89c885cc 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -51,8 +51,7 @@ struct vhost_vsock {  	struct hlist_node hash;  	struct vhost_work send_pkt_work; -	spinlock_t send_pkt_list_lock; -	struct list_head send_pkt_list;	/* host->guest pending packets */ +	struct sk_buff_head send_pkt_queue; /* host->guest pending packets */  	atomic_t queued_replies; @@ -108,40 +107,31 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,  	vhost_disable_notify(&vsock->dev, vq);  	do { -		struct virtio_vsock_pkt *pkt; +		struct virtio_vsock_hdr *hdr; +		size_t iov_len, payload_len;  		struct iov_iter iov_iter; +		u32 flags_to_restore = 0; +		struct sk_buff *skb;  		unsigned out, in;  		size_t nbytes; -		size_t iov_len, payload_len;  		int head; -		u32 flags_to_restore = 0; -		spin_lock_bh(&vsock->send_pkt_list_lock); -		if (list_empty(&vsock->send_pkt_list)) { -			spin_unlock_bh(&vsock->send_pkt_list_lock); +		skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue); + +		if (!skb) {  			vhost_enable_notify(&vsock->dev, vq);  			break;  		} -		pkt = list_first_entry(&vsock->send_pkt_list, -				       struct virtio_vsock_pkt, list); -		list_del_init(&pkt->list); -		spin_unlock_bh(&vsock->send_pkt_list_lock); -  		head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),  					 &out, &in, NULL, NULL);  		if (head < 0) { -			spin_lock_bh(&vsock->send_pkt_list_lock); -			list_add(&pkt->list, &vsock->send_pkt_list); -			spin_unlock_bh(&vsock->send_pkt_list_lock); +			virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);  			break;  		}  		if (head == vq->num) { -			spin_lock_bh(&vsock->send_pkt_list_lock); -			list_add(&pkt->list, &vsock->send_pkt_list); -			spin_unlock_bh(&vsock->send_pkt_list_lock); - +			virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);  			/* We cannot finish yet if more buffers snuck in while  			 * re-enabling notify.  			 */ @@ -153,26 +143,27 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,  		}  		if (out) { -			virtio_transport_free_pkt(pkt); +			kfree_skb(skb);  			vq_err(vq, "Expected 0 output buffers, got %u\n", out);  			break;  		}  		iov_len = iov_length(&vq->iov[out], in); -		if (iov_len < sizeof(pkt->hdr)) { -			virtio_transport_free_pkt(pkt); +		if (iov_len < sizeof(*hdr)) { +			kfree_skb(skb);  			vq_err(vq, "Buffer len [%zu] too small\n", iov_len);  			break;  		} -		iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len); -		payload_len = pkt->len - pkt->off; +		iov_iter_init(&iov_iter, ITER_DEST, &vq->iov[out], in, iov_len); +		payload_len = skb->len; +		hdr = virtio_vsock_hdr(skb);  		/* If the packet is greater than the space available in the  		 * buffer, we split it using multiple buffers.  		 */ -		if (payload_len > iov_len - sizeof(pkt->hdr)) { -			payload_len = iov_len - sizeof(pkt->hdr); +		if (payload_len > iov_len - sizeof(*hdr)) { +			payload_len = iov_len - sizeof(*hdr);  			/* As we are copying pieces of large packet's buffer to  			 * small rx buffers, headers of packets in rx queue are @@ -185,31 +176,30 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,  			 * bits set. After initialized header will be copied to  			 * rx buffer, these required bits will be restored.  			 */ -			if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) { -				pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM); +			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { +				hdr->flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);  				flags_to_restore |= VIRTIO_VSOCK_SEQ_EOM; -				if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) { -					pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR); +				if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) { +					hdr->flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);  					flags_to_restore |= VIRTIO_VSOCK_SEQ_EOR;  				}  			}  		}  		/* Set the correct length in the header */ -		pkt->hdr.len = cpu_to_le32(payload_len); +		hdr->len = cpu_to_le32(payload_len); -		nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); -		if (nbytes != sizeof(pkt->hdr)) { -			virtio_transport_free_pkt(pkt); +		nbytes = copy_to_iter(hdr, sizeof(*hdr), &iov_iter); +		if (nbytes != sizeof(*hdr)) { +			kfree_skb(skb);  			vq_err(vq, "Faulted on copying pkt hdr\n");  			break;  		} -		nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len, -				      &iov_iter); +		nbytes = copy_to_iter(skb->data, payload_len, &iov_iter);  		if (nbytes != payload_len) { -			virtio_transport_free_pkt(pkt); +			kfree_skb(skb);  			vq_err(vq, "Faulted on copying pkt buf\n");  			break;  		} @@ -217,31 +207,28 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,  		/* Deliver to monitoring devices all packets that we  		 * will transmit.  		 */ -		virtio_transport_deliver_tap_pkt(pkt); +		virtio_transport_deliver_tap_pkt(skb); -		vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len); +		vhost_add_used(vq, head, sizeof(*hdr) + payload_len);  		added = true; -		pkt->off += payload_len; +		skb_pull(skb, payload_len);  		total_len += payload_len;  		/* If we didn't send all the payload we can requeue the packet  		 * to send it with the next available buffer.  		 */ -		if (pkt->off < pkt->len) { -			pkt->hdr.flags |= cpu_to_le32(flags_to_restore); +		if (skb->len > 0) { +			hdr->flags |= cpu_to_le32(flags_to_restore); -			/* We are queueing the same virtio_vsock_pkt to handle +			/* We are queueing the same skb to handle  			 * the remaining bytes, and we want to deliver it  			 * to monitoring devices in the next iteration.  			 */ -			pkt->tap_delivered = false; - -			spin_lock_bh(&vsock->send_pkt_list_lock); -			list_add(&pkt->list, &vsock->send_pkt_list); -			spin_unlock_bh(&vsock->send_pkt_list_lock); +			virtio_vsock_skb_clear_tap_delivered(skb); +			virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);  		} else { -			if (pkt->reply) { +			if (virtio_vsock_skb_reply(skb)) {  				int val;  				val = atomic_dec_return(&vsock->queued_replies); @@ -253,7 +240,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,  					restart_tx = true;  			} -			virtio_transport_free_pkt(pkt); +			consume_skb(skb);  		}  	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));  	if (added) @@ -278,28 +265,26 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work)  }  static int -vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) +vhost_transport_send_pkt(struct sk_buff *skb)  { +	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);  	struct vhost_vsock *vsock; -	int len = pkt->len; +	int len = skb->len;  	rcu_read_lock();  	/* Find the vhost_vsock according to guest context id  */ -	vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); +	vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid));  	if (!vsock) {  		rcu_read_unlock(); -		virtio_transport_free_pkt(pkt); +		kfree_skb(skb);  		return -ENODEV;  	} -	if (pkt->reply) +	if (virtio_vsock_skb_reply(skb))  		atomic_inc(&vsock->queued_replies); -	spin_lock_bh(&vsock->send_pkt_list_lock); -	list_add_tail(&pkt->list, &vsock->send_pkt_list); -	spin_unlock_bh(&vsock->send_pkt_list_lock); - +	virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb);  	vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);  	rcu_read_unlock(); @@ -310,10 +295,8 @@ static int  vhost_transport_cancel_pkt(struct vsock_sock *vsk)  {  	struct vhost_vsock *vsock; -	struct virtio_vsock_pkt *pkt, *n;  	int cnt = 0;  	int ret = -ENODEV; -	LIST_HEAD(freeme);  	rcu_read_lock(); @@ -322,20 +305,7 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)  	if (!vsock)  		goto out; -	spin_lock_bh(&vsock->send_pkt_list_lock); -	list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { -		if (pkt->vsk != vsk) -			continue; -		list_move(&pkt->list, &freeme); -	} -	spin_unlock_bh(&vsock->send_pkt_list_lock); - -	list_for_each_entry_safe(pkt, n, &freeme, list) { -		if (pkt->reply) -			cnt++; -		list_del(&pkt->list); -		virtio_transport_free_pkt(pkt); -	} +	cnt = virtio_transport_purge_skbs(vsk, &vsock->send_pkt_queue);  	if (cnt) {  		struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; @@ -352,12 +322,14 @@ out:  	return ret;  } -static struct virtio_vsock_pkt * -vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, +static struct sk_buff * +vhost_vsock_alloc_skb(struct vhost_virtqueue *vq,  		      unsigned int out, unsigned int in)  { -	struct virtio_vsock_pkt *pkt; +	struct virtio_vsock_hdr *hdr;  	struct iov_iter iov_iter; +	struct sk_buff *skb; +	size_t payload_len;  	size_t nbytes;  	size_t len; @@ -366,50 +338,48 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,  		return NULL;  	} -	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); -	if (!pkt) +	len = iov_length(vq->iov, out); + +	/* len contains both payload and hdr */ +	skb = virtio_vsock_alloc_skb(len, GFP_KERNEL); +	if (!skb)  		return NULL; -	len = iov_length(vq->iov, out); -	iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); +	iov_iter_init(&iov_iter, ITER_SOURCE, vq->iov, out, len); -	nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); -	if (nbytes != sizeof(pkt->hdr)) { +	hdr = virtio_vsock_hdr(skb); +	nbytes = copy_from_iter(hdr, sizeof(*hdr), &iov_iter); +	if (nbytes != sizeof(*hdr)) {  		vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", -		       sizeof(pkt->hdr), nbytes); -		kfree(pkt); +		       sizeof(*hdr), nbytes); +		kfree_skb(skb);  		return NULL;  	} -	pkt->len = le32_to_cpu(pkt->hdr.len); +	payload_len = le32_to_cpu(hdr->len);  	/* No payload */ -	if (!pkt->len) -		return pkt; +	if (!payload_len) +		return skb; -	/* The pkt is too big */ -	if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { -		kfree(pkt); +	/* The pkt is too big or the length in the header is invalid */ +	if (payload_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || +	    payload_len + sizeof(*hdr) > len) { +		kfree_skb(skb);  		return NULL;  	} -	pkt->buf = kvmalloc(pkt->len, GFP_KERNEL); -	if (!pkt->buf) { -		kfree(pkt); -		return NULL; -	} +	virtio_vsock_skb_rx_put(skb); -	pkt->buf_len = pkt->len; - -	nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); -	if (nbytes != pkt->len) { -		vq_err(vq, "Expected %u byte payload, got %zu bytes\n", -		       pkt->len, nbytes); -		virtio_transport_free_pkt(pkt); +	nbytes = copy_from_iter(skb->data, payload_len, &iov_iter); +	if (nbytes != payload_len) { +		vq_err(vq, "Expected %zu byte payload, got %zu bytes\n", +		       payload_len, nbytes); +		kfree_skb(skb);  		return NULL;  	} -	return pkt; +	return skb;  }  /* Is there space left for replies to rx packets? */ @@ -496,9 +466,9 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)  						  poll.work);  	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,  						 dev); -	struct virtio_vsock_pkt *pkt;  	int head, pkts = 0, total_len = 0;  	unsigned int out, in; +	struct sk_buff *skb;  	bool added = false;  	mutex_lock(&vq->mutex); @@ -511,6 +481,8 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)  	vhost_disable_notify(&vsock->dev, vq);  	do { +		struct virtio_vsock_hdr *hdr; +  		if (!vhost_vsock_more_replies(vsock)) {  			/* Stop tx until the device processes already  			 * pending replies.  Leave tx virtqueue @@ -532,24 +504,26 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)  			break;  		} -		pkt = vhost_vsock_alloc_pkt(vq, out, in); -		if (!pkt) { +		skb = vhost_vsock_alloc_skb(vq, out, in); +		if (!skb) {  			vq_err(vq, "Faulted on pkt\n");  			continue;  		} -		total_len += sizeof(pkt->hdr) + pkt->len; +		total_len += sizeof(*hdr) + skb->len;  		/* Deliver to monitoring devices all received packets */ -		virtio_transport_deliver_tap_pkt(pkt); +		virtio_transport_deliver_tap_pkt(skb); + +		hdr = virtio_vsock_hdr(skb);  		/* Only accept correctly addressed packets */ -		if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && -		    le64_to_cpu(pkt->hdr.dst_cid) == +		if (le64_to_cpu(hdr->src_cid) == vsock->guest_cid && +		    le64_to_cpu(hdr->dst_cid) ==  		    vhost_transport_get_local_cid()) -			virtio_transport_recv_pkt(&vhost_transport, pkt); +			virtio_transport_recv_pkt(&vhost_transport, skb);  		else -			virtio_transport_free_pkt(pkt); +			kfree_skb(skb);  		vhost_add_used(vq, head, 0);  		added = true; @@ -693,8 +667,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)  		       VHOST_VSOCK_WEIGHT, true, NULL);  	file->private_data = vsock; -	spin_lock_init(&vsock->send_pkt_list_lock); -	INIT_LIST_HEAD(&vsock->send_pkt_list); +	skb_queue_head_init(&vsock->send_pkt_queue);  	vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);  	return 0; @@ -760,16 +733,7 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)  	vhost_vsock_flush(vsock);  	vhost_dev_stop(&vsock->dev); -	spin_lock_bh(&vsock->send_pkt_list_lock); -	while (!list_empty(&vsock->send_pkt_list)) { -		struct virtio_vsock_pkt *pkt; - -		pkt = list_first_entry(&vsock->send_pkt_list, -				struct virtio_vsock_pkt, list); -		list_del_init(&pkt->list); -		virtio_transport_free_pkt(pkt); -	} -	spin_unlock_bh(&vsock->send_pkt_list_lock); +	virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue);  	vhost_dev_cleanup(&vsock->dev);  	kfree(vsock->dev.vqs); @@ -959,7 +923,14 @@ static int __init vhost_vsock_init(void)  				  VSOCK_TRANSPORT_F_H2G);  	if (ret < 0)  		return ret; -	return misc_register(&vhost_vsock_misc); + +	ret = misc_register(&vhost_vsock_misc); +	if (ret) { +		vsock_core_unregister(&vhost_transport.transport); +		return ret; +	} + +	return 0;  };  static void __exit vhost_vsock_exit(void)  |