diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -29,6 +29,9 @@ return datetime.datetime.now(tz=datetime.timezone.utc) +ENABLE_ORIGIN_IDS = True + + class Storage: def __init__(self, journal_writer=None): self._contents = {} @@ -987,6 +990,7 @@ for origin in origins: result = None if 'id' in origin: + assert ENABLE_ORIGIN_IDS, 'origin ids are disabled' if origin['id'] <= len(self._origins_by_id): result = self._origins[self._origins_by_id[origin['id']-1]] elif 'url' in origin: @@ -1052,9 +1056,10 @@ origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: origins = [orig for orig in origins - if len(self._origin_visits[orig['id']-1]) > 0] + if len(self._origin_visits[orig['url']]) > 0] - origins.sort(key=lambda origin: origin['id']) + if ENABLE_ORIGIN_IDS: + origins.sort(key=lambda origin: origin['id']) origins = copy.deepcopy(origins[offset:offset+limit]) return origins @@ -1095,7 +1100,10 @@ """ origins = copy.deepcopy(origins) for origin in origins: - origin['id'] = self.origin_add_one(origin) + if ENABLE_ORIGIN_IDS: + origin['id'] = self.origin_add_one(origin) + else: + self.origin_add_one(origin) return origins def origin_add_one(self, origin): @@ -1116,25 +1124,31 @@ origin = copy.deepcopy(origin) assert 'id' not in origin if origin['url'] in self._origins: - origin_id = self._origins[origin['url']]['id'] + if ENABLE_ORIGIN_IDS: + origin_id = self._origins[origin['url']]['id'] else: if self.journal_writer: self.journal_writer.write_addition('origin', origin) - # origin ids are in the range [1, +inf[ - origin_id = len(self._origins) + 1 - origin['id'] = origin_id - self._origins_by_id.append(origin['url']) - assert len(self._origins_by_id) == origin_id + if ENABLE_ORIGIN_IDS: + # origin ids are in the range [1, +inf[ + origin_id = len(self._origins) + 1 + origin['id'] = origin_id + self._origins_by_id.append(origin['url']) + assert len(self._origins_by_id) == origin_id self._origins[origin['url']] = origin self._origin_visits[origin['url']] = [] self._objects[origin['url']].append(('origin', origin['url'])) - return origin_id + if ENABLE_ORIGIN_IDS: + return origin_id + else: + return origin['url'] def fetch_history_start(self, origin_id): """Add an entry for origin origin_id in fetch_history. Returns the id of the added fetch_history entry """ + assert not ENABLE_ORIGIN_IDS, 'origin ids are disabled' pass def fetch_history_end(self, fetch_history_id, data): @@ -1201,7 +1215,7 @@ } self._origin_visits[origin_url].append(visit) visit_ret = { - 'origin': origin['id'], + 'origin': origin['id'] if ENABLE_ORIGIN_IDS else origin['url'], 'visit': visit_id, } @@ -1210,7 +1224,8 @@ if self.journal_writer: origin = self._origins[origin_url].copy() - del origin['id'] + if 'id' in origin: + del origin['id'] self.journal_writer.write_addition('origin_visit', { **visit, 'origin': origin}) @@ -1243,7 +1258,8 @@ from None if self.journal_writer: origin = self._origins[origin_url].copy() - del origin['id'] + if 'id' in origin: + del origin['id'] self.journal_writer.write_update('origin_visit', { 'origin': origin, 'type': origin['type'], 'visit': visit_id, @@ -1304,6 +1320,20 @@ visit = self._origin_visits[origin_url][visit_id-1] = visit + def _convert_visit(self, visit): + if visit is None: + return + + visit = visit.copy() + if ENABLE_ORIGIN_IDS: + visit['origin'] = \ + self._origins[visit['origin']['url']]['id'] + else: + visit['origin'] = \ + self._origins[visit['origin']['url']]['url'] + + return visit + def origin_visit_get(self, origin, last_visit=None, limit=None): """Retrieve all the origin's visit's information. @@ -1329,13 +1359,9 @@ if not visit: continue visit_id = visit['visit'] - visit = copy.deepcopy( - self._origin_visits[origin_url][visit_id-1]) - - # TODO: remove this to return the origin url: - visit['origin'] = self._origins[visit['origin']['url']]['id'] - yield visit + yield self._convert_visit( + self._origin_visits[origin_url][visit_id-1]) def origin_visit_find_by_date(self, origin, visit_date): """Retrieves the origin visit whose date is closest to the provided @@ -1353,9 +1379,10 @@ origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: visits = self._origin_visits[origin_url] - return min( + visit = min( visits, key=lambda v: (abs(v['date'] - visit_date), -v['visit'])) + return self._convert_visit(visit) def origin_visit_get_by(self, origin, visit): """Retrieve origin visit's information. @@ -1371,13 +1398,8 @@ origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and \ visit <= len(self._origin_visits[origin_url]): - visit = self._origin_visits[origin_url][visit-1] - visit = copy.deepcopy(visit) - - # TODO: remove this to return the origin url: - visit['origin'] = self._origins[visit['origin']['url']]['id'] - - return visit + return self._convert_visit( + self._origin_visits[origin_url][visit-1]) def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False): @@ -1416,13 +1438,9 @@ visits = [visit for visit in visits if visit['snapshot']] - visits = copy.deepcopy(visits) - - # TODO: remove this to return the origin url: - for visit in visits: - visit['origin'] = self._origins[visit['origin']['url']]['id'] - - return max(visits, key=lambda v: (v['date'], v['visit']), default=None) + visit = max( + visits, key=lambda v: (v['date'], v['visit']), default=None) + return self._convert_visit(visit) def person_get(self, person): """Return the persons identified by their ids.