Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/postgresql/provenance.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from __future__ import annotations | from __future__ import annotations | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from datetime import datetime | from datetime import datetime | ||||
from functools import wraps | |||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from types import TracebackType | from types import TracebackType | ||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | ||||
import psycopg2.extensions | import psycopg2.extensions | ||||
import psycopg2.extras | import psycopg2.extras | ||||
Show All 11 Lines | from ..interface import ( | ||||
RevisionData, | RevisionData, | ||||
) | ) | ||||
LOGGER = logging.getLogger(__name__) | LOGGER = logging.getLogger(__name__) | ||||
STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds" | STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds" | ||||
def handle_raise_on_commit(f): | |||||
@wraps(f) | |||||
def handle(self, *args, **kwargs): | |||||
try: | |||||
return f(self, *args, **kwargs) | |||||
except BaseException as ex: | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise ex | |||||
return False | |||||
return handle | |||||
class ProvenanceStoragePostgreSql: | class ProvenanceStoragePostgreSql: | ||||
def __init__( | def __init__( | ||||
self, page_size: Optional[int] = None, raise_on_commit: bool = False, **kwargs | self, page_size: Optional[int] = None, raise_on_commit: bool = False, **kwargs | ||||
) -> None: | ) -> None: | ||||
self.conn_args = kwargs | self.conn_args = kwargs | ||||
self._flavor: Optional[str] = None | self._flavor: Optional[str] = None | ||||
self.page_size = page_size | self.page_size = page_size | ||||
self.raise_on_commit = raise_on_commit | self.raise_on_commit = raise_on_commit | ||||
Show All 32 Lines | class ProvenanceStoragePostgreSql: | ||||
def denormalized(self) -> bool: | def denormalized(self) -> bool: | ||||
return "denormalized" in self.flavor | return "denormalized" in self.flavor | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | ||||
def close(self) -> None: | def close(self) -> None: | ||||
self.conn.close() | self.conn.close() | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | ||||
@handle_raise_on_commit | |||||
def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: | def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: | ||||
try: | |||||
if cnts: | if cnts: | ||||
sql = """ | sql = """ | ||||
INSERT INTO content(sha1, date) VALUES %s | INSERT INTO content(sha1, date) VALUES %s | ||||
ON CONFLICT (sha1) DO | ON CONFLICT (sha1) DO | ||||
UPDATE SET date=LEAST(EXCLUDED.date,content.date) | UPDATE SET date=LEAST(EXCLUDED.date,content.date) | ||||
""" | """ | ||||
page_size = self.page_size or len(cnts) | page_size = self.page_size or len(cnts) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cursor, sql, argslist=cnts.items(), page_size=page_size | cursor, sql, argslist=cnts.items(), page_size=page_size | ||||
) | ) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | ||||
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | ||||
sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | sql = "SELECT * FROM swh_provenance_content_find_first(%s)" | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=(id,)) | cursor.execute(query=sql, vars=(id,)) | ||||
row = cursor.fetchone() | row = cursor.fetchone() | ||||
return ProvenanceResult(**row) if row is not None else None | return ProvenanceResult(**row) if row is not None else None | ||||
Show All 21 Lines | def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
AND date IS NOT NULL | AND date IS NOT NULL | ||||
""" | """ | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=sha1s) | cursor.execute(query=sql, vars=sha1s) | ||||
dates.update((row["sha1"], row["date"]) for row in cursor) | dates.update((row["sha1"], row["date"]) for row in cursor) | ||||
return dates | return dates | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | ||||
@handle_raise_on_commit | |||||
def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: | def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: | ||||
data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()] | data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()] | ||||
try: | |||||
if data: | if data: | ||||
sql = """ | sql = """ | ||||
INSERT INTO directory(sha1, date, flat) VALUES %s | INSERT INTO directory(sha1, date, flat) VALUES %s | ||||
ON CONFLICT (sha1) DO | ON CONFLICT (sha1) DO | ||||
UPDATE SET | UPDATE SET | ||||
date=LEAST(EXCLUDED.date, directory.date), | date=LEAST(EXCLUDED.date, directory.date), | ||||
flat=(EXCLUDED.flat OR directory.flat) | flat=(EXCLUDED.flat OR directory.flat) | ||||
""" | """ | ||||
page_size = self.page_size or len(data) | page_size = self.page_size or len(data) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cur=cursor, sql=sql, argslist=data, page_size=page_size | cur=cursor, sql=sql, argslist=data, page_size=page_size | ||||
) | ) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | ||||
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: | def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: | ||||
result: Dict[Sha1Git, DirectoryData] = {} | result: Dict[Sha1Git, DirectoryData] = {} | ||||
sha1s = tuple(ids) | sha1s = tuple(ids) | ||||
if sha1s: | if sha1s: | ||||
# TODO: consider splitting this query in several ones if sha1s is too big! | # TODO: consider splitting this query in several ones if sha1s is too big! | ||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | values = ", ".join(itertools.repeat("%s", len(sha1s))) | ||||
Show All 13 Lines | class ProvenanceStoragePostgreSql: | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | ||||
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(f"SELECT sha1 FROM {entity.value}") | cursor.execute(f"SELECT sha1 FROM {entity.value}") | ||||
return {row["sha1"] for row in cursor} | return {row["sha1"] for row in cursor} | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | ||||
@handle_raise_on_commit | |||||
def location_add(self, paths: Iterable[bytes]) -> bool: | def location_add(self, paths: Iterable[bytes]) -> bool: | ||||
if not self.with_path(): | if self.with_path(): | ||||
return True | |||||
try: | |||||
values = [(path,) for path in paths] | values = [(path,) for path in paths] | ||||
if values: | if values: | ||||
sql = """ | sql = """ | ||||
INSERT INTO location(path) VALUES %s | INSERT INTO location(path) VALUES %s | ||||
ON CONFLICT DO NOTHING | ON CONFLICT DO NOTHING | ||||
""" | """ | ||||
page_size = self.page_size or len(values) | page_size = self.page_size or len(values) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cursor, sql, argslist=values, page_size=page_size | cursor, sql, argslist=values, page_size=page_size | ||||
) | ) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) | ||||
def location_get_all(self) -> Set[bytes]: | def location_get_all(self) -> Set[bytes]: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute("SELECT location.path AS path FROM location") | cursor.execute("SELECT location.path AS path FROM location") | ||||
return {row["path"] for row in cursor} | return {row["path"] for row in cursor} | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) | ||||
@handle_raise_on_commit | |||||
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: | def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: | ||||
try: | |||||
if orgs: | if orgs: | ||||
sql = """ | sql = """ | ||||
INSERT INTO origin(sha1, url) VALUES %s | INSERT INTO origin(sha1, url) VALUES %s | ||||
ON CONFLICT DO NOTHING | ON CONFLICT DO NOTHING | ||||
""" | """ | ||||
page_size = self.page_size or len(orgs) | page_size = self.page_size or len(orgs) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cur=cursor, | cur=cursor, | ||||
sql=sql, | sql=sql, | ||||
argslist=orgs.items(), | argslist=orgs.items(), | ||||
page_size=page_size, | page_size=page_size, | ||||
) | ) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) | ||||
def open(self) -> None: | def open(self) -> None: | ||||
self.conn = BaseDb.connect(**self.conn_args).conn | self.conn = BaseDb.connect(**self.conn_args).conn | ||||
BaseDb.adapt_conn(self.conn) | BaseDb.adapt_conn(self.conn) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
cursor.execute("SET timezone TO 'UTC'") | cursor.execute("SET timezone TO 'UTC'") | ||||
Show All 10 Lines | def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: | ||||
WHERE sha1 IN ({values}) | WHERE sha1 IN ({values}) | ||||
""" | """ | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=sha1s) | cursor.execute(query=sql, vars=sha1s) | ||||
urls.update((row["sha1"], row["url"]) for row in cursor) | urls.update((row["sha1"], row["url"]) for row in cursor) | ||||
return urls | return urls | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) | ||||
@handle_raise_on_commit | |||||
def revision_add( | def revision_add( | ||||
self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] | self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] | ||||
) -> bool: | ) -> bool: | ||||
if isinstance(revs, dict): | if isinstance(revs, dict): | ||||
data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] | data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] | ||||
else: | else: | ||||
data = [(sha1, None, None) for sha1 in revs] | data = [(sha1, None, None) for sha1 in revs] | ||||
try: | |||||
if data: | if data: | ||||
sql = """ | sql = """ | ||||
INSERT INTO revision(sha1, date, origin) | INSERT INTO revision(sha1, date, origin) | ||||
(SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin | (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin | ||||
FROM (VALUES %s) AS V(rev, date, org) | FROM (VALUES %s) AS V(rev, date, org) | ||||
LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) | LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) | ||||
ON CONFLICT (sha1) DO | ON CONFLICT (sha1) DO | ||||
UPDATE SET | UPDATE SET | ||||
date=LEAST(EXCLUDED.date, revision.date), | date=LEAST(EXCLUDED.date, revision.date), | ||||
origin=COALESCE(EXCLUDED.origin, revision.origin) | origin=COALESCE(EXCLUDED.origin, revision.origin) | ||||
""" | """ | ||||
page_size = self.page_size or len(data) | page_size = self.page_size or len(data) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cur=cursor, sql=sql, argslist=data, page_size=page_size | cur=cursor, sql=sql, argslist=data, page_size=page_size | ||||
) | ) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) | ||||
def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: | def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: | ||||
result: Dict[Sha1Git, RevisionData] = {} | result: Dict[Sha1Git, RevisionData] = {} | ||||
sha1s = tuple(ids) | sha1s = tuple(ids) | ||||
if sha1s: | if sha1s: | ||||
# TODO: consider splitting this query in several ones if sha1s is too big! | # TODO: consider splitting this query in several ones if sha1s is too big! | ||||
values = ", ".join(itertools.repeat("%s", len(sha1s))) | values = ", ".join(itertools.repeat("%s", len(sha1s))) | ||||
sql = f""" | sql = f""" | ||||
SELECT R.sha1, R.date, O.sha1 AS origin | SELECT R.sha1, R.date, O.sha1 AS origin | ||||
FROM revision AS R | FROM revision AS R | ||||
LEFT JOIN origin AS O ON (O.id=R.origin) | LEFT JOIN origin AS O ON (O.id=R.origin) | ||||
WHERE R.sha1 IN ({values}) | WHERE R.sha1 IN ({values}) | ||||
AND (R.date is not NULL OR O.sha1 is not NULL) | AND (R.date is not NULL OR O.sha1 is not NULL) | ||||
""" | """ | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=sha1s) | cursor.execute(query=sql, vars=sha1s) | ||||
result.update( | result.update( | ||||
(row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) | (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) | ||||
for row in cursor | for row in cursor | ||||
) | ) | ||||
return result | return result | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) | ||||
@handle_raise_on_commit | |||||
def relation_add( | def relation_add( | ||||
self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] | self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] | ||||
) -> bool: | ) -> bool: | ||||
rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] | rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] | ||||
try: | |||||
if rows: | if rows: | ||||
rel_table = relation.value | rel_table = relation.value | ||||
src_table, *_, dst_table = rel_table.split("_") | src_table, *_, dst_table = rel_table.split("_") | ||||
page_size = self.page_size or len(rows) | page_size = self.page_size or len(rows) | ||||
# Put the next three queries in a manual single transaction: | # Put the next three queries in a manual single transaction: | ||||
# they use the same temp table | # they use the same temp table | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
cursor.execute("SELECT swh_mktemp_relation_add()") | cursor.execute("SELECT swh_mktemp_relation_add()") | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cur=cursor, | cur=cursor, | ||||
sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", | sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", | ||||
argslist=rows, | argslist=rows, | ||||
page_size=page_size, | page_size=page_size, | ||||
) | ) | ||||
sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" | sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" | ||||
cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) | cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) | ||||
return True | return True | ||||
except: # noqa: E722 | |||||
# Unexpected error occurred, rollback all changes and log message | |||||
LOGGER.exception("Unexpected error") | |||||
if self.raise_on_commit: | |||||
raise | |||||
return False | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) | ||||
def relation_get( | def relation_get( | ||||
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False | self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False | ||||
) -> Dict[Sha1Git, Set[RelationData]]: | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
return self._relation_get(relation, ids, reverse) | return self._relation_get(relation, ids, reverse) | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | ||||
Show All 38 Lines |