Page MenuHomeSoftware Heritage

D1088.id3455.diff
No OneTemporary

D1088.id3455.diff

diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py
--- a/swh/storage/api/client.py
+++ b/swh/storage/api/client.py
@@ -122,8 +122,17 @@
'target_types': target_types
})
- def origin_get(self, origin):
- return self.post('origin/get', {'origin': origin})
+ def origin_get(self, origins=None, *, origin=None):
+ if origin is None:
+ if origins is None:
+ raise TypeError('origin_get expected 1 argument')
+ else:
+ assert origins is None
+ origins = origin
+ warnings.warn("argument 'origin' of origin_get was renamed "
+ "to 'origins' in v0.0.123.",
+ DeprecationWarning)
+ return self.post('origin/get', {'origins': origins})
def origin_search(self, url_pattern, offset=0, limit=50, regexp=False,
with_visit=False):
diff --git a/swh/storage/db.py b/swh/storage/db.py
--- a/swh/storage/db.py
+++ b/swh/storage/db.py
@@ -596,30 +596,30 @@
origin_cols = ['id', 'type', 'url']
- def origin_get_with(self, type, url, cur=None):
+ def origin_get_with(self, origins, cur=None):
"""Retrieve the origin id from its type and url if found."""
cur = self._cursor(cur)
- query = """SELECT %s
- FROM origin
- WHERE type=%%s AND url=%%s
- """ % ','.join(self.origin_cols)
+ query = """SELECT %s FROM (VALUES %%s) as t(type, url)
+ LEFT JOIN origin
+ ON (t.type=origin.type AND t.url=origin.url)
+ """ % ','.join('origin.' + col for col in self.origin_cols)
- cur.execute(query, (type, url))
- return cur.fetchone()
+ yield from execute_values_generator(
+ cur, query, origins)
- def origin_get(self, id, cur=None):
+ def origin_get(self, ids, cur=None):
"""Retrieve the origin per its identifier.
"""
cur = self._cursor(cur)
- query = """SELECT %s
- FROM origin WHERE id=%%s
- """ % ','.join(self.origin_cols)
+ query = """SELECT %s FROM (VALUES %%s) as t(id)
+ LEFT JOIN origin ON t.id = origin.id
+ """ % ','.join('origin.' + col for col in self.origin_cols)
- cur.execute(query, (id,))
- return cur.fetchone()
+ yield from execute_values_generator(
+ cur, query, ((id,) for id in ids))
def origin_search(self, url_pattern, offset=0, limit=50,
regexp=False, with_visit=False, cur=None):
diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -778,13 +778,14 @@
} for obj in objs]
return ret
- def origin_get(self, origin):
- """Return the origin either identified by its id or its tuple
- (type, url).
+ def origin_get(self, origins):
+ """Return origins, either all identified by their ids or all
+ identified by tuples (type, url).
Args:
- origin: dictionary representing the individual origin to find.
- This dict has either the keys type and url:
+ origin: a list of dictionaries representing the individual
+ origins to find.
+ These dicts have either the keys type and url:
- type (FIXME: enum TBD): the origin type ('git', 'wget', ...)
- url (bytes): the url the origin points to
@@ -804,18 +805,46 @@
ValueError: if the keys does not match (url and type) nor id.
"""
- if 'id' in origin:
- origin_id = origin['id']
- elif 'type' in origin and 'url' in origin:
- origin_id = self._origin_id(origin)
+ if isinstance(origins, dict):
+ # Old API
+ return_single = True
+ origins = [origins]
else:
- raise ValueError('Origin must have either id or (type and url).')
- origin = None
- # self._origin_id can return None
- if origin_id is not None and origin_id <= len(self._origins):
- origin = copy.deepcopy(self._origins[origin_id-1])
- origin['id'] = origin_id
- return origin
+ return_single = False
+
+ # Sanity check to be error-compatible with the pgsql backend
+ if any('id' in origin for origin in origins) \
+ and not all('id' in origin for origin in origins):
+ raise ValueError(
+ 'Either all origins or none at all should have an "id".')
+ if any('type' in origin and 'url' in origin for origin in origins) \
+ and not all('type' in origin and 'url' in origin
+ for origin in origins):
+ raise ValueError(
+ 'Either all origins or none at all should have a '
+ '"type" and an "url".')
+
+ results = []
+ for origin in origins:
+ if 'id' in origin:
+ origin_id = origin['id']
+ elif 'type' in origin and 'url' in origin:
+ origin_id = self._origin_id(origin)
+ else:
+ raise ValueError(
+ 'Origin must have either id or (type and url).')
+ origin = None
+ # self._origin_id can return None
+ if origin_id is not None and origin_id <= len(self._origins):
+ origin = copy.deepcopy(self._origins[origin_id-1])
+ origin['id'] = origin_id
+ results.append(origin)
+
+ if return_single:
+ assert len(results) == 1
+ return results[0]
+ else:
+ return results
def origin_get_range(self, origin_from=1, origin_count=100):
"""Retrieve ``origin_count`` origins whose ids are greater
diff --git a/swh/storage/storage.py b/swh/storage/storage.py
--- a/swh/storage/storage.py
+++ b/swh/storage/storage.py
@@ -1081,13 +1081,14 @@
origin_keys = ['id', 'type', 'url']
@db_transaction(statement_timeout=500)
- def origin_get(self, origin, db=None, cur=None):
- """Return the origin either identified by its id or its tuple
- (type, url).
+ def origin_get(self, origins, db=None, cur=None):
+ """Return origins, either all identified by their ids or all
+ identified by tuples (type, url).
Args:
- origin: dictionary representing the individual origin to find.
- This dict has either the keys type and url:
+ origin: a list of dictionaries representing the individual
+ origins to find.
+ These dicts have either the keys type and url:
- type (FIXME: enum TBD): the origin type ('git', 'wget', ...)
- url (bytes): the url the origin points to
@@ -1107,17 +1108,46 @@
ValueError: if the keys does not match (url and type) nor id.
"""
- origin_id = origin.get('id')
- if origin_id: # check lookup per id first
- ori = db.origin_get(origin_id, cur)
- elif 'type' in origin and 'url' in origin: # or lookup per type, url
- ori = db.origin_get_with(origin['type'], origin['url'], cur)
+ if isinstance(origins, dict):
+ # Old API
+ return_single = True
+ origins = [origins]
+ elif len(origins) == 0:
+ return []
+ else:
+ return_single = False
+
+ origin_ids = [origin.get('id') for origin in origins]
+ origin_types_and_urls = [(origin.get('type'), origin.get('url'))
+ for origin in origins]
+ if any(origin_ids):
+ # Lookup per ID
+ if all(origin_ids):
+ results = db.origin_get(origin_ids, cur)
+ else:
+ raise ValueError(
+ 'Either all origins or none at all should have an "id".')
+ elif any(type_ and url for (type_, url) in origin_types_and_urls):
+ # Lookup per type + URL
+ if all(type_ and url for (type_, url) in origin_types_and_urls):
+ results = db.origin_get_with(origin_types_and_urls, cur)
+ else:
+ raise ValueError(
+ 'Either all origins or none at all should have a '
+ '"type" and an "url".')
else: # unsupported lookup
raise ValueError('Origin must have either id or (type and url).')
- if ori:
- return dict(zip(self.origin_keys, ori))
- return None
+ results = [dict(zip(self.origin_keys, result))
+ for result in results]
+ if return_single:
+ assert len(results) == 1
+ if results[0]['id'] is not None:
+ return results[0]
+ else:
+ return None
+ else:
+ return [None if res['id'] is None else res for res in results]
@db_transaction_generator()
def origin_search(self, url_pattern, offset=0, limit=50,
@@ -1211,9 +1241,10 @@
exists.
"""
- data = db.origin_get_with(origin['type'], origin['url'], cur)
- if data:
- return data[0]
+ origin_id = list(db.origin_get_with(
+ [(origin['type'], origin['url'])], cur))[0][0]
+ if origin_id:
+ return origin_id
return db.origin_add(origin['type'], origin['url'], cur)
diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py
--- a/swh/storage/tests/test_storage.py
+++ b/swh/storage/tests/test_storage.py
@@ -958,21 +958,21 @@
self.assertEqual(id, id2)
def test_origin_add(self):
- origin0 = self.storage.origin_get(self.origin)
+ origin0 = self.storage.origin_get([self.origin])[0]
self.assertIsNone(origin0)
origin1, origin2 = self.storage.origin_add([self.origin, self.origin2])
- actual_origin = self.storage.origin_get({
+ actual_origin = self.storage.origin_get([{
'url': self.origin['url'],
'type': self.origin['type'],
- })
+ }])[0]
self.assertEqual(actual_origin['id'], origin1['id'])
- actual_origin2 = self.storage.origin_get({
+ actual_origin2 = self.storage.origin_get([{
'url': self.origin2['url'],
'type': self.origin2['type'],
- })
+ }])[0]
self.assertEqual(actual_origin2['id'], origin2['id'])
def test_origin_add_twice(self):
@@ -981,13 +981,13 @@
self.assertEqual(add1, add2)
- def test_origin_get(self):
+ def test_origin_get_legacy(self):
self.assertIsNone(self.storage.origin_get(self.origin))
id = self.storage.origin_add_one(self.origin)
# lookup per type and url (returns id)
- actual_origin0 = self.storage.origin_get({'url': self.origin['url'],
- 'type': self.origin['type']})
+ actual_origin0 = self.storage.origin_get(
+ {'url': self.origin['url'], 'type': self.origin['type']})
self.assertEqual(actual_origin0['id'], id)
# lookup per id (returns dict)
@@ -997,6 +997,33 @@
'type': self.origin['type'],
'url': self.origin['url']})
+ def test_origin_get(self):
+ self.assertIsNone(self.storage.origin_get(self.origin))
+ origin_id = self.storage.origin_add_one(self.origin)
+
+ # lookup per type and url (returns id)
+ actual_origin0 = self.storage.origin_get(
+ [{'url': self.origin['url'], 'type': self.origin['type']}])
+ self.assertEqual(len(actual_origin0), 1, actual_origin0)
+ self.assertEqual(actual_origin0[0]['id'], origin_id)
+
+ # lookup per id (returns dict)
+ actual_origin1 = self.storage.origin_get([{'id': origin_id}])
+
+ self.assertEqual(len(actual_origin1), 1, actual_origin1)
+ self.assertEqual(actual_origin1[0], {'id': origin_id,
+ 'type': self.origin['type'],
+ 'url': self.origin['url']})
+
+ def test_origin_get_consistency(self):
+ self.assertIsNone(self.storage.origin_get(self.origin))
+ id = self.storage.origin_add_one(self.origin)
+
+ with self.assertRaises(ValueError):
+ self.storage.origin_get([
+ {'url': self.origin['url'], 'type': self.origin['type']},
+ {'id': id}])
+
def test_origin_search(self):
found_origins = list(self.storage.origin_search(self.origin['url']))
self.assertEqual(len(found_origins), 0)
@@ -1055,7 +1082,7 @@
def test_origin_visit_add(self):
# given
- self.assertIsNone(self.storage.origin_get(self.origin2))
+ self.assertIsNone(self.storage.origin_get([self.origin2])[0])
origin_id = self.storage.origin_add_one(self.origin2)
self.assertIsNotNone(origin_id)
@@ -2143,14 +2170,21 @@
keys_to_check)
- def test_origin_get_invalid_id(self):
-
+ def test_origin_get_invalid_id_legacy(self):
invalid_origin_id = 1
origin_info = self.storage.origin_get({'id': invalid_origin_id})
self.assertIsNone(origin_info)
- origin_visits = list(self.storage.origin_visit_get(invalid_origin_id))
+ origin_visits = list(self.storage.origin_visit_get(
+ invalid_origin_id))
+ self.assertEqual(origin_visits, [])
+
+ def test_origin_get_invalid_id(self):
+ origin_info = self.storage.origin_get([{'id': 1}, {'id': 2}])
+ self.assertEqual(origin_info, [None, None])
+
+ origin_visits = list(self.storage.origin_visit_get(1))
self.assertEqual(origin_visits, [])
@given(gen_origins(min_size=100, max_size=100))

File Metadata

Mime Type
text/plain
Expires
Sun, Aug 24, 6:04 PM (2 d, 3 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3215581

Event Timeline