Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/postgresql/provenance.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from __future__ import annotations | from __future__ import annotations | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from datetime import datetime | from datetime import datetime | ||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from types import TracebackType | from types import TracebackType | ||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | ||||
import psycopg2.extensions | import psycopg2.extensions | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from typing_extensions import Literal | |||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.core.statsd import statsd | from swh.core.statsd import statsd | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from ..interface import ( | from ..interface import ( | ||||
DirectoryData, | |||||
EntityType, | EntityType, | ||||
ProvenanceResult, | ProvenanceResult, | ||||
ProvenanceStorageInterface, | ProvenanceStorageInterface, | ||||
RelationData, | RelationData, | ||||
RelationType, | RelationType, | ||||
RevisionData, | RevisionData, | ||||
) | ) | ||||
▲ Show 20 Lines • Show All 45 Lines • ▼ Show 20 Lines | class ProvenanceStoragePostgreSql: | ||||
def denormalized(self) -> bool: | def denormalized(self) -> bool: | ||||
return "denormalized" in self.flavor | return "denormalized" in self.flavor | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | ||||
def close(self) -> None: | def close(self) -> None: | ||||
self.conn.close() | self.conn.close() | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | ||||
def content_add( | def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: | ||||
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | try: | ||||
) -> bool: | if cnts: | ||||
return self._entity_set_date("content", cnts) | sql = """ | ||||
INSERT INTO content(sha1, date) VALUES %s | |||||
ON CONFLICT (sha1) DO | |||||
UPDATE SET date=LEAST(EXCLUDED.date,content.date) | |||||
""" | |||||
page_size = self.page_size or len(cnts) | |||||
with self.transaction() as cursor: | |||||
psycopg2.extras.execute_values( | |||||
cursor, sql, argslist=cnts.items(), page_size=page_size | |||||
) | |||||
return True | |||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | ||||
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | ||||
sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=(id,)) | cursor.execute(query=sql, vars=(id,)) | ||||
row = cursor.fetchone() | row = cursor.fetchone() | ||||
return ProvenanceResult(**row) if row is not None else None | return ProvenanceResult(**row) if row is not None else None | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) | ||||
def content_find_all( | def content_find_all( | ||||
self, id: Sha1Git, limit: Optional[int] = None | self, id: Sha1Git, limit: Optional[int] = None | ||||
) -> Generator[ProvenanceResult, None, None]: | ) -> Generator[ProvenanceResult, None, None]: | ||||
sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" | sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=(id, limit)) | cursor.execute(query=sql, vars=(id, limit)) | ||||
yield from (ProvenanceResult(**row) for row in cursor) | yield from (ProvenanceResult(**row) for row in cursor) | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) | ||||
def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
return self._entity_get_date("content", ids) | dates: Dict[Sha1Git, datetime] = {} | ||||
sha1s = tuple(ids) | |||||
if sha1s: | |||||
# TODO: consider splitting this query in several ones if sha1s is too big! | |||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | |||||
sql = f""" | |||||
SELECT sha1, date | |||||
FROM content | |||||
WHERE sha1 IN ({values}) | |||||
AND date IS NOT NULL | |||||
""" | |||||
with self.transaction(readonly=True) as cursor: | |||||
cursor.execute(query=sql, vars=sha1s) | |||||
dates.update((row["sha1"], row["date"]) for row in cursor) | |||||
return dates | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | ||||
def directory_add( | def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: | ||||
self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()] | ||||
) -> bool: | try: | ||||
return self._entity_set_date("directory", dirs) | if data: | ||||
sql = """ | |||||
INSERT INTO directory(sha1, date, flat) VALUES %s | |||||
ON CONFLICT (sha1) DO | |||||
UPDATE SET | |||||
date=LEAST(EXCLUDED.date, directory.date), | |||||
flat=(EXCLUDED.flat OR directory.flat) | |||||
""" | |||||
page_size = self.page_size or len(data) | |||||
with self.transaction() as cursor: | |||||
psycopg2.extras.execute_values( | |||||
cur=cursor, sql=sql, argslist=data, page_size=page_size | |||||
) | |||||
return True | |||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | ||||
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: | ||||
return self._entity_get_date("directory", ids) | result: Dict[Sha1Git, DirectoryData] = {} | ||||
sha1s = tuple(ids) | |||||
if sha1s: | |||||
# TODO: consider splitting this query in several ones if sha1s is too big! | |||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | |||||
sql = f""" | |||||
SELECT sha1, date, flat | |||||
FROM directory | |||||
WHERE sha1 IN ({values}) | |||||
AND date IS NOT NULL | |||||
""" | |||||
with self.transaction(readonly=True) as cursor: | |||||
cursor.execute(query=sql, vars=sha1s) | |||||
result.update( | |||||
(row["sha1"], DirectoryData(date=row["date"], flat=row["flat"])) | |||||
for row in cursor | |||||
) | |||||
return result | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | ||||
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(f"SELECT sha1 FROM {entity.value}") | cursor.execute(f"SELECT sha1 FROM {entity.value}") | ||||
return {row["sha1"] for row in cursor} | return {row["sha1"] for row in cursor} | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | ||||
▲ Show 20 Lines • Show All 166 Lines • ▼ Show 20 Lines | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
return self._relation_get(relation, ids, reverse) | return self._relation_get(relation, ids, reverse) | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | ||||
def relation_get_all( | def relation_get_all( | ||||
self, relation: RelationType | self, relation: RelationType | ||||
) -> Dict[Sha1Git, Set[RelationData]]: | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
return self._relation_get(relation, None) | return self._relation_get(relation, None) | ||||
def _entity_get_date( | |||||
self, | |||||
entity: Literal["content", "directory", "revision"], | |||||
ids: Iterable[Sha1Git], | |||||
) -> Dict[Sha1Git, datetime]: | |||||
dates: Dict[Sha1Git, datetime] = {} | |||||
sha1s = tuple(ids) | |||||
if sha1s: | |||||
# TODO: consider splitting this query in several ones if sha1s is too big! | |||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | |||||
sql = f""" | |||||
SELECT sha1, date | |||||
FROM {entity} | |||||
WHERE sha1 IN ({values}) | |||||
AND date IS NOT NULL | |||||
""" | |||||
with self.transaction(readonly=True) as cursor: | |||||
cursor.execute(query=sql, vars=sha1s) | |||||
dates.update((row["sha1"], row["date"]) for row in cursor) | |||||
return dates | |||||
def _entity_set_date( | |||||
self, | |||||
entity: Literal["content", "directory"], | |||||
dates: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]], | |||||
) -> bool: | |||||
data = dates if isinstance(dates, dict) else dict.fromkeys(dates) | |||||
try: | |||||
if data: | |||||
sql = f""" | |||||
INSERT INTO {entity}(sha1, date) VALUES %s | |||||
ON CONFLICT (sha1) DO | |||||
UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) | |||||
""" | |||||
page_size = self.page_size or len(data) | |||||
with self.transaction() as cursor: | |||||
psycopg2.extras.execute_values( | |||||
cursor, sql, argslist=data.items(), page_size=page_size | |||||
) | |||||
return True | |||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
def _relation_get( | def _relation_get( | ||||
self, | self, | ||||
relation: RelationType, | relation: RelationType, | ||||
ids: Optional[Iterable[Sha1Git]], | ids: Optional[Iterable[Sha1Git]], | ||||
reverse: bool = False, | reverse: bool = False, | ||||
) -> Dict[Sha1Git, Set[RelationData]]: | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
result: Dict[Sha1Git, Set[RelationData]] = {} | result: Dict[Sha1Git, Set[RelationData]] = {} | ||||
Show All 25 Lines |