Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/postgresql/provenance.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from __future__ import annotations | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from datetime import datetime | from datetime import datetime | ||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Union | from types import TracebackType | ||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | |||||
import psycopg2.extensions | import psycopg2.extensions | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from typing_extensions import Literal | from typing_extensions import Literal | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from ..interface import ( | from ..interface import ( | ||||
EntityType, | EntityType, | ||||
ProvenanceResult, | ProvenanceResult, | ||||
ProvenanceStorageInterface, | |||||
RelationData, | RelationData, | ||||
RelationType, | RelationType, | ||||
RevisionData, | RevisionData, | ||||
) | ) | ||||
LOGGER = logging.getLogger(__name__) | LOGGER = logging.getLogger(__name__) | ||||
class ProvenanceStoragePostgreSql: | class ProvenanceStoragePostgreSql: | ||||
def __init__( | def __init__(self, raise_on_commit: bool = False, **kwargs) -> None: | ||||
self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False | self.conn_args = kwargs | ||||
) -> None: | |||||
BaseDb.adapt_conn(conn) | |||||
self.conn = conn | |||||
with self.transaction() as cursor: | |||||
cursor.execute("SET timezone TO 'UTC'") | |||||
self._flavor: Optional[str] = None | self._flavor: Optional[str] = None | ||||
self.raise_on_commit = raise_on_commit | self.raise_on_commit = raise_on_commit | ||||
def __enter__(self) -> ProvenanceStorageInterface: | |||||
self.open() | |||||
return self | |||||
def __exit__( | |||||
self, | |||||
exc_type: Optional[Type[BaseException]], | |||||
exc_val: Optional[BaseException], | |||||
exc_tb: Optional[TracebackType], | |||||
) -> None: | |||||
self.close() | |||||
@contextmanager | @contextmanager | ||||
def transaction( | def transaction( | ||||
self, readonly: bool = False | self, readonly: bool = False | ||||
) -> Generator[psycopg2.extensions.cursor, None, None]: | ) -> Generator[psycopg2.extensions.cursor, None, None]: | ||||
self.conn.set_session(readonly=readonly) | self.conn.set_session(readonly=readonly) | ||||
with self.conn: | with self.conn: | ||||
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: | with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: | ||||
yield cur | yield cur | ||||
@property | @property | ||||
def flavor(self) -> str: | def flavor(self) -> str: | ||||
if self._flavor is None: | if self._flavor is None: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute("SELECT swh_get_dbflavor() AS flavor") | cursor.execute("SELECT swh_get_dbflavor() AS flavor") | ||||
self._flavor = cursor.fetchone()["flavor"] | self._flavor = cursor.fetchone()["flavor"] | ||||
assert self._flavor is not None | assert self._flavor is not None | ||||
return self._flavor | return self._flavor | ||||
@property | @property | ||||
def denormalized(self) -> bool: | def denormalized(self) -> bool: | ||||
return "denormalized" in self.flavor | return "denormalized" in self.flavor | ||||
def close(self) -> None: | |||||
self.conn.close() | |||||
def content_add( | def content_add( | ||||
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | ||||
) -> bool: | ) -> bool: | ||||
return self._entity_set_date("content", cnts) | return self._entity_set_date("content", cnts) | ||||
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | ||||
sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
▲ Show 20 Lines • Show All 64 Lines • ▼ Show 20 Lines | def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: | ||||
return True | return True | ||||
except: # noqa: E722 | except: # noqa: E722 | ||||
# Unexpected error occurred, rollback all changes and log message | # Unexpected error occurred, rollback all changes and log message | ||||
LOGGER.exception("Unexpected error") | LOGGER.exception("Unexpected error") | ||||
if self.raise_on_commit: | if self.raise_on_commit: | ||||
raise | raise | ||||
return False | return False | ||||
def open(self) -> None: | |||||
self.conn = BaseDb.connect(**self.conn_args).conn | |||||
BaseDb.adapt_conn(self.conn) | |||||
with self.transaction() as cursor: | |||||
cursor.execute("SET timezone TO 'UTC'") | |||||
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: | def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: | ||||
urls: Dict[Sha1Git, str] = {} | urls: Dict[Sha1Git, str] = {} | ||||
sha1s = tuple(ids) | sha1s = tuple(ids) | ||||
if sha1s: | if sha1s: | ||||
# TODO: consider splitting this query in several ones if sha1s is too big! | # TODO: consider splitting this query in several ones if sha1s is too big! | ||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | values = ", ".join(itertools.repeat("%s", len(sha1s))) | ||||
sql = f""" | sql = f""" | ||||
SELECT sha1, url | SELECT sha1, url | ||||
▲ Show 20 Lines • Show All 172 Lines • Show Last 20 Lines |