Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/tests/conftest.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from os import path | from os import path | ||||
import re | import re | ||||
from typing import Iterable, Iterator, List, Optional | from typing import Any, Dict, Iterable, Iterator, List, Optional | ||||
import msgpack | import msgpack | ||||
import psycopg2 | |||||
import pytest | import pytest | ||||
from typing_extensions import TypedDict | from typing_extensions import TypedDict | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.journal.serializers import msgpack_ext_hook | from swh.journal.serializers import msgpack_ext_hook | ||||
from swh.model.hashutil import hash_to_bytes | from swh.model.hashutil import hash_to_bytes | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from swh.model.tests.swh_model_data import TEST_OBJECTS | from swh.model.tests.swh_model_data import TEST_OBJECTS | ||||
from swh.provenance import get_provenance | from swh.provenance import get_provenance | ||||
from swh.provenance.archive import ArchiveInterface | |||||
from swh.provenance.postgresql.archive import ArchivePostgreSQL | from swh.provenance.postgresql.archive import ArchivePostgreSQL | ||||
from swh.provenance.postgresql.provenancedb_base import ProvenanceDBBase | |||||
from swh.provenance.provenance import ProvenanceInterface | |||||
from swh.provenance.storage.archive import ArchiveStorage | from swh.provenance.storage.archive import ArchiveStorage | ||||
from swh.storage.postgresql.storage import Storage | |||||
from swh.storage.replay import process_replay_objects | from swh.storage.replay import process_replay_objects | ||||
@pytest.fixture(params=["with-path", "without-path"]) | @pytest.fixture(params=["with-path", "without-path"]) | ||||
def provenance(request, postgresql): | def provenance( | ||||
request, postgresql: psycopg2.extensions.connection | |||||
) -> ProvenanceInterface: | |||||
"""return a working and initialized provenance db""" | """return a working and initialized provenance db""" | ||||
from swh.core.cli.db import populate_database_for_package | from swh.core.cli.db import populate_database_for_package | ||||
flavor = request.param | flavor = request.param | ||||
populate_database_for_package("swh.provenance", postgresql.dsn, flavor=flavor) | populate_database_for_package("swh.provenance", postgresql.dsn, flavor=flavor) | ||||
BaseDb.adapt_conn(postgresql) | BaseDb.adapt_conn(postgresql) | ||||
args = dict(tuple(item.split("=")) for item in postgresql.dsn.split()) | args: Dict[str, str] = { | ||||
args.pop("options") | item.split("=")[0]: item.split("=")[1] | ||||
for item in postgresql.dsn.split() | |||||
if item.split("=")[0] != "options" | |||||
} | |||||
prov = get_provenance(cls="local", db=args) | prov = get_provenance(cls="local", db=args) | ||||
assert isinstance(prov.storage, ProvenanceDBBase) | |||||
assert prov.storage.flavor == flavor | assert prov.storage.flavor == flavor | ||||
# in test sessions, we DO want to raise any exception occurring at commit time | # in test sessions, we DO want to raise any exception occurring at commit time | ||||
prov.storage.raise_on_commit = True | prov.storage.raise_on_commit = True | ||||
return prov | return prov | ||||
@pytest.fixture | @pytest.fixture | ||||
def swh_storage_with_objects(swh_storage): | def swh_storage_with_objects(swh_storage: Storage) -> Storage: | ||||
"""return a Storage object (postgresql-based by default) with a few of each | """return a Storage object (postgresql-based by default) with a few of each | ||||
object type in it | object type in it | ||||
The inserted content comes from swh.model.tests.swh_model_data. | The inserted content comes from swh.model.tests.swh_model_data. | ||||
""" | """ | ||||
for obj_type in ( | for obj_type in ( | ||||
"content", | "content", | ||||
"skipped_content", | "skipped_content", | ||||
"directory", | "directory", | ||||
"revision", | "revision", | ||||
"release", | "release", | ||||
"snapshot", | "snapshot", | ||||
"origin", | "origin", | ||||
"origin_visit", | "origin_visit", | ||||
"origin_visit_status", | "origin_visit_status", | ||||
): | ): | ||||
getattr(swh_storage, f"{obj_type}_add")(TEST_OBJECTS[obj_type]) | getattr(swh_storage, f"{obj_type}_add")(TEST_OBJECTS[obj_type]) | ||||
return swh_storage | return swh_storage | ||||
@pytest.fixture | @pytest.fixture | ||||
def archive_direct(swh_storage_with_objects): | def archive_direct(swh_storage_with_objects: Storage) -> ArchiveInterface: | ||||
return ArchivePostgreSQL(swh_storage_with_objects.get_db().conn) | return ArchivePostgreSQL(swh_storage_with_objects.get_db().conn) | ||||
@pytest.fixture | @pytest.fixture | ||||
def archive_api(swh_storage_with_objects): | def archive_api(swh_storage_with_objects: Storage) -> ArchiveInterface: | ||||
return ArchiveStorage(swh_storage_with_objects) | return ArchiveStorage(swh_storage_with_objects) | ||||
@pytest.fixture(params=["archive", "db"]) | @pytest.fixture(params=["archive", "db"]) | ||||
def archive(request, swh_storage_with_objects): | def archive(request, swh_storage_with_objects: Storage) -> Iterator[ArchiveInterface]: | ||||
"""Return a ArchivePostgreSQL based StorageInterface object""" | """Return a ArchivePostgreSQL based StorageInterface object""" | ||||
# this is a workaround to prevent tests from hanging because of an unclosed | # this is a workaround to prevent tests from hanging because of an unclosed | ||||
# transaction. | # transaction. | ||||
# TODO: refactor the ArchivePostgreSQL to properly deal with | # TODO: refactor the ArchivePostgreSQL to properly deal with | ||||
# transactions and get rif of this fixture | # transactions and get rid of this fixture | ||||
if request.param == "db": | if request.param == "db": | ||||
archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn) | archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn) | ||||
yield archive | yield archive | ||||
archive.conn.rollback() | archive.conn.rollback() | ||||
else: | else: | ||||
yield ArchiveStorage(swh_storage_with_objects) | yield ArchiveStorage(swh_storage_with_objects) | ||||
def get_datafile(fname): | def get_datafile(fname: str) -> str: | ||||
return path.join(path.dirname(__file__), "data", fname) | return path.join(path.dirname(__file__), "data", fname) | ||||
def load_repo_data(repo): | def load_repo_data(repo: str) -> Dict[str, Any]: | ||||
data = {} | data: Dict[str, Any] = {} | ||||
with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: | with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: | ||||
unpacker = msgpack.Unpacker( | unpacker = msgpack.Unpacker( | ||||
fobj, | fobj, | ||||
raw=False, | raw=False, | ||||
ext_hook=msgpack_ext_hook, | ext_hook=msgpack_ext_hook, | ||||
strict_map_key=False, | strict_map_key=False, | ||||
timestamp=3, # convert Timestamp in datetime objects (tz UTC) | timestamp=3, # convert Timestamp in datetime objects (tz UTC) | ||||
) | ) | ||||
for objtype, objd in unpacker: | for objtype, objd in unpacker: | ||||
data.setdefault(objtype, []).append(objd) | data.setdefault(objtype, []).append(objd) | ||||
return data | return data | ||||
def filter_dict(d, keys): | def filter_dict(d: Dict[Any, Any], keys: Iterable[Any]) -> Dict[Any, Any]: | ||||
return {k: v for (k, v) in d.items() if k in keys} | return {k: v for (k, v) in d.items() if k in keys} | ||||
def fill_storage(storage, data): | def fill_storage(storage: Storage, data: Dict[str, Any]) -> None: | ||||
process_replay_objects(data, storage=storage) | process_replay_objects(data, storage=storage) | ||||
class SynthRelation(TypedDict): | class SynthRelation(TypedDict): | ||||
prefix: Optional[str] | prefix: Optional[str] | ||||
path: str | path: str | ||||
src: Sha1Git | src: Sha1Git | ||||
dst: Sha1Git | dst: Sha1Git | ||||
▲ Show 20 Lines • Show All 56 Lines • ▼ Show 20 Lines | for m in (regex.match(line) for line in fobj): | ||||
if current_rev: | if current_rev: | ||||
yield _mk_synth_rev(current_rev) | yield _mk_synth_rev(current_rev) | ||||
current_rev.clear() | current_rev.clear() | ||||
current_rev.append(d) | current_rev.append(d) | ||||
if current_rev: | if current_rev: | ||||
yield _mk_synth_rev(current_rev) | yield _mk_synth_rev(current_rev) | ||||
def _mk_synth_rev(synth_rev) -> SynthRevision: | def _mk_synth_rev(synth_rev: List[Dict[str, str]]) -> SynthRevision: | ||||
assert synth_rev[0]["type"] == "R" | assert synth_rev[0]["type"] == "R" | ||||
rev = SynthRevision( | rev = SynthRevision( | ||||
sha1=hash_to_bytes(synth_rev[0]["sha1"]), | sha1=hash_to_bytes(synth_rev[0]["sha1"]), | ||||
date=float(synth_rev[0]["ts"]), | date=float(synth_rev[0]["ts"]), | ||||
msg=synth_rev[0]["revname"], | msg=synth_rev[0]["revname"], | ||||
R_C=[], | R_C=[], | ||||
R_D=[], | R_D=[], | ||||
D_C=[], | D_C=[], | ||||
▲ Show 20 Lines • Show All 42 Lines • Show Last 20 Lines |