diff --git a/pytest.ini b/pytest.ini index d8fe211..5270638 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,9 @@ [pytest] norecursedirs = docs .* postgresql_postgres_options = -N 500 + +markers = + kafka + grpc + rabbitmq diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py index 018b30d..860364a 100644 --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -1,275 +1,277 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Counter from operator import itemgetter from typing import Any from typing import Counter as TCounter from typing import Dict, Iterable, List, Set, Tuple, Type, Union import pytest from swh.core.db import BaseDb from swh.model.model import ( SWH_MODEL_OBJECT_TYPES, BaseModel, Content, Directory, DirectoryEntry, ObjectType, Origin, OriginVisitStatus, Release, Revision, Sha1Git, Snapshot, SnapshotBranch, TargetType, ) from swh.model.swhids import CoreSWHID, ExtendedObjectType, ExtendedSWHID from swh.provenance.archive import ArchiveInterface from swh.provenance.archive.multiplexer import ArchiveMultiplexed from swh.provenance.archive.postgresql import ArchivePostgreSQL from swh.provenance.archive.storage import ArchiveStorage from swh.provenance.archive.swhgraph import ArchiveGraph from swh.storage.interface import StorageInterface from swh.storage.postgresql.storage import Storage from .utils import fill_storage, grpc_server, load_repo_data class ArchiveNoop: storage: StorageInterface def directory_ls(self, id: Sha1Git, minsize: int = 0) -> Iterable[Dict[str, Any]]: return [] def revision_get_some_outbound_edges( self, revision_id: Sha1Git ) -> Iterable[Tuple[Sha1Git, Sha1Git]]: return [] def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: return [] def check_directory_ls( reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] ) -> None: for directory in data["directory"]: entries_ref = sorted( reference.directory_ls(directory["id"]), key=itemgetter("name") ) entries = sorted(archive.directory_ls(directory["id"]), key=itemgetter("name")) assert entries_ref == entries def check_revision_get_some_outbound_edges( reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] ) -> None: for revision in data["revision"]: parents_ref: TCounter[Tuple[Sha1Git, Sha1Git]] = Counter( reference.revision_get_some_outbound_edges(revision["id"]) ) parents: TCounter[Tuple[Sha1Git, Sha1Git]] = Counter( archive.revision_get_some_outbound_edges(revision["id"]) ) # Check that all the reference outbound edges are included in the other # archives's outbound edges assert set(parents_ref.items()) <= set(parents.items()) def check_snapshot_get_heads( reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] ) -> None: for snapshot in data["snapshot"]: heads_ref: TCounter[Sha1Git] = Counter( reference.snapshot_get_heads(snapshot["id"]) ) heads: TCounter[Sha1Git] = Counter(archive.snapshot_get_heads(snapshot["id"])) assert heads_ref == heads def get_object_class(object_type: str) -> Type[BaseModel]: return SWH_MODEL_OBJECT_TYPES[object_type] def data_to_model(data: Dict[str, List[dict]]) -> Dict[str, List[BaseModel]]: model: Dict[str, List[BaseModel]] = {} for object_type, objects in data.items(): for object in objects: model.setdefault(object_type, []).append( get_object_class(object_type).from_dict(object) ) return model def add_link( edges: Set[ Tuple[ Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] ] ], src_obj: Union[Content, Directory, Origin, Release, Revision, Snapshot], dst_id: bytes, dst_type: ExtendedObjectType, ) -> None: swhid = ExtendedSWHID(object_type=dst_type, object_id=dst_id) edges.add((src_obj.swhid(), swhid)) def get_graph_data( data: Dict[str, List[dict]] ) -> Tuple[ List[Union[CoreSWHID, ExtendedSWHID, str]], List[ Tuple[ Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] ] ], ]: nodes: Set[Union[CoreSWHID, ExtendedSWHID, str]] = set() edges: Set[ Tuple[ Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] ] ] = set() model = data_to_model(data) for origin in model["origin"]: assert isinstance(origin, Origin) nodes.add(origin.swhid()) for status in model["origin_visit_status"]: assert isinstance(status, OriginVisitStatus) if status.origin == origin.url and status.snapshot is not None: add_link(edges, origin, status.snapshot, ExtendedObjectType.SNAPSHOT) for snapshot in model["snapshot"]: assert isinstance(snapshot, Snapshot) nodes.add(snapshot.swhid()) for branch in snapshot.branches.values(): assert isinstance(branch, SnapshotBranch) if branch.target_type in [TargetType.RELEASE, TargetType.REVISION]: target_type = ( ExtendedObjectType.RELEASE if branch.target_type == TargetType.RELEASE else ExtendedObjectType.REVISION ) add_link(edges, snapshot, branch.target, target_type) for revision in model["revision"]: assert isinstance(revision, Revision) nodes.add(revision.swhid()) # root directory add_link(edges, revision, revision.directory, ExtendedObjectType.DIRECTORY) # parent for parent in revision.parents: add_link(edges, revision, parent, ExtendedObjectType.REVISION) dir_entry_types = { "file": ExtendedObjectType.CONTENT, "dir": ExtendedObjectType.DIRECTORY, "rev": ExtendedObjectType.REVISION, } for directory in model["directory"]: assert isinstance(directory, Directory) nodes.add(directory.swhid()) for entry in directory.entries: assert isinstance(entry, DirectoryEntry) add_link(edges, directory, entry.target, dir_entry_types[entry.type]) for content in model["content"]: assert isinstance(content, Content) nodes.add(content.swhid()) object_type = { ObjectType.CONTENT: ExtendedObjectType.CONTENT, ObjectType.DIRECTORY: ExtendedObjectType.DIRECTORY, ObjectType.REVISION: ExtendedObjectType.REVISION, ObjectType.RELEASE: ExtendedObjectType.RELEASE, ObjectType.SNAPSHOT: ExtendedObjectType.SNAPSHOT, } for release in model["release"]: assert isinstance(release, Release) nodes.add(release.swhid()) if release.target is not None: add_link(edges, release, release.target, object_type[release.target_type]) return list(nodes), list(edges) @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) def test_archive_interface(repo: str, archive: ArchiveInterface) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) # test against ArchiveStorage archive_api = ArchiveStorage(archive.storage) check_directory_ls(archive, archive_api, data) check_revision_get_some_outbound_edges(archive, archive_api, data) check_snapshot_get_heads(archive, archive_api, data) # test against ArchivePostgreSQL assert isinstance(archive.storage, Storage) dsn = archive.storage.get_db().conn.dsn with BaseDb.connect(dsn).conn as conn: BaseDb.adapt_conn(conn) archive_direct = ArchivePostgreSQL(conn) check_directory_ls(archive, archive_direct, data) check_revision_get_some_outbound_edges(archive, archive_direct, data) check_snapshot_get_heads(archive, archive_direct, data) +@pytest.mark.grpc @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) def test_archive_graph(repo: str, archive: ArchiveInterface) -> None: data = load_repo_data(repo) fill_storage(archive.storage, data) with grpc_server(repo) as url: # test against ArchiveGraph archive_graph = ArchiveGraph(url, archive.storage) with pytest.raises(NotImplementedError): check_directory_ls(archive, archive_graph, data) check_revision_get_some_outbound_edges(archive, archive_graph, data) check_snapshot_get_heads(archive, archive_graph, data) +@pytest.mark.grpc @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) def test_archive_multiplexed(repo: str, archive: ArchiveInterface) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) # test against ArchiveMultiplexer with grpc_server(repo) as url: archive_api = ArchiveStorage(archive.storage) archive_graph = ArchiveGraph(url, archive.storage) archive_multiplexed = ArchiveMultiplexed( [("noop", ArchiveNoop()), ("graph", archive_graph), ("api", archive_api)] ) check_directory_ls(archive, archive_multiplexed, data) check_revision_get_some_outbound_edges(archive, archive_multiplexed, data) check_snapshot_get_heads(archive, archive_multiplexed, data) def test_noop_multiplexer(): archive = ArchiveMultiplexed([("noop", ArchiveNoop())]) assert not archive.directory_ls(Sha1Git(b"abcd")) assert not archive.revision_get_some_outbound_edges(Sha1Git(b"abcd")) assert not archive.snapshot_get_heads(Sha1Git(b"abcd")) diff --git a/swh/provenance/tests/utils.py b/swh/provenance/tests/utils.py index 21e624a..e43bd4f 100644 --- a/swh/provenance/tests/utils.py +++ b/swh/provenance/tests/utils.py @@ -1,128 +1,115 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime import logging -import multiprocessing from os import path from pathlib import Path +import socket import tempfile +import time from typing import Any, Dict, List, Optional -from aiohttp.test_utils import TestClient, TestServer, loop_context from click.testing import CliRunner, Result import msgpack from yaml import safe_dump -from swh.graph.http_rpc_server import make_app +from swh.graph.grpc_server import spawn_java_grpc_server, stop_java_grpc_server from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel, TimestampWithTimezone from swh.provenance.cli import cli from swh.storage.interface import StorageInterface from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects +logger = logging.getLogger(__name__) + def invoke( args: List[str], config: Optional[Dict] = None, catch_exceptions: bool = False ) -> Result: """Invoke swh journal subcommands""" runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: if config is not None: safe_dump(config, config_fd) config_fd.seek(0) args = ["-C" + config_fd.name] + args result = runner.invoke(cli, args, obj={"log_level": logging.DEBUG}, env=None) if not catch_exceptions and result.exception: print(result.output) raise result.exception return result def fill_storage(storage: StorageInterface, data: Dict[str, List[dict]]) -> None: objects = { objtype: [objs_from_dict(objtype, d) for d in dicts] for objtype, dicts in data.items() } process_replay_objects(objects, storage=storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) # TODO: this should return Dict[str, List[BaseModel]] directly, but it requires # refactoring several tests def load_repo_data(repo: str) -> Dict[str, List[dict]]: data: Dict[str, List[dict]] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for msg in unpacker: if len(msg) == 2: # old format objtype, objd = msg else: # now we should have a triplet (type, key, value) objtype, _, objd = msg data.setdefault(objtype, []).append(objd) return data def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: if object_type in OBJECT_FIXERS: dict_repr = OBJECT_FIXERS[object_type](dict_repr) obj = OBJECT_CONVERTERS[object_type](dict_repr) return obj def ts2dt(ts: Dict[str, Any]) -> datetime: return TimestampWithTimezone.from_dict(ts).to_datetime() -def run_grpc_server(queue, dataset_path): - try: - config = { - "graph": { - "cls": "local", - "grpc_server": {"path": dataset_path}, - "http_rpc_server": {"debug": True}, - } - } - with loop_context() as loop: - app = make_app(config=config) - client = TestClient(TestServer(app), loop=loop) - loop.run_until_complete(client.start_server()) - url = client.make_url("/graph/") - queue.put((url, app["rpc_url"])) - loop.run_forever() - except Exception as e: - queue.put(e) - - @contextmanager def grpc_server(dataset): dataset_path = ( Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" ) - queue = multiprocessing.Queue() - server = multiprocessing.Process( - target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} - ) - server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - grpc_url = res[1] + server, port = spawn_java_grpc_server(path=dataset_path) + logging.debug("Spawned GRPC server on port %s", port) try: - yield grpc_url + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logging.debug("Waiting for the TCP socket localhost:%s...", port) + for i in range(50): + if sock.connect_ex(("localhost", port)) == 0: + sock.close() + break + time.sleep(0.1) + else: + raise EnvironmentError( + "Cannot connect to the GRPC server on localhost:%s", port + ) + logger.debug("Connection to localhost:%s OK", port) + yield f"localhost:{port}" finally: - server.terminate() + stop_java_grpc_server(server)