diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -46,8 +46,9 @@ self._revisions = {} self._releases = {} self._snapshots = {} - self._origins = [] - self._origin_visits = [] + self._origins = {} + self._origins_by_id = [] + self._origin_visits = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} @@ -777,10 +778,14 @@ branches. """ - if origin > len(self._origins) or \ - visit > len(self._origin_visits[origin-1]): + origin_url = self._get_origin_url(origin) + if not origin_url: + return + + if origin_url not in self._origins or \ + visit > len(self._origin_visits[origin_url]): return None - snapshot_id = self._origin_visits[origin-1][visit-1]['snapshot'] + snapshot_id = self._origin_visits[origin_url][visit-1]['snapshot'] if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -814,14 +819,14 @@ or :const:`None` if the snapshot has less than 1000 branches. """ - if isinstance(origin, int): - origin = self.origin_get({'id': origin}) - if not origin: - return - origin = origin['url'] + origin_url = self._get_origin_url(origin) + if not origin_url: + return visit = self.origin_visit_get_latest( - origin, allowed_statuses=allowed_statuses, require_snapshot=True) + origin_url, + allowed_statuses=allowed_statuses, + require_snapshot=True) if visit and visit['snapshot']: snapshot = self.snapshot_get(visit['snapshot']) if not snapshot: @@ -980,19 +985,17 @@ results = [] for origin in origins: + result = None if 'id' in origin: - origin_id = origin['id'] + if origin['id'] <= len(self._origins_by_id): + result = self._origins[self._origins_by_id[origin['id']-1]] elif 'url' in origin: - origin_id = self._origin_id(origin) + if origin['url'] in self._origins: + result = copy.deepcopy(self._origins[origin['url']]) else: raise ValueError( - 'Origin must have either id or (type and url).') - origin = None - # self._origin_id can return None - if origin_id is not None and origin_id <= len(self._origins): - origin = copy.deepcopy(self._origins[origin_id-1]) - origin['id'] = origin_id - results.append(origin) + 'Origin must have either id or url.') + results.append(result) if return_single: assert len(results) == 1 @@ -1015,12 +1018,12 @@ by :meth:`swh.storage.in_memory.Storage.origin_get`. """ origin_from = max(origin_from, 1) - if origin_from <= len(self._origins): + if origin_from <= len(self._origins_by_id): max_idx = origin_from + origin_count - 1 - if max_idx > len(self._origins): - max_idx = len(self._origins) + if max_idx > len(self._origins_by_id): + max_idx = len(self._origins_by_id) for idx in range(origin_from-1, max_idx): - yield copy.deepcopy(self._origins[idx]) + yield copy.deepcopy(self._origins[self._origins_by_id[idx]]) def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): @@ -1041,7 +1044,7 @@ An iterable of dict containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ - origins = self._origins + origins = self._origins.values() if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig['url'])] @@ -1051,6 +1054,8 @@ origins = [orig for orig in origins if len(self._origin_visits[orig['id']-1]) > 0] + origins.sort(key=lambda origin: origin['id']) + origins = copy.deepcopy(origins[offset:offset+limit]) return origins @@ -1110,19 +1115,19 @@ """ origin = copy.deepcopy(origin) assert 'id' not in origin - origin_id = self._origin_id(origin) - if origin_id is None: + if origin['url'] in self._origins: + origin_id = self._origins[origin['url']]['id'] + else: if self.journal_writer: self.journal_writer.write_addition('origin', origin) # origin ids are in the range [1, +inf[ origin_id = len(self._origins) + 1 origin['id'] = origin_id - self._origins.append(origin) - self._origin_visits.append([]) - key = (origin['type'], origin['url']) - self._objects[key].append(('origin', origin_id)) - else: - origin['id'] = origin_id + self._origins_by_id.append(origin['url']) + assert len(self._origins_by_id) == origin_id + self._origins[origin['url']] = origin + self._origin_visits[origin['url']] = [] + self._objects[origin['url']].append(('origin', origin['url'])) return origin_id @@ -1172,39 +1177,39 @@ DeprecationWarning) date = ts - if isinstance(origin, str): - origin_id = self.origin_get({'url': origin})['id'] - else: - origin_id = origin + origin_url = self._get_origin_url(origin) + if origin_url is None: + raise ValueError('Unknown origin.') if isinstance(date, str): date = dateutil.parser.parse(date) visit_ret = None - if origin_id <= len(self._origin_visits): + if origin_url in self._origins: + origin = self._origins[origin_url] # visit ids are in the range [1, +inf[ - visit_id = len(self._origin_visits[origin_id-1]) + 1 + visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' visit = { - 'origin': origin_id, + 'origin': origin, 'date': date, - 'type': type or self._origins[origin_id-1]['type'], + 'type': type or origin['type'], 'status': status, 'snapshot': None, 'metadata': None, 'visit': visit_id } - self._origin_visits[origin_id-1].append(visit) + self._origin_visits[origin_url].append(visit) visit_ret = { - 'origin': origin_id, + 'origin': origin['id'], 'visit': visit_id, } - self._objects[(origin_id, visit_id)].append( + self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) if self.journal_writer: - origin = self.origin_get([{'id': origin_id}])[0] + origin = self._origins[origin_url].copy() del origin['id'] self.journal_writer.write_addition('origin_visit', { **visit, 'origin': origin}) @@ -1227,17 +1232,17 @@ None """ - if isinstance(origin, str): - origin_id = self.origin_get({'url': origin})['id'] - else: - origin_id = origin + origin_url = self._get_origin_url(origin) + if origin_url is None: + raise ValueError('Unknown origin.') try: - visit = self._origin_visits[origin_id-1][visit_id-1] + visit = self._origin_visits[origin_url][visit_id-1] except IndexError: - raise ValueError('Invalid origin_id or visit_id') from None + raise ValueError('Unknown visit_id for this origin') \ + from None if self.journal_writer: - origin = self.origin_get([{'id': origin_id}])[0] + origin = self._origins[origin_url].copy() del origin['id'] self.journal_writer.write_update('origin_visit', { 'origin': origin, 'type': origin['type'], @@ -1246,8 +1251,8 @@ 'date': visit['date'], 'metadata': metadata or visit['metadata'], 'snapshot': snapshot or visit['snapshot']}) - if origin_id > len(self._origin_visits) or \ - visit_id > len(self._origin_visits[origin_id-1]): + if origin_url not in self._origin_visits or \ + visit_id > len(self._origin_visits[origin_url]): return if status: visit['status'] = status @@ -1259,7 +1264,7 @@ def origin_visit_upsert(self, visits): """Add a origin_visits with a specific id and with all its data. If there is already an origin_visit with the same - `(origin_id, visit_id)`, updates it instead of inserting a new one. + `(origin_url, visit_id)`, updates it instead of inserting a new one. Args: visits: iterable of dicts with keys: @@ -1277,31 +1282,27 @@ for visit in visits: if isinstance(visit['date'], str): visit['date'] = dateutil.parser.parse(visit['date']) - origin = visit['origin'] - visit['origin'] = self.origin_get([origin])[0] - if not visit['origin']: - raise ValueError('Unknown origin: %s' % origin) if self.journal_writer: for visit in visits: - visit = copy.deepcopy(visit) + visit = visit.copy() + visit['origin'] = self._origins[visit['origin']['url']].copy() del visit['origin']['id'] self.journal_writer.write_addition('origin_visit', visit) for visit in visits: - origin_id = visit['origin']['id'] visit_id = visit['visit'] + origin_url = visit['origin']['url'] - self._objects[(origin_id, visit_id)].append( + self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) - while len(self._origin_visits[origin_id-1]) < visit_id: - self._origin_visits[origin_id-1].append(None) + while len(self._origin_visits[origin_url]) < visit_id: + self._origin_visits[origin_url].append(None) visit = visit.copy() - visit['origin'] = origin_id - visit = self._origin_visits[origin_id-1][visit_id-1] = visit + visit = self._origin_visits[origin_url][visit_id-1] = visit def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. @@ -1317,13 +1318,9 @@ List of visits. """ - if isinstance(origin, str): - origin = self.origin_get([{'url': origin}])[0] - if not origin: - return - origin = origin['id'] - if origin <= len(self._origin_visits): - visits = self._origin_visits[origin-1] + origin_url = self._get_origin_url(origin) + if origin_url in self._origin_visits: + visits = self._origin_visits[origin_url] if last_visit is not None: visits = visits[last_visit:] if limit is not None: @@ -1332,7 +1329,13 @@ if not visit: continue visit_id = visit['visit'] - yield copy.deepcopy(self._origin_visits[origin-1][visit_id-1]) + visit = copy.deepcopy( + self._origin_visits[origin_url][visit_id-1]) + + # TODO: remove this to return the origin url: + visit['origin'] = self._origins[visit['origin']['url']]['id'] + + yield visit def origin_visit_find_by_date(self, origin, visit_date): """Retrieves the origin visit whose date is closest to the provided @@ -1347,12 +1350,9 @@ A visit. """ - origin = self.origin_get([{'url': origin}])[0] - if not origin: - return - origin = origin['id'] - if origin <= len(self._origin_visits): - visits = self._origin_visits[origin-1] + origin_url = self._get_origin_url(origin) + if origin_url in self._origin_visits: + visits = self._origin_visits[origin_url] return min( visits, key=lambda v: (abs(v['date'] - visit_date), -v['visit'])) @@ -1368,16 +1368,16 @@ it does not exist """ - if isinstance(origin, str): - origin = self.origin_get({'url': origin}) - if not origin: - return - origin = origin['id'] - origin_visit = None - if origin <= len(self._origin_visits) and \ - visit <= len(self._origin_visits[origin-1]): - origin_visit = self._origin_visits[origin-1][visit-1] - return copy.deepcopy(origin_visit) + origin_url = self._get_origin_url(origin) + if origin_url in self._origin_visits and \ + visit <= len(self._origin_visits[origin_url]): + visit = self._origin_visits[origin_url][visit-1] + visit = copy.deepcopy(visit) + + # TODO: remove this to return the origin url: + visit['origin'] = self._origins[visit['origin']['url']]['id'] + + return visit def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False): @@ -1405,11 +1405,10 @@ snapshot (Optional[sha1_git]): identifier of the snapshot associated to the visit """ - origin = self.origin_get({'url': origin}) + origin = self._origins.get(origin) if not origin: return - origin = origin['id'] - visits = self._origin_visits[origin-1] + visits = self._origin_visits[origin['url']] if allowed_statuses is not None: visits = [visit for visit in visits if visit['status'] in allowed_statuses] @@ -1417,6 +1416,12 @@ visits = [visit for visit in visits if visit['snapshot']] + visits = copy.deepcopy(visits) + + # TODO: remove this to return the origin url: + for visit in visits: + visit['origin'] = self._origins[visit['origin']['url']]['id'] + return max(visits, key=lambda v: (v['date'], v['visit']), default=None) def person_get(self, person): @@ -1614,15 +1619,16 @@ key = self._metadata_provider_key(provider) return self._metadata_providers.get(key) - def _origin_id(self, origin): - origin_id = None - for stored_origin in self._origins: - if stored_origin['url'] == origin['url'] \ - and ('type' not in origin - or stored_origin['type'] == origin['type']): - origin_id = stored_origin['id'] - break - return origin_id + def _get_origin_url(self, origin): + if isinstance(origin, str): + return origin + elif isinstance(origin, int): + if origin <= len(self._origins_by_id): + return self._origins_by_id[origin-1] + else: + return None + else: + raise TypeError('origin must be a string or an integer.') def _person_add(self, person): """Add a person in storage. diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1421,10 +1421,21 @@ del actual_origin_2_or_3['id'] del origin3['id'] - self.assertEqual(list(self.journal_writer.objects), - [('origin', self.origin), - ('origin', self.origin2), - ('origin', origin3)]) + objects = list(self.journal_writer.objects) + + if len(objects) == 3: + # current behavior of the pg storage, where 'type' is part of + # the primary key + self.assertEqual(objects, + [('origin', self.origin), + ('origin', self.origin2), + ('origin', origin3)]) + else: + # current behavior of the in-mem storage, and future behavior + # of the pg storage, where 'type' is not part of the PK + self.assertEqual(objects, + [('origin', self.origin), + ('origin', self.origin2)]) def test_origin_get_legacy(self): self.assertIsNone(self.storage.origin_get(self.origin))