diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -10,12 +10,13 @@ import os import sys import threading -from typing import Any, Callable, Iterable, Mapping, Optional +from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar from contextlib import contextmanager import psycopg2 import psycopg2.extras +import psycopg2.pool logger = logging.getLogger(__name__) @@ -112,6 +113,9 @@ return data.tobytes() +BaseDbType = TypeVar("BaseDbType", bound="BaseDb") + + class BaseDb: """Base class for swh.*.*Db. @@ -119,8 +123,8 @@ """ - @classmethod - def adapt_conn(cls, conn): + @staticmethod + def adapt_conn(conn: psycopg2.extensions.connection): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) @@ -130,7 +134,7 @@ psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod - def connect(cls, *args, **kwargs): + def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType: """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific @@ -144,11 +148,17 @@ return cls(conn) @classmethod - def from_pool(cls, pool): + def from_pool( + cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool + ) -> BaseDbType: conn = pool.getconn() return cls(conn, pool=pool) - def __init__(self, conn, pool=None): + def __init__( + self, + conn: psycopg2.extensions.connection, + pool: Optional[psycopg2.pool.AbstractConnectionPool] = None, + ): """create a DB proxy Args: @@ -160,11 +170,13 @@ self.conn = conn self.pool = pool - def put_conn(self): + def put_conn(self) -> None: if self.pool: self.pool.putconn(self.conn) - def cursor(self, cur_arg=None): + def cursor( + self, cur_arg: Optional[psycopg2.extensions.cursor] = None + ) -> psycopg2.extensions.cursor: """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored @@ -179,7 +191,7 @@ _cursor = cursor # for bw compat @contextmanager - def transaction(self): + def transaction(self) -> Iterator[psycopg2.extensions.cursor]: """context manager to execute within a DB transaction Yields: @@ -283,5 +295,5 @@ # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) - def mktemp(self, tblname, cur=None): + def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None): self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,))