diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,9 +25,6 @@ return datetime.datetime.now(tz=datetime.timezone.utc) -OriginVisitKey = collections.namedtuple('OriginVisitKey', 'origin date') - - class Storage: def __init__(self): self._contents = {} @@ -38,8 +35,8 @@ self._revisions = {} self._releases = {} self._snapshots = {} - self._origins = {} - self._origin_visits = {} + self._origins = [] + self._origin_visits = [] self._origin_metadata = defaultdict(list) self._tools = {} self._metadata_providers = {} @@ -550,6 +547,9 @@ - **target** (:class:`bytes`): identifier of the target (currently a ``sha1_git`` for all object kinds, or the name of the target branch for aliases) + + Raises: + ValueError: if the origin or visit id does not exist. """ snapshot_id = snapshot['id'] if snapshot_id not in self._snapshots: @@ -561,7 +561,12 @@ '_sorted_branch_names': sorted(snapshot['branches']) } self._objects[snapshot_id].append(('snapshot', snapshot_id)) - self._origin_visits[visit]['snapshot'] = snapshot_id + if origin <= len(self._origin_visits) and \ + visit <= len(self._origin_visits[origin-1]): + self._origin_visits[origin-1][visit-1]['snapshot'] = snapshot_id + else: + raise ValueError('Origin with id %s does not exist or has no visit' + ' with id %s' % (origin, visit)) def snapshot_get(self, snapshot_id): """Get the content, possibly partial, of a snapshot with the given id @@ -612,9 +617,10 @@ branches. """ - if visit not in self._origin_visits: + if origin > len(self._origins) or \ + visit > len(self._origin_visits[origin-1]): return None - snapshot_id = self._origin_visits[visit]['snapshot'] + snapshot_id = self._origin_visits[origin-1][visit-1]['snapshot'] if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -648,22 +654,19 @@ or :const:`None` if the snapshot has less than 1000 branches. """ - if allowed_statuses is None: - visits_dates = list(itertools.chain( - *self._origins[origin]['visits_dates'].values())) - else: - last_visits = self._origins[origin]['visits_dates'] - visits_dates = list(itertools.chain( - *map(last_visits.__getitem__, allowed_statuses))) - - for visit_date in sorted(visits_dates, reverse=True): - visit_id = OriginVisitKey(origin=origin, date=visit_date) - snapshot_id = self._origin_visits[visit_id]['snapshot'] + visits = self._origin_visits[origin-1] + if allowed_statuses is not None: + visits = [visit for visit in visits + if visit['status'] in allowed_statuses] + snapshot = None + for visit in sorted(visits, key=lambda v: (v['date'], v['visit']), + reverse=True): + snapshot_id = visit['snapshot'] snapshot = self.snapshot_get(snapshot_id) if snapshot: - return snapshot + break - return None + return snapshot def snapshot_count_branches(self, snapshot_id, db=None, cur=None): """Count the number of branches in the snapshot with the given id @@ -793,18 +796,17 @@ """ if 'id' in origin: - key = origin['id'] + origin_id = origin['id'] elif 'type' in origin and 'url' in origin: - key = self._origin_key(origin) + origin_id = self._origin_id(origin) else: raise ValueError('Origin must have either id or (type and url).') - if key not in self._origins: - return None - else: - origin = copy.deepcopy(self._origins[key]) - del origin['visits_dates'] - origin['id'] = self._origin_key(origin) - return origin + origin = None + # self._origin_id can return None + if origin_id is not None: + origin = copy.deepcopy(self._origins[origin_id-1]) + origin['id'] = origin_id + return origin def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): @@ -825,19 +827,16 @@ An iterable of dict containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ - origins = iter(self._origins.values()) + origins = self._origins if regexp: pat = re.compile(url_pattern) - origins = (orig for orig in origins if pat.match(orig['url'])) + origins = [orig for orig in origins if pat.match(orig['url'])] else: - origins = (orig for orig in origins if url_pattern in orig['url']) + origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: - origins = (orig for orig in origins if orig['visits_dates']) - origins = sorted(origins, key=self._origin_key) + origins = [orig for orig in origins + if len(self._origin_visits[orig['id']-1]) > 0] origins = copy.deepcopy(origins[offset:offset+limit]) - for orig in origins: - del orig['visits_dates'] - orig['id'] = self._origin_key(orig) return origins def origin_add(self, origins): @@ -876,13 +875,16 @@ """ origin = copy.deepcopy(origin) assert 'id' not in origin - assert 'visits_dates' not in origin - key = self._origin_key(origin) - origin['visits_dates'] = defaultdict(set) - if key not in self._origins: - self._origins[key] = origin - self._objects[key].append(('origin', key)) - return key + origin_id = self._origin_id(origin) + if origin_id is None: + # origin ids are in the range [1, +inf[ + origin_id = len(self._origins) + 1 + origin['id'] = origin_id + self._origins.append(origin) + self._origin_visits.append([]) + key = (origin['type'], origin['url']) + self._objects[key].append(('origin', origin_id)) + return origin_id def fetch_history_start(self, origin_id): """Add an entry for origin origin_id in fetch_history. Returns the id @@ -929,25 +931,26 @@ if isinstance(date, str): date = dateutil.parser.parse(date) - status = 'ongoing' - - visit = { + visit_ret = None + if origin <= len(self._origin_visits): + # visit ids are in the range [1, +inf[ + visit_id = len(self._origin_visits[origin-1]) + 1 + status = 'ongoing' + visit = { 'origin': origin, 'date': date, 'status': status, 'snapshot': None, 'metadata': None, - } - key = OriginVisitKey(origin=origin, date=date) - visit['visit'] = key - if key not in self._origin_visits: - self._origin_visits[key] = copy.deepcopy(visit) - self._origins[origin]['visits_dates'][status].add(date) + 'visit': visit_id + } + self._origin_visits[origin-1].append(copy.deepcopy(visit)) + visit_ret = { + 'origin': origin, + 'visit': visit_id, + } - return { - 'origin': origin, - 'visit': key, - } + return visit_ret def origin_visit_update(self, origin, visit_id, status, metadata=None): """Update an origin_visit's status. @@ -962,12 +965,10 @@ None """ - old_status = self._origin_visits[visit_id]['status'] - self._origins[origin]['visits_dates'][old_status] \ - .remove(visit_id.date) - self._origins[origin]['visits_dates'][status] \ - .add(visit_id.date) - self._origin_visits[visit_id].update({ + if origin > len(self._origin_visits) or \ + visit_id > len(self._origin_visits[origin-1]): + return + self._origin_visits[origin-1][visit_id-1].update({ 'status': status, 'metadata': metadata}) @@ -985,16 +986,14 @@ List of visits. """ - visits_dates = sorted(itertools.chain.from_iterable( - self._origins[origin]['visits_dates'].values())) + visits = self._origin_visits[origin-1] if last_visit is not None: - from_index = bisect.bisect_right(visits_dates, last_visit.date) - visits_dates = visits_dates[from_index:] + visits = visits[last_visit:] if limit is not None: - visits_dates = visits_dates[:limit] - keys = (OriginVisitKey(origin=origin, date=date) - for date in visits_dates) - yield from map(self._origin_visits.__getitem__, keys) + visits = visits[:limit] + for visit in visits: + visit_id = visit['visit'] + yield self._origin_visits[origin-1][visit_id-1] def origin_visit_get_by(self, origin, visit): """Retrieve origin visit's information. @@ -1007,7 +1006,11 @@ it does not exist """ - return self._origin_visits.get(visit) + origin_visit = None + if origin <= len(self._origin_visits) and \ + visit <= len(self._origin_visits[origin-1]): + origin_visit = self._origin_visits[origin-1][visit-1] + return origin_visit def stat_counters(self): """compute statistics about the number of tuples in various tables @@ -1190,16 +1193,21 @@ 'url': provider['provider_url']}) return self._metadata_providers.get(key) + def _origin_id(self, origin): + origin_id = None + for stored_origin in self._origins: + if stored_origin['type'] == origin['type'] and \ + stored_origin['url'] == origin['url']: + origin_id = stored_origin['id'] + break + return origin_id + @staticmethod def _content_key(content): """A stable key for a content""" return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod - def _origin_key(origin): - return (origin['type'], origin['url']) - - @staticmethod def _tool_key(tool): return (tool['name'], tool['version'], tuple(sorted(tool['configuration'].items()))) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1411,6 +1411,11 @@ self.date_visit2) visit2_id = origin_visit2['visit'] + # Add a visit with the same date as the previous one + origin_visit3 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit3_id = origin_visit3['visit'] + # Two visits, both with no snapshot: latest snapshot is None self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) @@ -1446,6 +1451,12 @@ allowed_statuses=['full']), ) + # Add snapshot to visit3 (same date as visit2) and check that + # the new snapshot is returned + self.storage.snapshot_add(origin_id, visit3_id, self.complete_snapshot) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + def test_stat_counters(self): expected_keys = ['content', 'directory', 'origin', 'person', 'revision']