diff mbox series

[1/1] hashserv: Unihash cache

Message ID 20231130081525.2537624-2-tobiasha@axis.com
State New
Headers show
Series hashserv unihash cache | expand

Commit Message

Tobias Hagelborn Nov. 30, 2023, 8:15 a.m. UTC
Cache unihashes to off-load reads of existing unihashes.
Due to non reproducible builds, the output hash has to be considered.
Only return a unihash if the output hash matches.
The cache is least-recently-used (LRU) and bound in size.

Data from handle_report and query_equivalent are inserted in the
unihash cache. This caches unihashes from the database that have
not been written during the current session.

Stats have been added for hits, misses and size of the unihash cache.

Signed-off-by: Tobias Hagelborn <tobias.hagelborn@axis.com>
---
 lib/hashserv/server.py | 188 ++++++++++++++++++++++++++++++++++-------
 1 file changed, 159 insertions(+), 29 deletions(-)

Comments

Joshua Watt Nov. 30, 2023, 12:58 p.m. UTC | #1
On Thu, Nov 30, 2023 at 1:15 AM Tobias Hagelborn
<tobias.hagelborn@axis.com> wrote:
>
> Cache unihashes to off-load reads of existing unihashes.
> Due to non reproducible builds, the output hash has to be considered.
> Only return a unihash if the output hash matches.
> The cache is least-recently-used (LRU) and bound in size.
>
> Data from handle_report and query_equivalent are inserted in the
> unihash cache. This caches unihashes from the database that have
> not been written during the current session.

Is the purpose of this to prevent needing to talk to the database?
Clients are generally doing client side caching, which means they
won't request unihashes they already know about in the first place.

The tricky part here is that we are trying to keep the server
stateless (stats excluded, but on the TODO list) so that multiple
instances can be run at the same time while talking to the same
backend SQL server for load-balancing and redundancy. Not to say that
caching can't work, but it still needs to be correct even if this
server doesn't see all of the RPC calls that manipulate the database.
If we are going to add state, I think we'll need to expand the test
suite to test the different API calls across multiple server (TBH, we
should probably be doing this already, but since the server is
stateless it's a lot easier to prove it works).

I'll look this over and see if I can reason through its correctness
w.r.t multiple server (or if you've done that please reply with your
assertions), but tests for that would help _substantially_ :)

>
> Stats have been added for hits, misses and size of the unihash cache.
>
> Signed-off-by: Tobias Hagelborn <tobias.hagelborn@axis.com>
> ---
>  lib/hashserv/server.py | 188 ++++++++++++++++++++++++++++++++++-------
>  1 file changed, 159 insertions(+), 29 deletions(-)
>
> diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
> index a8650783..3bfd4e2f 100644
> --- a/lib/hashserv/server.py
> +++ b/lib/hashserv/server.py
> @@ -4,6 +4,8 @@
>  #
>
>  from datetime import datetime, timedelta
> +from collections import OrderedDict
> +from collections.abc import MutableMapping
>  import asyncio
>  import logging
>  import math
> @@ -95,17 +97,34 @@ class Sample(object):
>
>
>  class Stats(object):
> -    def __init__(self):
> +
> +    named_stats = (
> +        'average',
> +        'equivs',
> +        'max_time',
> +        'num',
> +        'stdev',
> +        'total_time',
> +        'unihash_cache_hits',
> +        'unihash_cache_inserts',
> +        'unihash_cache_misses',
> +        'unihash_cache_size',
> +    )
> +
> +    def __init__(self, unihash_cache):
>          self.reset()
> +        self.unihash_cache = unihash_cache
>
>      def reset(self):
> -        self.num = 0
> -        self.total_time = 0
> -        self.max_time = 0
>          self.m = 0
>          self.s = 0
>          self.current_elapsed = None
>
> +        self.num = 0
> +        self.total_time = 0
> +        self.max_time = 0
> +        self.equivs = 0
> +
>      def add(self, elapsed):
>          self.num += 1
>          if self.num == 1:
> @@ -136,12 +155,24 @@ class Stats(object):
>              return 0
>          return math.sqrt(self.s / (self.num - 1))
>
> -    def todict(self):
> -        return {
> -            k: getattr(self, k)
> -            for k in ("num", "total_time", "max_time", "average", "stdev")
> -        }
> +    @property
> +    def unihash_cache_hits(self):
> +        return self.unihash_cache.stats_hits
>
> +    @property
> +    def unihash_cache_inserts(self):
> +        return self.unihash_cache.stats_inserts
> +
> +    @property
> +    def unihash_cache_misses(self):
> +        return self.unihash_cache.stats_misses
> +
> +    @property
> +    def unihash_cache_size(self):
> +        return len(self.unihash_cache)
> +
> +    def todict(self):
> +        return {k: getattr(self, k) for k in self.named_stats}
>
>  token_refresh_semaphore = asyncio.Lock()
>
> @@ -232,6 +263,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>          upstream,
>          read_only,
>          anon_perms,
> +        unihash_cache
>      ):
>          super().__init__(socket, "OEHASHEQUIV", logger)
>          self.db_engine = db_engine
> @@ -242,6 +274,7 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>          self.read_only = read_only
>          self.user = None
>          self.anon_perms = anon_perms
> +        self.unihash_cache = unihash_cache
>
>          self.handlers.update(
>              {
> @@ -413,20 +446,23 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>
>                  (method, taskhash) = l.split()
>                  # self.logger.debug('Looking up %s %s' % (method, taskhash))
> -                row = await self.db.get_equivalent(method, taskhash)
> -
> -                if row is not None:
> -                    msg = row["unihash"]
> -                    # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
> -                elif self.upstream_client is not None:
> -                    upstream = await self.upstream_client.get_unihash(method, taskhash)
> -                    if upstream:
> -                        msg = upstream
> -                    else:
> -                        msg = ""
> -                else:
> -                    msg = ""
> -
> +                unihash = self.unihash_cache.get_hash(method,taskhash)
> +                if not unihash:
> +                    row = await self.db.get_equivalent(method, taskhash)
> +
> +                    if row is not None:
> +                        unihash = row['unihash']
> +                        # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
> +                        self.request_stats.equivs+=1
> +                    elif self.upstream_client is not None:
> +                        upstream = await self.upstream_client.get_unihash(method, taskhash)
> +                        if upstream:
> +                            unihash = upstream
> +                # Cache the found item in the read cache
> +                msg = ""
> +                if unihash:
> +                    self.unihash_cache.insert_hash(method, taskhash, unihash, outhash=None)
> +                    msg = unihash
>                  await self.socket.send(msg)
>              finally:
>                  request_measure.end()
> @@ -461,6 +497,16 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>      # report is made inside the function
>      @permissions(READ_PERM)
>      async def handle_report(self, data):
> +
> +        unihash = self.unihash_cache.get_hash(data['method'],data['taskhash'],data['outhash'])
> +        if unihash:
> +            d = {
> +                'taskhash': data['taskhash'],
> +                'method': data['method'],
> +                'unihash': unihash,
> +            }
> +            return d
> +
>          if self.read_only or not self.user_has_permissions(REPORT_PERM):
>              return await self.report_readonly(data)
>
> @@ -511,11 +557,13 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>          else:
>              unihash = data["unihash"]
>
> -        return {
> -            "taskhash": data["taskhash"],
> -            "method": data["method"],
> -            "unihash": unihash,
> -        }
> +        d = {
> +                'taskhash': data['taskhash'],
> +                'method': data['method'],
> +                'unihash': unihash,
> +            }
> +        self.unihash_cache.insert_hash(d['method'], d['taskhash'], unihash, data['outhash'])
> +        return d
>
>      @permissions(READ_PERM, REPORT_PERM)
>      async def handle_equivreport(self, data):
> @@ -738,6 +786,85 @@ class ServerClient(bb.asyncrpc.AsyncServerConnection):
>              "permissions": self.return_perms(self.user.permissions),
>          }
>
> +# LRU Cache Dict based on work by Alex Martelli and martineau used under CC BY 4.0
> +# https://stackoverflow.com/a/2438926
> +class LRUCache(MutableMapping):
> +    def __init__(self, maxlen, items=None):
> +        self._maxlen = maxlen
> +        self.d = OrderedDict()
> +        if items:
> +            for k, v in items:
> +                self[k] = v
> +
> +    @property
> +    def maxlen(self):
> +        return self._maxlen
> +
> +    def __getitem__(self, key):
> +        self.d.move_to_end(key)
> +        return self.d[key]
> +
> +    def __setitem__(self, key, value):
> +        if key in self.d:
> +            self.d.move_to_end(key)
> +        elif len(self.d) == self.maxlen:
> +            self.d.popitem(last=False)
> +        self.d[key] = value
> +
> +    def __delitem__(self, key):
> +        del self.d[key]
> +
> +    def __iter__(self):
> +        return self.d.__iter__()
> +
> +    def __len__(self):
> +        return len(self.d)
> +
> +
> +class UnihashCache():
> +    """
> +    Size limited LRU cache (dict) of taskhash->(unihash,output-hash)
> +    if output-hash is provided, take it into account when matching,
> +    otherwise only map task-hash to unihash.
> +    """
> +
> +    def __init__(self, maxlen=0x20000):
> +        self.hash_cache = LRUCache(maxlen)
> +        self.stats_hits = 0
> +        self.stats_inserts = 0
> +        self.stats_misses = 0
> +
> +    def get_hash(self, method, taskhash, outhash=None):
> +        method_hash = hash(method)
> +        taskhash_hash = hash(taskhash)
> +        cache_entry = self.hash_cache.get((method_hash,taskhash_hash))
> +        result = None
> +        if cache_entry:
> +            if not outhash:
> +                result = cache_entry[0]
> +            else:
> +                outhash_hash = hash(outhash)
> +                if outhash_hash == cache_entry[1]:
> +                    result = cache_entry[0]
> +                else:
> +                    result = None
> +        if result:
> +            self.stats_hits += 1
> +        else:
> +            self.stats_misses += 1
> +        return result
> +
> +    def insert_hash(self, method, taskhash, unihash, outhash=None):
> +        method_hash = hash(method)
> +        taskhash_hash = hash(taskhash)
> +        outhash_hash = hash(outhash) if outhash else None
> +        cache_key=(method_hash,taskhash_hash)
> +        if not self.hash_cache.get(cache_key):
> +            self.hash_cache[cache_key] = (unihash, outhash_hash)
> +            self.stats_inserts += 1
> +
> +    def __len__(self):
> +        return len(self.hash_cache)
>
>  class Server(bb.asyncrpc.AsyncServer):
>      def __init__(
> @@ -765,11 +892,13 @@ class Server(bb.asyncrpc.AsyncServer):
>
>          super().__init__(logger)
>
> -        self.request_stats = Stats()
>          self.db_engine = db_engine
>          self.upstream = upstream
>          self.read_only = read_only
>          self.backfill_queue = None
> +        self.unihash_cache = UnihashCache()
> +        self.request_stats = Stats(self.unihash_cache)
> +
>          self.anon_perms = set(anon_perms)
>          self.admin_username = admin_username
>          self.admin_password = admin_password
> @@ -787,6 +916,7 @@ class Server(bb.asyncrpc.AsyncServer):
>              self.upstream,
>              self.read_only,
>              self.anon_perms,
> +            self.unihash_cache,
>          )
>
>      async def create_admin_user(self):
> --
> 2.30.2
>
>
> -=-=-=-=-=-=-=-=-=-=-=-
> Links: You receive all messages sent to this group.
> View/Reply Online (#191480): https://lists.openembedded.org/g/openembedded-core/message/191480
> Mute This Topic: https://lists.openembedded.org/mt/102890273/3616693
> Group Owner: openembedded-core+owner@lists.openembedded.org
> Unsubscribe: https://lists.openembedded.org/g/openembedded-core/unsub [JPEWhacker@gmail.com]
> -=-=-=-=-=-=-=-=-=-=-=-
>
diff mbox series

Patch

diff --git a/lib/hashserv/server.py b/lib/hashserv/server.py
index a8650783..3bfd4e2f 100644
--- a/lib/hashserv/server.py
+++ b/lib/hashserv/server.py
@@ -4,6 +4,8 @@ 
 #
 
 from datetime import datetime, timedelta
+from collections import OrderedDict
+from collections.abc import MutableMapping
 import asyncio
 import logging
 import math
@@ -95,17 +97,34 @@  class Sample(object):
 
 
 class Stats(object):
-    def __init__(self):
+
+    named_stats = (
+        'average',
+        'equivs',
+        'max_time',
+        'num',
+        'stdev',
+        'total_time',
+        'unihash_cache_hits',
+        'unihash_cache_inserts',
+        'unihash_cache_misses',
+        'unihash_cache_size',
+    )
+
+    def __init__(self, unihash_cache):
         self.reset()
+        self.unihash_cache = unihash_cache
 
     def reset(self):
-        self.num = 0
-        self.total_time = 0
-        self.max_time = 0
         self.m = 0
         self.s = 0
         self.current_elapsed = None
 
+        self.num = 0
+        self.total_time = 0
+        self.max_time = 0
+        self.equivs = 0
+
     def add(self, elapsed):
         self.num += 1
         if self.num == 1:
@@ -136,12 +155,24 @@  class Stats(object):
             return 0
         return math.sqrt(self.s / (self.num - 1))
 
-    def todict(self):
-        return {
-            k: getattr(self, k)
-            for k in ("num", "total_time", "max_time", "average", "stdev")
-        }
+    @property
+    def unihash_cache_hits(self):
+        return self.unihash_cache.stats_hits
 
+    @property
+    def unihash_cache_inserts(self):
+        return self.unihash_cache.stats_inserts
+
+    @property
+    def unihash_cache_misses(self):
+        return self.unihash_cache.stats_misses
+
+    @property
+    def unihash_cache_size(self):
+        return len(self.unihash_cache)
+
+    def todict(self):
+        return {k: getattr(self, k) for k in self.named_stats}
 
 token_refresh_semaphore = asyncio.Lock()
 
@@ -232,6 +263,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         upstream,
         read_only,
         anon_perms,
+        unihash_cache
     ):
         super().__init__(socket, "OEHASHEQUIV", logger)
         self.db_engine = db_engine
@@ -242,6 +274,7 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         self.read_only = read_only
         self.user = None
         self.anon_perms = anon_perms
+        self.unihash_cache = unihash_cache
 
         self.handlers.update(
             {
@@ -413,20 +446,23 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
 
                 (method, taskhash) = l.split()
                 # self.logger.debug('Looking up %s %s' % (method, taskhash))
-                row = await self.db.get_equivalent(method, taskhash)
-
-                if row is not None:
-                    msg = row["unihash"]
-                    # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
-                elif self.upstream_client is not None:
-                    upstream = await self.upstream_client.get_unihash(method, taskhash)
-                    if upstream:
-                        msg = upstream
-                    else:
-                        msg = ""
-                else:
-                    msg = ""
-
+                unihash = self.unihash_cache.get_hash(method,taskhash)
+                if not unihash:
+                    row = await self.db.get_equivalent(method, taskhash)
+
+                    if row is not None:
+                        unihash = row['unihash']
+                        # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
+                        self.request_stats.equivs+=1
+                    elif self.upstream_client is not None:
+                        upstream = await self.upstream_client.get_unihash(method, taskhash)
+                        if upstream:
+                            unihash = upstream
+                # Cache the found item in the read cache
+                msg = ""
+                if unihash:
+                    self.unihash_cache.insert_hash(method, taskhash, unihash, outhash=None)
+                    msg = unihash
                 await self.socket.send(msg)
             finally:
                 request_measure.end()
@@ -461,6 +497,16 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
     # report is made inside the function
     @permissions(READ_PERM)
     async def handle_report(self, data):
+
+        unihash = self.unihash_cache.get_hash(data['method'],data['taskhash'],data['outhash'])
+        if unihash:
+            d = {
+                'taskhash': data['taskhash'],
+                'method': data['method'],
+                'unihash': unihash,
+            }
+            return d
+
         if self.read_only or not self.user_has_permissions(REPORT_PERM):
             return await self.report_readonly(data)
 
@@ -511,11 +557,13 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
         else:
             unihash = data["unihash"]
 
-        return {
-            "taskhash": data["taskhash"],
-            "method": data["method"],
-            "unihash": unihash,
-        }
+        d = {
+                'taskhash': data['taskhash'],
+                'method': data['method'],
+                'unihash': unihash,
+            }
+        self.unihash_cache.insert_hash(d['method'], d['taskhash'], unihash, data['outhash'])
+        return d
 
     @permissions(READ_PERM, REPORT_PERM)
     async def handle_equivreport(self, data):
@@ -738,6 +786,85 @@  class ServerClient(bb.asyncrpc.AsyncServerConnection):
             "permissions": self.return_perms(self.user.permissions),
         }
 
+# LRU Cache Dict based on work by Alex Martelli and martineau used under CC BY 4.0
+# https://stackoverflow.com/a/2438926
+class LRUCache(MutableMapping):
+    def __init__(self, maxlen, items=None):
+        self._maxlen = maxlen
+        self.d = OrderedDict()
+        if items:
+            for k, v in items:
+                self[k] = v
+
+    @property
+    def maxlen(self):
+        return self._maxlen
+
+    def __getitem__(self, key):
+        self.d.move_to_end(key)
+        return self.d[key]
+
+    def __setitem__(self, key, value):
+        if key in self.d:
+            self.d.move_to_end(key)
+        elif len(self.d) == self.maxlen:
+            self.d.popitem(last=False)
+        self.d[key] = value
+
+    def __delitem__(self, key):
+        del self.d[key]
+
+    def __iter__(self):
+        return self.d.__iter__()
+
+    def __len__(self):
+        return len(self.d)
+
+
+class UnihashCache():
+    """
+    Size limited LRU cache (dict) of taskhash->(unihash,output-hash)
+    if output-hash is provided, take it into account when matching,
+    otherwise only map task-hash to unihash.
+    """
+
+    def __init__(self, maxlen=0x20000):
+        self.hash_cache = LRUCache(maxlen)
+        self.stats_hits = 0
+        self.stats_inserts = 0
+        self.stats_misses = 0
+
+    def get_hash(self, method, taskhash, outhash=None):
+        method_hash = hash(method)
+        taskhash_hash = hash(taskhash)
+        cache_entry = self.hash_cache.get((method_hash,taskhash_hash))
+        result = None
+        if cache_entry:
+            if not outhash:
+                result = cache_entry[0]
+            else:
+                outhash_hash = hash(outhash)
+                if outhash_hash == cache_entry[1]:
+                    result = cache_entry[0]
+                else:
+                    result = None
+        if result:
+            self.stats_hits += 1
+        else:
+            self.stats_misses += 1
+        return result
+
+    def insert_hash(self, method, taskhash, unihash, outhash=None):
+        method_hash = hash(method)
+        taskhash_hash = hash(taskhash)
+        outhash_hash = hash(outhash) if outhash else None
+        cache_key=(method_hash,taskhash_hash)
+        if not self.hash_cache.get(cache_key):
+            self.hash_cache[cache_key] = (unihash, outhash_hash)
+            self.stats_inserts += 1
+
+    def __len__(self):
+        return len(self.hash_cache)
 
 class Server(bb.asyncrpc.AsyncServer):
     def __init__(
@@ -765,11 +892,13 @@  class Server(bb.asyncrpc.AsyncServer):
 
         super().__init__(logger)
 
-        self.request_stats = Stats()
         self.db_engine = db_engine
         self.upstream = upstream
         self.read_only = read_only
         self.backfill_queue = None
+        self.unihash_cache = UnihashCache()
+        self.request_stats = Stats(self.unihash_cache)
+
         self.anon_perms = set(anon_perms)
         self.admin_username = admin_username
         self.admin_password = admin_password
@@ -787,6 +916,7 @@  class Server(bb.asyncrpc.AsyncServer):
             self.upstream,
             self.read_only,
             self.anon_perms,
+            self.unihash_cache,
         )
 
     async def create_admin_user(self):