diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h index 1c53c4c4d88fecc3e0b8e1632b9132f644fea403..568a87c5e0d0f95622bc153ecb9a39e803479ad7 100644 --- a/include/net/af_vsock.h +++ b/include/net/af_vsock.h @@ -78,6 +78,7 @@ struct vsock_sock { s64 vsock_stream_has_data(struct vsock_sock *vsk); s64 vsock_stream_has_space(struct vsock_sock *vsk); struct sock *vsock_create_connected(struct sock *parent); +void vsock_data_ready(struct sock *sk); /**** TRANSPORT ****/ @@ -135,6 +136,7 @@ struct vsock_transport { u64 (*stream_rcvhiwat)(struct vsock_sock *); bool (*stream_is_active)(struct vsock_sock *); bool (*stream_allow)(u32 cid, u32 port); + int (*set_rcvlowat)(struct vsock_sock *vsk, int val); /* SEQ_PACKET. */ ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg, diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index b4ee163154a683ef22568c6c05825fa2ad14a768..ee418701cdee902ee7fea08767f51be055b87ad6 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -882,6 +882,16 @@ s64 vsock_stream_has_space(struct vsock_sock *vsk) } EXPORT_SYMBOL_GPL(vsock_stream_has_space); +void vsock_data_ready(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat || + sock_flag(sk, SOCK_DONE)) + sk->sk_data_ready(sk); +} +EXPORT_SYMBOL_GPL(vsock_data_ready); + static int vsock_release(struct socket *sock) { __vsock_release(sock->sk, 0); @@ -1066,8 +1076,9 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock, if (transport && transport->stream_is_active(vsk) && !(sk->sk_shutdown & RCV_SHUTDOWN)) { bool data_ready_now = false; + int target = sock_rcvlowat(sk, 0, INT_MAX); int ret = transport->notify_poll_in( - vsk, 1, &data_ready_now); + vsk, target, &data_ready_now); if (ret < 0) { mask |= EPOLLERR; } else { @@ -2137,6 +2148,25 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, return err; } +static int vsock_set_rcvlowat(struct sock *sk, int val) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (val > vsk->buffer_size) + return -EINVAL; + + transport = vsk->transport; + + if (transport && transport->set_rcvlowat) + return transport->set_rcvlowat(vsk, val); + + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + return 0; +} + static const struct proto_ops vsock_stream_ops = { .family = PF_VSOCK, .owner = THIS_MODULE, @@ -2156,6 +2186,7 @@ static const struct proto_ops vsock_stream_ops = { .recvmsg = vsock_connectible_recvmsg, .mmap = sock_no_mmap, .sendpage = sock_no_sendpage, + .set_rcvlowat = vsock_set_rcvlowat, }; static const struct proto_ops vsock_seqpacket_ops = { diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c index fd98229e3db30988b983761b9c809ce059d6d9b2..59c3e2697069045ae1227f0ea1d1064823148517 100644 --- a/net/vmw_vsock/hyperv_transport.c +++ b/net/vmw_vsock/hyperv_transport.c @@ -815,6 +815,12 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, return 0; } +static +int hvs_set_rcvlowat(struct vsock_sock *vsk, int val) +{ + return -EOPNOTSUPP; +} + static struct vsock_transport hvs_transport = { .module = THIS_MODULE, @@ -850,6 +856,7 @@ static struct vsock_transport hvs_transport = { .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue, .notify_send_post_enqueue = hvs_notify_send_post_enqueue, + .set_rcvlowat = hvs_set_rcvlowat }; static bool hvs_check_transport(struct vsock_sock *vsk) diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index ec2c2afbf0d060b4f6335a7616c8deba0a95fde6..35863132f4f117d3e2435cd3e9c9510814844b79 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -634,10 +634,7 @@ virtio_transport_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *data_ready_now) { - if (vsock_stream_has_data(vsk)) - *data_ready_now = true; - else - *data_ready_now = false; + *data_ready_now = vsock_stream_has_data(vsk) >= target; return 0; } @@ -1084,7 +1081,7 @@ virtio_transport_recv_connected(struct sock *sk, switch (le16_to_cpu(pkt->hdr.op)) { case VIRTIO_VSOCK_OP_RW: virtio_transport_recv_enqueue(vsk, pkt); - sk->sk_data_ready(sk); + vsock_data_ready(sk); return err; case VIRTIO_VSOCK_OP_CREDIT_REQUEST: virtio_transport_send_credit_update(vsk); diff --git a/net/vmw_vsock/vmci_transport_notify.c b/net/vmw_vsock/vmci_transport_notify.c index d69fc4b595ad42b2784698ea1e8e2b7b60632343..7c3a7db134b2820033c5d475731b0ab2b84382e5 100644 --- a/net/vmw_vsock/vmci_transport_notify.c +++ b/net/vmw_vsock/vmci_transport_notify.c @@ -307,7 +307,7 @@ vmci_transport_handle_wrote(struct sock *sk, struct vsock_sock *vsk = vsock_sk(sk); PKT_FIELD(vsk, sent_waiting_read) = false; #endif - sk->sk_data_ready(sk); + vsock_data_ready(sk); } static void vmci_transport_notify_pkt_socket_init(struct sock *sk) @@ -340,12 +340,12 @@ vmci_transport_notify_pkt_poll_in(struct sock *sk, { struct vsock_sock *vsk = vsock_sk(sk); - if (vsock_stream_has_data(vsk)) { + if (vsock_stream_has_data(vsk) >= target) { *data_ready_now = true; } else { - /* We can't read right now because there is nothing in the - * queue. Ask for notifications when there is something to - * read. + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. */ if (sk->sk_state == TCP_ESTABLISHED) { if (!send_waiting_read(sk, 1)) diff --git a/net/vmw_vsock/vmci_transport_notify_qstate.c b/net/vmw_vsock/vmci_transport_notify_qstate.c index 0f36d7c45db3933f25add3a3e15f6f5fb8c242f7..e96a88d850a86a9ffbd422674b080ef324dc4799 100644 --- a/net/vmw_vsock/vmci_transport_notify_qstate.c +++ b/net/vmw_vsock/vmci_transport_notify_qstate.c @@ -84,7 +84,7 @@ vmci_transport_handle_wrote(struct sock *sk, bool bottom_half, struct sockaddr_vm *dst, struct sockaddr_vm *src) { - sk->sk_data_ready(sk); + vsock_data_ready(sk); } static void vsock_block_update_write_window(struct sock *sk) @@ -161,12 +161,12 @@ vmci_transport_notify_pkt_poll_in(struct sock *sk, { struct vsock_sock *vsk = vsock_sk(sk); - if (vsock_stream_has_data(vsk)) { + if (vsock_stream_has_data(vsk) >= target) { *data_ready_now = true; } else { - /* We can't read right now because there is nothing in the - * queue. Ask for notifications when there is something to - * read. + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. */ if (sk->sk_state == TCP_ESTABLISHED) vsock_block_update_write_window(sk); @@ -282,7 +282,7 @@ vmci_transport_notify_pkt_recv_post_dequeue( /* See the comment in * vmci_transport_notify_pkt_send_post_enqueue(). */ - sk->sk_data_ready(sk); + vsock_data_ready(sk); } return err; diff --git a/tools/testing/vsock/vsock_test.c b/tools/testing/vsock/vsock_test.c index dc577461afc20546e44f700ddd49b83b78ececb4..bb6d691cb30d0d2fccdd1ea53c824f83ebde11c7 100644 --- a/tools/testing/vsock/vsock_test.c +++ b/tools/testing/vsock/vsock_test.c @@ -18,6 +18,7 @@ #include <sys/socket.h> #include <time.h> #include <sys/mman.h> +#include <poll.h> #include "timeout.h" #include "control.h" @@ -596,6 +597,108 @@ static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opt close(fd); } +#define RCVLOWAT_BUF_SIZE 128 + +static void test_stream_poll_rcvlowat_server(const struct test_opts *opts) +{ + int fd; + int i; + + fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL); + if (fd < 0) { + perror("accept"); + exit(EXIT_FAILURE); + } + + /* Send 1 byte. */ + send_byte(fd, 1, 0); + + control_writeln("SRVSENT"); + + /* Wait until client is ready to receive rest of data. */ + control_expectln("CLNSENT"); + + for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++) + send_byte(fd, 1, 0); + + /* Keep socket in active state. */ + control_expectln("POLLDONE"); + + close(fd); +} + +static void test_stream_poll_rcvlowat_client(const struct test_opts *opts) +{ + unsigned long lowat_val = RCVLOWAT_BUF_SIZE; + char buf[RCVLOWAT_BUF_SIZE]; + struct pollfd fds; + ssize_t read_res; + short poll_flags; + int fd; + + fd = vsock_stream_connect(opts->peer_cid, 1234); + if (fd < 0) { + perror("connect"); + exit(EXIT_FAILURE); + } + + if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT, + &lowat_val, sizeof(lowat_val))) { + perror("setsockopt"); + exit(EXIT_FAILURE); + } + + control_expectln("SRVSENT"); + + /* At this point, server sent 1 byte. */ + fds.fd = fd; + poll_flags = POLLIN | POLLRDNORM; + fds.events = poll_flags; + + /* Try to wait for 1 sec. */ + if (poll(&fds, 1, 1000) < 0) { + perror("poll"); + exit(EXIT_FAILURE); + } + + /* poll() must return nothing. */ + if (fds.revents) { + fprintf(stderr, "Unexpected poll result %hx\n", + fds.revents); + exit(EXIT_FAILURE); + } + + /* Tell server to send rest of data. */ + control_writeln("CLNSENT"); + + /* Poll for data. */ + if (poll(&fds, 1, 10000) < 0) { + perror("poll"); + exit(EXIT_FAILURE); + } + + /* Only these two bits are expected. */ + if (fds.revents != poll_flags) { + fprintf(stderr, "Unexpected poll result %hx\n", + fds.revents); + exit(EXIT_FAILURE); + } + + /* Use MSG_DONTWAIT, if call is going to wait, EAGAIN + * will be returned. + */ + read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT); + if (read_res != RCVLOWAT_BUF_SIZE) { + fprintf(stderr, "Unexpected recv result %zi\n", + read_res); + exit(EXIT_FAILURE); + } + + control_writeln("POLLDONE"); + + close(fd); +} + static struct test_case test_cases[] = { { .name = "SOCK_STREAM connection reset", @@ -646,6 +749,11 @@ static struct test_case test_cases[] = { .run_client = test_seqpacket_invalid_rec_buffer_client, .run_server = test_seqpacket_invalid_rec_buffer_server, }, + { + .name = "SOCK_STREAM poll() + SO_RCVLOWAT", + .run_client = test_stream_poll_rcvlowat_client, + .run_server = test_stream_poll_rcvlowat_server, + }, {}, };