diff mbox series

[bitbake-devel,1/2] asyncrpc: Add Client Pool object

Message ID 20240129194208.4096506-1-JPEWhacker@gmail.com
State New
Headers show
Series [bitbake-devel,1/2] asyncrpc: Add Client Pool object | expand

Commit Message

Joshua Watt Jan. 29, 2024, 7:42 p.m. UTC
Adds an abstract base class that can be used to implement a pool of
client connections. The class implements a thread that runs an async
event loop, and allows derived classes to schedule work on the loop and
wait for the work to be finished.

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
 bitbake/lib/bb/asyncrpc/__init__.py |  2 +-
 bitbake/lib/bb/asyncrpc/client.py   | 77 +++++++++++++++++++++++++++++
 2 files changed, 78 insertions(+), 1 deletion(-)
diff mbox series

Patch

diff --git a/bitbake/lib/bb/asyncrpc/__init__.py b/bitbake/lib/bb/asyncrpc/__init__.py
index a4371643d74..639e1607f8e 100644
--- a/bitbake/lib/bb/asyncrpc/__init__.py
+++ b/bitbake/lib/bb/asyncrpc/__init__.py
@@ -5,7 +5,7 @@ 
 #
 
 
-from .client import AsyncClient, Client
+from .client import AsyncClient, Client, ClientPool
 from .serv import AsyncServer, AsyncServerConnection
 from .connection import DEFAULT_MAX_CHUNK
 from .exceptions import (
diff --git a/bitbake/lib/bb/asyncrpc/client.py b/bitbake/lib/bb/asyncrpc/client.py
index 0d7cd85780d..a6228bb0ba0 100644
--- a/bitbake/lib/bb/asyncrpc/client.py
+++ b/bitbake/lib/bb/asyncrpc/client.py
@@ -10,6 +10,8 @@  import json
 import os
 import socket
 import sys
+import contextlib
+from threading import Thread
 from .connection import StreamConnection, WebsocketConnection, DEFAULT_MAX_CHUNK
 from .exceptions import ConnectionClosedError, InvokeError
 
@@ -180,3 +182,78 @@  class Client(object):
     def __exit__(self, exc_type, exc_value, traceback):
         self.close()
         return False
+
+
+class ClientPool(object):
+    def __init__(self, max_clients):
+        self.avail_clients = []
+        self.num_clients = 0
+        self.max_clients = max_clients
+        self.loop = None
+        self.client_condition = None
+
+    @abc.abstractmethod
+    async def _new_client(self):
+        raise NotImplementedError("Must be implemented in derived class")
+
+    def close(self):
+        if self.client_condition:
+            self.client_condition = None
+
+        if self.loop:
+            self.loop.run_until_complete(self.__close_clients())
+            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
+            self.loop.close()
+            self.loop = None
+
+    def run_tasks(self, tasks):
+        if not self.loop:
+            self.loop = asyncio.new_event_loop()
+
+        thread = Thread(target=self.__thread_main, args=(tasks,))
+        thread.start()
+        thread.join()
+
+    @contextlib.asynccontextmanager
+    async def get_client(self):
+        async with self.client_condition:
+            if self.avail_clients:
+                client = self.avail_clients.pop()
+            elif self.num_clients < self.max_clients:
+                self.num_clients += 1
+                client = await self._new_client()
+            else:
+                while not self.avail_clients:
+                    await self.client_condition.wait()
+                client = self.avail_clients.pop()
+
+        try:
+            yield client
+        finally:
+            async with self.client_condition:
+                self.avail_clients.append(client)
+                self.client_condition.notify()
+
+    def __thread_main(self, tasks):
+        async def process_task(task):
+            async with self.get_client() as client:
+                await task(client)
+
+        asyncio.set_event_loop(self.loop)
+        if not self.client_condition:
+            self.client_condition = asyncio.Condition()
+        tasks = [process_task(t) for t in tasks]
+        self.loop.run_until_complete(asyncio.gather(*tasks))
+
+    async def __close_clients(self):
+        for c in self.avail_clients:
+            await c.close()
+        self.avail_clients = []
+        self.num_clients = 0
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+        return False