diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,3 +1,3 @@ swh.core[db,http] >= 0.0.94 -swh.model >= 0.0.63 +swh.model >= 0.0.66 swh.objstorage >= 0.0.40 diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,7 +14,7 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import attr @@ -27,6 +27,7 @@ Release, Snapshot, OriginVisit, + OriginVisitStatus, Origin, SHA1_SIZE, ) @@ -64,6 +65,7 @@ self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} + self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} @@ -499,7 +501,9 @@ self._origin_visits[origin_url] ): return None - snapshot_id = self._origin_visits[origin_url][visit - 1].snapshot + + visit = self._origin_visit_get_updated(origin_url, visit) + snapshot_id = visit.snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -666,19 +670,20 @@ else: origins = [orig for orig in origins if url_pattern in orig["url"]] if with_visit: - origins = [ - orig - for orig in origins - if len(self._origin_visits[orig["url"]]) > 0 - and set( - ov.snapshot + filtered_origins = [] + for orig in origins: + visits = ( + self._origin_visit_get_updated(ov.origin, ov.visit) for ov in self._origin_visits[orig["url"]] - if ov.snapshot ) - & set(self._snapshots) - ] + for ov in visits: + if ov.snapshot and ov.snapshot in self._snapshots: + filtered_origins.append(orig) + break + else: + filtered_origins = origins - return origins[offset : offset + limit] + return filtered_origins[offset : offset + limit] def origin_count(self, url_pattern, regexp=False, with_visit=False): return len( @@ -730,19 +735,33 @@ # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = "ongoing" - visit = OriginVisit( - origin=origin_url, - date=date, - type=type, - status=status, - snapshot=None, - metadata=None, - visit=visit_id, - ) + with convert_validation_exceptions(): + visit = OriginVisit( + origin=origin_url, + date=date, + type=type, + # TODO: Remove when we remove those fields from the model + status=status, + snapshot=None, + metadata=None, + visit=visit_id, + ) self._origin_visits[origin_url].append(visit) - visit = visit + assert visit.visit is not None + visit_key = (origin_url, visit.visit) - self._objects[(origin_url, visit.visit)].append(("origin_visit", None)) + with convert_validation_exceptions(): + visit_update = OriginVisitStatus( + origin=origin_url, + visit=visit_id, + date=date, + status=status, + snapshot=None, + metadata=None, + ) + self._origin_visit_statuses[visit_key] = [visit_update] + + self._objects[visit_key].append(("origin_visit", None)) self.journal_writer.origin_visit_add(visit) @@ -767,16 +786,28 @@ except IndexError: raise StorageArgumentException("Unknown visit_id for this origin") from None - updates: Dict[str, Any] = {"status": status} - if metadata: - updates["metadata"] = metadata - if snapshot: - updates["snapshot"] = snapshot + # Retrieve the previous visit update + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + last_visit_update = max( + self._origin_visit_statuses[visit_key], key=lambda v: v.date + ) with convert_validation_exceptions(): - visit = attr.evolve(visit, **updates) + visit_update = OriginVisitStatus( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update.snapshot, + metadata=metadata or last_visit_update.metadata, + ) + self._origin_visit_statuses[visit_key].append(visit_update) - self.journal_writer.origin_visit_update(visit) + self.journal_writer.origin_visit_update( + self._origin_visit_get_updated(origin_url, visit_id) + ) self._origin_visits[origin_url][visit_id - 1] = visit @@ -787,28 +818,65 @@ self.journal_writer.origin_visit_upsert(visits) + date = now() + for visit in visits: - visit_id = visit.visit + assert visit.visit is not None origin_url = visit.origin + origin = self.origin_get({"url": origin_url}) + + if not origin: # Cannot add a visit without an origin + raise StorageArgumentException("Unknown origin %s", origin_url) + + if origin_url in self._origins: + origin = self._origins[origin_url] + # visit ids are in the range [1, +inf[ + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitStatus( + origin=origin_url, + visit=visit.visit, + date=date, + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) - with convert_validation_exceptions(): - visit = attr.evolve(visit, origin=origin_url) - - self._objects[(origin_url, visit_id)].append(("origin_visit", None)) - - if visit_id: - while len(self._origin_visits[origin_url]) <= visit_id: + self._origin_visit_statuses.setdefault(visit_key, []) + while len(self._origin_visits[origin_url]) <= visit.visit: self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit_id - 1] = visit + self._origin_visits[origin_url][visit.visit - 1] = visit + self._origin_visit_statuses[visit_key].append(visit_update) - def _convert_visit(self, visit): - if visit is None: - return + self._objects[visit_key].append(("origin_visit", None)) - visit = visit.to_dict() + def _origin_visit_get_updated( + self, origin: str, visit_id: int + ) -> Optional[OriginVisit]: + """Merge origin visit and latest origin visit update - return visit + """ + assert visit_id >= 1 + visit = self._origin_visits[origin][visit_id - 1] + if visit is None: + return None + visit_key = (origin, visit_id) + + visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) + + return OriginVisit.from_dict( + { + # default to the values in visit + **visit.to_dict(), + # override with the last update + **visit_update.to_dict(), + # but keep the date of the creation of the origin visit + "date": visit.date, + } + ) def origin_visit_get( self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None @@ -825,7 +893,9 @@ continue visit_id = visit.visit - yield self._convert_visit(self._origin_visits[origin_url][visit_id - 1]) + visit_update = self._origin_visit_get_updated(origin_url, visit_id) + assert visit_update is not None + yield visit_update.to_dict() def origin_visit_find_by_date( self, origin: str, visit_date: datetime.datetime @@ -834,7 +904,9 @@ if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - return self._convert_visit(visit) + visit_update = self._origin_visit_get_updated(origin, visit.visit) + assert visit_update is not None + return visit_update.to_dict() return None def origin_visit_get_by(self, origin: str, visit: int) -> Optional[Dict[str, Any]]: @@ -842,7 +914,9 @@ if origin_url in self._origin_visits and visit <= len( self._origin_visits[origin_url] ): - return self._convert_visit(self._origin_visits[origin_url][visit - 1]) + visit_update = self._origin_visit_get_updated(origin_url, visit) + assert visit_update is not None + return visit_update.to_dict() return None def origin_visit_get_latest( @@ -855,13 +929,21 @@ if not ori: return None visits = self._origin_visits[ori.url] + visits = [ + self._origin_visit_get_updated(visit.origin, visit.visit) + for visit in visits + if visit is not None + ] + if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] if require_snapshot: visits = [visit for visit in visits if visit.snapshot] visit = max(visits, key=lambda v: (v.date, v.visit), default=None) - return self._convert_visit(visit) + if visit is None: + return None + return visit.to_dict() def _select_random_origin_visit_by_type(self, type: str) -> str: while True: @@ -877,8 +959,10 @@ back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - if visit.date > back_in_the_day and visit.status == "full": - return visit.to_dict() + updated_visit = self._origin_visit_get_updated(url, visit.visit) + assert updated_visit is not None + if updated_visit.date > back_in_the_day and updated_visit.status == "full": + return updated_visit.to_dict() else: return None