diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1533,6 +1533,12 @@ tool: id of the tool used to extract metadata metadata (jsonb): the metadata retrieved at the time and location """ + if isinstance(origin_id, str): + origin = self.origin_get({'url': origin_id}) + if not origin: + return + origin_id = origin['id'] + if isinstance(ts, str): ts = dateutil.parser.parse(ts) @@ -1567,6 +1573,12 @@ - provider_url (str) """ + if isinstance(origin_id, str): + origin = self.origin_get({'url': origin_id}) + if not origin: + return + origin_id = origin['id'] + metadata = [] for item in self._origin_metadata[origin_id]: item = copy.deepcopy(item) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1722,6 +1722,12 @@ Returns: id (int): the origin_metadata unique id """ + if isinstance(origin_id, str): + origin = self.origin_get({'url': origin_id}, db=db, cur=cur) + if not origin: + return + origin_id = origin['id'] + if isinstance(ts, str): ts = dateutil.parser.parse(ts) @@ -1750,6 +1756,12 @@ - provider_url (str) """ + if isinstance(origin_id, str): + origin = self.origin_get({'url': origin_id}, db=db, cur=cur) + if not origin: + return + origin_id = origin['id'] + for line in db.origin_metadata_get_by(origin_id, provider_type, cur): yield dict(zip(db.origin_metadata_get_cols, line)) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -3419,12 +3419,19 @@ # then self.assertTrue(provider_id, actual_provider['id']) - def test_origin_metadata_add(self): + @given(strategies.booleans()) + def test_origin_metadata_add(self, use_url): + self.reset_storage() # given - origin_id = self.storage.origin_add([self.origin])[0]['id'] + origin = self.storage.origin_add([self.origin])[0] + origin_id = origin['id'] + if use_url: + origin = origin['url'] + else: + origin = origin['id'] origin_metadata0 = list(self.storage.origin_metadata_get_by( - origin_id)) - self.assertTrue(len(origin_metadata0) == 0) + origin)) + self.assertEqual(len(origin_metadata0), 0, origin_metadata0) tools = self.storage.tool_add([self.metadata_tool]) tool = tools[0] @@ -3441,19 +3448,19 @@ # when adding for the same origin 2 metadatas self.storage.origin_metadata_add( - origin_id, + origin, self.origin_metadata['discovery_date'], provider['id'], tool['id'], self.origin_metadata['metadata']) self.storage.origin_metadata_add( - origin_id, + origin, '2015-01-01 23:00:00+00', provider['id'], tool['id'], self.origin_metadata2['metadata']) actual_om = list(self.storage.origin_metadata_get_by( - origin_id)) + origin)) # then self.assertCountEqual( [item['origin_id'] for item in actual_om],