diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -6,14 +6,12 @@ import base64 import bisect import collections -import copy import datetime import functools import itertools import random from collections import defaultdict -from datetime import timedelta from typing import ( Any, Callable, @@ -37,9 +35,6 @@ from swh.model.model import ( Content, SkippedContent, - OriginVisit, - OriginVisitStatus, - Origin, MetadataAuthority, MetadataAuthorityType, MetadataFetcher, @@ -56,6 +51,8 @@ DirectoryEntryRow, ObjectCountRow, OriginRow, + OriginVisitRow, + OriginVisitStatusRow, ReleaseRow, RevisionRow, RevisionParentRow, @@ -66,10 +63,8 @@ from swh.storage.interface import ( ListOrder, PagedResult, - VISIT_STATUSES, ) from swh.storage.objstorage import ObjStorage -from swh.storage.utils import now from .converters import origin_url_to_sha1 from .exc import StorageArgumentException @@ -246,6 +241,8 @@ self._snapshots = Table(SnapshotRow) self._snapshot_branches = Table(SnapshotBranchRow) self._origins = Table(OriginRow) + self._origin_visits = Table(OriginVisitRow) + self._origin_visit_statuses = Table(OriginVisitStatusRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -517,6 +514,109 @@ for (clustering_key, row) in partition.items() ) + def origin_generate_unique_visit_id(self, origin_url: str) -> int: + origin = list(self.origin_get_by_url(origin_url))[0] + visit_id = origin.next_visit_id + origin.next_visit_id += 1 + return visit_id + + ########################## + # 'origin_visit' table + ########################## + + def origin_visit_get( + self, + origin_url: str, + last_visit: Optional[int], + limit: Optional[int], + order: ListOrder, + ) -> Iterable[OriginVisitRow]: + visits = list(self._origin_visits.get_from_partition_key((origin_url,))) + + if last_visit is not None: + if order == ListOrder.ASC: + visits = [v for v in visits if v.visit > last_visit] + else: + visits = [v for v in visits if v.visit < last_visit] + + visits.sort(key=lambda v: v.visit, reverse=order == ListOrder.DESC) + + if limit is not None: + visits = visits[0:limit] + + return visits + + def origin_visit_add_one(self, visit: OriginVisitRow) -> None: + self._origin_visits.insert(visit) + self.increment_counter("origin_visit", 1) + + def origin_visit_get_one( + self, origin_url: str, visit_id: int + ) -> Optional[OriginVisitRow]: + return self._origin_visits.get_from_primary_key((origin_url, visit_id)) + + def origin_visit_get_all(self, origin_url: str) -> Iterable[OriginVisitRow]: + return self._origin_visits.get_from_partition_key((origin_url,)) + + def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]: + """Returns all origin visits in order from this token, + and wraps around the token space.""" + return ( + row + for (token, partition) in self._origin_visits.data.items() + for (clustering_key, row) in partition.items() + ) + + ########################## + # 'origin_visit_status' table + ########################## + + def origin_visit_status_get_range( + self, + origin: str, + visit: int, + date_from: Optional[datetime.datetime], + limit: int, + order: ListOrder, + ) -> Iterable[OriginVisitStatusRow]: + statuses = list(self.origin_visit_status_get(origin, visit)) + + if date_from is not None: + if order == ListOrder.ASC: + statuses = [s for s in statuses if s.date >= date_from] + else: + statuses = [s for s in statuses if s.date <= date_from] + + statuses.sort(key=lambda s: s.date, reverse=order == ListOrder.DESC) + + return statuses[0:limit] + + def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None: + self._origin_visit_statuses.insert(visit_update) + self.increment_counter("origin_visit_status", 1) + + def origin_visit_status_get_latest( + self, origin: str, visit: int, + ) -> Optional[OriginVisitStatusRow]: + """Given an origin visit id, return its latest origin_visit_status + + """ + return next(self.origin_visit_status_get(origin, visit), None) + + def origin_visit_status_get( + self, origin: str, visit: int, + ) -> Iterator[OriginVisitStatusRow]: + """Return all origin visit statuses for a given visit + + """ + statuses = [ + s + for s in self._origin_visit_statuses.get_from_partition_key((origin,)) + if s.visit == visit + ] + statuses.sort(key=lambda s: s.date, reverse=True) + return iter(statuses) + class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore @@ -527,8 +627,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._origin_visits = {} - self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} self._persons = {} # {object_type: {id: {authority: [metadata]}}} @@ -566,289 +664,6 @@ def check_config(self, *, check_write: bool) -> bool: return True - def origin_add(self, origins: List[Origin]) -> Dict[str, int]: - for origin in origins: - if origin.url not in self._origin_visits: - self._origin_visits[origin.url] = [] - return super().origin_add(origins) - - def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: - for visit in visits: - origin = self.origin_get_one(visit.origin) - if not origin: # Cannot add a visit without an origin - raise StorageArgumentException("Unknown origin %s", visit.origin) - - all_visits = [] - for visit in visits: - origin_url = visit.origin - if list(self._cql_runner.origin_get_by_url(origin_url)): - if visit.visit: - self.journal_writer.origin_visit_add([visit]) - while len(self._origin_visits[origin_url]) < visit.visit: - self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit.visit - 1] = visit - else: - # visit ids are in the range [1, +inf[ - visit_id = len(self._origin_visits[origin_url]) + 1 - visit = attr.evolve(visit, visit=visit_id) - self.journal_writer.origin_visit_add([visit]) - self._origin_visits[origin_url].append(visit) - visit_key = (origin_url, visit.visit) - self._objects[visit_key].append(("origin_visit", None)) - assert visit.visit is not None - self._origin_visit_status_add_one( - OriginVisitStatus( - origin=visit.origin, - visit=visit.visit, - date=visit.date, - status="created", - snapshot=None, - ) - ) - all_visits.append(visit) - - self._cql_runner.increment_counter("origin_visit", len(all_visits)) - - return all_visits - - def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: - """Add an origin visit status without checks. If already present, do nothing. - - """ - self.journal_writer.origin_visit_status_add([visit_status]) - visit_key = (visit_status.origin, visit_status.visit) - self._origin_visit_statuses.setdefault(visit_key, []) - visit_statuses = self._origin_visit_statuses[visit_key] - if visit_status not in visit_statuses: - visit_statuses.append(visit_status) - - def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None: - # First round to check existence (fail early if any is ko) - for visit_status in visit_statuses: - origin_url = self.origin_get_one(visit_status.origin) - if not origin_url: - raise StorageArgumentException(f"Unknown origin {visit_status.origin}") - - for visit_status in visit_statuses: - self._origin_visit_status_add_one(visit_status) - - def _origin_visit_status_get_latest( - self, origin: str, visit_id: int - ) -> Tuple[OriginVisit, OriginVisitStatus]: - """Return a tuple of OriginVisit, latest associated OriginVisitStatus. - - """ - assert visit_id >= 1 - visit = self._origin_visits[origin][visit_id - 1] - assert visit is not None - visit_key = (origin, visit_id) - - visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) - return visit, visit_update - - def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]: - """Merge origin visit and latest origin visit status - - """ - visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id) - assert visit is not None and visit_update is not None - return { - # default to the values in visit - **visit.to_dict(), - # override with the last update - **visit_update.to_dict(), - # but keep the date of the creation of the origin visit - "date": visit.date, - } - - def origin_visit_get( - self, - origin: str, - page_token: Optional[str] = None, - order: ListOrder = ListOrder.ASC, - limit: int = 10, - ) -> PagedResult[OriginVisit]: - next_page_token = None - page_token = page_token or "0" - if not isinstance(order, ListOrder): - raise StorageArgumentException("order must be a ListOrder value") - if not isinstance(page_token, str): - raise StorageArgumentException("page_token must be a string.") - - visit_from = int(page_token) - origin_url = self._get_origin_url(origin) - extra_limit = limit + 1 - visits = sorted( - self._origin_visits.get(origin_url, []), - key=lambda v: v.visit, - reverse=(order == ListOrder.DESC), - ) - - if visit_from > 0 and order == ListOrder.ASC: - visits = [v for v in visits if v.visit > visit_from] - elif visit_from > 0 and order == ListOrder.DESC: - visits = [v for v in visits if v.visit < visit_from] - visits = visits[:extra_limit] - - assert len(visits) <= extra_limit - if len(visits) == extra_limit: - visits = visits[:limit] - next_page_token = str(visits[-1].visit) - - return PagedResult(results=visits, next_page_token=next_page_token) - - def origin_visit_find_by_date( - self, origin: str, visit_date: datetime.datetime - ) -> Optional[OriginVisit]: - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits: - visits = self._origin_visits[origin_url] - visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - return visit - return None - - def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]: - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits and visit <= len( - self._origin_visits[origin_url] - ): - found_visit, _ = self._origin_visit_status_get_latest(origin, visit) - return found_visit - return None - - def origin_visit_get_latest( - self, - origin: str, - type: Optional[str] = None, - allowed_statuses: Optional[List[str]] = None, - require_snapshot: bool = False, - ) -> Optional[OriginVisit]: - if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): - raise StorageArgumentException( - f"Unknown allowed statuses {','.join(allowed_statuses)}, only " - f"{','.join(VISIT_STATUSES)} authorized" - ) - - if not list(self._cql_runner.origin_get_by_url(origin)): - return None - - visits = sorted( - self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True, - ) - for visit in visits: - if type is not None and visit.type != type: - continue - visit_statuses = self._origin_visit_statuses[origin, visit.visit] - - if allowed_statuses is not None: - visit_statuses = [ - vs for vs in visit_statuses if vs.status in allowed_statuses - ] - if require_snapshot: - visit_statuses = [vs for vs in visit_statuses if vs.snapshot] - - if visit_statuses: # we found visit statuses matching criteria - visit_status = max(visit_statuses, key=lambda vs: (vs.date, vs.visit)) - assert visit.origin == visit_status.origin - assert visit.visit == visit_status.visit - return visit - - return None - - def origin_visit_status_get( - self, - origin: str, - visit: int, - page_token: Optional[str] = None, - order: ListOrder = ListOrder.ASC, - limit: int = 10, - ) -> PagedResult[OriginVisitStatus]: - next_page_token = None - date_from = None - if page_token is not None: - date_from = datetime.datetime.fromisoformat(page_token) - - visit_statuses = sorted( - self._origin_visit_statuses.get((origin, visit), []), - key=lambda v: v.date, - reverse=(order == ListOrder.DESC), - ) - - if date_from is not None: - if order == ListOrder.ASC: - visit_statuses = [v for v in visit_statuses if v.date >= date_from] - elif order == ListOrder.DESC: - visit_statuses = [v for v in visit_statuses if v.date <= date_from] - - # Take one more visit status so we can reuse it as the next page token if any - visit_statuses = visit_statuses[: limit + 1] - - if len(visit_statuses) > limit: - # last visit status date is the next page token - next_page_token = str(visit_statuses[-1].date) - # excluding that visit status from the result to respect the limit size - visit_statuses = visit_statuses[:limit] - - return PagedResult(results=visit_statuses, next_page_token=next_page_token) - - def origin_visit_status_get_latest( - self, - origin_url: str, - visit: int, - allowed_statuses: Optional[List[str]] = None, - require_snapshot: bool = False, - ) -> Optional[OriginVisitStatus]: - if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES): - raise StorageArgumentException( - f"Unknown allowed statuses {','.join(allowed_statuses)}, only " - f"{','.join(VISIT_STATUSES)} authorized" - ) - - if not list(self._cql_runner.origin_get_by_url(origin_url)): - return None - - visit_key = (origin_url, visit) - visits = self._origin_visit_statuses.get(visit_key) - if not visits: - return None - - if allowed_statuses is not None: - visits = [visit for visit in visits if visit.status in allowed_statuses] - if require_snapshot: - visits = [visit for visit in visits if visit.snapshot] - - visit_status = max(visits, key=lambda v: (v.date, v.visit), default=None) - return visit_status - - def _select_random_origin_visit_by_type(self, type: str) -> str: - while True: - url = random.choice(list(self._origin_visits.keys())) - random_origin_visits = self._origin_visits[url] - if random_origin_visits[0].type == type: - return url - - def origin_visit_status_get_random( - self, type: str - ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]: - - url = self._select_random_origin_visit_by_type(type) - random_origin_visits = copy.deepcopy(self._origin_visits[url]) - random_origin_visits.reverse() - back_in_the_day = now() - timedelta(weeks=12) # 3 months back - # This should be enough for tests - for visit in random_origin_visits: - origin_visit, latest_visit_status = self._origin_visit_status_get_latest( - url, visit.visit - ) - assert latest_visit_status is not None - if ( - origin_visit.date > back_in_the_day - and latest_visit_status.status == "full" - ): - return origin_visit, latest_visit_status - else: - return None - def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: @@ -995,12 +810,6 @@ self._metadata_authority_key(MetadataAuthority(type=type, url=url)) ) - def _get_origin_url(self, origin): - if isinstance(origin, str): - return origin - else: - raise TypeError("origin must be a string.") - @staticmethod def _metadata_fetcher_key(fetcher: MetadataFetcher) -> FetcherKey: return (fetcher.name, fetcher.version) diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -205,16 +205,6 @@ got_persons = set(dst._persons.values()) assert got_persons == expected_persons - for attr_ in ( - "origin_visits", - "origin_visit_statuses", - ): - if exclude and attr_ in exclude: - continue - expected_objects = sorted(getattr(src, f"_{attr_}").items()) - got_objects = sorted(getattr(dst, f"_{attr_}").items()) - assert got_objects == expected_objects, f"Mismatch object list for {attr_}" - for attr_ in ( "contents", "skipped_contents", @@ -223,6 +213,8 @@ "releases", "snapshots", "origins", + "origin_visits", + "origin_visit_statuses", ): if exclude and attr_ in exclude: continue @@ -360,10 +352,7 @@ def maybe_anonymize(attr_, row): if expected_anonymized: - if hasattr(row, "anonymize"): - # for model objects; cases below are for BaseRow objects - row = row.anonymize() or row - elif attr_ == "releases": + if attr_ == "releases": row = dataclasses.replace(row, author=row.author.anonymize()) elif attr_ == "revisions": row = dataclasses.replace( @@ -379,16 +368,6 @@ got_persons = set(dst._persons.values()) assert got_persons == expected_persons - for attr_ in ("origin_visit_statuses",): - expected_objects = [ - (id, maybe_anonymize(attr_, obj)) - for id, obj in sorted(getattr(src, f"_{attr_}").items()) - ] - got_objects = [ - (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items()) - ] - assert got_objects == expected_objects, f"Mismatch object list for {attr_}" - for attr_ in ( "contents", "skipped_contents", @@ -397,6 +376,7 @@ "releases", "snapshots", "origins", + "origin_visit_statuses", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj)))