Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7066449
D3777.id.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
12 KB
Subscribers
None
D3777.id.diff
View Options
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -11,7 +11,6 @@
import functools
import itertools
import random
-import re
from collections import defaultdict
from datetime import timedelta
@@ -56,6 +55,7 @@
DirectoryRow,
DirectoryEntryRow,
ObjectCountRow,
+ OriginRow,
ReleaseRow,
RevisionRow,
RevisionParentRow,
@@ -245,6 +245,7 @@
self._releases = Table(ReleaseRow)
self._snapshots = Table(SnapshotRow)
self._snapshot_branches = Table(SnapshotBranchRow)
+ self._origins = Table(OriginRow)
self._stat_counters = defaultdict(int)
def increment_counter(self, object_type: str, nb: int):
@@ -482,6 +483,40 @@
if count >= limit:
break
+ ##########################
+ # 'origin' table
+ ##########################
+
+ def origin_add_one(self, origin: OriginRow) -> None:
+ self._origins.insert(origin)
+ self.increment_counter("origin", 1)
+
+ def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]:
+ return self._origins.get_from_partition_key((sha1,))
+
+ def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
+ return self.origin_get_by_sha1(origin_url_to_sha1(url))
+
+ def origin_list(
+ self, start_token: int, limit: int
+ ) -> Iterable[Tuple[int, OriginRow]]:
+ """Returns an iterable of (token, origin)"""
+ matches = [
+ (token, row)
+ for (token, partition) in self._origins.data.items()
+ for (clustering_key, row) in partition.items()
+ if token >= start_token
+ ]
+ matches.sort()
+ return matches[0:limit]
+
+ def origin_iter_all(self) -> Iterable[OriginRow]:
+ return (
+ row
+ for (token, partition) in self._origins.data.items()
+ for (clustering_key, row) in partition.items()
+ )
+
class InMemoryStorage(CassandraStorage):
_cql_runner: InMemoryCqlRunner # type: ignore
@@ -492,8 +527,6 @@
def reset(self):
self._cql_runner = InMemoryCqlRunner()
- self._origins = {}
- self._origins_by_sha1 = {}
self._origin_visits = {}
self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {}
self._persons = {}
@@ -533,114 +566,11 @@
def check_config(self, *, check_write: bool) -> bool:
return True
- def _convert_origin(self, t):
- if t is None:
- return None
-
- return t.to_dict()
-
- def origin_get_one(self, origin_url: str) -> Optional[Origin]:
- return self._origins.get(origin_url)
-
- def origin_get(self, origins: List[str]) -> Iterable[Optional[Origin]]:
- return [self.origin_get_one(origin_url) for origin_url in origins]
-
- def origin_get_by_sha1(self, sha1s: List[bytes]) -> List[Optional[Dict[str, Any]]]:
- return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
-
- def origin_list(
- self, page_token: Optional[str] = None, limit: int = 100
- ) -> PagedResult[Origin]:
- origin_urls = sorted(self._origins)
- from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0
- next_page_token = None
-
- # Take one more origin so we can reuse it as the next page token if any
- origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]]
-
- if len(origins) > limit:
- # last origin id is the next page token
- next_page_token = str(origins[-1].url)
- # excluding that origin from the result to respect the limit size
- origins = origins[:limit]
-
- assert len(origins) <= limit
-
- return PagedResult(results=origins, next_page_token=next_page_token)
-
- def origin_search(
- self,
- url_pattern: str,
- page_token: Optional[str] = None,
- limit: int = 50,
- regexp: bool = False,
- with_visit: bool = False,
- ) -> PagedResult[Origin]:
- next_page_token = None
- offset = int(page_token) if page_token else 0
-
- origins = self._origins.values()
- if regexp:
- pat = re.compile(url_pattern)
- origins = [orig for orig in origins if pat.search(orig.url)]
- else:
- origins = [orig for orig in origins if url_pattern in orig.url]
-
- if with_visit:
- filtered_origins = []
- for orig in origins:
- visits = (
- self._origin_visit_get_updated(ov.origin, ov.visit)
- for ov in self._origin_visits[orig.url]
- )
- for ov in visits:
- snapshot = ov["snapshot"]
- if snapshot and not list(self.snapshot_missing([snapshot])):
- filtered_origins.append(orig)
- break
- else:
- filtered_origins = origins
-
- # Take one more origin so we can reuse it as the next page token if any
- origins = filtered_origins[offset : offset + limit + 1]
- if len(origins) > limit:
- # next offset
- next_page_token = str(offset + limit)
- # excluding that origin from the result to respect the limit size
- origins = origins[:limit]
-
- assert len(origins) <= limit
- return PagedResult(results=origins, next_page_token=next_page_token)
-
- def origin_count(
- self, url_pattern: str, regexp: bool = False, with_visit: bool = False
- ) -> int:
- actual_page = self.origin_search(
- url_pattern, regexp=regexp, with_visit=with_visit, limit=len(self._origins),
- )
- assert actual_page.next_page_token is None
- return len(actual_page.results)
-
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
- added = 0
for origin in origins:
- if origin.url not in self._origins:
- self.origin_add_one(origin)
- added += 1
-
- self._cql_runner.increment_counter("origin", added)
-
- return {"origin:add": added}
-
- def origin_add_one(self, origin: Origin) -> str:
- if origin.url not in self._origins:
- self.journal_writer.origin_add([origin])
- self._origins[origin.url] = origin
- self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin
- self._origin_visits[origin.url] = []
- self._objects[origin.url].append(("origin", origin.url))
-
- return origin.url
+ if origin.url not in self._origin_visits:
+ self._origin_visits[origin.url] = []
+ return super().origin_add(origins)
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
@@ -651,8 +581,7 @@
all_visits = []
for visit in visits:
origin_url = visit.origin
- if origin_url in self._origins:
- origin = self._origins[origin_url]
+ if list(self._cql_runner.origin_get_by_url(origin_url)):
if visit.visit:
self.journal_writer.origin_visit_add([visit])
while len(self._origin_visits[origin_url]) < visit.visit:
@@ -800,12 +729,11 @@
f"{','.join(VISIT_STATUSES)} authorized"
)
- ori = self._origins.get(origin)
- if not ori:
+ if not list(self._cql_runner.origin_get_by_url(origin)):
return None
visits = sorted(
- self._origin_visits[ori.url], key=lambda v: (v.date, v.visit), reverse=True,
+ self._origin_visits[origin], key=lambda v: (v.date, v.visit), reverse=True,
)
for visit in visits:
if type is not None and visit.type != type:
@@ -876,8 +804,7 @@
f"{','.join(VISIT_STATUSES)} authorized"
)
- ori = self._origins.get(origin_url)
- if not ori:
+ if not list(self._cql_runner.origin_get_by_url(origin_url)):
return None
visit_key = (origin_url, visit)
diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py
--- a/swh/storage/tests/test_api_client.py
+++ b/swh/storage/tests/test_api_client.py
@@ -8,7 +8,9 @@
import swh.storage.api.server as server
import swh.storage.storage
from swh.storage import get_storage
-from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import (
+ TestStorageGeneratedData as _TestStorageGeneratedData,
+)
from swh.storage.tests.test_storage import TestStorage as _TestStorage
# tests are executed using imported classes (TestStorage and
@@ -60,14 +62,36 @@
storage.journal_writer = journal_writer
-class TestStorage(_TestStorage):
- @pytest.mark.skip("content_update is not yet implemented for Cassandra")
- def test_content_update(self):
- pass
-
+class TestStorageApi(_TestStorage):
@pytest.mark.skip(
'The "person" table of the pgsql is a legacy thing, and not '
"supported by the cassandra backend."
)
def test_person_fullname_unicity(self):
pass
+
+ @pytest.mark.skip("content_update is not yet implemented for Cassandra")
+ def test_content_update(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+
+class TestStorageApiGeneratedData(_TestStorageGeneratedData):
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_no_visits(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_and_snapshot(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_no_snapshot(self):
+ pass
diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py
--- a/swh/storage/tests/test_in_memory.py
+++ b/swh/storage/tests/test_in_memory.py
@@ -10,7 +10,9 @@
from swh.storage.cassandra.model import BaseRow
from swh.storage.in_memory import SortedList, Table
from swh.storage.tests.test_storage import TestStorage as _TestStorage
-from swh.storage.tests.test_storage import TestStorageGeneratedData # noqa
+from swh.storage.tests.test_storage import (
+ TestStorageGeneratedData as _TestStorageGeneratedData,
+)
# tests are executed using imported classes (TestStorage and
@@ -168,3 +170,25 @@
@pytest.mark.skip("content_update is not yet implemented for Cassandra")
def test_content_update(self):
pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+
+class TestInMemoryStorageGeneratedData(_TestStorageGeneratedData):
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_no_visits(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_and_snapshot(self):
+ pass
+
+ @pytest.mark.skip("Not supported by Cassandra")
+ def test_origin_count_with_visit_with_visits_no_snapshot(self):
+ pass
diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py
--- a/swh/storage/tests/test_replay.py
+++ b/swh/storage/tests/test_replay.py
@@ -206,7 +206,6 @@
assert got_persons == expected_persons
for attr_ in (
- "origins",
"origin_visits",
"origin_visit_statuses",
):
@@ -223,6 +222,7 @@
"revisions",
"releases",
"snapshots",
+ "origins",
):
if exclude and attr_ in exclude:
continue
@@ -379,10 +379,7 @@
got_persons = set(dst._persons.values())
assert got_persons == expected_persons
- for attr_ in (
- "origins",
- "origin_visit_statuses",
- ):
+ for attr_ in ("origin_visit_statuses",):
expected_objects = [
(id, maybe_anonymize(attr_, obj))
for id, obj in sorted(getattr(src, f"_{attr_}").items())
@@ -399,6 +396,7 @@
"revisions",
"releases",
"snapshots",
+ "origins",
):
expected_objects = [
(id, nullify_ctime(maybe_anonymize(attr_, obj)))
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Nov 5 2024, 9:33 AM (12 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3224726
Attached To
D3777: in_memory: Remove InMemoryStorage.origin_* and implement InMemoryCqlRunner.origin_*
Event Timeline
Log In to Comment