diff --git a/requirements-swh.txt b/requirements-swh.txt index 1865990..a6fe211 100644 --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,6 +1,6 @@ # Add here internal Software Heritage dependencies, one per line. -swh.core[db,http] >= 0.14 +swh.core[db,http] >= 2 swh.model >= 2.6.1 swh.storage swh.graph >= 2.0.0 swh.journal diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 99163ba..bca3cf0 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,121 +1,124 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING import warnings if TYPE_CHECKING: from .archive import ArchiveInterface from .interface import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> ArchiveInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: archive's class, either 'api', 'direct' or 'graph' args: dictionary of arguments passed to the archive class constructor Returns: an instance of archive object (either using swh.storage API or direct queries to the archive's database) Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": from swh.core.db import BaseDb from .postgresql.archive import ArchivePostgreSQL return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) elif cls == "graph": try: from swh.storage import get_storage from .swhgraph.archive import ArchiveGraph return ArchiveGraph(kwargs.get("url"), get_storage(**kwargs["storage"])) except ModuleNotFoundError: raise EnvironmentError( "Graph configuration required but module is not installed." ) elif cls == "multiplexer": from .multiplexer.archive import ArchiveMultiplexed archives = [] for ctr, archive in enumerate(kwargs["archives"]): name = archive.pop("name", f"backend_{ctr}") archives.append((name, get_archive(**archive))) return ArchiveMultiplexed(archives) else: raise ValueError def get_provenance(**kwargs) -> ProvenanceInterface: """Get an provenance object with arguments ``args``. Args: args: dictionary of arguments to retrieve a swh.provenance.storage class (see :func:`get_provenance_storage` for details) Returns: an instance of provenance object """ from .provenance import Provenance return Provenance(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: storage's class, only 'local' is currently supported args: dictionary of arguments passed to the storage class constructor Returns: an instance of storage object Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": warnings.warn( '"local" class is deprecated for provenance storage, please ' 'use "postgresql" class instead.', DeprecationWarning, ) raise_on_commit = kwargs.get("raise_on_commit", False) return ProvenanceStoragePostgreSql( raise_on_commit=raise_on_commit, **kwargs["db"] ) elif cls == "rabbitmq": from .api.client import ProvenanceStorageRabbitMQClient rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) if TYPE_CHECKING: assert isinstance(rmq_storage, ProvenanceStorageInterface) return rmq_storage raise ValueError + + +get_datastore = get_provenance_storage diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index 79ea282..f5471be 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,401 +1,403 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from contextlib import contextmanager from datetime import datetime from functools import wraps import itertools import logging from types import TracebackType from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union import psycopg2.extensions import psycopg2.extras from swh.core.db import BaseDb from swh.core.statsd import statsd from swh.model.model import Sha1Git from ..interface import ( DirectoryData, EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds" def handle_raise_on_commit(f): @wraps(f) def handle(self, *args, **kwargs): try: return f(self, *args, **kwargs) except BaseException as ex: # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise ex return False return handle class ProvenanceStoragePostgreSql: + current_version = 3 + def __init__( self, page_size: Optional[int] = None, raise_on_commit: bool = False, **kwargs ) -> None: self.conn: Optional[psycopg2.extensions.connection] = None self.conn_args = kwargs self._flavor: Optional[str] = None self.page_size = page_size self.raise_on_commit = raise_on_commit def __enter__(self) -> ProvenanceStorageInterface: self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extras.RealDictCursor, None, None]: if self.conn is None: raise RuntimeError( "Tried to access ProvenanceStoragePostgreSQL transaction() without opening it" ) self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") flavor = cursor.fetchone() assert flavor # please mypy self._flavor = flavor["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) def close(self) -> None: assert self.conn is not None self.conn.close() @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) @handle_raise_on_commit def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: if cnts: sql = """ INSERT INTO content(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,content.date) """ page_size = self.page_size or len(cnts) with self.transaction() as cursor: psycopg2.extras.execute_values( cursor, sql, argslist=cnts.items(), page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM content WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) @handle_raise_on_commit def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()] if data: sql = """ INSERT INTO directory(sha1, date, flat) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, directory.date), flat=(EXCLUDED.flat OR directory.flat) """ page_size = self.page_size or len(data) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=data, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: result: Dict[Sha1Git, DirectoryData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date, flat FROM directory WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], DirectoryData(date=row["date"], flat=row["flat"])) for row in cursor ) return result @statsd.timed( metric=STORAGE_DURATION_METRIC, tags={"method": "directory_iter_not_flattenned"} ) def directory_iter_not_flattenned( self, limit: int, start_id: Sha1Git ) -> List[Sha1Git]: sql = """ SELECT sha1 FROM directory WHERE flat=false AND sha1>%s ORDER BY sha1 LIMIT %s """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(start_id, limit)) return [row["sha1"] for row in cursor] @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) @handle_raise_on_commit def location_add(self, paths: Iterable[bytes]) -> bool: if self.with_path(): values = [(path,) for path in paths] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ page_size = self.page_size or len(values) with self.transaction() as cursor: psycopg2.extras.execute_values( cursor, sql, argslist=values, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) @handle_raise_on_commit def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ page_size = self.page_size or len(orgs) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items(), page_size=page_size, ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: self.conn = BaseDb.connect(**self.conn_args).conn BaseDb.adapt_conn(self.conn) with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"}) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) @handle_raise_on_commit def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: if isinstance(revs, dict): data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] else: data = [(sha1, None, None) for sha1 in revs] if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ page_size = self.page_size or len(data) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=data, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) @handle_raise_on_commit def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") page_size = self.page_size or len(rows) # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, page_size=page_size, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) for row in cursor: src = row.pop("src") result.setdefault(src, set()).add(RelationData(**row)) return result @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"}) def with_path(self) -> bool: return "with-path" in self.flavor diff --git a/swh/provenance/sql/30-schema.sql b/swh/provenance/sql/30-schema.sql index 038a2de..73551d7 100644 --- a/swh/provenance/sql/30-schema.sql +++ b/swh/provenance/sql/30-schema.sql @@ -1,173 +1,157 @@ -- psql variables to get the current database flavor select position('denormalized' in swh_get_dbflavor()::text) = 0 as dbflavor_norm \gset select position('without-path' in swh_get_dbflavor()::text) = 0 as dbflavor_with_path \gset -create table dbversion -( - version int primary key, - release timestamptz, - description text -); - -comment on table dbversion is 'Details of current db version'; -comment on column dbversion.version is 'SQL schema version'; -comment on column dbversion.release is 'Version deployment timestamp'; -comment on column dbversion.description is 'Release description'; - --- latest schema version -insert into dbversion(version, release, description) - values(3, now(), 'Work In Progress'); - -- a Git object ID, i.e., a Git-style salted SHA1 checksum create domain sha1_git as bytea check (length(value) = 20); -- UNIX path (absolute, relative, individual path component, etc.) create domain unix_path as bytea; -- relation filter options for querying create type rel_flt as enum ( 'filter-src', 'filter-dst', 'no-filter' ); comment on type rel_flt is 'Relation get filter types'; -- entity tables create table content ( id bigserial primary key, -- internal identifier of the content blob sha1 sha1_git unique not null, -- intrinsic identifier of the content blob date timestamptz not null -- timestamp of the revision where the blob appears early ); comment on column content.id is 'Content internal identifier'; comment on column content.sha1 is 'Content intrinsic identifier'; comment on column content.date is 'Earliest timestamp for the content (first seen time)'; create table directory ( id bigserial primary key, -- internal identifier of the directory appearing in an isochrone inner frontier sha1 sha1_git unique not null, -- intrinsic identifier of the directory date timestamptz not null, -- max timestamp among those of the directory children's flat boolean not null default false -- flag acknowledging if the directory is flattenned in the model ); comment on column directory.id is 'Directory internal identifier'; comment on column directory.sha1 is 'Directory intrinsic identifier'; comment on column directory.date is 'Latest timestamp for the content in the directory'; create table revision ( id bigserial primary key, -- internal identifier of the revision sha1 sha1_git unique not null, -- intrinsic identifier of the revision date timestamptz, -- timestamp of the revision origin bigint -- id of the preferred origin -- foreign key (origin) references origin (id) ); comment on column revision.id is 'Revision internal identifier'; comment on column revision.sha1 is 'Revision intrinsic identifier'; comment on column revision.date is 'Revision timestamp'; comment on column revision.origin is 'preferred origin for the revision'; create table location ( id bigserial primary key, -- internal identifier of the location path unix_path -- path to the location ); comment on column location.id is 'Location internal identifier'; comment on column location.path is 'Path to the location'; create table origin ( id bigserial primary key, -- internal identifier of the origin sha1 sha1_git unique not null, -- intrinsic identifier of the origin url text -- url of the origin ); comment on column origin.id is 'Origin internal identifier'; comment on column origin.sha1 is 'Origin intrinsic identifier'; comment on column origin.url is 'URL of the origin'; -- relation tables create table content_in_revision ( content bigint not null, -- internal identifier of the content blob \if :dbflavor_norm revision bigint not null, -- internal identifier of the revision where the blob appears for the first time location bigint -- location of the content relative to the revision's root directory \else revision bigint[], -- internal identifiers of the revisions where the blob appears for the first time location bigint[] -- locations of the content relative to the revisions' root directory \endif -- foreign key (content) references content (id), -- foreign key (revision) references revision (id), -- foreign key (location) references location (id) ); comment on column content_in_revision.content is 'Content internal identifier'; \if :dbflavor_norm comment on column content_in_revision.revision is 'Revision internal identifier'; comment on column content_in_revision.location is 'Location of content in revision'; \else comment on column content_in_revision.revision is 'Revision/location internal identifiers'; \endif create table content_in_directory ( content bigint not null, -- internal identifier of the content blob \if :dbflavor_norm directory bigint not null, -- internal identifier of the directory containing the blob location bigint -- location of the content relative to its parent directory in the isochrone frontier \else directory bigint[], -- internal reference of the directories containing the blob location bigint[] -- locations of the content relative to its parent directories in the isochrone frontier \endif -- foreign key (content) references content (id), -- foreign key (directory) references directory (id), -- foreign key (location) references location (id) ); comment on column content_in_directory.content is 'Content internal identifier'; \if :dbflavor_norm comment on column content_in_directory.directory is 'Directory internal identifier'; comment on column content_in_directory.location is 'Location of content in directory'; \else comment on column content_in_directory.directory is 'Directory/location internal identifiers'; \endif create table directory_in_revision ( directory bigint not null, -- internal identifier of the directory appearing in the revision \if :dbflavor_norm revision bigint not null, -- internal identifier of the revision containing the directory location bigint -- location of the directory relative to the revision's root directory \else revision bigint[], -- internal identifiers of the revisions containing the directory location bigint[] -- locations of the directory relative to the revisions' root directory \endif -- foreign key (directory) references directory (id), -- foreign key (revision) references revision (id), -- foreign key (location) references location (id) ); comment on column directory_in_revision.directory is 'Directory internal identifier'; \if :dbflavor_norm comment on column directory_in_revision.revision is 'Revision internal identifier'; comment on column directory_in_revision.location is 'Location of content in revision'; \else comment on column directory_in_revision.revision is 'Revision/location internal identifiers'; \endif create table revision_in_origin ( revision bigint not null, -- internal identifier of the revision poined by the origin origin bigint not null -- internal identifier of the origin that points to the revision -- foreign key (revision) references revision (id), -- foreign key (origin) references origin (id) ); comment on column revision_in_origin.revision is 'Revision internal identifier'; comment on column revision_in_origin.origin is 'Origin internal identifier'; create table revision_before_revision ( prev bigserial not null, -- internal identifier of the source revision next bigserial not null -- internal identifier of the destination revision -- foreign key (prev) references revision (id), -- foreign key (next) references revision (id) ); comment on column revision_before_revision.prev is 'Source revision internal identifier'; comment on column revision_before_revision.next is 'Destination revision internal identifier'; diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py index 297d06b..858e056 100644 --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -1,207 +1,177 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime +from functools import partial import multiprocessing from os import path from pathlib import Path from typing import Any, Dict, Generator, List from _pytest.fixtures import SubRequest from aiohttp.test_utils import TestClient, TestServer, loop_context import msgpack import psycopg2.extensions import pytest -from pytest_postgresql.factories import postgresql +from pytest_postgresql import factories +from swh.core.db.db_utils import initialize_database_for_module from swh.graph.http_rpc_server import make_app from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel, TimestampWithTimezone from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface +from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects - -@pytest.fixture( - params=[ - "with-path", - "without-path", - "with-path-denormalized", - "without-path-denormalized", - ] +provenance_postgresql_proc = factories.postgresql_proc( + load=[ + partial( + initialize_database_for_module, + modname="provenance", + flavor="with-path", + version=ProvenanceStoragePostgreSql.current_version, + ) + ], ) -def provenance_postgresqldb( - request: SubRequest, - postgresql: psycopg2.extensions.connection, -) -> Dict[str, str]: - """return a working and initialized provenance db""" - from swh.core.db.db_utils import ( - init_admin_extensions, - populate_database_for_package, - ) - init_admin_extensions("swh.provenance", postgresql.dsn) - populate_database_for_package( - "swh.provenance", postgresql.dsn, flavor=request.param - ) - return postgresql.get_dsn_parameters() +postgres_provenance = factories.postgresql("provenance_postgresql_proc") -@pytest.fixture(params=["postgresql", "rabbitmq"]) +@pytest.fixture() +def provenance_postgresqldb(request, postgres_provenance): + return postgres_provenance.get_dsn_parameters() + + +@pytest.fixture() def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], ) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" - if request.param == "rabbitmq": - from swh.provenance.api.server import ProvenanceStorageRabbitMQServer - - rabbitmq = request.getfixturevalue("rabbitmq") - host = rabbitmq.args["host"] - port = rabbitmq.args["port"] - rabbitmq_params: Dict[str, Any] = { - "url": f"amqp://guest:guest@{host}:{port}/%2f", - "storage_config": { - "cls": "postgresql", - "db": provenance_postgresqldb, - "raise_on_commit": True, - }, - } - server = ProvenanceStorageRabbitMQServer( - url=rabbitmq_params["url"], storage_config=rabbitmq_params["storage_config"] - ) - server.start() - with get_provenance_storage(cls=request.param, **rabbitmq_params) as storage: - yield storage - server.stop() - - else: - # in test sessions, we DO want to raise any exception occurring at commit time - with get_provenance_storage( - cls=request.param, db=provenance_postgresqldb, raise_on_commit=True - ) as storage: - yield storage - - -provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") + # in test sessions, we DO want to raise any exception occurring at commit time + with get_provenance_storage( + cls="postgresql", db=provenance_postgresqldb, raise_on_commit=True + ) as storage: + yield storage @pytest.fixture def provenance( - provenance_postgresql: psycopg2.extensions.connection, + postgres_provenance: psycopg2.extensions.connection, ) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, ) - init_admin_extensions("swh.provenance", provenance_postgresql.dsn) + init_admin_extensions("swh.provenance", postgres_provenance.dsn) populate_database_for_package( - "swh.provenance", provenance_postgresql.dsn, flavor="with-path" + "swh.provenance", postgres_provenance.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time with get_provenance( cls="postgresql", - db=provenance_postgresql.get_dsn_parameters(), + db=postgres_provenance.get_dsn_parameters(), raise_on_commit=True, ) as provenance: yield provenance @pytest.fixture def archive(swh_storage: StorageInterface) -> ArchiveInterface: """Return an ArchiveStorage-based ArchiveInterface object""" return ArchiveStorage(swh_storage) def fill_storage(storage: StorageInterface, data: Dict[str, List[dict]]) -> None: objects = { objtype: [objs_from_dict(objtype, d) for d in dicts] for objtype, dicts in data.items() } process_replay_objects(objects, storage=storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) # TODO: this should return Dict[str, List[BaseModel]] directly, but it requires # refactoring several tests def load_repo_data(repo: str) -> Dict[str, List[dict]]: data: Dict[str, List[dict]] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for msg in unpacker: if len(msg) == 2: # old format objtype, objd = msg else: # now we should have a triplet (type, key, value) objtype, _, objd = msg data.setdefault(objtype, []).append(objd) return data def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: if object_type in OBJECT_FIXERS: dict_repr = OBJECT_FIXERS[object_type](dict_repr) obj = OBJECT_CONVERTERS[object_type](dict_repr) return obj def ts2dt(ts: Dict[str, Any]) -> datetime: return TimestampWithTimezone.from_dict(ts).to_datetime() def run_grpc_server(queue, dataset_path): try: config = { "graph": { "cls": "local", "grpc_server": {"path": dataset_path}, "http_rpc_server": {"debug": True}, } } with loop_context() as loop: app = make_app(config=config) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") queue.put((url, app["rpc_url"])) loop.run_forever() except Exception as e: queue.put(e) @contextmanager def grpc_server(dataset): dataset_path = ( Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" ) queue = multiprocessing.Queue() server = multiprocessing.Process( target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} ) server.start() res = queue.get() if isinstance(res, Exception): raise res grpc_url = res[1] try: yield grpc_url finally: server.terminate() diff --git a/swh/provenance/tests/test_journal_client.py b/swh/provenance/tests/test_journal_client.py index c0fc79c..4fd6854 100644 --- a/swh/provenance/tests/test_journal_client.py +++ b/swh/provenance/tests/test_journal_client.py @@ -1,135 +1,135 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Dict from confluent_kafka import Consumer import pytest from swh.model.hashutil import MultiHash from swh.provenance.tests.conftest import fill_storage, load_repo_data from swh.storage.interface import StorageInterface from .test_utils import invoke, write_configuration_path @pytest.fixture def swh_storage_backend_config(swh_storage_backend_config, kafka_server, kafka_prefix): writer_config = { "cls": "kafka", "brokers": [kafka_server], "client_id": "kafka_writer", "prefix": kafka_prefix, "anonymize": False, } yield {**swh_storage_backend_config, "journal_writer": writer_config} def test_cli_origin_from_journal_client( swh_storage: StorageInterface, swh_storage_backend_config: Dict, kafka_prefix: str, kafka_server: str, consumer: Consumer, tmp_path: str, provenance, - provenance_postgresql, + postgres_provenance, ) -> None: """Test origin journal client cli""" # Prepare storage data data = load_repo_data("cmdbts2") assert len(data["origin"]) >= 1 origin_url = data["origin"][0]["url"] fill_storage(swh_storage, data) # Prepare configuration for cli call swh_storage_backend_config.pop("journal_writer", None) # no need for that config storage_config_dict = swh_storage_backend_config cfg = { "journal_client": { "cls": "kafka", "brokers": [kafka_server], "group_id": "toto", "prefix": kafka_prefix, "stop_on_eof": True, }, "provenance": { "archive": { "cls": "api", "storage": storage_config_dict, }, "storage": { "cls": "postgresql", - "db": provenance_postgresql.get_dsn_parameters(), + "db": postgres_provenance.get_dsn_parameters(), }, }, } config_path = write_configuration_path(cfg, tmp_path) # call the cli 'swh provenance origin from-journal' result = invoke(["origin", "from-journal"], config_path) assert result.exit_code == 0, f"Unexpected result: {result.output}" origin_sha1 = MultiHash.from_data( origin_url.encode(), hash_names=["sha1"] ).digest()["sha1"] actual_result = provenance.storage.origin_get([origin_sha1]) assert actual_result == {origin_sha1: origin_url} def test_cli_revision_from_journal_client( swh_storage: StorageInterface, swh_storage_backend_config: Dict, kafka_prefix: str, kafka_server: str, consumer: Consumer, tmp_path: str, provenance, - provenance_postgresql, + postgres_provenance, ) -> None: """Test revision journal client cli""" # Prepare storage data data = load_repo_data("cmdbts2") assert len(data["origin"]) >= 1 fill_storage(swh_storage, data) # Prepare configuration for cli call swh_storage_backend_config.pop("journal_writer", None) # no need for that config storage_config_dict = swh_storage_backend_config cfg = { "journal_client": { "cls": "kafka", "brokers": [kafka_server], "group_id": "toto", "prefix": kafka_prefix, "stop_on_eof": True, }, "provenance": { "archive": { "cls": "api", "storage": storage_config_dict, }, "storage": { "cls": "postgresql", - "db": provenance_postgresql.get_dsn_parameters(), + "db": postgres_provenance.get_dsn_parameters(), }, }, } config_path = write_configuration_path(cfg, tmp_path) revisions = [rev["id"] for rev in data["revision"]] result = provenance.storage.revision_get(revisions) assert not result # call the cli 'swh provenance revision from-journal' cli_result = invoke(["revision", "from-journal"], config_path) assert cli_result.exit_code == 0, f"Unexpected result: {result.output}" result = provenance.storage.revision_get(revisions) assert set(result.keys()) == set(revisions) diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index f36b469..fee0a88 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,463 +1,470 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import inspect import os from typing import Any, Dict, Iterable, Optional, Set, Tuple from swh.model.hashutil import hash_to_bytes from swh.model.model import Origin, Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ( DirectoryData, EntityType, ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from swh.provenance.model import OriginEntry, RevisionEntry from swh.provenance.origin import origin_add from swh.provenance.provenance import Provenance from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt -def test_provenance_storage_content( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests content methods for every `ProvenanceStorageInterface` implementation.""" - - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") - - # Add all content present in the current repo to the storage, just assigning their - # creation dates. Then check that the returned results when querying are the same. - cnt_dates = { - cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"]) - } - assert provenance_storage.content_add(cnt_dates) - assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates - assert provenance_storage.entity_get_all(EntityType.CONTENT) == set( - cnt_dates.keys() - ) +class TestProvenanceStorage: + def test_provenance_storage_content( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests content methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Add all content present in the current repo to the storage, just assigning their + # creation dates. Then check that the returned results when querying are the same. + cnt_dates = { + cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"]) + } + assert provenance_storage.content_add(cnt_dates) + assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates + assert provenance_storage.entity_get_all(EntityType.CONTENT) == set( + cnt_dates.keys() + ) + def test_provenance_storage_directory( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Of all directories present in the current repo, only assign a date to those + # containing blobs (picking the max date among the available ones). Then check that + # the returned results when querying are the same. + def getmaxdate( + directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] + ) -> Optional[datetime]: + dates = [ + content["ctime"] + for entry in directory["entries"] + for content in contents + if entry["type"] == "file" and entry["target"] == content["sha1_git"] + ] + return max(dates) if dates else None + + flat_values = (False, True) + dir_dates = {} + for idx, dir in enumerate(data["directory"]): + date = getmaxdate(dir, data["content"]) + if date is not None: + dir_dates[dir["id"]] = DirectoryData( + date=date, flat=flat_values[idx % 2] + ) + assert provenance_storage.directory_add(dir_dates) + assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates + assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set( + dir_dates.keys() + ) -def test_provenance_storage_directory( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" - - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") - - # Of all directories present in the current repo, only assign a date to those - # containing blobs (picking the max date among the available ones). Then check that - # the returned results when querying are the same. - def getmaxdate( - directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] - ) -> Optional[datetime]: - dates = [ - content["ctime"] - for entry in directory["entries"] - for content in contents - if entry["type"] == "file" and entry["target"] == content["sha1_git"] - ] - return max(dates) if dates else None - - flat_values = (False, True) - dir_dates = {} - for idx, dir in enumerate(data["directory"]): - date = getmaxdate(dir, data["content"]) - if date is not None: - dir_dates[dir["id"]] = DirectoryData(date=date, flat=flat_values[idx % 2]) - assert provenance_storage.directory_add(dir_dates) - assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates - assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set( - dir_dates.keys() - ) + def test_provenance_storage_location( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests location methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Add all names of entries present in the directories of the current repo as paths + # to the storage. Then check that the returned results when querying are the same. + paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} + assert provenance_storage.location_add(paths) + + if provenance_storage.with_path(): + assert provenance_storage.location_get_all() == paths + else: + assert provenance_storage.location_get_all() == set() + + def test_provenance_storage_origin( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Test origin methods. + # Add all origins present in the current repo to the storage. Then check that the + # returned results when querying are the same. + orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} + assert orgs + assert provenance_storage.origin_add(orgs) + assert provenance_storage.origin_get(set(orgs.keys())) == orgs + assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) + + def test_provenance_storage_revision( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Test revision methods. + # Add all revisions present in the current repo to the storage, assigning their + # dates and an arbitrary origin to each one. Then check that the returned results + # when querying are the same. + origin = Origin(url=next(iter(data["origin"]))["url"]) + # Origin must be inserted in advance. + assert provenance_storage.origin_add({origin.id: origin.url}) + + revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} + rev_data = { + rev["id"]: RevisionData( + date=ts2dt(rev["date"]) if idx % 2 != 0 else None, + origin=origin.id if idx % 3 != 0 else None, + ) + for idx, rev in enumerate(data["revision"]) + if idx % 6 != 0 + } + assert revs + assert provenance_storage.revision_add(revs) + assert provenance_storage.revision_add(rev_data) + assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data + assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( + rev_data.keys() + ) + def test_provenance_storage_relation( + self, + provenance_storage: ProvenanceStorageInterface, + ) -> None: + """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" -def test_provenance_storage_location( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests location methods for every `ProvenanceStorageInterface` implementation.""" + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") + # Test content-in-revision relation. + # Create flat models of every root directory for the revisions in the dataset. + cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + root = next( + subdir + for subdir in data["directory"] + if subdir["id"] == rev["directory"] + ) + for cnt, rel in dircontent(data, rev["id"], root): + cnt_in_rev.setdefault(cnt, set()).add(rel) + relation_add_and_compare_result( + provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev + ) - # Add all names of entries present in the directories of the current repo as paths - # to the storage. Then check that the returned results when querying are the same. - paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} - assert provenance_storage.location_add(paths) + # Test content-in-directory relation. + # Create flat models for every directory in the dataset. + cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} + for dir in data["directory"]: + for cnt, rel in dircontent(data, dir["id"], dir): + cnt_in_dir.setdefault(cnt, set()).add(rel) + relation_add_and_compare_result( + provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir + ) - if provenance_storage.with_path(): - assert provenance_storage.location_get_all() == paths - else: - assert provenance_storage.location_get_all() == set() + # Test content-in-directory relation. + # Add root directories to their correspondent revision in the dataset. + dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + dir_in_rev.setdefault(rev["directory"], set()).add( + RelationData(dst=rev["id"], path=b".") + ) + relation_add_and_compare_result( + provenance_storage, RelationType.DIR_IN_REV, dir_in_rev + ) + # Test revision-in-origin relation. + # Origins must be inserted in advance (cannot be done by `entity_add` inside + # `relation_add_and_compare_result`). + orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} + assert provenance_storage.origin_add(orgs) + # Add all revisions that are head of some snapshot branch to the corresponding + # origin. + rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} + for status in data["origin_visit_status"]: + if status["snapshot"] is not None: + for snapshot in data["snapshot"]: + if snapshot["id"] == status["snapshot"]: + for branch in snapshot["branches"].values(): + if branch["target_type"] == "revision": + rev_in_org.setdefault(branch["target"], set()).add( + RelationData( + dst=Origin(url=status["origin"]).id, + path=None, + ) + ) + relation_add_and_compare_result( + provenance_storage, RelationType.REV_IN_ORG, rev_in_org + ) -def test_provenance_storage_origin( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" + # Test revision-before-revision relation. + # For each revision in the data set add an entry for each parent to the relation. + rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + for parent in rev["parents"]: + rev_before_rev.setdefault(parent, set()).add( + RelationData(dst=rev["id"], path=None) + ) + relation_add_and_compare_result( + provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev + ) - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") + def test_provenance_storage_find( + self, + provenance: ProvenanceInterface, + provenance_storage: ProvenanceStorageInterface, + archive: ArchiveInterface, + ) -> None: + """Tests `content_find_first` and `content_find_all` methods for every + `ProvenanceStorageInterface` implementation. + """ + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + fill_storage(archive.storage, data) + + # Test content_find_first and content_find_all, first only executing the + # revision-content algorithm, then adding the origin-revision layer. + def adapt_result( + result: Optional[ProvenanceResult], with_path: bool + ) -> Optional[ProvenanceResult]: + if result is not None: + return ProvenanceResult( + result.content, + result.revision, + result.date, + result.origin, + result.path if with_path else b"", + ) + return result + + # Execute the revision-content algorithm on both storages. + revisions = [ + RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) + for rev in data["revision"] + ] + revision_add(provenance, archive, revisions) + revision_add(Provenance(provenance_storage), archive, revisions) - # Test origin methods. - # Add all origins present in the current repo to the storage. Then check that the - # returned results when querying are the same. - orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} - assert orgs - assert provenance_storage.origin_add(orgs) - assert provenance_storage.origin_get(set(orgs.keys())) == orgs - assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) + assert adapt_result( + ProvenanceResult( + content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), + revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), + date=datetime.fromtimestamp(1000000000.0, timezone.utc), + origin=None, + path=b"A/B/C/a", + ), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first( + hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") + ) + for cnt in {cnt["sha1_git"] for cnt in data["content"]}: + assert adapt_result( + provenance.storage.content_find_first(cnt), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first(cnt) + assert { + adapt_result(occur, provenance_storage.with_path()) + for occur in provenance.storage.content_find_all(cnt) + } == set(provenance_storage.content_find_all(cnt)) + + # Execute the origin-revision algorithm on both storages. + origins = [ + OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) + for sta in data["origin_visit_status"] + if sta["snapshot"] is not None + ] + origin_add(provenance, archive, origins) + origin_add(Provenance(provenance_storage), archive, origins) -def test_provenance_storage_revision( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" - - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") - - # Test revision methods. - # Add all revisions present in the current repo to the storage, assigning their - # dates and an arbitrary origin to each one. Then check that the returned results - # when querying are the same. - origin = Origin(url=next(iter(data["origin"]))["url"]) - # Origin must be inserted in advance. - assert provenance_storage.origin_add({origin.id: origin.url}) - - revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} - rev_data = { - rev["id"]: RevisionData( - date=ts2dt(rev["date"]) if idx % 2 != 0 else None, - origin=origin.id if idx % 3 != 0 else None, + assert adapt_result( + ProvenanceResult( + content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), + revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), + date=datetime.fromtimestamp(1000000000.0, timezone.utc), + origin="https://cmdbts2", + path=b"A/B/C/a", + ), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first( + hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) - for idx, rev in enumerate(data["revision"]) - if idx % 6 != 0 - } - assert revs - assert provenance_storage.revision_add(revs) - assert provenance_storage.revision_add(rev_data) - assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data - assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( - rev_data.keys() - ) + + for cnt in {cnt["sha1_git"] for cnt in data["content"]}: + assert adapt_result( + provenance.storage.content_find_first(cnt), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first(cnt) + assert { + adapt_result(occur, provenance_storage.with_path()) + for occur in provenance.storage.content_find_all(cnt) + } == set(provenance_storage.content_find_all(cnt)) + + def test_types(self, provenance_storage: ProvenanceStorageInterface) -> None: + """Checks all methods of ProvenanceStorageInterface are implemented by this + backend, and that they have the same signature.""" + # Create an instance of the protocol (which cannot be instantiated + # directly, so this creates a subclass, then instantiates it) + interface = type("_", (ProvenanceStorageInterface,), {})() + + assert "content_find_first" in dir(interface) + + missing_methods = [] + + for meth_name in dir(interface): + if meth_name.startswith("_"): + continue + interface_meth = getattr(interface, meth_name) + try: + concrete_meth = getattr(provenance_storage, meth_name) + except AttributeError: + if not getattr(interface_meth, "deprecated_endpoint", False): + # The backend is missing a (non-deprecated) endpoint + missing_methods.append(meth_name) + continue + + expected_signature = inspect.signature(interface_meth) + actual_signature = inspect.signature(concrete_meth) + + assert expected_signature == actual_signature, meth_name + + assert missing_methods == [] + + # If all the assertions above succeed, then this one should too. + # But there's no harm in double-checking. + # And we could replace the assertions above by this one, but unlike + # the assertions above, it doesn't explain what is missing. + assert isinstance(provenance_storage, ProvenanceStorageInterface) def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", ) -> Iterable[Tuple[Sha1Git, RelationData]]: content = { ( entry["target"], RelationData(dst=ref, path=os.path.join(prefix, entry["name"])), ) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( dircontent(data, ref, child, os.path.join(prefix, entry["name"])) ) return content def entity_add( storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git] ) -> bool: now = datetime.now(tz=timezone.utc) if entity == EntityType.CONTENT: return storage.content_add({sha1: now for sha1 in ids}) elif entity == EntityType.DIRECTORY: return storage.directory_add( {sha1: DirectoryData(date=now, flat=False) for sha1 in ids} ) else: # entity == EntityType.REVISION: return storage.revision_add( {sha1: RevisionData(date=None, origin=None) for sha1 in ids} ) def relation_add_and_compare_result( storage: ProvenanceStorageInterface, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]], ) -> None: # Source, destinations and locations must be added in advance. src, *_, dst = relation.value.split("_") srcs = {sha1 for sha1 in data} if src != "origin": assert entity_add(storage, EntityType(src), srcs) dsts = {rel.dst for rels in data.values() for rel in rels} if dst != "origin": assert entity_add(storage, EntityType(dst), dsts) if storage.with_path(): assert storage.location_add( {rel.path for rels in data.values() for rel in rels if rel.path is not None} ) assert data assert storage.relation_add(relation, data) for src_sha1 in srcs: relation_compare_result( storage.relation_get(relation, [src_sha1]), {src_sha1: data[src_sha1]}, storage.with_path(), ) for dst_sha1 in dsts: relation_compare_result( storage.relation_get(relation, [dst_sha1], reverse=True), { src_sha1: { RelationData(dst=dst_sha1, path=rel.path) for rel in rels if dst_sha1 == rel.dst } for src_sha1, rels in data.items() if dst_sha1 in {rel.dst for rel in rels} }, storage.with_path(), ) relation_compare_result( storage.relation_get_all(relation), data, storage.with_path() ) def relation_compare_result( computed: Dict[Sha1Git, Set[RelationData]], expected: Dict[Sha1Git, Set[RelationData]], with_path: bool, ) -> None: assert { src_sha1: { RelationData(dst=rel.dst, path=rel.path if with_path else None) for rel in rels } for src_sha1, rels in expected.items() } == computed - - -def test_provenance_storage_relation( - provenance_storage: ProvenanceStorageInterface, -) -> None: - """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" - - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") - - # Test content-in-revision relation. - # Create flat models of every root directory for the revisions in the dataset. - cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} - for rev in data["revision"]: - root = next( - subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] - ) - for cnt, rel in dircontent(data, rev["id"], root): - cnt_in_rev.setdefault(cnt, set()).add(rel) - relation_add_and_compare_result( - provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev - ) - - # Test content-in-directory relation. - # Create flat models for every directory in the dataset. - cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} - for dir in data["directory"]: - for cnt, rel in dircontent(data, dir["id"], dir): - cnt_in_dir.setdefault(cnt, set()).add(rel) - relation_add_and_compare_result( - provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir - ) - - # Test content-in-directory relation. - # Add root directories to their correspondent revision in the dataset. - dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} - for rev in data["revision"]: - dir_in_rev.setdefault(rev["directory"], set()).add( - RelationData(dst=rev["id"], path=b".") - ) - relation_add_and_compare_result( - provenance_storage, RelationType.DIR_IN_REV, dir_in_rev - ) - - # Test revision-in-origin relation. - # Origins must be inserted in advance (cannot be done by `entity_add` inside - # `relation_add_and_compare_result`). - orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} - assert provenance_storage.origin_add(orgs) - # Add all revisions that are head of some snapshot branch to the corresponding - # origin. - rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} - for status in data["origin_visit_status"]: - if status["snapshot"] is not None: - for snapshot in data["snapshot"]: - if snapshot["id"] == status["snapshot"]: - for branch in snapshot["branches"].values(): - if branch["target_type"] == "revision": - rev_in_org.setdefault(branch["target"], set()).add( - RelationData( - dst=Origin(url=status["origin"]).id, - path=None, - ) - ) - relation_add_and_compare_result( - provenance_storage, RelationType.REV_IN_ORG, rev_in_org - ) - - # Test revision-before-revision relation. - # For each revision in the data set add an entry for each parent to the relation. - rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} - for rev in data["revision"]: - for parent in rev["parents"]: - rev_before_rev.setdefault(parent, set()).add( - RelationData(dst=rev["id"], path=None) - ) - relation_add_and_compare_result( - provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev - ) - - -def test_provenance_storage_find( - provenance: ProvenanceInterface, - provenance_storage: ProvenanceStorageInterface, - archive: ArchiveInterface, -) -> None: - """Tests `content_find_first` and `content_find_all` methods for every - `ProvenanceStorageInterface` implementation. - """ - - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data("cmdbts2") - fill_storage(archive.storage, data) - - # Test content_find_first and content_find_all, first only executing the - # revision-content algorithm, then adding the origin-revision layer. - def adapt_result( - result: Optional[ProvenanceResult], with_path: bool - ) -> Optional[ProvenanceResult]: - if result is not None: - return ProvenanceResult( - result.content, - result.revision, - result.date, - result.origin, - result.path if with_path else b"", - ) - return result - - # Execute the revision-content algorithm on both storages. - revisions = [ - RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) - for rev in data["revision"] - ] - revision_add(provenance, archive, revisions) - revision_add(Provenance(provenance_storage), archive, revisions) - - assert adapt_result( - ProvenanceResult( - content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), - revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), - date=datetime.fromtimestamp(1000000000.0, timezone.utc), - origin=None, - path=b"A/B/C/a", - ), - provenance_storage.with_path(), - ) == provenance_storage.content_find_first( - hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") - ) - - for cnt in {cnt["sha1_git"] for cnt in data["content"]}: - assert adapt_result( - provenance.storage.content_find_first(cnt), provenance_storage.with_path() - ) == provenance_storage.content_find_first(cnt) - assert { - adapt_result(occur, provenance_storage.with_path()) - for occur in provenance.storage.content_find_all(cnt) - } == set(provenance_storage.content_find_all(cnt)) - - # Execute the origin-revision algorithm on both storages. - origins = [ - OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) - for sta in data["origin_visit_status"] - if sta["snapshot"] is not None - ] - origin_add(provenance, archive, origins) - origin_add(Provenance(provenance_storage), archive, origins) - - assert adapt_result( - ProvenanceResult( - content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), - revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), - date=datetime.fromtimestamp(1000000000.0, timezone.utc), - origin="https://cmdbts2", - path=b"A/B/C/a", - ), - provenance_storage.with_path(), - ) == provenance_storage.content_find_first( - hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") - ) - - for cnt in {cnt["sha1_git"] for cnt in data["content"]}: - assert adapt_result( - provenance.storage.content_find_first(cnt), provenance_storage.with_path() - ) == provenance_storage.content_find_first(cnt) - assert { - adapt_result(occur, provenance_storage.with_path()) - for occur in provenance.storage.content_find_all(cnt) - } == set(provenance_storage.content_find_all(cnt)) - - -def test_types(provenance_storage: ProvenanceStorageInterface) -> None: - """Checks all methods of ProvenanceStorageInterface are implemented by this - backend, and that they have the same signature.""" - # Create an instance of the protocol (which cannot be instantiated - # directly, so this creates a subclass, then instantiates it) - interface = type("_", (ProvenanceStorageInterface,), {})() - - assert "content_find_first" in dir(interface) - - missing_methods = [] - - for meth_name in dir(interface): - if meth_name.startswith("_"): - continue - interface_meth = getattr(interface, meth_name) - try: - concrete_meth = getattr(provenance_storage, meth_name) - except AttributeError: - if not getattr(interface_meth, "deprecated_endpoint", False): - # The backend is missing a (non-deprecated) endpoint - missing_methods.append(meth_name) - continue - - expected_signature = inspect.signature(interface_meth) - actual_signature = inspect.signature(concrete_meth) - - assert expected_signature == actual_signature, meth_name - - assert missing_methods == [] - - # If all the assertions above succeed, then this one should too. - # But there's no harm in double-checking. - # And we could replace the assertions above by this one, but unlike - # the assertions above, it doesn't explain what is missing. - assert isinstance(provenance_storage, ProvenanceStorageInterface) diff --git a/swh/provenance/tests/test_provenance_storage_rabbitmq.py b/swh/provenance/tests/test_provenance_storage_rabbitmq.py new file mode 100644 index 0000000..48cf787 --- /dev/null +++ b/swh/provenance/tests/test_provenance_storage_rabbitmq.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Generator + +import pytest + +from swh.provenance import get_provenance_storage +from swh.provenance.interface import ProvenanceStorageInterface + +from .test_provenance_storage import TestProvenanceStorage # noqa: F401 + + +@pytest.fixture() +def provenance_storage( + provenance_postgresqldb: Dict[str, str], + rabbitmq, +) -> Generator[ProvenanceStorageInterface, None, None]: + """Return a working and initialized ProvenanceStorageInterface object""" + + from swh.provenance.api.server import ProvenanceStorageRabbitMQServer + + host = rabbitmq.args["host"] + port = rabbitmq.args["port"] + rabbitmq_params: Dict[str, Any] = { + "url": f"amqp://guest:guest@{host}:{port}/%2f", + "storage_config": { + "cls": "postgresql", + "db": provenance_postgresqldb, + "raise_on_commit": True, + }, + } + server = ProvenanceStorageRabbitMQServer( + url=rabbitmq_params["url"], storage_config=rabbitmq_params["storage_config"] + ) + server.start() + try: + with get_provenance_storage(cls="rabbitmq", **rabbitmq_params) as storage: + yield storage + finally: + server.stop() diff --git a/swh/provenance/tests/test_provenance_storage_with_path_denormalized.py b/swh/provenance/tests/test_provenance_storage_with_path_denormalized.py new file mode 100644 index 0000000..c721c56 --- /dev/null +++ b/swh/provenance/tests/test_provenance_storage_with_path_denormalized.py @@ -0,0 +1,19 @@ +from functools import partial + +from pytest_postgresql import factories + +from swh.core.db.db_utils import initialize_database_for_module +from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql + +from .test_provenance_storage import TestProvenanceStorage # noqa: F401 + +provenance_postgresql_proc = factories.postgresql_proc( + load=[ + partial( + initialize_database_for_module, + modname="provenance", + flavor="with-path-denormalized", + version=ProvenanceStoragePostgreSql.current_version, + ) + ], +) diff --git a/swh/provenance/tests/test_provenance_storage_without_path.py b/swh/provenance/tests/test_provenance_storage_without_path.py new file mode 100644 index 0000000..fc77300 --- /dev/null +++ b/swh/provenance/tests/test_provenance_storage_without_path.py @@ -0,0 +1,19 @@ +from functools import partial + +from pytest_postgresql import factories + +from swh.core.db.db_utils import initialize_database_for_module +from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql + +from .test_provenance_storage import TestProvenanceStorage # noqa: F401 + +provenance_postgresql_proc = factories.postgresql_proc( + load=[ + partial( + initialize_database_for_module, + modname="provenance", + flavor="without-path", + version=ProvenanceStoragePostgreSql.current_version, + ) + ], +) diff --git a/swh/provenance/tests/test_provenance_storage_without_path_denormalized.py b/swh/provenance/tests/test_provenance_storage_without_path_denormalized.py new file mode 100644 index 0000000..550d702 --- /dev/null +++ b/swh/provenance/tests/test_provenance_storage_without_path_denormalized.py @@ -0,0 +1,19 @@ +from functools import partial + +from pytest_postgresql import factories + +from swh.core.db.db_utils import initialize_database_for_module +from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql + +from .test_provenance_storage import TestProvenanceStorage # noqa: F401 + +provenance_postgresql_proc = factories.postgresql_proc( + load=[ + partial( + initialize_database_for_module, + modname="provenance", + flavor="without-path-denormalized", + version=ProvenanceStoragePostgreSql.current_version, + ) + ], +)