diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -11,7 +11,6 @@ import functools import itertools import random -import re from collections import defaultdict from datetime import timedelta @@ -56,6 +55,7 @@ DirectoryRow, DirectoryEntryRow, ObjectCountRow, + OriginRow, ReleaseRow, RevisionRow, RevisionParentRow, @@ -245,6 +245,7 @@ self._releases = Table(ReleaseRow) self._snapshots = Table(SnapshotRow) self._snapshot_branches = Table(SnapshotBranchRow) + self._origins = Table(OriginRow) self._stat_counters = defaultdict(int) def increment_counter(self, object_type: str, nb: int): @@ -482,6 +483,40 @@ if count >= limit: break + ########################## + # 'origin' table + ########################## + + def origin_add_one(self, origin: OriginRow) -> None: + self._origins.insert(origin) + self.increment_counter("origin", 1) + + def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]: + return self._origins.get_from_partition_key((sha1,)) + + def origin_get_by_url(self, url: str) -> Iterable[OriginRow]: + return self.origin_get_by_sha1(origin_url_to_sha1(url)) + + def origin_list( + self, start_token: int, limit: int + ) -> Iterable[Tuple[int, OriginRow]]: + """Returns an iterable of (token, origin)""" + matches = [ + (token, row) + for (token, partition) in self._origins.data.items() + for (clustering_key, row) in partition.items() + if token >= start_token + ] + matches.sort() + return matches[0:limit] + + def origin_iter_all(self) -> Iterable[OriginRow]: + return ( + row + for (token, partition) in self._origins.data.items() + for (clustering_key, row) in partition.items() + ) + class InMemoryStorage(CassandraStorage): _cql_runner: InMemoryCqlRunner # type: ignore @@ -492,8 +527,6 @@ def reset(self): self._cql_runner = InMemoryCqlRunner() - self._origins = {} - self._origins_by_sha1 = {} self._origin_visits = {} self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} self._persons = {} @@ -533,114 +566,11 @@ def check_config(self, *, check_write: bool) -> bool: return True - def _convert_origin(self, t): - if t is None: - return None - - return t.to_dict() - - def origin_get_one(self, origin_url: str) -> Optional[Origin]: - return self._origins.get(origin_url) - - def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]: - return [self.origin_get_one(origin_url) for origin_url in origins] - - def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]: - return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] - - def origin_list( - self, page_token: Optional[str] = None, limit: int = 100 - ) -> PagedResult[Origin]: - origin_urls = sorted(self._origins) - from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0 - next_page_token = None - - # Take one more origin so we can reuse it as the next page token if any - origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]] - - if len(origins) > limit: - # last origin id is the next page token - next_page_token = str(origins[-1].url) - # excluding that origin from the result to respect the limit size - origins = origins[:limit] - - assert len(origins) <= limit - - return PagedResult(results=origins, next_page_token=next_page_token) - - def origin_search( - self, - url_pattern: str, - page_token: Optional[str] = None, - limit: int = 50, - regexp: bool = False, - with_visit: bool = False, - ) -> PagedResult[Origin]: - next_page_token = None - offset = int(page_token) if page_token else 0 - - origins = self._origins.values() - if regexp: - pat = re.compile(url_pattern) - origins = [orig for orig in origins if pat.search(orig.url)] - else: - origins = [orig for orig in origins if url_pattern in orig.url] - - if with_visit: - filtered_origins = [] - for orig in origins: - visits = ( - self._origin_visit_get_updated(ov.origin, ov.visit) - for ov in self._origin_visits[orig.url] - ) - for ov in visits: - snapshot = ov["snapshot"] - if snapshot and not list(self.snapshot_missing([snapshot])): - filtered_origins.append(orig) - break - else: - filtered_origins = origins - - # Take one more origin so we can reuse it as the next page token if any - origins = filtered_origins[offset : offset + limit + 1] - if len(origins) > limit: - # next offset - next_page_token = str(offset + limit) - # excluding that origin from the result to respect the limit size - origins = origins[:limit] - - assert len(origins) <= limit - return PagedResult(results=origins, next_page_token=next_page_token) - - def origin_count( - self, url_pattern: str, regexp: bool = False, with_visit: bool = False - ) -> int: - actual_page = self.origin_search( - url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins), - ) - assert actual_page.next_page_token is None - return len(actual_page.results) - def origin_add(self, origins: List[Origin]) -> Dict[str, int]: - added = 0 for origin in origins: - if origin.url not in self._origins: - self.origin_add_one(origin) - added += 1 - - self._cql_runner.increment_counter("origin", added) - - return {"origin:add": added} - - def origin_add_one(self, origin: Origin) -> str: - if origin.url not in self._origins: - self.journal_writer.origin_add([origin]) - self._origins[origin.url] = origin - self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin - self._origin_visits[origin.url] = [] - self._objects[origin.url].append(("origin", origin.url)) - - return origin.url + if origin.url not in self._origin_visits: + self._origin_visits[origin.url] = [] + return super().origin_add(origins) def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: for visit in visits: @@ -651,8 +581,7 @@ all_visits = [] for visit in visits: origin_url = visit.origin - if origin_url in self._origins: - origin = self._origins[origin_url] + if list(self._cql_runner.origin_get_by_url(origin_url)): if visit.visit: self.journal_writer.origin_visit_add([visit]) while len(self._origin_visits[origin_url]) < visit.visit: @@ -800,12 +729,11 @@ f"{','.join(VISIT_STATUSES)} authorized" ) - ori = self._origins.get(origin) - if not ori: + if not list(self._cql_runner.origin_get_by_url(origin)): return None visits = sorted( - self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True, + self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True, ) for visit in visits: if type is not None and visit.type != type: @@ -876,8 +804,7 @@ f"{','.join(VISIT_STATUSES)} authorized" ) - ori = self._origins.get(origin_url) - if not ori: + if not list(self._cql_runner.origin_get_by_url(origin_url)): return None visit_key = (origin_url, visit) diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -8,7 +8,9 @@ import swh.storage.api.server as server import swh.storage.storage from swh.storage import get_storage -from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa +from swh.storage.tests.test_storage import ( + TestStorageGeneratedData as _TestStorageGeneratedData, +) from swh.storage.tests.test_storage import TestStorage as _TestStorage # tests are executed using imported classes (TestStorage and @@ -60,14 +62,36 @@ storage.journal_writer = journal_writer -class TestStorage(_TestStorage): - @pytest.mark.skip("content_update is not yet implemented for Cassandra") - def test_content_update(self): - pass - +class TestStorageApi(_TestStorage): @pytest.mark.skip( 'The "person" table of the pgsql is a legacy thing, and not ' "supported by the cassandra backend." ) def test_person_fullname_unicity(self): pass + + @pytest.mark.skip("content_update is not yet implemented for Cassandra") + def test_content_update(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count(self): + pass + + +class TestStorageApiGeneratedData(_TestStorageGeneratedData): + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_no_visits(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_with_visits_and_snapshot(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_with_visits_no_snapshot(self): + pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -10,7 +10,9 @@ from swh.storage.cassandra.model import BaseRow from swh.storage.in_memory import SortedList, Table from swh.storage.tests.test_storage import TestStorage as _TestStorage -from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa +from swh.storage.tests.test_storage import ( + TestStorageGeneratedData as _TestStorageGeneratedData, +) # tests are executed using imported classes (TestStorage and @@ -168,3 +170,25 @@ @pytest.mark.skip("content_update is not yet implemented for Cassandra") def test_content_update(self): pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count(self): + pass + + +class TestInMemoryStorageGeneratedData(_TestStorageGeneratedData): + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_no_visits(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_with_visits_and_snapshot(self): + pass + + @pytest.mark.skip("Not supported by Cassandra") + def test_origin_count_with_visit_with_visits_no_snapshot(self): + pass diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -206,7 +206,6 @@ assert got_persons == expected_persons for attr_ in ( - "origins", "origin_visits", "origin_visit_statuses", ): @@ -223,6 +222,7 @@ "revisions", "releases", "snapshots", + "origins", ): if exclude and attr_ in exclude: continue @@ -379,10 +379,7 @@ got_persons = set(dst._persons.values()) assert got_persons == expected_persons - for attr_ in ( - "origins", - "origin_visit_statuses", - ): + for attr_ in ("origin_visit_statuses",): expected_objects = [ (id, maybe_anonymize(attr_, obj)) for id, obj in sorted(getattr(src, f"_{attr_}").items()) @@ -399,6 +396,7 @@ "revisions", "releases", "snapshots", + "origins", ): expected_objects = [ (id, nullify_ctime(maybe_anonymize(attr_, obj)))