diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -14,13 +14,13 @@ from collections import defaultdict from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import attr from swh.model.model import ( BaseContent, Content, SkippedContent, Directory, Revision, - Release, Snapshot, OriginVisit, Origin, SHA1_SIZE + Release, Snapshot, OriginVisit, OriginVisitUpdate, Origin, SHA1_SIZE ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.objstorage import ObjStorage @@ -29,6 +29,7 @@ from .converters import origin_url_to_sha1 from .utils import get_partition_bounds_bytes +from .validate import convert_validation_exceptions from .writer import JournalWriter # Max block size of contents to return @@ -58,6 +59,8 @@ self._origins_by_id = [] self._origins_by_sha1 = {} self._origin_visits = {} + self._origin_visit_updates: Dict[ + Tuple[str, int], List[OriginVisitUpdate]] = {} self._persons = [] self._origin_metadata = defaultdict(list) self._tools = {} @@ -500,7 +503,9 @@ if origin_url not in self._origins or \ visit > len(self._origin_visits[origin_url]): return None - snapshot_id = self._origin_visits[origin_url][visit-1].snapshot + + visit = self._origin_visit_get_updated(origin_url, visit) + snapshot_id = visit.snapshot if snapshot_id: return self.snapshot_get(snapshot_id) else: @@ -665,15 +670,18 @@ else: origins = [orig for orig in origins if url_pattern in orig['url']] if with_visit: - origins = [ - orig for orig in origins - if len(self._origin_visits[orig['url']]) > 0 and - set(ov.snapshot - for ov in self._origin_visits[orig['url']] - if ov.snapshot) & - set(self._snapshots)] + filtered_origins = [] + for orig in origins: + visits = (self._origin_visit_get_updated(ov.origin, ov.visit) + for ov in self._origin_visits[orig['url']]) + for ov in visits: + if ov.snapshot and ov.snapshot in self._snapshots: + filtered_origins.append(orig) + break + else: + filtered_origins = origins - return origins[offset:offset+limit] + return filtered_origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False): return len(self.origin_search(url_pattern, regexp=regexp, @@ -722,19 +730,33 @@ # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' - visit = OriginVisit( - origin=origin_url, - date=date, - type=type, - status=status, - snapshot=None, - metadata=None, - visit=visit_id, - ) + with convert_validation_exceptions(): + visit = OriginVisit( + origin=origin_url, + date=date, + type=type, + # TODO: Remove when we remove those fields from the model + status=status, + snapshot=None, + metadata=None, + visit=visit_id, + ) self._origin_visits[origin_url].append(visit) - visit = visit - - self._objects[(origin_url, visit.visit)].append( + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date, + status=status, + snapshot=None, + metadata=None, + ) + self._origin_visit_updates[visit_key] = [visit_update] + + self._objects[visit_key].append( ('origin_visit', None)) self.journal_writer.origin_visit_add(visit) @@ -744,7 +766,8 @@ def origin_visit_update( self, origin: str, visit_id: int, status: Optional[str] = None, - metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None): + metadata: Optional[Dict] = None, snapshot: Optional[bytes] = None, + date: Optional[datetime.datetime] = None): origin_url = self._get_origin_url(origin) if origin_url is None: raise StorageArgumentException('Unknown origin.') @@ -755,43 +778,67 @@ raise StorageArgumentException( 'Unknown visit_id for this origin') from None - updates: Dict[str, Any] = {} - if status: - updates['status'] = status - if metadata: - updates['metadata'] = metadata - if snapshot: - updates['snapshot'] = snapshot + # Retrieve the previous visit update + assert visit.visit is not None + visit_key = (origin_url, visit.visit) - try: - visit = attr.evolve(visit, **updates) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) + last_visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit_id, + date=date or now(), + status=status or last_visit_update.status, + snapshot=snapshot or last_visit_update.snapshot, + metadata=metadata or last_visit_update.metadata, + ) + self._origin_visit_updates[visit_key].append(visit_update) - self.journal_writer.origin_visit_update(visit) + self.journal_writer.origin_visit_update( + self._origin_visit_get_updated(origin_url, visit_id)) self._origin_visits[origin_url][visit_id-1] = visit def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: self.journal_writer.origin_visit_upsert(visits) + date = now() + for visit in visits: - visit_id = visit.visit origin_url = visit.origin + origin = self.origin_get({'url': origin_url}) - try: - visit = attr.evolve(visit, origin=origin_url) - except (KeyError, TypeError, ValueError) as e: - raise StorageArgumentException(*e.args) - - self._objects[(origin_url, visit_id)].append( - ('origin_visit', None)) - - if visit_id: - while len(self._origin_visits[origin_url]) <= visit_id: + if not origin: # Cannot add a visit without an origin + raise StorageArgumentException( + 'Unknown origin %s', origin_url) + + if origin_url in self._origins: + origin = self._origins[origin_url] + # visit ids are in the range [1, +inf[ + assert visit.visit is not None + visit_key = (origin_url, visit.visit) + + with convert_validation_exceptions(): + visit_update = OriginVisitUpdate( + origin=origin_url, + visit=visit.visit, + date=date, + status=visit.status, + snapshot=visit.snapshot, + metadata=visit.metadata, + ) + + self._origin_visit_updates.setdefault(visit_key, []) + while len(self._origin_visits[origin_url]) <= visit.visit: self._origin_visits[origin_url].append(None) - self._origin_visits[origin_url][visit_id-1] = visit + self._origin_visits[origin_url][visit.visit-1] = visit + self._origin_visit_updates[visit_key].append(visit_update) + + self._objects[visit_key].append( + ('origin_visit', None)) def _convert_visit(self, visit): if visit is None: @@ -801,6 +848,25 @@ return visit + def _origin_visit_get_updated(self, origin: str, visit_id: int): + assert visit_id >= 1 + visit = self._origin_visits[origin][visit_id-1] + if visit is None: + return None + visit_key = (origin, visit_id) + + visit_update = max( + self._origin_visit_updates[visit_key], key=lambda v: v.date) + + return OriginVisit.from_dict({ + # default to the values in visit + **visit.to_dict(), + # override with the last update + **visit_update.to_dict(), + # but keep the date of the creation of the origin visit + 'date': visit.date + }) + def origin_visit_get(self, origin, last_visit=None, limit=None): origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits: @@ -814,8 +880,8 @@ continue visit_id = visit.visit - yield self._convert_visit( - self._origin_visits[origin_url][visit_id-1]) + yield self._origin_visit_get_updated( + origin_url, visit_id).to_dict() def origin_visit_find_by_date(self, origin, visit_date): origin_url = self._get_origin_url(origin) @@ -824,14 +890,15 @@ visit = min( visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - return self._convert_visit(visit) + return self._origin_visit_get_updated( + origin, visit.visit).to_dict() def origin_visit_get_by(self, origin, visit): origin_url = self._get_origin_url(origin) if origin_url in self._origin_visits and \ visit <= len(self._origin_visits[origin_url]): - return self._convert_visit( - self._origin_visits[origin_url][visit-1]) + return self._origin_visit_get_updated( + origin_url, visit).to_dict() def origin_visit_get_latest( self, origin, allowed_statuses=None, require_snapshot=False): @@ -839,6 +906,9 @@ if not origin: return visits = self._origin_visits[origin.url] + visits = [self._origin_visit_get_updated(visit.origin, visit.visit) + for visit in visits + if visit is not None] if allowed_statuses is not None: visits = [visit for visit in visits if visit.status in allowed_statuses] @@ -848,7 +918,9 @@ visit = max( visits, key=lambda v: (v.date, v.visit), default=None) - return self._convert_visit(visit) + if visit is None: + return None + return visit.to_dict() def _select_random_origin_visit_by_type(self, type: str) -> str: while True: @@ -864,8 +936,11 @@ back_in_the_day = now() - timedelta(weeks=12) # 3 months back # This should be enough for tests for visit in random_origin_visits: - if visit.date > back_in_the_day and visit.status == 'full': - return visit.to_dict() + updated_visit = self._origin_visit_get_updated( + url, visit.visit) + if updated_visit.date > back_in_the_day \ + and updated_visit.status == 'full': + return updated_visit.to_dict() else: return None diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -1674,8 +1674,8 @@ # given origin_url = swh_storage.origin_add_one(data.origin) origin_url2 = swh_storage.origin_add_one(data.origin2) - date_visit = datetime.datetime.now(datetime.timezone.utc) - date_visit2 = date_visit + datetime.timedelta(minutes=1) + date_visit = data.date_visit1 + date_visit2 = data.date_visit2 # Round to milliseconds before insertion, so equality doesn't fail # after a round-trip through a DB (eg. Cassandra)