Page MenuHomeSoftware Heritage

D3777.id.diff
No OneTemporary

D3777.id.diff

diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -11,7 +11,6 @@
import functools
import itertools
import random
-import re
from collections import defaultdict
from datetime import timedelta
@@ -56,6 +55,7 @@
DirectoryRow,
DirectoryEntryRow,
ObjectCountRow,
+ OriginRow,
ReleaseRow,
RevisionRow,
RevisionParentRow,
@@ -245,6 +245,7 @@
self._releases = Table(ReleaseRow)
self._snapshots = Table(SnapshotRow)
self._snapshot_branches = Table(SnapshotBranchRow)
+ self._origins = Table(OriginRow)
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
@@ -482,6 +483,40 @@
if count >= limit:
break
+ ##########################
+ # 'origin' table
+ ##########################
+
+ def origin_add_one(self, origin: OriginRow) -> None:
+ self._origins.insert(origin)
+ self.increment_counter("origin", 1)
+
+ def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]:
+ return self._origins.get_from_partition_key((sha1,))
+
+ def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
+ return self.origin_get_by_sha1(origin_url_to_sha1(url))
+
+ def origin_list(
+ self, start_token: int, limit: int
+ ) -> Iterable[Tuple[int, OriginRow]]:
+ """Returns an iterable of (token, origin)"""
+ matches = [
+ (token, row)
+ for (token, partition) in self._origins.data.items()
+ for (clustering_key, row) in partition.items()
+ if token >= start_token
+ ]
+ matches.sort()
+ return matches[0:limit]
+
+ def origin_iter_all(self) -> Iterable[OriginRow]:
+ return (
+ row
+ for (token, partition) in self._origins.data.items()
+ for (clustering_key, row) in partition.items()
+ )
+
class InMemoryStorage(CassandraStorage):
_cql_runner: InMemoryCqlRunner # type: ignore
@@ -492,8 +527,6 @@
def reset(self):
self._cql_runner = InMemoryCqlRunner()
- self._origins = {}
- self._origins_by_sha1 = {}
self._origin_visits = {}
self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {}
self._persons = {}
@@ -533,114 +566,11 @@
def check_config(self, *, check_write: bool) -> bool:
return True
- def _convert_origin(self, t):
- if t is None:
- return None
-
- return t.to_dict()
-
- def origin_get_one(self, origin_url: str) -> Optional[Origin]:
- return self._origins.get(origin_url)
-
- def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]:
- return [self.origin_get_one(origin_url) for origin_url in origins]
-
- def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
- return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
-
- def origin_list(
- self, page_token: Optional[str] = None, limit: int = 100
- ) -> PagedResult[Origin]:
- origin_urls = sorted(self._origins)
- from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0
- next_page_token = None
-
- # Take one more origin so we can reuse it as the next page token if any
- origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]]
-
- if len(origins) > limit:
- # last origin id is the next page token
- next_page_token = str(origins[-1].url)
- # excluding that origin from the result to respect the limit size
- origins = origins[:limit]
-
- assert len(origins) <= limit
-
- return PagedResult(results=origins, next_page_token=next_page_token)
-
- def origin_search(
- self,
- url_pattern: str,
- page_token: Optional[str] = None,
- limit: int = 50,
- regexp: bool = False,
- with_visit: bool = False,
- ) -> PagedResult[Origin]:
- next_page_token = None
- offset = int(page_token) if page_token else 0
-
- origins = self._origins.values()
- if regexp:
- pat = re.compile(url_pattern)
- origins = [orig for orig in origins if pat.search(orig.url)]
- else:
- origins = [orig for orig in origins if url_pattern in orig.url]
-
- if with_visit:
- filtered_origins = []
- for orig in origins:
- visits = (
- self._origin_visit_get_updated(ov.origin, ov.visit)
- for ov in self._origin_visits[orig.url]
- )
- for ov in visits:
- snapshot = ov["snapshot"]
- if snapshot and not list(self.snapshot_missing([snapshot])):
- filtered_origins.append(orig)
- break
- else:
- filtered_origins = origins
-
- # Take one more origin so we can reuse it as the next page token if any
- origins = filtered_origins[offset : offset + limit + 1]
- if len(origins) > limit:
- # next offset
- next_page_token = str(offset + limit)
- # excluding that origin from the result to respect the limit size
- origins = origins[:limit]
-
- assert len(origins) <= limit
- return PagedResult(results=origins, next_page_token=next_page_token)
-
- def origin_count(
- self, url_pattern: str, regexp: bool = False, with_visit: bool = False
- ) -> int:
- actual_page = self.origin_search(
- url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins),
- )
- assert actual_page.next_page_token is None
- return len(actual_page.results)
-
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
- added = 0
for origin in origins:
- if origin.url not in self._origins:
- self.origin_add_one(origin)
- added += 1
-
- self._cql_runner.increment_counter("origin", added)
-
- return {"origin:add": added}
-
- def origin_add_one(self, origin: Origin) -> str:
- if origin.url not in self._origins:
- self.journal_writer.origin_add([origin])
- self._origins[origin.url] = origin
- self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin
- self._origin_visits[origin.url] = []
- self._objects[origin.url].append(("origin", origin.url))
-
- return origin.url
+ if origin.url not in self._origin_visits:
+ self._origin_visits[origin.url] = []
+ return super().origin_add(origins)
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
@@ -651,8 +581,7 @@
all_visits = []
for visit in visits:
origin_url = visit.origin
- if origin_url in self._origins:
- origin = self._origins[origin_url]
+ if list(self._cql_runner.origin_get_by_url(origin_url)):
if visit.visit:
self.journal_writer.origin_visit_add([visit])
while len(self._origin_visits[origin_url]) < visit.visit:
@@ -800,12 +729,11 @@
f"{','.join(VISIT_STATUSES)} authorized"
)
- ori = self._origins.get(origin)
- if not ori:
+ if not list(self._cql_runner.origin_get_by_url(origin)):
return None
visits = sorted(
- self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True,
+ self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True,
)
for visit in visits:
if type is not None and visit.type != type:
@@ -876,8 +804,7 @@
f"{','.join(VISIT_STATUSES)} authorized"
)
- ori = self._origins.get(origin_url)
- if not ori:
+ if not list(self._cql_runner.origin_get_by_url(origin_url)):
return None
visit_key = (origin_url, visit)
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -8,7 +8,9 @@
import swh.storage.api.server as server
import swh.storage.storage
from swh.storage import get_storage
-from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import (
+ TestStorageGeneratedData as _TestStorageGeneratedData,
+)
from swh.storage.tests.test_storage import TestStorage as _TestStorage
# tests are executed using imported classes (TestStorage and
@@ -60,14 +62,36 @@
storage.journal_writer = journal_writer
-class TestStorage(_TestStorage):
- @pytest.mark.skip("content_update is not yet implemented for Cassandra")
- def test_content_update(self):
- pass
-
+class TestStorageApi(_TestStorage):
@pytest.mark.skip(
'The "person" table of the pgsql is a legacy thing, and not '
"supported by the cassandra backend."
)
def test_person_fullname_unicity(self):
pass
+
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+
+class TestStorageApiGeneratedData(_TestStorageGeneratedData):
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_no_visits(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_and_snapshot(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_no_snapshot(self):
+ pass
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -10,7 +10,9 @@
from swh.storage.cassandra.model import BaseRow
from swh.storage.in_memory import SortedList, Table
from swh.storage.tests.test_storage import TestStorage as _TestStorage
-from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import (
+ TestStorageGeneratedData as _TestStorageGeneratedData,
+)
# tests are executed using imported classes (TestStorage and
@@ -168,3 +170,25 @@
@pytest.mark.skip("content_update is not yet implemented for Cassandra")
def test_content_update(self):
pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+
+class TestInMemoryStorageGeneratedData(_TestStorageGeneratedData):
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_no_visits(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_and_snapshot(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_no_snapshot(self):
+ pass
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -206,7 +206,6 @@
assert got_persons == expected_persons
for attr_ in (
- "origins",
"origin_visits",
"origin_visit_statuses",
):
@@ -223,6 +222,7 @@
"revisions",
"releases",
"snapshots",
+ "origins",
):
if exclude and attr_ in exclude:
continue
@@ -379,10 +379,7 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr_ in (
- "origins",
- "origin_visit_statuses",
- ):
+ for attr_ in ("origin_visit_statuses",):
expected_objects = [
(id, maybe_anonymize(attr_, obj))
for id, obj in sorted(getattr(src, f"_{attr_}").items())
@@ -399,6 +396,7 @@
"revisions",
"releases",
"snapshots",
+ "origins",
):
expected_objects = [
(id, nullify_ctime(maybe_anonymize(attr_, obj)))

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 9:33 AM (12 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3224726

Event Timeline