diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -924,6 +924,7 @@ def origin_visit_get_latest( self, origin: str, + type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[Dict[str, Any]]: @@ -933,6 +934,8 @@ for row in rows: visit = self._format_origin_visit_row(row) updated_visit = self._origin_visit_apply_last_status(visit) + if type is not None and updated_visit["type"] != type: + continue if allowed_statuses and updated_visit["status"] not in allowed_statuses: continue if require_snapshot and updated_visit["snapshot"] is None: diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -6,7 +6,7 @@ import datetime import random import select -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple from swh.core.db import BaseDb from swh.core.db.db_utils import stored_procedure, jsonize @@ -671,13 +671,19 @@ return bool(cur.fetchone()) def origin_visit_get_latest( - self, origin_id: str, allowed_statuses=None, require_snapshot=False, cur=None + self, + origin_id: str, + type: Optional[str], + allowed_statuses: Optional[Iterable[str]] = None, + require_snapshot: bool = False, + cur=None, ): """Retrieve the most recent origin_visit of the given origin, with optional filters. Args: origin_id: the origin concerned + type: Optional visit type to filter on allowed_statuses: the visit statuses allowed for the returned visit require_snapshot (bool): If True, only a visit with a known snapshot will be returned. @@ -697,6 +703,10 @@ query_parts.append("WHERE o.url = %s") query_params: List[Any] = [origin_id] + if type is not None: + query_parts.append("AND ov.type = %s") + query_params.append(type) + if require_snapshot: query_parts.append("AND ovs.snapshot is not null") diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -918,6 +918,7 @@ def origin_visit_get_latest( self, origin: str, + type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[Dict[str, Any]]: @@ -932,6 +933,8 @@ if visit is not None ] + if type is not None: + visits = [visit for visit in visits if visit.type == type] if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] if require_snapshot: diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -866,6 +866,7 @@ def origin_visit_get_latest( self, origin: str, + type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, ) -> Optional[Dict[str, Any]]: @@ -875,6 +876,7 @@ Args: origin: origin URL + type: Optional visit type to filter on allowed_statuses: list of visit statuses considered to find the latest visit. For instance, ``allowed_statuses=['full']`` will only consider visits that diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -984,6 +984,7 @@ def origin_visit_get_latest( self, origin: str, + type: Optional[str] = None, allowed_statuses: Optional[List[str]] = None, require_snapshot: bool = False, db=None, @@ -991,6 +992,7 @@ ) -> Optional[Dict[str, Any]]: row = db.origin_visit_get_latest( origin, + type=type, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, cur=cur, diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1784,10 +1784,20 @@ ) assert origin2.url != origin1.url assert origin_visit2 + assert origin_visit2["type"] == data.type_visit2 assert origin_visit2["status"] == "ongoing" assert origin_visit2["snapshot"] is None assert origin_visit2["metadata"] == {"intrinsic": "something"} + assert data.type_visit1 != data.type_visit2 + + origin_visit = swh_storage.origin_visit_get_latest( + origin2.url, + require_snapshot=False, + type=data.type_visit1, # wrong type will make the visit not found + ) + assert origin_visit is None, "Visit should not be found since wrong type" + actual_objects = list(swh_storage.journal_writer.journal.objects) expected_origins = [origin1, origin2]