Page MenuHomeSoftware Heritage

D3780.id13288.diff
No OneTemporary

D3780.id13288.diff

diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -6,14 +6,12 @@
import base64
import bisect
import collections
-import copy
import datetime
import functools
import itertools
import random
from collections import defaultdict
-from datetime import timedelta
from typing import (
Any,
Callable,
@@ -37,9 +35,6 @@
from swh.model.model import (
Content,
SkippedContent,
- OriginVisit,
- OriginVisitStatus,
- Origin,
MetadataAuthority,
MetadataAuthorityType,
MetadataFetcher,
@@ -56,6 +51,8 @@
DirectoryEntryRow,
ObjectCountRow,
OriginRow,
+ OriginVisitRow,
+ OriginVisitStatusRow,
ReleaseRow,
RevisionRow,
RevisionParentRow,
@@ -66,10 +63,8 @@
from swh.storage.interface import (
ListOrder,
PagedResult,
- VISIT_STATUSES,
)
from swh.storage.objstorage import ObjStorage
-from swh.storage.utils import now
from .converters import origin_url_to_sha1
from .exc import StorageArgumentException
@@ -246,6 +241,8 @@
self._snapshots = Table(SnapshotRow)
self._snapshot_branches = Table(SnapshotBranchRow)
self._origins = Table(OriginRow)
+ self._origin_visits = Table(OriginVisitRow)
+ self._origin_visit_statuses = Table(OriginVisitStatusRow)
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
@@ -517,6 +514,109 @@
for (clustering_key, row) in partition.items()
)
+ def origin_generate_unique_visit_id(self, origin_url: str) -> int:
+ origin = list(self.origin_get_by_url(origin_url))[0]
+ visit_id = origin.next_visit_id
+ origin.next_visit_id += 1
+ return visit_id
+
+ ##########################
+ # 'origin_visit' table
+ ##########################
+
+ def origin_visit_get(
+ self,
+ origin_url: str,
+ last_visit: Optional[int],
+ limit: Optional[int],
+ order: ListOrder,
+ ) -> Iterable[OriginVisitRow]:
+ visits = list(self._origin_visits.get_from_partition_key((origin_url,)))
+
+ if last_visit is not None:
+ if order == ListOrder.ASC:
+ visits = [v for v in visits if v.visit > last_visit]
+ else:
+ visits = [v for v in visits if v.visit < last_visit]
+
+ visits.sort(key=lambda v: v.visit, reverse=order == ListOrder.DESC)
+
+ if limit is not None:
+ visits = visits[0:limit]
+
+ return visits
+
+ def origin_visit_add_one(self, visit: OriginVisitRow) -> None:
+ self._origin_visits.insert(visit)
+ self.increment_counter("origin_visit", 1)
+
+ def origin_visit_get_one(
+ self, origin_url: str, visit_id: int
+ ) -> Optional[OriginVisitRow]:
+ return self._origin_visits.get_from_primary_key((origin_url, visit_id))
+
+ def origin_visit_get_all(self, origin_url: str) -> Iterable[OriginVisitRow]:
+ return self._origin_visits.get_from_partition_key((origin_url,))
+
+ def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
+ """Returns all origin visits in order from this token,
+ and wraps around the token space."""
+ return (
+ row
+ for (token, partition) in self._origin_visits.data.items()
+ for (clustering_key, row) in partition.items()
+ )
+
+ ##########################
+ # 'origin_visit_status' table
+ ##########################
+
+ def origin_visit_status_get_range(
+ self,
+ origin: str,
+ visit: int,
+ date_from: Optional[datetime.datetime],
+ limit: int,
+ order: ListOrder,
+ ) -> Iterable[OriginVisitStatusRow]:
+ statuses = list(self.origin_visit_status_get(origin, visit))
+
+ if date_from is not None:
+ if order == ListOrder.ASC:
+ statuses = [s for s in statuses if s.date >= date_from]
+ else:
+ statuses = [s for s in statuses if s.date <= date_from]
+
+ statuses.sort(key=lambda s: s.date, reverse=order == ListOrder.DESC)
+
+ return statuses[0:limit]
+
+ def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None:
+ self._origin_visit_statuses.insert(visit_update)
+ self.increment_counter("origin_visit_status", 1)
+
+ def origin_visit_status_get_latest(
+ self, origin: str, visit: int,
+ ) -> Optional[OriginVisitStatusRow]:
+ """Given an origin visit id, return its latest origin_visit_status
+
+ """
+ return next(self.origin_visit_status_get(origin, visit), None)
+
+ def origin_visit_status_get(
+ self, origin: str, visit: int,
+ ) -> Iterator[OriginVisitStatusRow]:
+ """Return all origin visit statuses for a given visit
+
+ """
+ statuses = [
+ s
+ for s in self._origin_visit_statuses.get_from_partition_key((origin,))
+ if s.visit == visit
+ ]
+ statuses.sort(key=lambda s: s.date, reverse=True)
+ return iter(statuses)
+
class InMemoryStorage(CassandraStorage):
_cql_runner: InMemoryCqlRunner # type: ignore
@@ -527,8 +627,6 @@
def reset(self):
self._cql_runner = InMemoryCqlRunner()
- self._origin_visits = {}
- self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {}
self._persons = {}
# {object_type: {id: {authority: [metadata]}}}
@@ -566,289 +664,6 @@
def check_config(self, *, check_write: bool) -> bool:
return True
- def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
- for origin in origins:
- if origin.url not in self._origin_visits:
- self._origin_visits[origin.url] = []
- return super().origin_add(origins)
-
- def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
- for visit in visits:
- origin = self.origin_get_one(visit.origin)
- if not origin: # Cannot add a visit without an origin
- raise StorageArgumentException("Unknown origin %s", visit.origin)
-
- all_visits = []
- for visit in visits:
- origin_url = visit.origin
- if list(self._cql_runner.origin_get_by_url(origin_url)):
- if visit.visit:
- self.journal_writer.origin_visit_add([visit])
- while len(self._origin_visits[origin_url]) < visit.visit:
- self._origin_visits[origin_url].append(None)
- self._origin_visits[origin_url][visit.visit - 1] = visit
- else:
- # visit ids are in the range [1, +inf[
- visit_id = len(self._origin_visits[origin_url]) + 1
- visit = attr.evolve(visit, visit=visit_id)
- self.journal_writer.origin_visit_add([visit])
- self._origin_visits[origin_url].append(visit)
- visit_key = (origin_url, visit.visit)
- self._objects[visit_key].append(("origin_visit", None))
- assert visit.visit is not None
- self._origin_visit_status_add_one(
- OriginVisitStatus(
- origin=visit.origin,
- visit=visit.visit,
- date=visit.date,
- status="created",
- snapshot=None,
- )
- )
- all_visits.append(visit)
-
- self._cql_runner.increment_counter("origin_visit", len(all_visits))
-
- return all_visits
-
- def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None:
- """Add an origin visit status without checks. If already present, do nothing.
-
- """
- self.journal_writer.origin_visit_status_add([visit_status])
- visit_key = (visit_status.origin, visit_status.visit)
- self._origin_visit_statuses.setdefault(visit_key, [])
- visit_statuses = self._origin_visit_statuses[visit_key]
- if visit_status not in visit_statuses:
- visit_statuses.append(visit_status)
-
- def origin_visit_status_add(self, visit_statuses: List[OriginVisitStatus],) -> None:
- # First round to check existence (fail early if any is ko)
- for visit_status in visit_statuses:
- origin_url = self.origin_get_one(visit_status.origin)
- if not origin_url:
- raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
-
- for visit_status in visit_statuses:
- self._origin_visit_status_add_one(visit_status)
-
- def _origin_visit_status_get_latest(
- self, origin: str, visit_id: int
- ) -> Tuple[OriginVisit, OriginVisitStatus]:
- """Return a tuple of OriginVisit, latest associated OriginVisitStatus.
-
- """
- assert visit_id >= 1
- visit = self._origin_visits[origin][visit_id - 1]
- assert visit is not None
- visit_key = (origin, visit_id)
-
- visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date)
- return visit, visit_update
-
- def _origin_visit_get_updated(self, origin: str, visit_id: int) -> Dict[str, Any]:
- """Merge origin visit and latest origin visit status
-
- """
- visit, visit_update = self._origin_visit_status_get_latest(origin, visit_id)
- assert visit is not None and visit_update is not None
- return {
- # default to the values in visit
- **visit.to_dict(),
- # override with the last update
- **visit_update.to_dict(),
- # but keep the date of the creation of the origin visit
- "date": visit.date,
- }
-
- def origin_visit_get(
- self,
- origin: str,
- page_token: Optional[str] = None,
- order: ListOrder = ListOrder.ASC,
- limit: int = 10,
- ) -> PagedResult[OriginVisit]:
- next_page_token = None
- page_token = page_token or "0"
- if not isinstance(order, ListOrder):
- raise StorageArgumentException("order must be a ListOrder value")
- if not isinstance(page_token, str):
- raise StorageArgumentException("page_token must be a string.")
-
- visit_from = int(page_token)
- origin_url = self._get_origin_url(origin)
- extra_limit = limit + 1
- visits = sorted(
- self._origin_visits.get(origin_url, []),
- key=lambda v: v.visit,
- reverse=(order == ListOrder.DESC),
- )
-
- if visit_from > 0 and order == ListOrder.ASC:
- visits = [v for v in visits if v.visit > visit_from]
- elif visit_from > 0 and order == ListOrder.DESC:
- visits = [v for v in visits if v.visit < visit_from]
- visits = visits[:extra_limit]
-
- assert len(visits) <= extra_limit
- if len(visits) == extra_limit:
- visits = visits[:limit]
- next_page_token = str(visits[-1].visit)
-
- return PagedResult(results=visits, next_page_token=next_page_token)
-
- def origin_visit_find_by_date(
- self, origin: str, visit_date: datetime.datetime
- ) -> Optional[OriginVisit]:
- origin_url = self._get_origin_url(origin)
- if origin_url in self._origin_visits:
- visits = self._origin_visits[origin_url]
- visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit))
- return visit
- return None
-
- def origin_visit_get_by(self, origin: str, visit: int) -> Optional[OriginVisit]:
- origin_url = self._get_origin_url(origin)
- if origin_url in self._origin_visits and visit <= len(
- self._origin_visits[origin_url]
- ):
- found_visit, _ = self._origin_visit_status_get_latest(origin, visit)
- return found_visit
- return None
-
- def origin_visit_get_latest(
- self,
- origin: str,
- type: Optional[str] = None,
- allowed_statuses: Optional[List[str]] = None,
- require_snapshot: bool = False,
- ) -> Optional[OriginVisit]:
- if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES):
- raise StorageArgumentException(
- f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
- f"{','.join(VISIT_STATUSES)} authorized"
- )
-
- if not list(self._cql_runner.origin_get_by_url(origin)):
- return None
-
- visits = sorted(
- self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True,
- )
- for visit in visits:
- if type is not None and visit.type != type:
- continue
- visit_statuses = self._origin_visit_statuses[origin, visit.visit]
-
- if allowed_statuses is not None:
- visit_statuses = [
- vs for vs in visit_statuses if vs.status in allowed_statuses
- ]
- if require_snapshot:
- visit_statuses = [vs for vs in visit_statuses if vs.snapshot]
-
- if visit_statuses: # we found visit statuses matching criteria
- visit_status = max(visit_statuses, key=lambda vs: (vs.date, vs.visit))
- assert visit.origin == visit_status.origin
- assert visit.visit == visit_status.visit
- return visit
-
- return None
-
- def origin_visit_status_get(
- self,
- origin: str,
- visit: int,
- page_token: Optional[str] = None,
- order: ListOrder = ListOrder.ASC,
- limit: int = 10,
- ) -> PagedResult[OriginVisitStatus]:
- next_page_token = None
- date_from = None
- if page_token is not None:
- date_from = datetime.datetime.fromisoformat(page_token)
-
- visit_statuses = sorted(
- self._origin_visit_statuses.get((origin, visit), []),
- key=lambda v: v.date,
- reverse=(order == ListOrder.DESC),
- )
-
- if date_from is not None:
- if order == ListOrder.ASC:
- visit_statuses = [v for v in visit_statuses if v.date >= date_from]
- elif order == ListOrder.DESC:
- visit_statuses = [v for v in visit_statuses if v.date <= date_from]
-
- # Take one more visit status so we can reuse it as the next page token if any
- visit_statuses = visit_statuses[: limit + 1]
-
- if len(visit_statuses) > limit:
- # last visit status date is the next page token
- next_page_token = str(visit_statuses[-1].date)
- # excluding that visit status from the result to respect the limit size
- visit_statuses = visit_statuses[:limit]
-
- return PagedResult(results=visit_statuses, next_page_token=next_page_token)
-
- def origin_visit_status_get_latest(
- self,
- origin_url: str,
- visit: int,
- allowed_statuses: Optional[List[str]] = None,
- require_snapshot: bool = False,
- ) -> Optional[OriginVisitStatus]:
- if allowed_statuses and not set(allowed_statuses).intersection(VISIT_STATUSES):
- raise StorageArgumentException(
- f"Unknown allowed statuses {','.join(allowed_statuses)}, only "
- f"{','.join(VISIT_STATUSES)} authorized"
- )
-
- if not list(self._cql_runner.origin_get_by_url(origin_url)):
- return None
-
- visit_key = (origin_url, visit)
- visits = self._origin_visit_statuses.get(visit_key)
- if not visits:
- return None
-
- if allowed_statuses is not None:
- visits = [visit for visit in visits if visit.status in allowed_statuses]
- if require_snapshot:
- visits = [visit for visit in visits if visit.snapshot]
-
- visit_status = max(visits, key=lambda v: (v.date, v.visit), default=None)
- return visit_status
-
- def _select_random_origin_visit_by_type(self, type: str) -> str:
- while True:
- url = random.choice(list(self._origin_visits.keys()))
- random_origin_visits = self._origin_visits[url]
- if random_origin_visits[0].type == type:
- return url
-
- def origin_visit_status_get_random(
- self, type: str
- ) -> Optional[Tuple[OriginVisit, OriginVisitStatus]]:
-
- url = self._select_random_origin_visit_by_type(type)
- random_origin_visits = copy.deepcopy(self._origin_visits[url])
- random_origin_visits.reverse()
- back_in_the_day = now() - timedelta(weeks=12) # 3 months back
- # This should be enough for tests
- for visit in random_origin_visits:
- origin_visit, latest_visit_status = self._origin_visit_status_get_latest(
- url, visit.visit
- )
- assert latest_visit_status is not None
- if (
- origin_visit.date > back_in_the_day
- and latest_visit_status.status == "full"
- ):
- return origin_visit, latest_visit_status
- else:
- return None
-
def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None:
self.journal_writer.raw_extrinsic_metadata_add(metadata)
for metadata_entry in metadata:
@@ -995,12 +810,6 @@
self._metadata_authority_key(MetadataAuthority(type=type, url=url))
)
- def _get_origin_url(self, origin):
- if isinstance(origin, str):
- return origin
- else:
- raise TypeError("origin must be a string.")
-
@staticmethod
def _metadata_fetcher_key(fetcher: MetadataFetcher) -> FetcherKey:
return (fetcher.name, fetcher.version)
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -205,16 +205,6 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr_ in (
- "origin_visits",
- "origin_visit_statuses",
- ):
- if exclude and attr_ in exclude:
- continue
- expected_objects = sorted(getattr(src, f"_{attr_}").items())
- got_objects = sorted(getattr(dst, f"_{attr_}").items())
- assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
-
for attr_ in (
"contents",
"skipped_contents",
@@ -223,6 +213,8 @@
"releases",
"snapshots",
"origins",
+ "origin_visits",
+ "origin_visit_statuses",
):
if exclude and attr_ in exclude:
continue
@@ -374,16 +366,6 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr_ in ("origin_visit_statuses",):
- expected_objects = [
- (id, maybe_anonymize(obj))
- for id, obj in sorted(getattr(src, f"_{attr_}").items())
- ]
- got_objects = [
- (id, obj) for id, obj in sorted(getattr(dst, f"_{attr_}").items())
- ]
- assert got_objects == expected_objects, f"Mismatch object list for {attr_}"
-
for attr_ in (
"contents",
"skipped_contents",
@@ -392,6 +374,7 @@
"releases",
"snapshots",
"origins",
+ "origin_visit_statuses",
):
expected_objects = [
(id, nullify_ctime(maybe_anonymize(attr_, obj)))

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 1:10 AM (9 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3232943

Event Timeline