diff mbox series

[bitbake-devel,v5,01/22] asyncrpc: Abstract sockets

Message ID 20231101154216.2758185-2-JPEWhacker@gmail.com
State New
Headers show
Series Bitbake Hash Server WebSockets, Alternate Database Backend, and User Management | expand

Commit Message

Joshua Watt Nov. 1, 2023, 3:41 p.m. UTC
Rewrites the asyncrpc client and server code to make it possible to have
other transport backends that are not stream based (e.g. websockets
which are message based). The connection handling classes are now shared
between both the client and server to make it easier to implement new
transport mechanisms

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
 lib/bb/asyncrpc/__init__.py   |  32 +---
 lib/bb/asyncrpc/client.py     |  78 +++------
 lib/bb/asyncrpc/connection.py |  95 +++++++++++
 lib/bb/asyncrpc/exceptions.py |  17 ++
 lib/bb/asyncrpc/serv.py       | 298 +++++++++++++++++-----------------
 lib/hashserv/__init__.py      |  21 ---
 lib/hashserv/client.py        |  38 ++---
 lib/hashserv/server.py        | 115 ++++++-------
 lib/prserv/client.py          |   8 +-
 lib/prserv/serv.py            |  31 ++--
 10 files changed, 380 insertions(+), 353 deletions(-)
 create mode 100644 lib/bb/asyncrpc/connection.py
 create mode 100644 lib/bb/asyncrpc/exceptions.py
diff mbox series

Patch

diff --git a/lib/bb/asyncrpc/__init__.py b/lib/bb/asyncrpc/__init__.py
index 9a85e996..9f677eac 100644
--- a/lib/bb/asyncrpc/__init__.py
+++ b/lib/bb/asyncrpc/__init__.py
@@ -4,30 +4,12 @@ 
 # SPDX-License-Identifier: GPL-2.0-only
 #
 
-import itertools
-import json
-
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
-
-def chunkify(msg, max_chunk):
-    if len(msg) < max_chunk - 1:
-        yield ''.join((msg, "\n"))
-    else:
-        yield ''.join((json.dumps({
-                'chunk-stream': None
-            }), "\n"))
-
-        args = [iter(msg)] * (max_chunk - 1)
-        for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
-            yield ''.join(itertools.chain(m, "\n"))
-        yield "\n"
-
 
 from .client import AsyncClient, Client
-from .serv import AsyncServer, AsyncServerConnection, ClientError, ServerError
+from .serv import AsyncServer, AsyncServerConnection
+from .connection import DEFAULT_MAX_CHUNK
+from .exceptions import (
+    ClientError,
+    ServerError,
+    ConnectionClosedError,
+)
diff --git a/lib/bb/asyncrpc/client.py b/lib/bb/asyncrpc/client.py
index fa042bbe..7f33099b 100644
--- a/lib/bb/asyncrpc/client.py
+++ b/lib/bb/asyncrpc/client.py
@@ -10,13 +10,13 @@  import json
 import os
 import socket
 import sys
-from . import chunkify, DEFAULT_MAX_CHUNK
+from .connection import StreamConnection, DEFAULT_MAX_CHUNK
+from .exceptions import ConnectionClosedError
 
 
 class AsyncClient(object):
     def __init__(self, proto_name, proto_version, logger, timeout=30):
-        self.reader = None
-        self.writer = None
+        self.socket = None
         self.max_chunk = DEFAULT_MAX_CHUNK
         self.proto_name = proto_name
         self.proto_version = proto_version
@@ -25,7 +25,8 @@  class AsyncClient(object):
 
     async def connect_tcp(self, address, port):
         async def connect_sock():
-            return await asyncio.open_connection(address, port)
+            reader, writer = await asyncio.open_connection(address, port)
+            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
 
         self._connect_sock = connect_sock
 
@@ -40,27 +41,27 @@  class AsyncClient(object):
                 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
                 sock.connect(os.path.basename(path))
             finally:
-               os.chdir(cwd)
-            return await asyncio.open_unix_connection(sock=sock)
+                os.chdir(cwd)
+            reader, writer = await asyncio.open_unix_connection(sock=sock)
+            return StreamConnection(reader, writer, self.timeout, self.max_chunk)
 
         self._connect_sock = connect_sock
 
     async def setup_connection(self):
-        s = '%s %s\n\n' % (self.proto_name, self.proto_version)
-        self.writer.write(s.encode("utf-8"))
-        await self.writer.drain()
+        # Send headers
+        await self.socket.send("%s %s" % (self.proto_name, self.proto_version))
+        # End of headers
+        await self.socket.send("")
 
     async def connect(self):
-        if self.reader is None or self.writer is None:
-            (self.reader, self.writer) = await self._connect_sock()
+        if self.socket is None:
+            self.socket = await self._connect_sock()
             await self.setup_connection()
 
     async def close(self):
-        self.reader = None
-
-        if self.writer is not None:
-            self.writer.close()
-            self.writer = None
+        if self.socket is not None:
+            await self.socket.close()
+            self.socket = None
 
     async def _send_wrapper(self, proc):
         count = 0
@@ -71,6 +72,7 @@  class AsyncClient(object):
             except (
                 OSError,
                 ConnectionError,
+                ConnectionClosedError,
                 json.JSONDecodeError,
                 UnicodeDecodeError,
             ) as e:
@@ -82,49 +84,15 @@  class AsyncClient(object):
                 await self.close()
                 count += 1
 
-    async def send_message(self, msg):
-        async def get_line():
-            try:
-                line = await asyncio.wait_for(self.reader.readline(), self.timeout)
-            except asyncio.TimeoutError:
-                raise ConnectionError("Timed out waiting for server")
-
-            if not line:
-                raise ConnectionError("Connection closed")
-
-            line = line.decode("utf-8")
-
-            if not line.endswith("\n"):
-                raise ConnectionError("Bad message %r" % (line))
-
-            return line
-
+    async def invoke(self, msg):
         async def proc():
-            for c in chunkify(json.dumps(msg), self.max_chunk):
-                self.writer.write(c.encode("utf-8"))
-            await self.writer.drain()
-
-            l = await get_line()
-
-            m = json.loads(l)
-            if m and "chunk-stream" in m:
-                lines = []
-                while True:
-                    l = (await get_line()).rstrip("\n")
-                    if not l:
-                        break
-                    lines.append(l)
-
-                m = json.loads("".join(lines))
-
-            return m
+            await self.socket.send_message(msg)
+            return await self.socket.recv_message()
 
         return await self._send_wrapper(proc)
 
     async def ping(self):
-        return await self.send_message(
-            {'ping': {}}
-        )
+        return await self.invoke({"ping": {}})
 
 
 class Client(object):
@@ -142,7 +110,7 @@  class Client(object):
         # required (but harmless) with it.
         asyncio.set_event_loop(self.loop)
 
-        self._add_methods('connect_tcp', 'ping')
+        self._add_methods("connect_tcp", "ping")
 
     @abc.abstractmethod
     def _get_async_client(self):
diff --git a/lib/bb/asyncrpc/connection.py b/lib/bb/asyncrpc/connection.py
new file mode 100644
index 00000000..c4fd2475
--- /dev/null
+++ b/lib/bb/asyncrpc/connection.py
@@ -0,0 +1,95 @@ 
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+import asyncio
+import itertools
+import json
+from .exceptions import ClientError, ConnectionClosedError
+
+
+# The Python async server defaults to a 64K receive buffer, so we hardcode our
+# maximum chunk size. It would be better if the client and server reported to
+# each other what the maximum chunk sizes were, but that will slow down the
+# connection setup with a round trip delay so I'd rather not do that unless it
+# is necessary
+DEFAULT_MAX_CHUNK = 32 * 1024
+
+
+def chunkify(msg, max_chunk):
+    if len(msg) < max_chunk - 1:
+        yield "".join((msg, "\n"))
+    else:
+        yield "".join((json.dumps({"chunk-stream": None}), "\n"))
+
+        args = [iter(msg)] * (max_chunk - 1)
+        for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
+            yield "".join(itertools.chain(m, "\n"))
+        yield "\n"
+
+
+class StreamConnection(object):
+    def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
+        self.reader = reader
+        self.writer = writer
+        self.timeout = timeout
+        self.max_chunk = max_chunk
+
+    @property
+    def address(self):
+        return self.writer.get_extra_info("peername")
+
+    async def send_message(self, msg):
+        for c in chunkify(json.dumps(msg), self.max_chunk):
+            self.writer.write(c.encode("utf-8"))
+        await self.writer.drain()
+
+    async def recv_message(self):
+        l = await self.recv()
+
+        m = json.loads(l)
+        if not m:
+            return m
+
+        if "chunk-stream" in m:
+            lines = []
+            while True:
+                l = await self.recv()
+                if not l:
+                    break
+                lines.append(l)
+
+            m = json.loads("".join(lines))
+
+        return m
+
+    async def send(self, msg):
+        self.writer.write(("%s\n" % msg).encode("utf-8"))
+        await self.writer.drain()
+
+    async def recv(self):
+        if self.timeout < 0:
+            line = await self.reader.readline()
+        else:
+            try:
+                line = await asyncio.wait_for(self.reader.readline(), self.timeout)
+            except asyncio.TimeoutError:
+                raise ConnectionError("Timed out waiting for data")
+
+        if not line:
+            raise ConnectionClosedError("Connection closed")
+
+        line = line.decode("utf-8")
+
+        if not line.endswith("\n"):
+            raise ConnectionError("Bad message %r" % (line))
+
+        return line.rstrip()
+
+    async def close(self):
+        self.reader = None
+        if self.writer is not None:
+            self.writer.close()
+            self.writer = None
diff --git a/lib/bb/asyncrpc/exceptions.py b/lib/bb/asyncrpc/exceptions.py
new file mode 100644
index 00000000..a8942b4f
--- /dev/null
+++ b/lib/bb/asyncrpc/exceptions.py
@@ -0,0 +1,17 @@ 
+#
+# Copyright BitBake Contributors
+#
+# SPDX-License-Identifier: GPL-2.0-only
+#
+
+
+class ClientError(Exception):
+    pass
+
+
+class ServerError(Exception):
+    pass
+
+
+class ConnectionClosedError(Exception):
+    pass
diff --git a/lib/bb/asyncrpc/serv.py b/lib/bb/asyncrpc/serv.py
index d2de4891..8d4da1e2 100644
--- a/lib/bb/asyncrpc/serv.py
+++ b/lib/bb/asyncrpc/serv.py
@@ -12,241 +12,242 @@  import signal
 import socket
 import sys
 import multiprocessing
-from . import chunkify, DEFAULT_MAX_CHUNK
-
-
-class ClientError(Exception):
-    pass
-
-
-class ServerError(Exception):
-    pass
+from .connection import StreamConnection
+from .exceptions import ClientError, ServerError, ConnectionClosedError
 
 
 class AsyncServerConnection(object):
-    def __init__(self, reader, writer, proto_name, logger):
-        self.reader = reader
-        self.writer = writer
+    def __init__(self, socket, proto_name, logger):
+        self.socket = socket
         self.proto_name = proto_name
-        self.max_chunk = DEFAULT_MAX_CHUNK
         self.handlers = {
-            'chunk-stream': self.handle_chunk,
-            'ping': self.handle_ping,
+            "ping": self.handle_ping,
         }
         self.logger = logger
 
+    async def close(self):
+        await self.socket.close()
+
     async def process_requests(self):
         try:
-            self.addr = self.writer.get_extra_info('peername')
-            self.logger.debug('Client %r connected' % (self.addr,))
+            self.logger.info("Client %r connected" % (self.socket.address,))
 
             # Read protocol and version
-            client_protocol = await self.reader.readline()
+            client_protocol = await self.socket.recv()
             if not client_protocol:
                 return
 
-            (client_proto_name, client_proto_version) = client_protocol.decode('utf-8').rstrip().split()
+            (client_proto_name, client_proto_version) = client_protocol.split()
             if client_proto_name != self.proto_name:
-                self.logger.debug('Rejecting invalid protocol %s' % (self.proto_name))
+                self.logger.debug("Rejecting invalid protocol %s" % (self.proto_name))
                 return
 
-            self.proto_version = tuple(int(v) for v in client_proto_version.split('.'))
+            self.proto_version = tuple(int(v) for v in client_proto_version.split("."))
             if not self.validate_proto_version():
-                self.logger.debug('Rejecting invalid protocol version %s' % (client_proto_version))
+                self.logger.debug(
+                    "Rejecting invalid protocol version %s" % (client_proto_version)
+                )
                 return
 
             # Read headers. Currently, no headers are implemented, so look for
             # an empty line to signal the end of the headers
             while True:
-                line = await self.reader.readline()
-                if not line:
-                    return
-
-                line = line.decode('utf-8').rstrip()
-                if not line:
+                header = await self.socket.recv()
+                if not header:
                     break
 
             # Handle messages
             while True:
-                d = await self.read_message()
+                d = await self.socket.recv_message()
                 if d is None:
                     break
-                await self.dispatch_message(d)
-                await self.writer.drain()
-        except ClientError as e:
+                response = await self.dispatch_message(d)
+                await self.socket.send_message(response)
+        except ConnectionClosedError as e:
+            self.logger.info(str(e))
+        except (ClientError, ConnectionError) as e:
             self.logger.error(str(e))
         finally:
-            self.writer.close()
+            await self.close()
 
     async def dispatch_message(self, msg):
         for k in self.handlers.keys():
             if k in msg:
-                self.logger.debug('Handling %s' % k)
-                await self.handlers[k](msg[k])
-                return
+                self.logger.debug("Handling %s" % k)
+                return await self.handlers[k](msg[k])
 
         raise ClientError("Unrecognized command %r" % msg)
 
-    def write_message(self, msg):
-        for c in chunkify(json.dumps(msg), self.max_chunk):
-            self.writer.write(c.encode('utf-8'))
+    async def handle_ping(self, request):
+        return {"alive": True}
 
-    async def read_message(self):
-        l = await self.reader.readline()
-        if not l:
-            return None
 
-        try:
-            message = l.decode('utf-8')
+class StreamServer(object):
+    def __init__(self, handler, logger):
+        self.handler = handler
+        self.logger = logger
+        self.closed = False
 
-            if not message.endswith('\n'):
-                return None
+    async def handle_stream_client(self, reader, writer):
+        # writer.transport.set_write_buffer_limits(0)
+        socket = StreamConnection(reader, writer, -1)
+        if self.closed:
+            await socket.close()
+            return
+
+        await self.handler(socket)
+
+    async def stop(self):
+        self.closed = True
+
+
+class TCPStreamServer(StreamServer):
+    def __init__(self, host, port, handler, logger):
+        super().__init__(handler, logger)
+        self.host = host
+        self.port = port
+
+    def start(self, loop):
+        self.server = loop.run_until_complete(
+            asyncio.start_server(self.handle_stream_client, self.host, self.port)
+        )
+
+        for s in self.server.sockets:
+            self.logger.debug("Listening on %r" % (s.getsockname(),))
+            # Newer python does this automatically. Do it manually here for
+            # maximum compatibility
+            s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
+            s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
+
+            # Enable keep alives. This prevents broken client connections
+            # from persisting on the server for long periods of time.
+            s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
+            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
+            s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
+
+        name = self.server.sockets[0].getsockname()
+        if self.server.sockets[0].family == socket.AF_INET6:
+            self.address = "[%s]:%d" % (name[0], name[1])
+        else:
+            self.address = "%s:%d" % (name[0], name[1])
+
+        return [self.server.wait_closed()]
+
+    async def stop(self):
+        await super().stop()
+        self.server.close()
+
+    def cleanup(self):
+        pass
 
-            return json.loads(message)
-        except (json.JSONDecodeError, UnicodeDecodeError) as e:
-            self.logger.error('Bad message from client: %r' % message)
-            raise e
 
-    async def handle_chunk(self, request):
-        lines = []
-        try:
-            while True:
-                l = await self.reader.readline()
-                l = l.rstrip(b"\n").decode("utf-8")
-                if not l:
-                    break
-                lines.append(l)
+class UnixStreamServer(StreamServer):
+    def __init__(self, path, handler, logger):
+        super().__init__(handler, logger)
+        self.path = path
 
-            msg = json.loads(''.join(lines))
-        except (json.JSONDecodeError, UnicodeDecodeError) as e:
-            self.logger.error('Bad message from client: %r' % lines)
-            raise e
+    def start(self, loop):
+        cwd = os.getcwd()
+        try:
+            # Work around path length limits in AF_UNIX
+            os.chdir(os.path.dirname(self.path))
+            self.server = loop.run_until_complete(
+                asyncio.start_unix_server(
+                    self.handle_stream_client, os.path.basename(self.path)
+                )
+            )
+        finally:
+            os.chdir(cwd)
 
-        if 'chunk-stream' in msg:
-            raise ClientError("Nested chunks are not allowed")
+        self.logger.debug("Listening on %r" % self.path)
+        self.address = "unix://%s" % os.path.abspath(self.path)
+        return [self.server.wait_closed()]
 
-        await self.dispatch_message(msg)
+    async def stop(self):
+        await super().stop()
+        self.server.close()
 
-    async def handle_ping(self, request):
-        response = {'alive': True}
-        self.write_message(response)
+    def cleanup(self):
+        os.unlink(self.path)
 
 
 class AsyncServer(object):
     def __init__(self, logger):
-        self._cleanup_socket = None
         self.logger = logger
-        self.start = None
-        self.address = None
         self.loop = None
+        self.run_tasks = []
 
     def start_tcp_server(self, host, port):
-        def start_tcp():
-            self.server = self.loop.run_until_complete(
-                asyncio.start_server(self.handle_client, host, port)
-            )
-
-            for s in self.server.sockets:
-                self.logger.debug('Listening on %r' % (s.getsockname(),))
-                # Newer python does this automatically. Do it manually here for
-                # maximum compatibility
-                s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
-                s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
-                # Enable keep alives. This prevents broken client connections
-                # from persisting on the server for long periods of time.
-                s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
-                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)
-                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 15)
-                s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 4)
-
-            name = self.server.sockets[0].getsockname()
-            if self.server.sockets[0].family == socket.AF_INET6:
-                self.address = "[%s]:%d" % (name[0], name[1])
-            else:
-                self.address = "%s:%d" % (name[0], name[1])
-
-        self.start = start_tcp
+        self.server = TCPStreamServer(host, port, self._client_handler, self.logger)
 
     def start_unix_server(self, path):
-        def cleanup():
-            os.unlink(path)
-
-        def start_unix():
-            cwd = os.getcwd()
-            try:
-                # Work around path length limits in AF_UNIX
-                os.chdir(os.path.dirname(path))
-                self.server = self.loop.run_until_complete(
-                    asyncio.start_unix_server(self.handle_client, os.path.basename(path))
-                )
-            finally:
-                os.chdir(cwd)
-
-            self.logger.debug('Listening on %r' % path)
-
-            self._cleanup_socket = cleanup
-            self.address = "unix://%s" % os.path.abspath(path)
-
-        self.start = start_unix
-
-    @abc.abstractmethod
-    def accept_client(self, reader, writer):
-        pass
+        self.server = UnixStreamServer(path, self._client_handler, self.logger)
 
-    async def handle_client(self, reader, writer):
-        # writer.transport.set_write_buffer_limits(0)
+    async def _client_handler(self, socket):
         try:
-            client = self.accept_client(reader, writer)
+            client = self.accept_client(socket)
             await client.process_requests()
         except Exception as e:
             import traceback
-            self.logger.error('Error from client: %s' % str(e), exc_info=True)
+
+            self.logger.error("Error from client: %s" % str(e), exc_info=True)
             traceback.print_exc()
-            writer.close()
-        self.logger.debug('Client disconnected')
+            await socket.close()
+        self.logger.debug("Client disconnected")
 
-    def run_loop_forever(self):
-        try:
-            self.loop.run_forever()
-        except KeyboardInterrupt:
-            pass
+    @abc.abstractmethod
+    def accept_client(self, socket):
+        pass
+
+    async def stop(self):
+        self.logger.debug("Stopping server")
+        await self.server.stop()
+
+    def start(self):
+        tasks = self.server.start(self.loop)
+        self.address = self.server.address
+        return tasks
 
     def signal_handler(self):
         self.logger.debug("Got exit signal")
-        self.loop.stop()
+        self.loop.create_task(self.stop())
 
-    def _serve_forever(self):
+    def _serve_forever(self, tasks):
         try:
             self.loop.add_signal_handler(signal.SIGTERM, self.signal_handler)
+            self.loop.add_signal_handler(signal.SIGINT, self.signal_handler)
+            self.loop.add_signal_handler(signal.SIGQUIT, self.signal_handler)
             signal.pthread_sigmask(signal.SIG_UNBLOCK, [signal.SIGTERM])
 
-            self.run_loop_forever()
-            self.server.close()
+            self.loop.run_until_complete(asyncio.gather(*tasks))
 
-            self.loop.run_until_complete(self.server.wait_closed())
-            self.logger.debug('Server shutting down')
+            self.logger.debug("Server shutting down")
         finally:
-            if self._cleanup_socket is not None:
-                self._cleanup_socket()
+            self.server.cleanup()
 
     def serve_forever(self):
         """
         Serve requests in the current process
         """
+        self._create_loop()
+        tasks = self.start()
+        self._serve_forever(tasks)
+        self.loop.close()
+
+    def _create_loop(self):
         # Create loop and override any loop that may have existed in
         # a parent process.  It is possible that the usecases of
         # serve_forever might be constrained enough to allow using
         # get_event_loop here, but better safe than sorry for now.
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(self.loop)
-        self.start()
-        self._serve_forever()
 
     def serve_as_process(self, *, prefunc=None, args=()):
         """
         Serve requests in a child process
         """
+
         def run(queue):
             # Create loop and override any loop that may have existed
             # in a parent process.  Without doing this and instead
@@ -259,18 +260,19 @@  class AsyncServer(object):
             # more general, though, as any potential use of asyncio in
             # Cooker could create a loop that needs to replaced in this
             # new process.
-            self.loop = asyncio.new_event_loop()
-            asyncio.set_event_loop(self.loop)
+            self._create_loop()
             try:
-                self.start()
+                self.address = None
+                tasks = self.start()
             finally:
+                # Always put the server address to wake up the parent task
                 queue.put(self.address)
                 queue.close()
 
             if prefunc is not None:
                 prefunc(self, *args)
 
-            self._serve_forever()
+            self._serve_forever(tasks)
 
             if sys.version_info >= (3, 6):
                 self.loop.run_until_complete(self.loop.shutdown_asyncgens())
diff --git a/lib/hashserv/__init__.py b/lib/hashserv/__init__.py
index 9cb3fd57..3a401835 100644
--- a/lib/hashserv/__init__.py
+++ b/lib/hashserv/__init__.py
@@ -15,13 +15,6 @@  UNIX_PREFIX = "unix://"
 ADDR_TYPE_UNIX = 0
 ADDR_TYPE_TCP = 1
 
-# The Python async server defaults to a 64K receive buffer, so we hardcode our
-# maximum chunk size. It would be better if the client and server reported to
-# each other what the maximum chunk sizes were, but that will slow down the
-# connection setup with a round trip delay so I'd rather not do that unless it
-# is necessary
-DEFAULT_MAX_CHUNK = 32 * 1024
-
 UNIHASH_TABLE_DEFINITION = (
     ("method", "TEXT NOT NULL", "UNIQUE"),
     ("taskhash", "TEXT NOT NULL", "UNIQUE"),
@@ -102,20 +95,6 @@  def parse_address(addr):
         return (ADDR_TYPE_TCP, (host, int(port)))
 
 
-def chunkify(msg, max_chunk):
-    if len(msg) < max_chunk - 1:
-        yield ''.join((msg, "\n"))
-    else:
-        yield ''.join((json.dumps({
-                'chunk-stream': None
-            }), "\n"))
-
-        args = [iter(msg)] * (max_chunk - 1)
-        for m in map(''.join, itertools.zip_longest(*args, fillvalue='')):
-            yield ''.join(itertools.chain(m, "\n"))
-        yield "\n"
-
-
 def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False):
     from . import server
     db = setup_database(dbname, sync=sync)
diff --git a/lib/hashserv/client.py b/lib/hashserv/client.py
index f676d267..ebf1c361 100644
--- a/lib/hashserv/client.py
+++ b/lib/hashserv/client.py
@@ -28,24 +28,24 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
 
     async def send_stream(self, msg):
         async def proc():
-            self.writer.write(("%s\n" % msg).encode("utf-8"))
-            await self.writer.drain()
-            l = await self.reader.readline()
-            if not l:
-                raise ConnectionError("Connection closed")
-            return l.decode("utf-8").rstrip()
+            await self.socket.send(msg)
+            return await self.socket.recv()
 
         return await self._send_wrapper(proc)
 
     async def _set_mode(self, new_mode):
+        async def stream_to_normal():
+            await self.socket.send("END")
+            return await self.socket.recv_message()
+
         if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
-            r = await self.send_stream("END")
+            r = await self._send_wrapper(stream_to_normal)
             if r != "ok":
-                raise ConnectionError("Bad response from server %r" % r)
+                raise ConnectionError("Unable to transition to normal mode: Bad response from server %r" % r)
         elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
-            r = await self.send_message({"get-stream": None})
+            r = await self.invoke({"get-stream": None})
             if r != "ok":
-                raise ConnectionError("Bad response from server %r" % r)
+                raise ConnectionError("Unable to transition to stream mode: Bad response from server %r" % r)
         elif new_mode != self.mode:
             raise Exception(
                 "Undefined mode transition %r -> %r" % (self.mode, new_mode)
@@ -67,7 +67,7 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
         m["method"] = method
         m["outhash"] = outhash
         m["unihash"] = unihash
-        return await self.send_message({"report": m})
+        return await self.invoke({"report": m})
 
     async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
         await self._set_mode(self.MODE_NORMAL)
@@ -75,39 +75,39 @@  class AsyncClient(bb.asyncrpc.AsyncClient):
         m["taskhash"] = taskhash
         m["method"] = method
         m["unihash"] = unihash
-        return await self.send_message({"report-equiv": m})
+        return await self.invoke({"report-equiv": m})
 
     async def get_taskhash(self, method, taskhash, all_properties=False):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message(
+        return await self.invoke(
             {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
         )
 
     async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message(
+        return await self.invoke(
             {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method, "with_unihash": with_unihash}}
         )
 
     async def get_stats(self):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message({"get-stats": None})
+        return await self.invoke({"get-stats": None})
 
     async def reset_stats(self):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message({"reset-stats": None})
+        return await self.invoke({"reset-stats": None})
 
     async def backfill_wait(self):
         await self._set_mode(self.MODE_NORMAL)
-        return (await self.send_message({"backfill-wait": None}))["tasks"]
+        return (await self.invoke({"backfill-wait": None}))["tasks"]
 
     async def remove(self, where):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message({"remove": {"where": where}})
+        return await self.invoke({"remove": {"where": where}})
 
     async def clean_unused(self, max_age):
         await self._set_mode(self.MODE_NORMAL)
-        return await self.send_message({"clean-unused": {"max_age_seconds": max_age}})
+        return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
 
 
 class Client(bb.asyncrpc.Client):
diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index 45bf476b..6d3a4751 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -165,8 +165,8 @@  class ServerCursor(object):
 
 
 class ServerClient(bb.asyncrpc.AsyncServerConnection):
-    def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
-        super().__init__(reader, writer, 'OEHASHEQUIV', logger)
+    def __init__(self, socket, db, request_stats, backfill_queue, upstream, read_only):
+        super().__init__(socket, 'OEHASHEQUIV', logger)
         self.db = db
         self.request_stats = request_stats
         self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
@@ -209,12 +209,11 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
             if k in msg:
                 logger.debug('Handling %s' % k)
                 if 'stream' in k:
-                    await self.handlers[k](msg[k])
+                    return await self.handlers[k](msg[k])
                 else:
                     with self.request_stats.start_sample() as self.request_sample, \
                             self.request_sample.measure():
-                        await self.handlers[k](msg[k])
-                return
+                        return await self.handlers[k](msg[k])
 
         raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
 
@@ -224,9 +223,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         fetch_all = request.get('all', False)
 
         with closing(self.db.cursor()) as cursor:
-            d = await self.get_unihash(cursor, method, taskhash, fetch_all)
-
-        self.write_message(d)
+            return await self.get_unihash(cursor, method, taskhash, fetch_all)
 
     async def get_unihash(self, cursor, method, taskhash, fetch_all=False):
         d = None
@@ -274,9 +271,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         with_unihash = request.get("with_unihash", True)
 
         with closing(self.db.cursor()) as cursor:
-            d = await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
-
-        self.write_message(d)
+            return await self.get_outhash(cursor, method, outhash, taskhash, with_unihash)
 
     async def get_outhash(self, cursor, method, outhash, taskhash, with_unihash=True):
         d = None
@@ -334,14 +329,14 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         )
 
     async def handle_get_stream(self, request):
-        self.write_message('ok')
+        await self.socket.send_message("ok")
 
         while True:
             upstream = None
 
-            l = await self.reader.readline()
+            l = await self.socket.recv()
             if not l:
-                return
+                break
 
             try:
                 # This inner loop is very sensitive and must be as fast as
@@ -352,10 +347,8 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
                 request_measure = self.request_sample.measure()
                 request_measure.start()
 
-                l = l.decode('utf-8').rstrip()
                 if l == 'END':
-                    self.writer.write('ok\n'.encode('utf-8'))
-                    return
+                    break
 
                 (method, taskhash) = l.split()
                 #logger.debug('Looking up %s %s' % (method, taskhash))
@@ -366,29 +359,29 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
                     cursor.close()
 
                 if row is not None:
-                    msg = ('%s\n' % row['unihash']).encode('utf-8')
+                    msg = row['unihash']
                     #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
                 elif self.upstream_client is not None:
                     upstream = await self.upstream_client.get_unihash(method, taskhash)
                     if upstream:
-                        msg = ("%s\n" % upstream).encode("utf-8")
+                        msg = upstream
                     else:
-                        msg = "\n".encode("utf-8")
+                        msg = ""
                 else:
-                    msg = '\n'.encode('utf-8')
+                    msg = ""
 
-                self.writer.write(msg)
+                await self.socket.send(msg)
             finally:
                 request_measure.end()
                 self.request_sample.end()
 
-            await self.writer.drain()
-
             # Post to the backfill queue after writing the result to minimize
             # the turn around time on a request
             if upstream is not None:
                 await self.backfill_queue.put((method, taskhash))
 
+        return "ok"
+
     async def handle_report(self, data):
         with closing(self.db.cursor()) as cursor:
             outhash_data = {
@@ -468,7 +461,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
                 'unihash': unihash,
             }
 
-        self.write_message(d)
+        return d
 
     async def handle_equivreport(self, data):
         with closing(self.db.cursor()) as cursor:
@@ -491,30 +484,28 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
 
             d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
 
-        self.write_message(d)
+        return d
 
 
     async def handle_get_stats(self, request):
-        d = {
+        return {
             'requests': self.request_stats.todict(),
         }
 
-        self.write_message(d)
-
     async def handle_reset_stats(self, request):
         d = {
             'requests': self.request_stats.todict(),
         }
 
         self.request_stats.reset()
-        self.write_message(d)
+        return d
 
     async def handle_backfill_wait(self, request):
         d = {
             'tasks': self.backfill_queue.qsize(),
         }
         await self.backfill_queue.join()
-        self.write_message(d)
+        return d
 
     async def handle_remove(self, request):
         condition = request["where"]
@@ -541,7 +532,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
             count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v2", cursor)
             self.db.commit()
 
-        self.write_message({"count": count})
+        return {"count": count}
 
     async def handle_clean_unused(self, request):
         max_age = request["max_age_seconds"]
@@ -558,7 +549,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
             )
             count = cursor.rowcount
 
-        self.write_message({"count": count})
+        return {"count": count}
 
     def query_equivalent(self, cursor, method, taskhash):
         # This is part of the inner loop and must be as fast as possible
@@ -583,41 +574,33 @@  class Server(bb.asyncrpc.AsyncServer):
         self.db = db
         self.upstream = upstream
         self.read_only = read_only
+        self.backfill_queue = None
 
-    def accept_client(self, reader, writer):
-        return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
+    def accept_client(self, socket):
+        return ServerClient(socket, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
 
-    @contextmanager
-    def _backfill_worker(self):
-        async def backfill_worker_task():
-            client = await create_async_client(self.upstream)
-            try:
-                while True:
-                    item = await self.backfill_queue.get()
-                    if item is None:
-                        self.backfill_queue.task_done()
-                        break
-                    method, taskhash = item
-                    await copy_unihash_from_upstream(client, self.db, method, taskhash)
+    async def backfill_worker_task(self):
+        client = await create_async_client(self.upstream)
+        try:
+            while True:
+                item = await self.backfill_queue.get()
+                if item is None:
                     self.backfill_queue.task_done()
-            finally:
-                await client.close()
+                    break
+                method, taskhash = item
+                await copy_unihash_from_upstream(client, self.db, method, taskhash)
+                self.backfill_queue.task_done()
+        finally:
+            await client.close()
 
-        async def join_worker(worker):
+    def start(self):
+        tasks = super().start()
+        if self.upstream:
+            self.backfill_queue = asyncio.Queue()
+            tasks += [self.backfill_worker_task()]
+        return tasks
+
+    async def stop(self):
+        if self.backfill_queue is not None:
             await self.backfill_queue.put(None)
-            await worker
-
-        if self.upstream is not None:
-            worker = asyncio.ensure_future(backfill_worker_task())
-            try:
-                yield
-            finally:
-                self.loop.run_until_complete(join_worker(worker))
-        else:
-            yield
-
-    def run_loop_forever(self):
-        self.backfill_queue = asyncio.Queue()
-
-        with self._backfill_worker():
-            super().run_loop_forever()
+        await super().stop()
diff --git a/lib/prserv/client.py b/lib/prserv/client.py
index 69ab7a4a..6b81356f 100644
--- a/lib/prserv/client.py
+++ b/lib/prserv/client.py
@@ -14,28 +14,28 @@  class PRAsyncClient(bb.asyncrpc.AsyncClient):
         super().__init__('PRSERVICE', '1.0', logger)
 
     async def getPR(self, version, pkgarch, checksum):
-        response = await self.send_message(
+        response = await self.invoke(
             {'get-pr': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum}}
         )
         if response:
             return response['value']
 
     async def importone(self, version, pkgarch, checksum, value):
-        response = await self.send_message(
+        response = await self.invoke(
             {'import-one': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'value': value}}
         )
         if response:
             return response['value']
 
     async def export(self, version, pkgarch, checksum, colinfo):
-        response = await self.send_message(
+        response = await self.invoke(
             {'export': {'version': version, 'pkgarch': pkgarch, 'checksum': checksum, 'colinfo': colinfo}}
         )
         if response:
             return (response['metainfo'], response['datainfo'])
 
     async def is_readonly(self):
-        response = await self.send_message(
+        response = await self.invoke(
             {'is-readonly': {}}
         )
         if response:
diff --git a/lib/prserv/serv.py b/lib/prserv/serv.py
index c686b206..ea793316 100644
--- a/lib/prserv/serv.py
+++ b/lib/prserv/serv.py
@@ -20,8 +20,8 @@  PIDPREFIX = "/tmp/PRServer_%s_%s.pid"
 singleton = None
 
 class PRServerClient(bb.asyncrpc.AsyncServerConnection):
-    def __init__(self, reader, writer, table, read_only):
-        super().__init__(reader, writer, 'PRSERVICE', logger)
+    def __init__(self, socket, table, read_only):
+        super().__init__(socket, 'PRSERVICE', logger)
         self.handlers.update({
             'get-pr': self.handle_get_pr,
             'import-one': self.handle_import_one,
@@ -36,12 +36,12 @@  class PRServerClient(bb.asyncrpc.AsyncServerConnection):
 
     async def dispatch_message(self, msg):
         try:
-            await super().dispatch_message(msg)
+            return await super().dispatch_message(msg)
         except:
             self.table.sync()
             raise
-
-        self.table.sync_if_dirty()
+        else:
+            self.table.sync_if_dirty()
 
     async def handle_get_pr(self, request):
         version = request['version']
@@ -57,7 +57,7 @@  class PRServerClient(bb.asyncrpc.AsyncServerConnection):
         except sqlite3.Error as exc:
             logger.error(str(exc))
 
-        self.write_message(response)
+        return response
 
     async def handle_import_one(self, request):
         response = None
@@ -71,7 +71,7 @@  class PRServerClient(bb.asyncrpc.AsyncServerConnection):
             if value is not None:
                 response = {'value': value}
 
-        self.write_message(response)
+        return response
 
     async def handle_export(self, request):
         version = request['version']
@@ -85,12 +85,10 @@  class PRServerClient(bb.asyncrpc.AsyncServerConnection):
             logger.error(str(exc))
             metainfo = datainfo = None
 
-        response = {'metainfo': metainfo, 'datainfo': datainfo}
-        self.write_message(response)
+        return {'metainfo': metainfo, 'datainfo': datainfo}
 
     async def handle_is_readonly(self, request):
-        response = {'readonly': self.read_only}
-        self.write_message(response)
+        return {'readonly': self.read_only}
 
 class PRServer(bb.asyncrpc.AsyncServer):
     def __init__(self, dbfile, read_only=False):
@@ -99,20 +97,23 @@  class PRServer(bb.asyncrpc.AsyncServer):
         self.table = None
         self.read_only = read_only
 
-    def accept_client(self, reader, writer):
-        return PRServerClient(reader, writer, self.table, self.read_only)
+    def accept_client(self, socket):
+        return PRServerClient(socket, self.table, self.read_only)
 
-    def _serve_forever(self):
+    def start(self):
+        tasks = super().start()
         self.db = prserv.db.PRData(self.dbfile, read_only=self.read_only)
         self.table = self.db["PRMAIN"]
 
         logger.info("Started PRServer with DBfile: %s, Address: %s, PID: %s" %
                      (self.dbfile, self.address, str(os.getpid())))
 
-        super()._serve_forever()
+        return tasks
 
+    async def stop(self):
         self.table.sync_if_dirty()
         self.db.disconnect()
+        await super().stop()
 
     def signal_handler(self):
         super().signal_handler()