Page MenuHomeSoftware Heritage

D3605.id12700.diff
No OneTemporary

D3605.id12700.diff

diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -685,39 +685,17 @@
return results
- def origin_get(self, origins):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- else:
- return_single = False
-
- if any("id" in origin for origin in origins):
- raise StorageArgumentException("Origin ids are not supported.")
-
- results = [self.origin_get_one(origin) for origin in origins]
-
- if return_single:
- assert len(results) == 1
- return results[0]
- else:
- return results
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ return [self.origin_get_one(origin) for origin in origins]
- def origin_get_one(self, origin: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- if "id" in origin:
- raise StorageArgumentException("Origin ids are not supported.")
- if "url" not in origin:
- raise StorageArgumentException("Missing origin url")
- rows = self._cql_runner.origin_get_by_url(origin["url"])
+ def origin_get_one(self, origin_url: str) -> Optional[Origin]:
+ """Given an origin url, return the origin if it exists, None otherwise
- rows = list(rows)
+ """
+ rows = list(self._cql_runner.origin_get_by_url(origin_url))
if rows:
assert len(rows) == 1
- result = rows[0]._asdict()
- return {
- "url": result["url"],
- }
+ return Origin(url=rows[0].url)
else:
return None
@@ -770,12 +748,7 @@
def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]:
origins = list(origins)
- known_origins = [
- Origin.from_dict(d)
- for d in self.origin_get([origin.to_dict() for origin in origins])
- if d is not None
- ]
- to_add = [origin for origin in origins if origin not in known_origins]
+ to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(to_add)
for origin in to_add:
self._cql_runner.origin_add_one(origin)
@@ -783,7 +756,7 @@
def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin})
+ origin = self.origin_get_one(visit.origin)
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -822,7 +795,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin})
+ origin_url = self.origin_get_one(visit_status.origin)
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -670,43 +670,11 @@
return t.to_dict()
- def origin_get(self, origins):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- else:
- return_single = False
+ def origin_get_one(self, origin_url: str) -> Optional[Origin]:
+ return self._origins.get(origin_url)
- # Sanity check to be error-compatible with the pgsql backend
- if any("id" in origin for origin in origins) and not all(
- "id" in origin for origin in origins
- ):
- raise StorageArgumentException(
- 'Either all origins or none at all should have an "id".'
- )
- if any("url" in origin for origin in origins) and not all(
- "url" in origin for origin in origins
- ):
- raise StorageArgumentException(
- "Either all origins or none at all should have " 'an "url" key.'
- )
-
- results = []
- for origin in origins:
- result = None
- if "url" in origin:
- if origin["url"] in self._origins:
- result = self._origins[origin["url"]]
- else:
- raise StorageArgumentException("Origin must have an url.")
- results.append(self._convert_origin(result))
-
- if return_single:
- assert len(results) == 1
- return results[0]
- else:
- return results
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ return [self.origin_get_one(origin_url) for origin_url in origins]
def origin_get_by_sha1(self, sha1s):
return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
@@ -803,7 +771,7 @@
def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin})
+ origin = self.origin_get_one(visit.origin)
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -855,7 +823,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin})
+ origin_url = self.origin_get_one(visit_status.origin)
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
diff --git a/swh/storage/interface.py b/swh/storage/interface.py
--- a/swh/storage/interface.py
+++ b/swh/storage/interface.py
@@ -944,28 +944,15 @@
...
@remote_api_endpoint("origin/get")
- def origin_get(self, origins):
- """Return origins, either all identified by their ids or all
- identified by tuples (type, url).
-
- If the url is given and the type is omitted, one of the origins with
- that url is returned.
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ """Return origins.
Args:
- origin: a list of dictionaries representing the individual
- origins to find.
- These dicts have the key url:
-
- - url (bytes): the url the origin points to
+ origin: a list of urls to find
Returns:
- dict: the origin dictionary with the keys:
-
- - id: origin's id
- - url: origin's url
-
- Raises:
- ValueError: if the url or the id don't exist.
+ the list of associated existing origin model objects. The unknown origins
+ will be returned as None at the same index as the input.
"""
...
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -802,7 +802,7 @@
self, visits: Iterable[OriginVisit], db=None, cur=None
) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin}, db=db, cur=cur)
+ origin = self.origin_get([visit.origin], db=db, cur=cur)[0]
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -851,7 +851,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin}, db=db, cur=cur)
+ origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0]
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
@@ -973,28 +973,17 @@
@timed
@db_transaction(statement_timeout=500)
- def origin_get(self, origins, db=None, cur=None):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- elif len(origins) == 0:
- return []
- else:
- return_single = False
-
- origin_urls = [origin["url"] for origin in origins]
- results = db.origin_get_by_url(origin_urls, cur)
-
- results = [dict(zip(db.origin_cols, result)) for result in results]
- if return_single:
- assert len(results) == 1
- if results[0]["url"] is not None:
- return results[0]
- else:
- return None
- else:
- return [None if res["url"] is None else res for res in results]
+ def origin_get(
+ self, origins: Iterable[str], db=None, cur=None
+ ) -> Iterable[Optional[Origin]]:
+ origin_urls = list(origins)
+ rows = db.origin_get_by_url(origin_urls, cur)
+ result: List[Optional[Origin]] = []
+ for row in rows:
+ origin_d = dict(zip(db.origin_cols, row))
+ url = origin_d["url"]
+ result.append(None if url is None else Origin(url=url))
+ return result
@timed
@db_transaction_generator(statement_timeout=500)
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -1169,22 +1169,19 @@
}
def test_origin_add(self, swh_storage, sample_data):
- origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
+ origins = list(sample_data.origins[:2])
+ origin_urls = [o.url for o in origins]
- assert swh_storage.origin_get([origin_dict])[0] is None
+ assert swh_storage.origin_get(origin_urls) == [None, None]
- stats = swh_storage.origin_add([origin, origin2])
+ stats = swh_storage.origin_add(origins)
assert stats == {"origin:add": 2}
- actual_origin = swh_storage.origin_get([origin_dict])[0]
- assert actual_origin["url"] == origin.url
-
- actual_origin2 = swh_storage.origin_get([origin2_dict])[0]
- assert actual_origin2["url"] == origin2.url
+ actual_origins = swh_storage.origin_get(origin_urls)
+ assert actual_origins == origins
assert set(swh_storage.journal_writer.journal.objects) == set(
- [("origin", origin), ("origin", origin2),]
+ [("origin", origins[0]), ("origin", origins[1]),]
)
swh_storage.refresh_stat_counters()
@@ -1192,7 +1189,6 @@
def test_origin_add_from_generator(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
def _ori_gen():
yield origin
@@ -1201,11 +1197,8 @@
stats = swh_storage.origin_add(_ori_gen())
assert stats == {"origin:add": 2}
- actual_origin = swh_storage.origin_get([origin_dict])[0]
- assert actual_origin["url"] == origin.url
-
- actual_origin2 = swh_storage.origin_get([origin2_dict])[0]
- assert actual_origin2["url"] == origin2.url
+ actual_origins = swh_storage.origin_get([origin.url, origin2.url])
+ assert actual_origins == [origin, origin2]
assert set(swh_storage.journal_writer.journal.objects) == set(
[("origin", origin), ("origin", origin2),]
@@ -1216,7 +1209,6 @@
def test_origin_add_twice(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
add1 = swh_storage.origin_add([origin, origin2])
assert set(swh_storage.journal_writer.journal.objects) == set(
@@ -1230,33 +1222,17 @@
)
assert add2 == {"origin:add": 0}
- def test_origin_get_legacy(self, swh_storage, sample_data):
- origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
-
- assert swh_storage.origin_get(origin_dict) is None
- swh_storage.origin_add([origin])
-
- actual_origin0 = swh_storage.origin_get(origin_dict)
- assert actual_origin0["url"] == origin.url
-
def test_origin_get(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
- assert swh_storage.origin_get(origin_dict) is None
- assert swh_storage.origin_get([origin_dict]) == [None]
+ assert swh_storage.origin_get([origin.url]) == [None]
swh_storage.origin_add([origin])
- actual_origins = swh_storage.origin_get([origin_dict])
- assert len(actual_origins) == 1
-
- actual_origin0 = swh_storage.origin_get(origin_dict)
- assert actual_origin0 == actual_origins[0]
- assert actual_origin0["url"] == origin.url
+ actual_origins = swh_storage.origin_get([origin.url])
+ assert actual_origins == [origin]
- actual_origins = swh_storage.origin_get([origin_dict, {"url": "not://exists"}])
- assert actual_origins == [origin_dict, None]
+ actual_origins = swh_storage.origin_get([origin.url, "not://exists"])
+ assert actual_origins == [origin, None]
def _generate_random_visits(self, nb_visits=100, start=0, end=7):
"""Generate random visits within the last 2 months (to avoid
@@ -1422,7 +1398,7 @@
def test_origin_get_by_sha1(self, swh_storage, sample_data):
origin = sample_data.origin
- assert swh_storage.origin_get(origin.to_dict()) is None
+ assert swh_storage.origin_get([origin.url])[0] is None
swh_storage.origin_add([origin])
origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)]))
@@ -1431,7 +1407,7 @@
def test_origin_get_by_sha1_not_found(self, swh_storage, sample_data):
origin = sample_data.origin
- assert swh_storage.origin_get(origin.to_dict()) is None
+ assert swh_storage.origin_get([origin.url])[0] is None
origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)]))
assert len(origins) == 1
assert origins[0] is None

File Metadata

Mime Type
text/plain
Expires
Dec 21 2024, 1:48 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3226495

Event Timeline