diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,9 +25,6 @@ return datetime.datetime.now(tz=datetime.timezone.utc) -OriginVisitKey = collections.namedtuple('OriginVisitKey', 'origin date') - - class Storage: def __init__(self): self._contents = {} @@ -38,8 +35,8 @@ self._revisions = {} self._releases = {} self._snapshots = {} - self._origins = {} - self._origin_visits = {} + self._origins = [] + self._origin_visits = [] self._origin_metadata = defaultdict(list) self._tools = {} self._metadata_providers = {} @@ -552,7 +549,7 @@ of the target branch for aliases) Raises: - ValueError: if the origin or visit id does not exist. + ValueError: if the origin's or visit's identifier does not exist. """ snapshot_id = snapshot['id'] if snapshot_id not in self._snapshots: @@ -564,9 +561,12 @@ '_sorted_branch_names': sorted(snapshot['branches']) } self._objects[snapshot_id].append(('snapshot', snapshot_id)) - if visit not in self._origin_visits: - raise ValueError('Origin %s has no visit %s' % (origin, visit)) - self._origin_visits[visit]['snapshot'] = snapshot_id + if origin <= len(self._origin_visits) and \ + visit <= len(self._origin_visits[origin-1]): + self._origin_visits[origin-1][visit-1]['snapshot'] = snapshot_id + else: + raise ValueError('Origin with id %s does not exist or has no visit' + ' with id %s' % (origin, visit)) def snapshot_get(self, snapshot_id): """Get the content, possibly partial, of a snapshot with the given id @@ -604,8 +604,8 @@ should be used instead. Args: - origin (int): the origin identifier - visit (int): the visit identifier + origin (int): the origin's identifier + visit (int): the visit's identifier Returns: dict: None if the snapshot does not exist; a dict with three keys otherwise: @@ -617,9 +617,10 @@ branches. """ - if visit not in self._origin_visits: + if origin > len(self._origins) or \ + visit > len(self._origin_visits[origin-1]): return None - snapshot_id = self._origin_visits[visit]['snapshot'] + snapshot_id = self._origin_visits[origin-1][visit-1]['snapshot'] if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -639,7 +640,7 @@ should be used instead. Args: - origin (int): the origin identifier + origin (int): the origin's identifier allowed_statuses (list of str): list of visit statuses considered to find the latest snapshot for the visit. For instance, ``allowed_statuses=['full']`` will only consider visits that @@ -653,22 +654,19 @@ or :const:`None` if the snapshot has less than 1000 branches. """ - if allowed_statuses is None: - visits_dates = list(itertools.chain( - *self._origins[origin]['visits_dates'].values())) - else: - last_visits = self._origins[origin]['visits_dates'] - visits_dates = list(itertools.chain( - *map(last_visits.__getitem__, allowed_statuses))) - - for visit_date in sorted(visits_dates, reverse=True): - visit_id = OriginVisitKey(origin=origin, date=visit_date) - snapshot_id = self._origin_visits[visit_id]['snapshot'] + visits = self._origin_visits[origin-1] + if allowed_statuses is not None: + visits = [visit for visit in visits + if visit['status'] in allowed_statuses] + snapshot = None + for visit in sorted(visits, key=lambda v: (v['date'], v['visit']), + reverse=True): + snapshot_id = visit['snapshot'] snapshot = self.snapshot_get(snapshot_id) if snapshot: - return snapshot + break - return None + return snapshot def snapshot_count_branches(self, snapshot_id, db=None, cur=None): """Count the number of branches in the snapshot with the given id @@ -784,7 +782,7 @@ or the id: - - id: the origin id + - id (int): the origin's identifier Returns: dict: the origin dictionary with the keys: @@ -798,18 +796,17 @@ """ if 'id' in origin: - key = origin['id'] + origin_id = origin['id'] elif 'type' in origin and 'url' in origin: - key = self._origin_key(origin) + origin_id = self._origin_id(origin) else: raise ValueError('Origin must have either id or (type and url).') - if key not in self._origins: - return None - else: - origin = copy.deepcopy(self._origins[key]) - del origin['visits_dates'] - origin['id'] = self._origin_key(origin) - return origin + origin = None + # self._origin_id can return None + if origin_id is not None: + origin = copy.deepcopy(self._origins[origin_id-1]) + origin['id'] = origin_id + return origin def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): @@ -830,19 +827,16 @@ An iterable of dict containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ - origins = iter(self._origins.values()) + origins = self._origins if regexp: pat = re.compile(url_pattern) - origins = (orig for orig in origins if pat.match(orig['url'])) + origins = [orig for orig in origins if pat.match(orig['url'])] else: - origins = (orig for orig in origins if url_pattern in orig['url']) + origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: - origins = (orig for orig in origins if orig['visits_dates']) - origins = sorted(origins, key=self._origin_key) + origins = [orig for orig in origins + if len(self._origin_visits[orig['id']-1]) > 0] origins = copy.deepcopy(origins[offset:offset+limit]) - for orig in origins: - del orig['visits_dates'] - orig['id'] = self._origin_key(orig) return origins def origin_add(self, origins): @@ -881,13 +875,16 @@ """ origin = copy.deepcopy(origin) assert 'id' not in origin - assert 'visits_dates' not in origin - key = self._origin_key(origin) - origin['visits_dates'] = defaultdict(set) - if key not in self._origins: - self._origins[key] = origin - self._objects[key].append(('origin', key)) - return key + origin_id = self._origin_id(origin) + if origin_id is None: + # origin ids are in the range [1, +inf[ + origin_id = len(self._origins) + 1 + origin['id'] = origin_id + self._origins.append(origin) + self._origin_visits.append([]) + key = (origin['type'], origin['url']) + self._objects[key].append(('origin', origin_id)) + return origin_id def fetch_history_start(self, origin_id): """Add an entry for origin origin_id in fetch_history. Returns the id @@ -911,14 +908,14 @@ """Add an origin_visit for the origin at date with status 'ongoing'. Args: - origin: Visited Origin id + origin (int): visited origin's identifier date: timestamp of such visit Returns: dict: dictionary with keys origin and visit where: - - origin: origin identifier - - visit: the visit identifier for the new visit occurrence + - origin: origin's identifier + - visit: the visit's identifier for the new visit occurrence """ if ts is None: @@ -934,45 +931,44 @@ if isinstance(date, str): date = dateutil.parser.parse(date) - status = 'ongoing' - - visit = { + visit_ret = None + if origin <= len(self._origin_visits): + # visit ids are in the range [1, +inf[ + visit_id = len(self._origin_visits[origin-1]) + 1 + status = 'ongoing' + visit = { 'origin': origin, 'date': date, 'status': status, 'snapshot': None, 'metadata': None, - } - key = OriginVisitKey(origin=origin, date=date) - visit['visit'] = key - if key not in self._origin_visits: - self._origin_visits[key] = copy.deepcopy(visit) - self._origins[origin]['visits_dates'][status].add(date) + 'visit': visit_id + } + self._origin_visits[origin-1].append(copy.deepcopy(visit)) + visit_ret = { + 'origin': origin, + 'visit': visit_id, + } - return { - 'origin': origin, - 'visit': key, - } + return visit_ret def origin_visit_update(self, origin, visit_id, status, metadata=None): """Update an origin_visit's status. Args: - origin: Visited Origin id - visit_id: Visit's id - status: Visit's new status - metadata: Data associated to the visit + origin (int): visited origin's identifier + visit_id (int): visit's identifier + status: visit's new status + metadata: data associated to the visit Returns: None """ - old_status = self._origin_visits[visit_id]['status'] - self._origins[origin]['visits_dates'][old_status] \ - .remove(visit_id.date) - self._origins[origin]['visits_dates'][status] \ - .add(visit_id.date) - self._origin_visits[visit_id].update({ + if origin > len(self._origin_visits) or \ + visit_id > len(self._origin_visits[origin-1]): + return + self._origin_visits[origin-1][visit_id-1].update({ 'status': status, 'metadata': metadata}) @@ -980,39 +976,41 @@ """Retrieve all the origin's visit's information. Args: - origin (int): The occurrence's origin (identifier). - last_visit: Starting point from which listing the next visits - Default to None - limit (int): Number of results to return from the last visit. - Default to None + origin (int): the origin's identifier + last_visit (int): visit's id from which listing the next ones, + default to None + limit (int): maximum number of results to return, + default to None Yields: List of visits. """ - visits_dates = sorted(itertools.chain.from_iterable( - self._origins[origin]['visits_dates'].values())) + visits = self._origin_visits[origin-1] if last_visit is not None: - from_index = bisect.bisect_right(visits_dates, last_visit.date) - visits_dates = visits_dates[from_index:] + visits = visits[last_visit:] if limit is not None: - visits_dates = visits_dates[:limit] - keys = (OriginVisitKey(origin=origin, date=date) - for date in visits_dates) - yield from map(self._origin_visits.__getitem__, keys) + visits = visits[:limit] + for visit in visits: + visit_id = visit['visit'] + yield self._origin_visits[origin-1][visit_id-1] def origin_visit_get_by(self, origin, visit): """Retrieve origin visit's information. Args: - origin: The occurrence's origin (identifier). + origin (int): the origin's identifier Returns: The information on that particular (origin, visit) or None if it does not exist """ - return self._origin_visits.get(visit) + origin_visit = None + if origin <= len(self._origin_visits) and \ + visit <= len(self._origin_visits[origin-1]): + origin_visit = self._origin_visits[origin-1][visit-1] + return origin_visit def stat_counters(self): """compute statistics about the number of tuples in various tables @@ -1050,7 +1048,7 @@ metadata. Args: - origin_id: the origin's id for which the metadata is added + origin_id (int): the origin's id for which the metadata is added ts (datetime): timestamp of the found metadata provider: id of the provider of metadata (ex:'hal') tool: id of the tool used to extract metadata @@ -1074,13 +1072,13 @@ """Retrieve list of all origin_metadata entries for the origin_id Args: - origin_id (int): the unique origin identifier + origin_id (int): the unique origin's identifier provider_type (str): (optional) type of provider Returns: list of dicts: the origin_metadata dictionary with the keys: - - origin_id (int): origin's id + - origin_id (int): origin's identifier - discovery_date (datetime): timestamp of discovery - tool_id (int): metadata's extracting tool - metadata (jsonb) @@ -1195,16 +1193,21 @@ 'url': provider['provider_url']}) return self._metadata_providers.get(key) + def _origin_id(self, origin): + origin_id = None + for stored_origin in self._origins: + if stored_origin['type'] == origin['type'] and \ + stored_origin['url'] == origin['url']: + origin_id = stored_origin['id'] + break + return origin_id + @staticmethod def _content_key(content): """A stable key for a content""" return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) @staticmethod - def _origin_key(origin): - return (origin['type'], origin['url']) - - @staticmethod def _tool_key(tool): return (tool['name'], tool['version'], tuple(sorted(tool['configuration'].items()))) diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1418,6 +1418,11 @@ self.date_visit2) visit2_id = origin_visit2['visit'] + # Add a visit with the same date as the previous one + origin_visit3 = self.storage.origin_visit_add(origin_id, + self.date_visit2) + visit3_id = origin_visit3['visit'] + # Two visits, both with no snapshot: latest snapshot is None self.assertIsNone(self.storage.snapshot_get_latest(origin_id)) @@ -1453,6 +1458,12 @@ allowed_statuses=['full']), ) + # Add snapshot to visit3 (same date as visit2) and check that + # the new snapshot is returned + self.storage.snapshot_add(origin_id, visit3_id, self.complete_snapshot) + self.assertEqual(self.complete_snapshot, + self.storage.snapshot_get_latest(origin_id)) + def test_stat_counters(self): expected_keys = ['content', 'directory', 'origin', 'person', 'revision']