diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -122,8 +122,17 @@ 'target_types': target_types }) - def origin_get(self, origin): - return self.post('origin/get', {'origin': origin}) + def origin_get(self, origins=None, *, origin=None): + if origin is None: + if origins is None: + raise TypeError('origin_get expected 1 argument') + else: + assert origins is None + origins = origin + warnings.warn("argument 'origin' of origin_get was renamed " + "to 'origins' in v0.0.123.", + DeprecationWarning) + return self.post('origin/get', {'origins': origins}) def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False): diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -596,30 +596,30 @@ origin_cols = ['id', 'type', 'url'] - def origin_get_with(self, type, url, cur=None): + def origin_get_with(self, origins, cur=None): """Retrieve the origin id from its type and url if found.""" cur = self._cursor(cur) - query = """SELECT %s - FROM origin - WHERE type=%%s AND url=%%s - """ % ','.join(self.origin_cols) + query = """SELECT %s FROM (VALUES %%s) as t(type, url) + LEFT JOIN origin + ON (t.type=origin.type AND t.url=origin.url) + """ % ','.join('origin.' + col for col in self.origin_cols) - cur.execute(query, (type, url)) - return cur.fetchone() + yield from execute_values_generator( + cur, query, origins) - def origin_get(self, id, cur=None): + def origin_get(self, ids, cur=None): """Retrieve the origin per its identifier. """ cur = self._cursor(cur) - query = """SELECT %s - FROM origin WHERE id=%%s - """ % ','.join(self.origin_cols) + query = """SELECT %s FROM (VALUES %%s) as t(id) + LEFT JOIN origin ON t.id = origin.id + """ % ','.join('origin.' + col for col in self.origin_cols) - cur.execute(query, (id,)) - return cur.fetchone() + yield from execute_values_generator( + cur, query, ((id,) for id in ids)) def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, cur=None): diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -778,13 +778,14 @@ } for obj in objs] return ret - def origin_get(self, origin): + def origin_get(self, origins): """Return the origin either identified by its id or its tuple (type, url). Args: - origin: dictionary representing the individual origin to find. - This dict has either the keys type and url: + origin: a list of dictionaries representing the individual + origins to find. + These dicts have either the keys type and url: - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) - url (bytes): the url the origin points to @@ -804,18 +805,47 @@ ValueError: if the keys does not match (url and type) nor id. """ - if 'id' in origin: - origin_id = origin['id'] - elif 'type' in origin and 'url' in origin: - origin_id = self._origin_id(origin) + if isinstance(origins, dict): + # Old API + return_single = True + origins = [origins] else: - raise ValueError('Origin must have either id or (type and url).') - origin = None - # self._origin_id can return None - if origin_id is not None and origin_id <= len(self._origins): - origin = copy.deepcopy(self._origins[origin_id-1]) - origin['id'] = origin_id - return origin + return_single = False + + # Sanity check to be error-compatible with the pgsql backend + if any('id' in origin for origin in origins) \ + and not all('id' in origin for origin in origins): + raise ValueError( + 'Either all origins or none at all should have an "id".') + if any('type' in origin and 'url' in origin for origin in origins) \ + and not all('type' in origin and 'url' in origin + for origin in origins): + raise ValueError( + 'Either all origins or none at all should have a ' + '"type" and an "url".') + + results = [] + for origin in origins: + if 'id' in origin: + origin_id = origin['id'] + elif 'type' in origin and 'url' in origin: + origin_id = self._origin_id(origin) + else: + raise ValueError( + 'Origin must have either id or (type and url).') + origin = None + # self._origin_id can return None + if origin_id is not None and origin_id <= len(self._origins): + origin = copy.deepcopy(self._origins[origin_id-1]) + origin['id'] = origin_id + results.append(origin) + + if return_single: + assert len(results) == 1 + if results[0] is not None: + return results[0] + else: + return results def origin_get_range(self, origin_from=1, origin_count=100): """Retrieve ``origin_count`` origins whose ids are greater diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1081,13 +1081,14 @@ origin_keys = ['id', 'type', 'url'] @db_transaction(statement_timeout=500) - def origin_get(self, origin, db=None, cur=None): + def origin_get(self, origins, db=None, cur=None): """Return the origin either identified by its id or its tuple (type, url). Args: - origin: dictionary representing the individual origin to find. - This dict has either the keys type and url: + origin: a list of dictionaries representing the individual + origins to find. + These dicts have either the keys type and url: - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) - url (bytes): the url the origin points to @@ -1107,17 +1108,44 @@ ValueError: if the keys does not match (url and type) nor id. """ - origin_id = origin.get('id') - if origin_id: # check lookup per id first - ori = db.origin_get(origin_id, cur) - elif 'type' in origin and 'url' in origin: # or lookup per type, url - ori = db.origin_get_with(origin['type'], origin['url'], cur) + if isinstance(origins, dict): + # Old API + return_single = True + origins = [origins] + elif len(origins) == 0: + return [] + else: + return_single = False + + origin_ids = [origin.get('id') for origin in origins] + origin_types_and_urls = [(origin.get('type'), origin.get('url')) + for origin in origins] + if any(origin_ids): + # Lookup per ID + if all(origin_ids): + results = db.origin_get(origin_ids, cur) + else: + raise ValueError( + 'Either all origins or none at all should have an "id".') + elif any(type_ and url for (type_, url) in origin_types_and_urls): + # Lookup per type + URL + if all(type_ and url for (type_, url) in origin_types_and_urls): + results = db.origin_get_with(origin_types_and_urls, cur) + else: + raise ValueError( + 'Either all origins or none at all should have a ' + '"type" and an "url".') else: # unsupported lookup raise ValueError('Origin must have either id or (type and url).') - if ori: - return dict(zip(self.origin_keys, ori)) - return None + results = [dict(zip(self.origin_keys, result)) + for result in results] + if return_single: + assert len(results) == 1 + if results[0]['id'] is not None: + return results[0] + else: + return [None if res['id'] is None else res for res in results] @db_transaction_generator() def origin_search(self, url_pattern, offset=0, limit=50, @@ -1211,9 +1239,10 @@ exists. """ - data = db.origin_get_with(origin['type'], origin['url'], cur) - if data: - return data[0] + id_ = list(db.origin_get_with( + [(origin['type'], origin['url'])], cur))[0][0] + if id_: + return id_ return db.origin_add(origin['type'], origin['url'], cur) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -958,21 +958,21 @@ self.assertEqual(id, id2) def test_origin_add(self): - origin0 = self.storage.origin_get(self.origin) + origin0 = self.storage.origin_get([self.origin])[0] self.assertIsNone(origin0) origin1, origin2 = self.storage.origin_add([self.origin, self.origin2]) - actual_origin = self.storage.origin_get({ + actual_origin = self.storage.origin_get([{ 'url': self.origin['url'], 'type': self.origin['type'], - }) + }])[0] self.assertEqual(actual_origin['id'], origin1['id']) - actual_origin2 = self.storage.origin_get({ + actual_origin2 = self.storage.origin_get([{ 'url': self.origin2['url'], 'type': self.origin2['type'], - }) + }])[0] self.assertEqual(actual_origin2['id'], origin2['id']) def test_origin_add_twice(self): @@ -981,13 +981,13 @@ self.assertEqual(add1, add2) - def test_origin_get(self): + def test_origin_get_legacy(self): self.assertIsNone(self.storage.origin_get(self.origin)) id = self.storage.origin_add_one(self.origin) # lookup per type and url (returns id) - actual_origin0 = self.storage.origin_get({'url': self.origin['url'], - 'type': self.origin['type']}) + actual_origin0 = self.storage.origin_get( + {'url': self.origin['url'], 'type': self.origin['type']}) self.assertEqual(actual_origin0['id'], id) # lookup per id (returns dict) @@ -997,6 +997,33 @@ 'type': self.origin['type'], 'url': self.origin['url']}) + def test_origin_get(self): + self.assertIsNone(self.storage.origin_get(self.origin)) + id = self.storage.origin_add_one(self.origin) + + # lookup per type and url (returns id) + actual_origin0 = self.storage.origin_get( + [{'url': self.origin['url'], 'type': self.origin['type']}]) + self.assertEqual(len(actual_origin0), 1, actual_origin0) + self.assertEqual(actual_origin0[0]['id'], id) + + # lookup per id (returns dict) + actual_origin1 = self.storage.origin_get([{'id': id}]) + + self.assertEqual(len(actual_origin1), 1, actual_origin1) + self.assertEqual(actual_origin1[0], {'id': id, + 'type': self.origin['type'], + 'url': self.origin['url']}) + + def test_origin_get_consistency(self): + self.assertIsNone(self.storage.origin_get(self.origin)) + id = self.storage.origin_add_one(self.origin) + + with self.assertRaises(ValueError): + self.storage.origin_get([ + {'url': self.origin['url'], 'type': self.origin['type']}, + {'id': id}]) + def test_origin_search(self): found_origins = list(self.storage.origin_search(self.origin['url'])) self.assertEqual(len(found_origins), 0) @@ -1055,7 +1082,7 @@ def test_origin_visit_add(self): # given - self.assertIsNone(self.storage.origin_get(self.origin2)) + self.assertIsNone(self.storage.origin_get([self.origin2])[0]) origin_id = self.storage.origin_add_one(self.origin2) self.assertIsNotNone(origin_id) @@ -2143,14 +2170,21 @@ keys_to_check) - def test_origin_get_invalid_id(self): - + def test_origin_get_invalid_id_legacy(self): invalid_origin_id = 1 origin_info = self.storage.origin_get({'id': invalid_origin_id}) self.assertIsNone(origin_info) - origin_visits = list(self.storage.origin_visit_get(invalid_origin_id)) + origin_visits = list(self.storage.origin_visit_get( + invalid_origin_id)) + self.assertEqual(origin_visits, []) + + def test_origin_get_invalid_id(self): + origin_info = self.storage.origin_get([{'id': 1}, {'id': 2}]) + self.assertEqual(origin_info, [None, None]) + + origin_visits = list(self.storage.origin_visit_get(1)) self.assertEqual(origin_visits, []) @given(gen_origins(min_size=100, max_size=100))