diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -17,7 +17,8 @@ import attr -from swh.model.model import Content, Directory, Revision, Release, Snapshot +from swh.model.model import \ + Content, Directory, Revision, Release, Snapshot, OriginVisit, Origin from swh.model.hashutil import DEFAULT_ALGORITHMS from swh.objstorage import get_objstorage from swh.objstorage.exc import ObjNotFoundError @@ -831,7 +832,7 @@ if origin_url not in self._origins or \ visit > len(self._origin_visits[origin_url]): return None - snapshot_id = self._origin_visits[origin_url][visit-1]['snapshot'] + snapshot_id = self._origin_visits[origin_url][visit-1].snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -986,6 +987,15 @@ } for obj in objs] return ret + def _convert_origin(self, t): + if t is None: + return None + (origin_id, origin) = t + origin = origin.to_dict() + if ENABLE_ORIGIN_IDS: + origin['id'] = origin_id + return origin + def origin_get(self, origins): """Return origins, either all identified by their ids or all identified by tuples (type, url). @@ -1043,11 +1053,11 @@ result = self._origins[self._origins_by_id[origin['id']-1]] elif 'url' in origin: if origin['url'] in self._origins: - result = copy.deepcopy(self._origins[origin['url']]) + result = self._origins[origin['url']] else: raise ValueError( 'Origin must have either id or url.') - results.append(result) + results.append(self._convert_origin(result)) if return_single: assert len(results) == 1 @@ -1075,7 +1085,8 @@ if max_idx > len(self._origins_by_id): max_idx = len(self._origins_by_id) for idx in range(origin_from-1, max_idx): - yield copy.deepcopy(self._origins[self._origins_by_id[idx]]) + yield self._convert_origin( + self._origins[self._origins_by_id[idx]]) def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): @@ -1096,7 +1107,7 @@ An iterable of dict containing origin information as returned by :meth:`swh.storage.storage.Storage.origin_get`. """ - origins = self._origins.values() + origins = map(self._convert_origin, self._origins.values()) if regexp: pat = re.compile(url_pattern) origins = [orig for orig in origins if pat.search(orig['url'])] @@ -1109,8 +1120,7 @@ if ENABLE_ORIGIN_IDS: origins.sort(key=lambda origin: origin['id']) - origins = copy.deepcopy(origins[offset:offset+limit]) - return origins + return origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False, db=None, cur=None): @@ -1169,28 +1179,28 @@ exists. """ - origin = copy.deepcopy(origin) - assert 'id' not in origin - if origin['url'] in self._origins: + origin = Origin.from_dict(origin) + if origin.url in self._origins: if ENABLE_ORIGIN_IDS: - origin_id = self._origins[origin['url']]['id'] + (origin_id, _) = self._origins[origin.url] else: if self.journal_writer: self.journal_writer.write_addition('origin', origin) if ENABLE_ORIGIN_IDS: # origin ids are in the range [1, +inf[ origin_id = len(self._origins) + 1 - origin['id'] = origin_id - self._origins_by_id.append(origin['url']) + self._origins_by_id.append(origin.url) assert len(self._origins_by_id) == origin_id - self._origins[origin['url']] = origin - self._origin_visits[origin['url']] = [] - self._objects[origin['url']].append(('origin', origin['url'])) + else: + origin_id = None + self._origins[origin.url] = (origin_id, origin) + self._origin_visits[origin.url] = [] + self._objects[origin.url].append(('origin', origin.url)) if ENABLE_ORIGIN_IDS: return origin_id else: - return origin['url'] + return origin.url def fetch_history_start(self, origin_id): """Add an entry for origin origin_id in fetch_history. Returns the id @@ -1245,25 +1255,27 @@ if isinstance(date, str): date = dateutil.parser.parse(date) + elif not isinstance(date, datetime.datetime): + raise TypeError('date must be a datetime or a string.') visit_ret = None if origin_url in self._origins: - origin = self._origins[origin_url] + (origin_id, origin) = self._origins[origin_url] # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' - visit = { - 'origin': {'url': origin['url']}, - 'date': date, - 'type': type or origin['type'], - 'status': status, - 'snapshot': None, - 'metadata': None, - 'visit': visit_id - } + visit = OriginVisit( + origin=origin, + date=date, + type=type or origin.type, + status=status, + snapshot=None, + metadata=None, + visit=visit_id, + ) self._origin_visits[origin_url].append(visit) visit_ret = { - 'origin': origin['id'] if ENABLE_ORIGIN_IDS else origin['url'], + 'origin': origin_id if ENABLE_ORIGIN_IDS else origin.url, 'visit': visit_id, } @@ -1271,11 +1283,7 @@ ('origin_visit', None)) if self.journal_writer: - origin = self._origins[origin_url].copy() - if 'id' in origin: - del origin['id'] - self.journal_writer.write_addition('origin_visit', { - **visit, 'origin': origin}) + self.journal_writer.write_addition('origin_visit', visit) return visit_ret @@ -1304,27 +1312,26 @@ except IndexError: raise ValueError('Unknown visit_id for this origin') \ from None + + updates = {} + if status: + updates['status'] = status + if metadata: + updates['metadata'] = metadata + if snapshot: + updates['snapshot'] = snapshot + + visit = attr.evolve(visit, **updates) + if self.journal_writer: - origin = self._origins[origin_url].copy() - if 'id' in origin: - del origin['id'] - self.journal_writer.write_update('origin_visit', { - 'origin': origin, - 'type': origin['type'], - 'visit': visit_id, - 'status': status or visit['status'], - 'date': visit['date'], - 'metadata': metadata or visit['metadata'], - 'snapshot': snapshot or visit['snapshot']}) + (_, origin) = self._origins[origin_url] + self.journal_writer.write_update('origin_visit', visit) + + self._origin_visits[origin_url][visit_id-1] = visit + if origin_url not in self._origin_visits or \ visit_id > len(self._origin_visits[origin_url]): return - if status: - visit['status'] = status - if metadata: - visit['metadata'] = metadata - if snapshot: - visit['snapshot'] = snapshot def origin_visit_upsert(self, visits): """Add a origin_visits with a specific id and with all its data. @@ -1343,22 +1350,16 @@ snapshot (sha1_git): identifier of the snapshot to add to the visit """ - visits = copy.deepcopy(visits) - for visit in visits: - if isinstance(visit['date'], str): - visit['date'] = dateutil.parser.parse(visit['date']) + visits = [OriginVisit.from_dict(d) for d in visits] if self.journal_writer: for visit in visits: - visit = visit.copy() - visit['origin'] = self._origins[visit['origin']['url']].copy() - if 'id' in visit['origin']: - del visit['origin']['id'] + (_, visit.origin) = self._origins[visit.origin.url] self.journal_writer.write_addition('origin_visit', visit) for visit in visits: - visit_id = visit['visit'] - origin_url = visit['origin']['url'] + visit_id = visit.visit + origin_url = visit.origin.url self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) @@ -1366,21 +1367,18 @@ while len(self._origin_visits[origin_url]) < visit_id: self._origin_visits[origin_url].append(None) - visit = visit.copy() - visit['origin'] = {'url': visit['origin']['url']} - - visit = self._origin_visits[origin_url][visit_id-1] = visit + self._origin_visits[origin_url][visit_id-1] = visit def _convert_visit(self, visit): if visit is None: return - visit = visit.copy() - origin = self._origins[visit['origin']['url']] + (origin_id, origin) = self._origins[visit.origin.url] + visit = visit.to_dict() if ENABLE_ORIGIN_IDS: - visit['origin'] = origin['id'] + visit['origin'] = origin_id else: - visit['origin'] = origin['url'] + visit['origin'] = origin.url return visit @@ -1408,7 +1406,7 @@ for visit in visits: if not visit: continue - visit_id = visit['visit'] + visit_id = visit.visit yield self._convert_visit( self._origin_visits[origin_url][visit_id-1]) @@ -1431,7 +1429,7 @@ visits = self._origin_visits[origin_url] visit = min( visits, - key=lambda v: (abs(v['date'] - visit_date), -v['visit'])) + key=lambda v: (abs(v.date - visit_date), -v.visit)) return self._convert_visit(visit) def origin_visit_get_by(self, origin, visit): @@ -1477,19 +1475,20 @@ snapshot (Optional[sha1_git]): identifier of the snapshot associated to the visit """ - origin = self._origins.get(origin) - if not origin: + res = self._origins.get(origin) + if not res: return - visits = self._origin_visits[origin['url']] + (_, origin) = res + visits = self._origin_visits[origin.url] if allowed_statuses is not None: visits = [visit for visit in visits - if visit['status'] in allowed_statuses] + if visit.status in allowed_statuses] if require_snapshot: visits = [visit for visit in visits - if visit['snapshot']] + if visit.snapshot] visit = max( - visits, key=lambda v: (v['date'], v['visit']), default=None) + visits, key=lambda v: (v.date, v.visit), default=None) return self._convert_visit(visit) def stat_counters(self): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1459,6 +1459,13 @@ self.assertEqual(add1, add2) + def test_origin_add_validation(self): + with self.assertRaisesRegex((TypeError, KeyError), 'url'): + self.storage.origin_add([{'type': 'git'}]) + + with self.assertRaisesRegex((TypeError, KeyError), 'type'): + self.storage.origin_add([{'url': 'file:///dev/null'}]) + def test_origin_get_legacy(self): self.assertIsNone(self.storage.origin_get(self.origin)) id = self.storage.origin_add_one(self.origin) @@ -1762,6 +1769,12 @@ ('origin_visit', data1), ('origin_visit', data2)]) + def test_origin_visit_add_validation(self): + origin_id_or_url = self.storage.origin_add_one(self.origin2) + + with self.assertRaises((TypeError, psycopg2.errors.UndefinedFunction)): + self.storage.origin_visit_add(origin_id_or_url, date=[b'foo']) + @given(strategies.booleans()) def test_origin_visit_update(self, use_url): if not self._test_origin_ids and not use_url: @@ -1918,6 +1931,18 @@ ('origin_visit', data4), ('origin_visit', data5)]) + def test_origin_visit_update_validation(self): + origin_id = self.storage.origin_add_one(self.origin) + visit = self.storage.origin_visit_add( + origin_id, + date=self.date_visit2) + + with self.assertRaisesRegexp( + (ValueError, psycopg2.errors.InvalidTextRepresentation), + 'status'): + self.storage.origin_visit_update( + origin_id, visit['visit'], status='foobar') + def test_origin_visit_find_by_date(self): # given self.storage.origin_add_one(self.origin) @@ -3852,7 +3877,8 @@ origin_id = self.storage.origin_add_one(obj.pop('origin')) if 'visit' in obj: del obj['visit'] - self.storage.origin_visit_add(origin_id, **obj) + self.storage.origin_visit_add( + origin_id, date=obj['date'], type=obj['type']) else: method = getattr(self.storage, obj_type + '_add') try: