Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/postgresql/provenancedb_base.py
from datetime import datetime | from datetime import datetime | ||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from typing import Dict, Generator, Iterable, Optional, Set, Tuple | from typing import Dict, Generator, Iterable, Optional, Set, Tuple | ||||
import psycopg2 | import psycopg2 | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from typing_extensions import Literal | from typing_extensions import Literal | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from ..provenance import ProvenanceResult, RelationType | from ..provenance import EntityType, ProvenanceResult, RelationType | ||||
class ProvenanceDBBase: | class ProvenanceDBBase: | ||||
raise_on_commit: bool = False | raise_on_commit: bool = False | ||||
def __init__(self, conn: psycopg2.extensions.connection): | def __init__(self, conn: psycopg2.extensions.connection): | ||||
BaseDb.adapt_conn(conn) | BaseDb.adapt_conn(conn) | ||||
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | ||||
Show All 33 Lines | def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
return self._entity_get_date("content", ids) | return self._entity_get_date("content", ids) | ||||
def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: | def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: | ||||
return self._entity_set_date("directory", dates) | return self._entity_set_date("directory", dates) | ||||
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
return self._entity_get_date("directory", ids) | return self._entity_get_date("directory", ids) | ||||
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | |||||
sql = f"SELECT sha1 FROM {entity.value}" | |||||
self.cursor.execute(sql) | |||||
return {row["sha1"] for row in self.cursor.fetchall()} | |||||
def location_get(self) -> Set[bytes]: | |||||
sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location" | |||||
self.cursor.execute(sql) | |||||
return {row["path"] for row in self.cursor.fetchall()} | |||||
def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: | def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: | ||||
try: | try: | ||||
if urls: | if urls: | ||||
sql = """ | sql = """ | ||||
LOCK TABLE ONLY origin; | LOCK TABLE ONLY origin; | ||||
INSERT INTO origin(sha1, url) VALUES %s | INSERT INTO origin(sha1, url) VALUES %s | ||||
ON CONFLICT DO NOTHING | ON CONFLICT DO NOTHING | ||||
""" | """ | ||||
▲ Show 20 Lines • Show All 126 Lines • ▼ Show 20 Lines | ) -> bool: | ||||
logging.exception("Unexpected error") | logging.exception("Unexpected error") | ||||
if self.raise_on_commit: | if self.raise_on_commit: | ||||
raise | raise | ||||
return False | return False | ||||
def relation_get( | def relation_get( | ||||
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False | self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False | ||||
) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: | ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: | ||||
result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() | return self._relation_get(relation, ids, reverse) | ||||
sha1s = tuple(ids) | |||||
if sha1s: | |||||
table = relation.value | |||||
src, *_, dst = table.split("_") | |||||
# TODO: improve this! | |||||
if src == "revision" and dst == "revision": | |||||
src_field = "prev" | |||||
dst_field = "next" | |||||
else: | |||||
src_field = src | |||||
dst_field = dst | |||||
joins = [ | |||||
f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", | |||||
f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", | |||||
] | |||||
selected = ["S.sha1 AS src", "D.sha1 AS dst"] | |||||
selector = "S.sha1" if not reverse else "D.sha1" | |||||
if self._relation_uses_location_table(relation): | |||||
joins.append("INNER JOIN location AS L ON (L.id=R.location)") | |||||
selected.append("L.path AS path") | |||||
else: | |||||
selected.append("NULL AS path") | |||||
sql = f""" | def relation_get_all( | ||||
SELECT {", ".join(selected)} | self, relation: RelationType | ||||
FROM {table} AS R | ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: | ||||
{" ".join(joins)} | return self._relation_get(relation, None) | ||||
WHERE {selector} IN %s | |||||
""" | |||||
self.cursor.execute(sql, (sha1s,)) | |||||
result.update( | |||||
(row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() | |||||
) | |||||
return result | |||||
def _entity_get_date( | def _entity_get_date( | ||||
self, | self, | ||||
entity: Literal["content", "directory", "revision"], | entity: Literal["content", "directory", "revision"], | ||||
ids: Iterable[Sha1Git], | ids: Iterable[Sha1Git], | ||||
) -> Dict[Sha1Git, datetime]: | ) -> Dict[Sha1Git, datetime]: | ||||
dates: Dict[Sha1Git, datetime] = {} | dates: Dict[Sha1Git, datetime] = {} | ||||
sha1s = tuple(ids) | sha1s = tuple(ids) | ||||
Show All 25 Lines | ) -> bool: | ||||
return True | return True | ||||
except: # noqa: E722 | except: # noqa: E722 | ||||
# Unexpected error occurred, rollback all changes and log message | # Unexpected error occurred, rollback all changes and log message | ||||
logging.exception("Unexpected error") | logging.exception("Unexpected error") | ||||
if self.raise_on_commit: | if self.raise_on_commit: | ||||
raise | raise | ||||
return False | return False | ||||
def _relation_get( | |||||
self, | |||||
relation: RelationType, | |||||
ids: Optional[Iterable[Sha1Git]], | |||||
reverse: bool = False, | |||||
) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: | |||||
result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() | |||||
sha1s: Optional[Tuple[Tuple[bytes, ...]]] | |||||
if ids is not None: | |||||
sha1s = (tuple(ids),) | |||||
where = f"WHERE {'S.sha1' if not reverse else 'D.sha1'} IN %s" | |||||
else: | |||||
sha1s = None | |||||
where = "" | |||||
if sha1s is None or sha1s[0]: | |||||
table = relation.value | |||||
src, *_, dst = table.split("_") | |||||
# TODO: improve this! | |||||
if src == "revision" and dst == "revision": | |||||
src_field = "prev" | |||||
dst_field = "next" | |||||
else: | |||||
src_field = src | |||||
dst_field = dst | |||||
joins = [ | |||||
f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", | |||||
f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", | |||||
] | |||||
selected = ["S.sha1 AS src", "D.sha1 AS dst"] | |||||
if self._relation_uses_location_table(relation): | |||||
joins.append("INNER JOIN location AS L ON (L.id=R.location)") | |||||
selected.append("L.path AS path") | |||||
else: | |||||
selected.append("NULL AS path") | |||||
sql = f""" | |||||
SELECT {", ".join(selected)} | |||||
FROM {table} AS R | |||||
{" ".join(joins)} | |||||
{where} | |||||
""" | |||||
self.cursor.execute(sql, sha1s) | |||||
result.update( | |||||
(row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() | |||||
) | |||||
return result | |||||
def _relation_uses_location_table(self, relation: RelationType) -> bool: | def _relation_uses_location_table(self, relation: RelationType) -> bool: | ||||
... | ... |