diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -661,21 +661,19 @@ origin_cols = ['id', 'type', 'url'] - def origin_get_with(self, origins, cur=None): - """Retrieve the origin id from its type and url if found.""" + def origin_get_by_url(self, origins, cur=None): + """Retrieve origin `(id, type, url)` from urls if found.""" cur = self._cursor(cur) - query = """SELECT %s FROM (VALUES %%s) as t(type, url) - LEFT JOIN origin - ON ((t.type IS NULL OR t.type=origin.type) - AND t.url=origin.url) + query = """SELECT %s FROM (VALUES %%s) as t(url) + LEFT JOIN origin ON t.url = origin.url """ % ','.join('origin.' + col for col in self.origin_cols) yield from execute_values_generator( - cur, query, origins) + cur, query, ((url,) for url in origins)) - def origin_get(self, ids, cur=None): - """Retrieve the origin per its identifier. + def origin_get_by_id(self, ids, cur=None): + """Retrieve origin `(id, type, url)` from ids if found. """ cur = self._cursor(cur) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1189,7 +1189,7 @@ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' visit = { - 'origin': origin, + 'origin': {'url': origin['url']}, 'date': date, 'type': type or origin['type'], 'status': status, @@ -1245,7 +1245,8 @@ if 'id' in origin: del origin['id'] self.journal_writer.write_update('origin_visit', { - 'origin': origin, 'type': origin['type'], + 'origin': origin, + 'type': origin['type'], 'visit': visit_id, 'status': status or visit['status'], 'date': visit['date'], @@ -1302,6 +1303,7 @@ self._origin_visits[origin_url].append(None) visit = visit.copy() + visit['origin'] = {'url': visit['origin']['url']} visit = self._origin_visits[origin_url][visit_id-1] = visit diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1455,19 +1455,18 @@ return_single = False origin_ids = [origin.get('id') for origin in origins] - origin_types_and_urls = [(origin.get('type'), origin.get('url')) - for origin in origins] + origin_urls = [origin.get('url') for origin in origins] if any(origin_ids): # Lookup per ID if all(origin_ids): - results = db.origin_get(origin_ids, cur) + results = db.origin_get_by_id(origin_ids, cur) else: raise ValueError( 'Either all origins or none at all should have an "id".') - elif any(url for (type_, url) in origin_types_and_urls): + elif any(origin_urls): # Lookup per type + URL - if all(url for (type_, url) in origin_types_and_urls): - results = db.origin_get_with(origin_types_and_urls, cur) + if all(origin_urls): + results = db.origin_get_by_url(origin_urls, cur) else: raise ValueError( 'Either all origins or none at all should have ' @@ -1599,8 +1598,8 @@ exists. """ - origin_id = list(db.origin_get_with( - [(origin['type'], origin['url'])], cur))[0][0] + origin_id = list(db.origin_get_by_url( + [origin['url']], cur))[0][0] if origin_id: return origin_id diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1358,8 +1358,7 @@ id = self.storage.origin_add_one(self.origin) - actual_origin = self.storage.origin_get({'url': self.origin['url'], - 'type': self.origin['type']}) + actual_origin = self.storage.origin_get({'url': self.origin['url']}) if self._test_origin_ids: self.assertEqual(actual_origin['id'], id) self.assertEqual(actual_origin['url'], self.origin['url']) @@ -1376,7 +1375,6 @@ actual_origin = self.storage.origin_get([{ 'url': self.origin['url'], - 'type': self.origin['type'], }])[0] if self._test_origin_ids: self.assertEqual(actual_origin['id'], origin1['id']) @@ -1384,7 +1382,6 @@ actual_origin2 = self.storage.origin_get([{ 'url': self.origin2['url'], - 'type': self.origin2['type'], }])[0] if self._test_origin_ids: self.assertEqual(actual_origin2['id'], origin2['id']) @@ -1404,62 +1401,13 @@ self.assertEqual(add1, add2) - def test_origin_get_without_type(self): - origin0 = self.storage.origin_get([self.origin])[0] - self.assertIsNone(origin0) - - origin3 = self.origin2.copy() - origin3['type'] += 'foo' - - origin1, origin2, origin3 = self.storage.origin_add( - [self.origin, self.origin2, origin3]) - - actual_origin = self.storage.origin_get([{ - 'url': self.origin['url'], - }])[0] - if self._test_origin_ids: - self.assertEqual(actual_origin['id'], origin1['id']) - self.assertEqual(actual_origin['url'], origin1['url']) - - actual_origin_2_or_3 = self.storage.origin_get([{ - 'url': self.origin2['url'], - }])[0] - if self._test_origin_ids: - self.assertIn( - actual_origin_2_or_3['id'], - [origin2['id'], origin3['id']]) - self.assertIn( - actual_origin_2_or_3['url'], - [origin2['url'], origin3['url']]) - - if 'id' in actual_origin: - del actual_origin['id'] - del actual_origin_2_or_3['id'] - del origin3['id'] - - objects = list(self.journal_writer.objects) - - if len(objects) == 3: - # current behavior of the pg storage, where 'type' is part of - # the primary key - self.assertEqual(objects, - [('origin', self.origin), - ('origin', self.origin2), - ('origin', origin3)]) - else: - # current behavior of the in-mem storage, and future behavior - # of the pg storage, where 'type' is not part of the PK - self.assertEqual(objects, - [('origin', self.origin), - ('origin', self.origin2)]) - def test_origin_get_legacy(self): self.assertIsNone(self.storage.origin_get(self.origin)) id = self.storage.origin_add_one(self.origin) - # lookup per type and url (returns id) + # lookup per url (returns id) actual_origin0 = self.storage.origin_get( - {'url': self.origin['url'], 'type': self.origin['type']}) + {'url': self.origin['url']}) if self._test_origin_ids: self.assertEqual(actual_origin0['id'], id) self.assertEqual(actual_origin0['url'], self.origin['url']) @@ -1476,9 +1424,9 @@ self.assertIsNone(self.storage.origin_get(self.origin)) origin_id = self.storage.origin_add_one(self.origin) - # lookup per type and url (returns id) + # lookup per url (returns id) actual_origin0 = self.storage.origin_get( - [{'url': self.origin['url'], 'type': self.origin['type']}]) + [{'url': self.origin['url']}]) self.assertEqual(len(actual_origin0), 1, actual_origin0) if self._test_origin_ids: self.assertEqual(actual_origin0[0]['id'], origin_id) @@ -1499,7 +1447,7 @@ with self.assertRaises(ValueError): self.storage.origin_get([ - {'url': self.origin['url'], 'type': self.origin['type']}, + {'url': self.origin['url']}, {'id': id}]) def test_origin_search_single_result(self): @@ -3795,12 +3743,11 @@ origin_visits = list(self.storage.origin_visit_get(1)) self.assertEqual(origin_visits, []) - @given(strategies.sets(origins().map(lambda x: tuple(x.to_dict().items())), - min_size=6, max_size=15)) + @given(strategies.lists(origins().map(lambda x: x.to_dict()), + unique_by=lambda x: x['url'], + min_size=6, max_size=15)) def test_origin_get_range(self, new_origins): self.reset_storage() - new_origins = list(map(dict, new_origins)) - nb_origins = len(new_origins) self.storage.origin_add(new_origins)