Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7124226
D3605.id12700.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
14 KB
Subscribers
None
D3605.id12700.diff
View Options
diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py
--- a/swh/storage/cassandra/storage.py
+++ b/swh/storage/cassandra/storage.py
@@ -685,39 +685,17 @@
return results
- def origin_get(self, origins):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- else:
- return_single = False
-
- if any("id" in origin for origin in origins):
- raise StorageArgumentException("Origin ids are not supported.")
-
- results = [self.origin_get_one(origin) for origin in origins]
-
- if return_single:
- assert len(results) == 1
- return results[0]
- else:
- return results
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ return [self.origin_get_one(origin) for origin in origins]
- def origin_get_one(self, origin: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- if "id" in origin:
- raise StorageArgumentException("Origin ids are not supported.")
- if "url" not in origin:
- raise StorageArgumentException("Missing origin url")
- rows = self._cql_runner.origin_get_by_url(origin["url"])
+ def origin_get_one(self, origin_url: str) -> Optional[Origin]:
+ """Given an origin url, return the origin if it exists, None otherwise
- rows = list(rows)
+ """
+ rows = list(self._cql_runner.origin_get_by_url(origin_url))
if rows:
assert len(rows) == 1
- result = rows[0]._asdict()
- return {
- "url": result["url"],
- }
+ return Origin(url=rows[0].url)
else:
return None
@@ -770,12 +748,7 @@
def origin_add(self, origins: Iterable[Origin]) -> Dict[str, int]:
origins = list(origins)
- known_origins = [
- Origin.from_dict(d)
- for d in self.origin_get([origin.to_dict() for origin in origins])
- if d is not None
- ]
- to_add = [origin for origin in origins if origin not in known_origins]
+ to_add = [ori for ori in origins if self.origin_get_one(ori.url) is None]
self.journal_writer.origin_add(to_add)
for origin in to_add:
self._cql_runner.origin_add_one(origin)
@@ -783,7 +756,7 @@
def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin})
+ origin = self.origin_get_one(visit.origin)
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -822,7 +795,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin})
+ origin_url = self.origin_get_one(visit_status.origin)
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -670,43 +670,11 @@
return t.to_dict()
- def origin_get(self, origins):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- else:
- return_single = False
+ def origin_get_one(self, origin_url: str) -> Optional[Origin]:
+ return self._origins.get(origin_url)
- # Sanity check to be error-compatible with the pgsql backend
- if any("id" in origin for origin in origins) and not all(
- "id" in origin for origin in origins
- ):
- raise StorageArgumentException(
- 'Either all origins or none at all should have an "id".'
- )
- if any("url" in origin for origin in origins) and not all(
- "url" in origin for origin in origins
- ):
- raise StorageArgumentException(
- "Either all origins or none at all should have " 'an "url" key.'
- )
-
- results = []
- for origin in origins:
- result = None
- if "url" in origin:
- if origin["url"] in self._origins:
- result = self._origins[origin["url"]]
- else:
- raise StorageArgumentException("Origin must have an url.")
- results.append(self._convert_origin(result))
-
- if return_single:
- assert len(results) == 1
- return results[0]
- else:
- return results
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ return [self.origin_get_one(origin_url) for origin_url in origins]
def origin_get_by_sha1(self, sha1s):
return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
@@ -803,7 +771,7 @@
def origin_visit_add(self, visits: Iterable[OriginVisit]) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin})
+ origin = self.origin_get_one(visit.origin)
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -855,7 +823,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin})
+ origin_url = self.origin_get_one(visit_status.origin)
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
diff --git a/swh/storage/interface.py b/swh/storage/interface.py
--- a/swh/storage/interface.py
+++ b/swh/storage/interface.py
@@ -944,28 +944,15 @@
...
@remote_api_endpoint("origin/get")
- def origin_get(self, origins):
- """Return origins, either all identified by their ids or all
- identified by tuples (type, url).
-
- If the url is given and the type is omitted, one of the origins with
- that url is returned.
+ def origin_get(self, origins: Iterable[str]) -> Iterable[Optional[Origin]]:
+ """Return origins.
Args:
- origin: a list of dictionaries representing the individual
- origins to find.
- These dicts have the key url:
-
- - url (bytes): the url the origin points to
+ origin: a list of urls to find
Returns:
- dict: the origin dictionary with the keys:
-
- - id: origin's id
- - url: origin's url
-
- Raises:
- ValueError: if the url or the id don't exist.
+ the list of associated existing origin model objects. The unknown origins
+ will be returned as None at the same index as the input.
"""
...
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -802,7 +802,7 @@
self, visits: Iterable[OriginVisit], db=None, cur=None
) -> Iterable[OriginVisit]:
for visit in visits:
- origin = self.origin_get({"url": visit.origin}, db=db, cur=cur)
+ origin = self.origin_get([visit.origin], db=db, cur=cur)[0]
if not origin: # Cannot add a visit without an origin
raise StorageArgumentException("Unknown origin %s", visit.origin)
@@ -851,7 +851,7 @@
) -> None:
# First round to check existence (fail early if any is ko)
for visit_status in visit_statuses:
- origin_url = self.origin_get({"url": visit_status.origin}, db=db, cur=cur)
+ origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0]
if not origin_url:
raise StorageArgumentException(f"Unknown origin {visit_status.origin}")
@@ -973,28 +973,17 @@
@timed
@db_transaction(statement_timeout=500)
- def origin_get(self, origins, db=None, cur=None):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- elif len(origins) == 0:
- return []
- else:
- return_single = False
-
- origin_urls = [origin["url"] for origin in origins]
- results = db.origin_get_by_url(origin_urls, cur)
-
- results = [dict(zip(db.origin_cols, result)) for result in results]
- if return_single:
- assert len(results) == 1
- if results[0]["url"] is not None:
- return results[0]
- else:
- return None
- else:
- return [None if res["url"] is None else res for res in results]
+ def origin_get(
+ self, origins: Iterable[str], db=None, cur=None
+ ) -> Iterable[Optional[Origin]]:
+ origin_urls = list(origins)
+ rows = db.origin_get_by_url(origin_urls, cur)
+ result: List[Optional[Origin]] = []
+ for row in rows:
+ origin_d = dict(zip(db.origin_cols, row))
+ url = origin_d["url"]
+ result.append(None if url is None else Origin(url=url))
+ return result
@timed
@db_transaction_generator(statement_timeout=500)
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -1169,22 +1169,19 @@
}
def test_origin_add(self, swh_storage, sample_data):
- origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
+ origins = list(sample_data.origins[:2])
+ origin_urls = [o.url for o in origins]
- assert swh_storage.origin_get([origin_dict])[0] is None
+ assert swh_storage.origin_get(origin_urls) == [None, None]
- stats = swh_storage.origin_add([origin, origin2])
+ stats = swh_storage.origin_add(origins)
assert stats == {"origin:add": 2}
- actual_origin = swh_storage.origin_get([origin_dict])[0]
- assert actual_origin["url"] == origin.url
-
- actual_origin2 = swh_storage.origin_get([origin2_dict])[0]
- assert actual_origin2["url"] == origin2.url
+ actual_origins = swh_storage.origin_get(origin_urls)
+ assert actual_origins == origins
assert set(swh_storage.journal_writer.journal.objects) == set(
- [("origin", origin), ("origin", origin2),]
+ [("origin", origins[0]), ("origin", origins[1]),]
)
swh_storage.refresh_stat_counters()
@@ -1192,7 +1189,6 @@
def test_origin_add_from_generator(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
def _ori_gen():
yield origin
@@ -1201,11 +1197,8 @@
stats = swh_storage.origin_add(_ori_gen())
assert stats == {"origin:add": 2}
- actual_origin = swh_storage.origin_get([origin_dict])[0]
- assert actual_origin["url"] == origin.url
-
- actual_origin2 = swh_storage.origin_get([origin2_dict])[0]
- assert actual_origin2["url"] == origin2.url
+ actual_origins = swh_storage.origin_get([origin.url, origin2.url])
+ assert actual_origins == [origin, origin2]
assert set(swh_storage.journal_writer.journal.objects) == set(
[("origin", origin), ("origin", origin2),]
@@ -1216,7 +1209,6 @@
def test_origin_add_twice(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
add1 = swh_storage.origin_add([origin, origin2])
assert set(swh_storage.journal_writer.journal.objects) == set(
@@ -1230,33 +1222,17 @@
)
assert add2 == {"origin:add": 0}
- def test_origin_get_legacy(self, swh_storage, sample_data):
- origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
-
- assert swh_storage.origin_get(origin_dict) is None
- swh_storage.origin_add([origin])
-
- actual_origin0 = swh_storage.origin_get(origin_dict)
- assert actual_origin0["url"] == origin.url
-
def test_origin_get(self, swh_storage, sample_data):
origin, origin2 = sample_data.origins[:2]
- origin_dict, origin2_dict = [o.to_dict() for o in [origin, origin2]]
- assert swh_storage.origin_get(origin_dict) is None
- assert swh_storage.origin_get([origin_dict]) == [None]
+ assert swh_storage.origin_get([origin.url]) == [None]
swh_storage.origin_add([origin])
- actual_origins = swh_storage.origin_get([origin_dict])
- assert len(actual_origins) == 1
-
- actual_origin0 = swh_storage.origin_get(origin_dict)
- assert actual_origin0 == actual_origins[0]
- assert actual_origin0["url"] == origin.url
+ actual_origins = swh_storage.origin_get([origin.url])
+ assert actual_origins == [origin]
- actual_origins = swh_storage.origin_get([origin_dict, {"url": "not://exists"}])
- assert actual_origins == [origin_dict, None]
+ actual_origins = swh_storage.origin_get([origin.url, "not://exists"])
+ assert actual_origins == [origin, None]
def _generate_random_visits(self, nb_visits=100, start=0, end=7):
"""Generate random visits within the last 2 months (to avoid
@@ -1422,7 +1398,7 @@
def test_origin_get_by_sha1(self, swh_storage, sample_data):
origin = sample_data.origin
- assert swh_storage.origin_get(origin.to_dict()) is None
+ assert swh_storage.origin_get([origin.url])[0] is None
swh_storage.origin_add([origin])
origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)]))
@@ -1431,7 +1407,7 @@
def test_origin_get_by_sha1_not_found(self, swh_storage, sample_data):
origin = sample_data.origin
- assert swh_storage.origin_get(origin.to_dict()) is None
+ assert swh_storage.origin_get([origin.url])[0] is None
origins = list(swh_storage.origin_get_by_sha1([sha1(origin.url)]))
assert len(origins) == 1
assert origins[0] is None
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Dec 21 2024, 1:48 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3226495
Attached To
D3605: storage*: origin_get(Iterable[str]) -> Iterable[Optional[Origin]]
Event Timeline
Log In to Comment