diff --git a/swh/search/journal_client.py b/swh/search/journal_client.py --- a/swh/search/journal_client.py +++ b/swh/search/journal_client.py @@ -4,9 +4,11 @@ # See top-level LICENSE file for more information import logging +from typing import Dict from swh.model.model import TargetType from swh.storage.algos.snapshot import snapshot_get_all_branches +from swh.storage.interface import StorageInterface EXPECTED_MESSAGE_TYPES = { "origin", @@ -16,11 +18,17 @@ } -def fetch_last_revision_release_date(snapshot_id, storage): +def fetch_last_revision_release_date( + snapshot_id: bytes, storage: StorageInterface +) -> Dict[str, str]: if not snapshot_id: return {} - branches = snapshot_get_all_branches(storage, snapshot_id).branches.values() + snapshot = snapshot_get_all_branches(storage, snapshot_id) + if not snapshot: + return {} + + branches = snapshot.branches.values() tip_revision_ids = [] tip_release_ids = [] @@ -34,16 +42,22 @@ revision_datetimes = [ revision.date.to_datetime() for revision in storage.revision_get(tip_revision_ids) + if revision and revision.date ] release_datetimes = [ - release.date.to_datetime() for release in storage.release_get(tip_release_ids) + release.date.to_datetime() + for release in storage.release_get(tip_release_ids) + if release and release.date ] - return { - "last_revision_date": max(revision_datetimes).isoformat(), - "last_release_date": max(release_datetimes).isoformat(), - } + ret = {} + if revision_datetimes: + ret["last_revision_date"] = max(revision_datetimes).isoformat() + if release_datetimes: + ret["last_release_date"] = max(release_datetimes).isoformat() + + return ret def process_journal_objects(messages, *, search, storage=None): diff --git a/swh/search/tests/test_journal_client.py b/swh/search/tests/test_journal_client.py --- a/swh/search/tests/test_journal_client.py +++ b/swh/search/tests/test_journal_client.py @@ -7,6 +7,8 @@ import functools from unittest.mock import MagicMock +import pytest + from swh.model.model import ( ObjectType, Person, @@ -20,7 +22,10 @@ TimestampWithTimezone, hash_to_bytes, ) -from swh.search.journal_client import process_journal_objects +from swh.search.journal_client import ( + fetch_last_revision_release_date, + process_journal_objects, +) from swh.storage import get_storage DATES = [ @@ -150,9 +155,34 @@ ), }, ), + Snapshot( + branches={ + b"target/revision1": SnapshotBranch( + target_type=TargetType.REVISION, target=REVISIONS[0].id, + ) + }, + ), + Snapshot( + branches={ + b"target/release1": SnapshotBranch( + target_type=TargetType.RELEASE, target=RELEASES[0].id + ) + }, + ), + Snapshot(branches={}), ] +@pytest.fixture +def storage(): + storage = get_storage("memory") + + storage.revision_add(REVISIONS) + storage.release_add(RELEASES) + storage.snapshot_add(SNAPSHOTS) + return storage + + def test_journal_client_origin_from_journal(): search_mock = MagicMock() @@ -182,13 +212,8 @@ ) -def test_journal_client_origin_visit_status_from_journal(): +def test_journal_client_origin_visit_status_from_journal(storage): search_mock = MagicMock() - storage = get_storage("memory") - - storage.revision_add(REVISIONS) - storage.release_add(RELEASES) - storage.snapshot_add(SNAPSHOTS) worker_fn = functools.partial( process_journal_objects, search=search_mock, storage=storage @@ -274,3 +299,8 @@ }, ] ) + + +def test_fetch_last_revision_release_date(storage): + for snapshot in SNAPSHOTS: + assert fetch_last_revision_release_date(snapshot.id, storage) is not None