diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -1146,7 +1146,7 @@ the origin's type. Args: - origin (int): visited origin's identifier + origin (Union[int,str]): visited origin's identifier or URL date: timestamp of such visit type (str): the type of loader used for the visit (hg, git, ...) @@ -1167,7 +1167,10 @@ DeprecationWarning) date = ts - origin_id = origin # TODO: rename the argument + if isinstance(origin, str): + origin_id = self.origin_get({'url': origin})['id'] + else: + origin_id = origin if isinstance(date, str): date = dateutil.parser.parse(date) @@ -1205,7 +1208,7 @@ """Update an origin_visit's status. Args: - origin (int): visited origin's identifier + origin (Union[int,str]): visited origin's identifier or URL visit_id (int): visit's identifier status: visit's new status metadata: data associated to the visit @@ -1216,7 +1219,10 @@ None """ - origin_id = origin # TODO: rename the argument + if isinstance(origin, str): + origin_id = self.origin_get({'url': origin})['id'] + else: + origin_id = origin try: visit = self._origin_visits[origin_id-1][visit_id-1] diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1157,7 +1157,7 @@ the origin's type. Args: - origin: Visited Origin id + origin (Union[int,str]): visited origin's identifier or URL date: timestamp of such visit type (str): the type of loader used for the visit (hg, git, ...) @@ -1178,13 +1178,17 @@ DeprecationWarning) date = ts - origin_id = origin # TODO: rename the argument + if isinstance(origin, str): + origin = self.origin_get({'url': origin}, db=db, cur=cur) + origin_id = origin['id'] + else: + origin = self.origin_get({'id': origin}, db=db, cur=cur) + origin_id = origin['id'] if isinstance(date, str): date = dateutil.parser.parse(date) if type is None: - origin = self.origin_get({'id': origin}) type = origin['type'] visit_id = db.origin_visit_add(origin_id, date, type, cur) @@ -1192,7 +1196,6 @@ if self.journal_writer: # We can write to the journal only after inserting to the # DB, because we want the id of the visit - origin = self.origin_get([{'id': origin_id}], db=db, cur=cur)[0] del origin['id'] self.journal_writer.write_addition('origin_visit', { 'origin': origin, 'date': date, 'type': type, @@ -1211,7 +1214,7 @@ """Update an origin_visit's status. Args: - origin: Visited Origin id + origin (Union[int,str]): visited origin's identifier or URL visit_id: Visit's id status: Visit's new status metadata: Data associated to the visit @@ -1222,7 +1225,10 @@ None """ - origin_id = origin # TODO: rename the argument + if isinstance(origin, str): + origin_id = self.origin_get({'url': origin}, db=db, cur=cur)['id'] + else: + origin_id = origin visit = db.origin_visit_get(origin_id, visit_id, cur=cur) @@ -1274,7 +1280,8 @@ if self.journal_writer: for visit in visits: visit = visit.copy() - origin = self.origin_get([{'id': visit['origin']}])[0] + origin = self.origin_get( + [{'id': visit['origin']}], db=db, cur=cur)[0] visit['origin'] = origin if visit.get('type') is None: visit['type'] = origin['type'] diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1490,6 +1490,46 @@ [('origin', expected_origin), ('origin_visit', data)]) + def test_origin_visit_add_from_url(self): + # given + self.assertIsNone(self.storage.origin_get([self.origin2])[0]) + + origin_id = self.storage.origin_add_one(self.origin2) + origin_url = self.origin2['url'] + self.assertIsNotNone(origin_id) + + # when + origin_visit1 = self.storage.origin_visit_add( + origin_url, + type='git', + date=self.date_visit2) + + actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) + self.assertEqual(actual_origin_visits, + [{ + 'origin': origin_id, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'type': 'git', + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + }]) + + expected_origin = self.origin2.copy() + data = { + 'origin': expected_origin, + 'date': self.date_visit2, + 'visit': origin_visit1['visit'], + 'type': 'git', + 'status': 'ongoing', + 'metadata': None, + 'snapshot': None, + } + self.assertEqual(list(self.journal_writer.objects), + [('origin', expected_origin), + ('origin_visit', data)]) + def test_origin_visit_add_default_type(self): # given self.assertIsNone(self.storage.origin_get([self.origin2])[0])