diff mbox series

[bitbake-devel,1/5] hashserv: sqlalchemy: Use _execute() helper

Message ID 20240218200743.2982923-2-JPEWhacker@gmail.com
State New
Headers show
Series Implement parallel Query API | expand

Commit Message

Joshua Watt Feb. 18, 2024, 8:07 p.m. UTC
Use the _execute() helper to execute queries. This helper does the
logging of the statement that was being done manually everywhere.

Signed-off-by: Joshua Watt <JPEWhacker@gmail.com>
---
 bitbake/lib/hashserv/sqlalchemy.py | 297 ++++++++++++++---------------
 1 file changed, 140 insertions(+), 157 deletions(-)
diff mbox series

Patch

diff --git a/bitbake/lib/hashserv/sqlalchemy.py b/bitbake/lib/hashserv/sqlalchemy.py
index 89a6b86d9d8..873547809a0 100644
--- a/bitbake/lib/hashserv/sqlalchemy.py
+++ b/bitbake/lib/hashserv/sqlalchemy.py
@@ -233,124 +233,113 @@  class Database(object):
         return row.value
 
     async def get_unihash_by_taskhash_full(self, method, taskhash):
-        statement = (
-            select(
-                OuthashesV2,
-                UnihashesV3.unihash.label("unihash"),
-            )
-            .join(
-                UnihashesV3,
-                and_(
-                    UnihashesV3.method == OuthashesV2.method,
-                    UnihashesV3.taskhash == OuthashesV2.taskhash,
-                ),
-            )
-            .where(
-                OuthashesV2.method == method,
-                OuthashesV2.taskhash == taskhash,
-            )
-            .order_by(
-                OuthashesV2.created.asc(),
-            )
-            .limit(1)
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(
+                    OuthashesV2,
+                    UnihashesV3.unihash.label("unihash"),
+                )
+                .join(
+                    UnihashesV3,
+                    and_(
+                        UnihashesV3.method == OuthashesV2.method,
+                        UnihashesV3.taskhash == OuthashesV2.taskhash,
+                    ),
+                )
+                .where(
+                    OuthashesV2.method == method,
+                    OuthashesV2.taskhash == taskhash,
+                )
+                .order_by(
+                    OuthashesV2.created.asc(),
+                )
+                .limit(1)
+            )
             return map_row(result.first())
 
     async def get_unihash_by_outhash(self, method, outhash):
-        statement = (
-            select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
-            .join(
-                UnihashesV3,
-                and_(
-                    UnihashesV3.method == OuthashesV2.method,
-                    UnihashesV3.taskhash == OuthashesV2.taskhash,
-                ),
-            )
-            .where(
-                OuthashesV2.method == method,
-                OuthashesV2.outhash == outhash,
-            )
-            .order_by(
-                OuthashesV2.created.asc(),
-            )
-            .limit(1)
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
+                .join(
+                    UnihashesV3,
+                    and_(
+                        UnihashesV3.method == OuthashesV2.method,
+                        UnihashesV3.taskhash == OuthashesV2.taskhash,
+                    ),
+                )
+                .where(
+                    OuthashesV2.method == method,
+                    OuthashesV2.outhash == outhash,
+                )
+                .order_by(
+                    OuthashesV2.created.asc(),
+                )
+                .limit(1)
+            )
             return map_row(result.first())
 
     async def get_outhash(self, method, outhash):
-        statement = (
-            select(OuthashesV2)
-            .where(
-                OuthashesV2.method == method,
-                OuthashesV2.outhash == outhash,
-            )
-            .order_by(
-                OuthashesV2.created.asc(),
-            )
-            .limit(1)
-        )
-
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(OuthashesV2)
+                .where(
+                    OuthashesV2.method == method,
+                    OuthashesV2.outhash == outhash,
+                )
+                .order_by(
+                    OuthashesV2.created.asc(),
+                )
+                .limit(1)
+            )
             return map_row(result.first())
 
     async def get_equivalent_for_outhash(self, method, outhash, taskhash):
-        statement = (
-            select(
-                OuthashesV2.taskhash.label("taskhash"),
-                UnihashesV3.unihash.label("unihash"),
-            )
-            .join(
-                UnihashesV3,
-                and_(
-                    UnihashesV3.method == OuthashesV2.method,
-                    UnihashesV3.taskhash == OuthashesV2.taskhash,
-                ),
-            )
-            .where(
-                OuthashesV2.method == method,
-                OuthashesV2.outhash == outhash,
-                OuthashesV2.taskhash != taskhash,
-            )
-            .order_by(
-                OuthashesV2.created.asc(),
-            )
-            .limit(1)
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(
+                    OuthashesV2.taskhash.label("taskhash"),
+                    UnihashesV3.unihash.label("unihash"),
+                )
+                .join(
+                    UnihashesV3,
+                    and_(
+                        UnihashesV3.method == OuthashesV2.method,
+                        UnihashesV3.taskhash == OuthashesV2.taskhash,
+                    ),
+                )
+                .where(
+                    OuthashesV2.method == method,
+                    OuthashesV2.outhash == outhash,
+                    OuthashesV2.taskhash != taskhash,
+                )
+                .order_by(
+                    OuthashesV2.created.asc(),
+                )
+                .limit(1)
+            )
             return map_row(result.first())
 
     async def get_equivalent(self, method, taskhash):
-        statement = select(
-            UnihashesV3.unihash,
-            UnihashesV3.method,
-            UnihashesV3.taskhash,
-        ).where(
-            UnihashesV3.method == method,
-            UnihashesV3.taskhash == taskhash,
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(
+                    UnihashesV3.unihash,
+                    UnihashesV3.method,
+                    UnihashesV3.taskhash,
+                ).where(
+                    UnihashesV3.method == method,
+                    UnihashesV3.taskhash == taskhash,
+                )
+            )
             return map_row(result.first())
 
     async def remove(self, condition):
         async def do_remove(table):
             where = _make_condition_statement(table, condition)
             if where:
-                statement = delete(table).where(*where)
-                self.logger.debug("%s", statement)
                 async with self.db.begin():
-                    result = await self.db.execute(statement)
+                    result = await self._execute(delete(table).where(*where))
                 return result.rowcount
 
             return 0
@@ -417,21 +406,21 @@  class Database(object):
             return result.rowcount
 
     async def clean_unused(self, oldest):
-        statement = delete(OuthashesV2).where(
-            OuthashesV2.created < oldest,
-            ~(
-                select(UnihashesV3.id)
-                .where(
-                    UnihashesV3.method == OuthashesV2.method,
-                    UnihashesV3.taskhash == OuthashesV2.taskhash,
-                )
-                .limit(1)
-                .exists()
-            ),
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                delete(OuthashesV2).where(
+                    OuthashesV2.created < oldest,
+                    ~(
+                        select(UnihashesV3.id)
+                        .where(
+                            UnihashesV3.method == OuthashesV2.method,
+                            UnihashesV3.taskhash == OuthashesV2.taskhash,
+                        )
+                        .limit(1)
+                        .exists()
+                    ),
+                )
+            )
             return result.rowcount
 
     async def insert_unihash(self, method, taskhash, unihash):
@@ -461,11 +450,9 @@  class Database(object):
         if "created" in data and not isinstance(data["created"], datetime):
             data["created"] = datetime.fromisoformat(data["created"])
 
-        statement = insert(OuthashesV2).values(**data)
-        self.logger.debug("%s", statement)
         try:
             async with self.db.begin():
-                await self.db.execute(statement)
+                await self._execute(insert(OuthashesV2).values(**data))
             return True
         except IntegrityError:
             self.logger.debug(
@@ -474,16 +461,16 @@  class Database(object):
             return False
 
     async def _get_user(self, username):
-        statement = select(
-            Users.username,
-            Users.permissions,
-            Users.token,
-        ).where(
-            Users.username == username,
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(
+                    Users.username,
+                    Users.permissions,
+                    Users.token,
+                ).where(
+                    Users.username == username,
+                )
+            )
             return result.first()
 
     async def lookup_user_token(self, username):
@@ -496,70 +483,66 @@  class Database(object):
         return map_user(await self._get_user(username))
 
     async def set_user_token(self, username, token):
-        statement = (
-            update(Users)
-            .where(
-                Users.username == username,
-            )
-            .values(
-                token=token,
-            )
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                update(Users)
+                .where(
+                    Users.username == username,
+                )
+                .values(
+                    token=token,
+                )
+            )
             return result.rowcount != 0
 
     async def set_user_perms(self, username, permissions):
-        statement = (
-            update(Users)
-            .where(Users.username == username)
-            .values(permissions=" ".join(permissions))
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                update(Users)
+                .where(Users.username == username)
+                .values(permissions=" ".join(permissions))
+            )
             return result.rowcount != 0
 
     async def get_all_users(self):
-        statement = select(
-            Users.username,
-            Users.permissions,
-        )
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                select(
+                    Users.username,
+                    Users.permissions,
+                )
+            )
             return [map_user(row) for row in result]
 
     async def new_user(self, username, permissions, token):
-        statement = insert(Users).values(
-            username=username,
-            permissions=" ".join(permissions),
-            token=token,
-        )
-        self.logger.debug("%s", statement)
         try:
             async with self.db.begin():
-                await self.db.execute(statement)
+                await self._execute(
+                    insert(Users).values(
+                        username=username,
+                        permissions=" ".join(permissions),
+                        token=token,
+                    )
+                )
             return True
         except IntegrityError as e:
             self.logger.debug("Cannot create new user %s: %s", username, e)
             return False
 
     async def delete_user(self, username):
-        statement = delete(Users).where(Users.username == username)
-        self.logger.debug("%s", statement)
         async with self.db.begin():
-            result = await self.db.execute(statement)
+            result = await self._execute(
+                delete(Users).where(Users.username == username)
+            )
             return result.rowcount != 0
 
     async def get_usage(self):
         usage = {}
         async with self.db.begin() as session:
             for name, table in Base.metadata.tables.items():
-                statement = select(func.count()).select_from(table)
-                self.logger.debug("%s", statement)
-                result = await self.db.execute(statement)
+                result = await self._execute(
+                    statement=select(func.count()).select_from(table)
+                )
                 usage[name] = {
                     "rows": result.scalar(),
                 }