Changeset View
Changeset View
Standalone View
Standalone View
swh/core/db/__init__.py
# Copyright (C) 2015-2019 The Software Heritage developers | # Copyright (C) 2015-2019 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import datetime | import datetime | ||||
import enum | import enum | ||||
import json | import json | ||||
import logging | import logging | ||||
import os | import os | ||||
import sys | import sys | ||||
import threading | import threading | ||||
from typing import Any, Callable, Iterable, Mapping, Optional | from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
import psycopg2 | import psycopg2 | ||||
import psycopg2.extras | import psycopg2.extras | ||||
import psycopg2.pool | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
psycopg2.extras.register_uuid() | psycopg2.extras.register_uuid() | ||||
▲ Show 20 Lines • Show All 80 Lines • ▼ Show 20 Lines | |||||
def typecast_bytea(value, cur): | def typecast_bytea(value, cur): | ||||
if value is not None: | if value is not None: | ||||
data = psycopg2.BINARY(value, cur) | data = psycopg2.BINARY(value, cur) | ||||
return data.tobytes() | return data.tobytes() | ||||
BaseDbType = TypeVar("BaseDbType", bound="BaseDb") | |||||
class BaseDb: | class BaseDb: | ||||
"""Base class for swh.*.*Db. | """Base class for swh.*.*Db. | ||||
cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb | cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb | ||||
""" | """ | ||||
@classmethod | @staticmethod | ||||
def adapt_conn(cls, conn): | def adapt_conn(conn: psycopg2.extensions.connection): | ||||
"""Makes psycopg2 use 'bytes' to decode bytea instead of | """Makes psycopg2 use 'bytes' to decode bytea instead of | ||||
'memoryview', for this connection.""" | 'memoryview', for this connection.""" | ||||
t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) | t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) | ||||
psycopg2.extensions.register_type(t_bytes, conn) | psycopg2.extensions.register_type(t_bytes, conn) | ||||
t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) | t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) | ||||
psycopg2.extensions.register_type(t_bytes_array, conn) | psycopg2.extensions.register_type(t_bytes_array, conn) | ||||
@classmethod | @classmethod | ||||
def connect(cls, *args, **kwargs): | def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType: | ||||
"""factory method to create a DB proxy | """factory method to create a DB proxy | ||||
Accepts all arguments of psycopg2.connect; only some specific | Accepts all arguments of psycopg2.connect; only some specific | ||||
possibilities are reported below. | possibilities are reported below. | ||||
Args: | Args: | ||||
connstring: libpq2 connection string | connstring: libpq2 connection string | ||||
""" | """ | ||||
conn = psycopg2.connect(*args, **kwargs) | conn = psycopg2.connect(*args, **kwargs) | ||||
return cls(conn) | return cls(conn) | ||||
@classmethod | @classmethod | ||||
def from_pool(cls, pool): | def from_pool( | ||||
cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool | |||||
) -> BaseDbType: | |||||
conn = pool.getconn() | conn = pool.getconn() | ||||
return cls(conn, pool=pool) | return cls(conn, pool=pool) | ||||
def __init__(self, conn, pool=None): | def __init__( | ||||
self, | |||||
conn: psycopg2.extensions.connection, | |||||
pool: Optional[psycopg2.pool.AbstractConnectionPool] = None, | |||||
): | |||||
"""create a DB proxy | """create a DB proxy | ||||
Args: | Args: | ||||
conn: psycopg2 connection to the SWH DB | conn: psycopg2 connection to the SWH DB | ||||
pool: psycopg2 pool of connections | pool: psycopg2 pool of connections | ||||
""" | """ | ||||
self.adapt_conn(conn) | self.adapt_conn(conn) | ||||
self.conn = conn | self.conn = conn | ||||
self.pool = pool | self.pool = pool | ||||
def put_conn(self): | def put_conn(self) -> None: | ||||
if self.pool: | if self.pool: | ||||
self.pool.putconn(self.conn) | self.pool.putconn(self.conn) | ||||
def cursor(self, cur_arg=None): | def cursor( | ||||
self, cur_arg: Optional[psycopg2.extensions.cursor] = None | |||||
) -> psycopg2.extensions.cursor: | |||||
"""get a cursor: from cur_arg if given, or a fresh one otherwise | """get a cursor: from cur_arg if given, or a fresh one otherwise | ||||
meant to avoid boilerplate if/then/else in methods that proxy stored | meant to avoid boilerplate if/then/else in methods that proxy stored | ||||
procedures | procedures | ||||
""" | """ | ||||
if cur_arg is not None: | if cur_arg is not None: | ||||
return cur_arg | return cur_arg | ||||
else: | else: | ||||
return self.conn.cursor() | return self.conn.cursor() | ||||
_cursor = cursor # for bw compat | _cursor = cursor # for bw compat | ||||
@contextmanager | @contextmanager | ||||
def transaction(self): | def transaction(self) -> Iterator[psycopg2.extensions.cursor]: | ||||
"""context manager to execute within a DB transaction | """context manager to execute within a DB transaction | ||||
Yields: | Yields: | ||||
a psycopg2 cursor | a psycopg2 cursor | ||||
""" | """ | ||||
with self.conn.cursor() as cur: | with self.conn.cursor() as cur: | ||||
try: | try: | ||||
▲ Show 20 Lines • Show All 87 Lines • ▼ Show 20 Lines | ) -> None: | ||||
# No problem bubbling up exceptions, but we still need to make sure | # No problem bubbling up exceptions, but we still need to make sure | ||||
# we finish copying, even though we're probably going to cancel the | # we finish copying, even though we're probably going to cancel the | ||||
# transaction. | # transaction. | ||||
write_thread.join() | write_thread.join() | ||||
if exc_info: | if exc_info: | ||||
# postgresql returned an error, let's raise it. | # postgresql returned an error, let's raise it. | ||||
raise exc_info[1].with_traceback(exc_info[2]) | raise exc_info[1].with_traceback(exc_info[2]) | ||||
def mktemp(self, tblname, cur=None): | def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None): | ||||
self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) | self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) |