diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,13 +14,13 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import attr from swh.model.model import ( BaseContent, Content, SkippedContent, Directory, Revision, - Release, Snapshot, OriginVisit, Origin, SHA1_SIZE + Release, Snapshot, OriginVisit, OriginVisitUpdate, Origin, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage @@ -56,6 +56,8 @@ self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} + self._origin_visit_updates: Dict[ + Tuple[str, int], List[OriginVisitUpdate]] = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} @@ -496,7 +498,9 @@ if origin_url not in self._origins or \ visit > len(self._origin_visits[origin_url]): return None - snapshot_id = self._origin_visits[origin_url][visit-1].snapshot + + visit = self._origin_visit_get_updated(origin_url, visit) + snapshot_id = visit.snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -661,15 +665,18 @@ else: origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: - origins = [ - orig for orig in origins - if len(self._origin_visits[orig['url']]) > 0 and - set(ov.snapshot - for ov in self._origin_visits[orig['url']] - if ov.snapshot) & - set(self._snapshots)] + filtered_origins = [] + for orig in origins: + visits = (self._origin_visit_get_updated(ov.origin, ov.visit) + for ov in self._origin_visits[orig['url']]) + for ov in visits: + if ov.snapshot and ov.snapshot in self._snapshots: + filtered_origins.append(orig) + break + else: + filtered_origins = origins - return origins[offset:offset+limit] + return filtered_origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False): return len(self.origin_search(url_pattern, regexp=regexp, @@ -718,19 +725,33 @@ # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' - visit = OriginVisit( - origin=origin_url, - date=date, - type=type, - status=status, - snapshot=None, - metadata=None, - visit=visit_id, - ) + with convert_validation_exceptions(): + visit = OriginVisit( + origin=origin_url, + date=date, + type=type, + # TODO: Remove when we remove those fields from the model + status=status, + snapshot=None, + metadata=None, + visit=visit_id, + ) self._origin_visits[origin_url].append(visit) - visit = visit + assert visit.visit is not None + visit_key = (origin_url, visit.visit) - self._objects[(origin_url, visit.visit)].append( + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=status, + snapshot=None, + metadata=None, + ) + self._origin_visit_updates[visit_key] = [visit_update] + + self._objects[visit_key].append( ('origin_visit', None)) self.journal_writer.origin_visit_add(visit) @@ -752,18 +773,26 @@ raise StorageArgumentException( 'Unknown visit_id for this origin') from None - updates: Dict[str, Any] = { - 'status': status - } - if metadata: - updates['metadata'] = metadata - if snapshot: - updates['snapshot'] = snapshot + # Retrieve the previous visit update + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + last_visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) with convert_validation_exceptions(): - visit = attr.evolve(visit, **updates) + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status, + snapshot=snapshot or last_visit_update.snapshot, + metadata=metadata or last_visit_update.metadata, + ) + self._origin_visit_updates[visit_key].append(visit_update) - self.journal_writer.origin_visit_update(visit) + self.journal_writer.origin_visit_update( + self._origin_visit_get_updated(origin_url, visit_id)) self._origin_visits[origin_url][visit_id-1] = visit @@ -775,29 +804,61 @@ self.journal_writer.origin_visit_upsert(visits) + date = now() + for visit in visits: - visit_id = visit.visit + assert visit.visit is not None origin_url = visit.origin + origin = self.origin_get({'url': origin_url}) - with convert_validation_exceptions(): - visit = attr.evolve(visit, origin=origin_url) - - self._objects[(origin_url, visit_id)].append( - ('origin_visit', None)) - - if visit_id: - while len(self._origin_visits[origin_url]) <= visit_id: + if not origin: # Cannot add a visit without an origin + raise StorageArgumentException( + 'Unknown origin %s', origin_url) + + if origin_url in self._origins: + origin = self._origins[origin_url] + # visit ids are in the range [1, +inf[ + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit.visit, + date=date, + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + + self._origin_visit_updates.setdefault(visit_key, []) + while len(self._origin_visits[origin_url]) <= visit.visit: self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit_id-1] = visit + self._origin_visits[origin_url][visit.visit-1] = visit + self._origin_visit_updates[visit_key].append(visit_update) - def _convert_visit(self, visit): + self._objects[visit_key].append( + ('origin_visit', None)) + + def _origin_visit_get_updated(self, origin: str, visit_id: int): + assert visit_id >= 1 + visit = self._origin_visits[origin][visit_id-1] if visit is None: - return + return None + visit_key = (origin, visit_id) - visit = visit.to_dict() + visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) - return visit + return OriginVisit.from_dict({ + # default to the values in visit + **visit.to_dict(), + # override with the last update + **visit_update.to_dict(), + # but keep the date of the creation of the origin visit + 'date': visit.date + }) def origin_visit_get( self, origin: str, last_visit: Optional[int] = None, @@ -814,8 +875,10 @@ continue visit_id = visit.visit - yield self._convert_visit( - self._origin_visits[origin_url][visit_id-1]) + visit_update = self._origin_visit_get_updated( + origin_url, visit_id) + assert visit_update is not None + yield visit_update.to_dict() def origin_visit_find_by_date( self, origin: str, @@ -826,7 +889,10 @@ visit = min( visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - return self._convert_visit(visit) + visit_update = self._origin_visit_get_updated( + origin, visit.visit) + assert visit_update is not None + return visit_update.to_dict() return None def origin_visit_get_by( @@ -834,8 +900,10 @@ origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and \ visit <= len(self._origin_visits[origin_url]): - return self._convert_visit( - self._origin_visits[origin_url][visit-1]) + visit_update = self._origin_visit_get_updated( + origin_url, visit) + assert visit_update is not None + return visit_update.to_dict() return None def origin_visit_get_latest( @@ -845,6 +913,10 @@ if not ori: return None visits = self._origin_visits[ori.url] + visits = [self._origin_visit_get_updated(visit.origin, visit.visit) + for visit in visits + if visit is not None] + if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] @@ -854,7 +926,9 @@ visit = max( visits, key=lambda v: (v.date, v.visit), default=None) - return self._convert_visit(visit) + if visit is None: + return None + return visit.to_dict() def _select_random_origin_visit_by_type(self, type: str) -> str: while True: @@ -870,8 +944,11 @@ back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - if visit.date > back_in_the_day and visit.status == 'full': - return visit.to_dict() + updated_visit = self._origin_visit_get_updated( + url, visit.visit) + if updated_visit.date > back_in_the_day \ + and updated_visit.status == 'full': + return updated_visit.to_dict() else: return None