diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -3,57 +3,45 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import re import bisect -import dateutil import collections -import copy import datetime +import functools import itertools +import json +import logging import random - -from collections import defaultdict -from datetime import timedelta from typing import ( Any, Callable, Dict, - Generic, - Hashable, Iterable, Iterator, + Generic, List, Optional, Tuple, TypeVar, - Union, ) -import attr from swh.model.model import ( - BaseContent, + Sha1Git, + TimestampWithTimezone, + Timestamp, + Person, Content, SkippedContent, - Directory, - Revision, - Release, - Snapshot, OriginVisit, OriginVisitStatus, Origin, - SHA1_SIZE, ) -from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex -from swh.storage.objstorage import ObjStorage -from swh.storage.validate import convert_validation_exceptions -from swh.storage.utils import now -from .exc import StorageArgumentException, HashCollision - -from .converters import origin_url_to_sha1 -from .utils import get_partition_bounds_bytes -from .writer import JournalWriter +from swh.storage.cassandra.schema import HASH_ALGORITHMS +from swh.storage.cassandra.storage import CassandraStorage +from swh.storage.exc import StorageArgumentException +from swh.storage.objstorage import ObjStorage +from swh.storage.writer import JournalWriter # Max block size of contents to return BULK_BLOCK_CONTENT_LEN_MAX = 10000 @@ -92,6 +80,9 @@ for (k, item) in self.data: yield item + def __len__(self): + return len(self.data) + def iter_from(self, start_key: SortedListKey) -> Iterator[SortedListItem]: """Returns an iterator over all the elements whose key is greater or equal to `start_key`. @@ -103,1098 +94,129 @@ yield item -class InMemoryStorage: - def __init__(self, journal_writer=None): +Row = TypeVar("Row") - self.reset() - self.journal_writer = JournalWriter(journal_writer) - - def reset(self): - self._contents = {} - self._content_indexes = defaultdict(lambda: defaultdict(set)) - self._skipped_contents = {} - self._skipped_content_indexes = defaultdict(lambda: defaultdict(set)) - self._directories = {} - self._revisions = {} - self._releases = {} - self._snapshots = {} - self._origins = {} - self._origins_by_id = [] - self._origins_by_sha1 = {} - self._origin_visits = {} - self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {} - self._persons = [] - # {origin_url: {authority: [metadata]}} - self._origin_metadata: Dict[ - str, Dict[Hashable, SortedList[datetime.datetime, Dict[str, Any]]] - ] = defaultdict( - lambda: defaultdict(lambda: SortedList(key=lambda x: x["discovery_date"])) - ) # noqa - - self._metadata_fetchers: Dict[Hashable, Dict[str, Any]] = {} - self._metadata_authorities: Dict[Hashable, Dict[str, Any]] = {} - self._objects = defaultdict(list) - self._sorted_sha1s = SortedList[bytes, bytes]() - - self.objstorage = ObjStorage({"cls": "memory", "args": {}}) +class Table(Generic[Row]): + """A dict-like class, that provides range requests and order based on + the hash of keys.""" - def check_config(self, *, check_write): - return True + primary_key: Tuple[str, ...] + clustering_key: Tuple[str, ...] - def _content_add(self, contents: Iterable[Content], with_data: bool) -> Dict: - self.journal_writer.content_add(contents) + def __init__(self): + if None in (self.primary_key, self.clustering_key): + raise TypeError(f"{self.__class__.__name__} is missing key definition.") - content_add = 0 - if with_data: - summary = self.objstorage.content_add( - c for c in contents if c.status != "absent" - ) - content_add_bytes = summary["content:add:bytes"] + self._list = SortedList[int, Row](key=self._get_row_primary_key) - for content in contents: - key = self._content_key(content) - if key in self._contents: - continue - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - if hash_ in self._content_indexes[algorithm] and ( - algorithm not in {"blake2s256", "sha256"} - ): - colliding_content_hashes = [] - # Add the already stored contents - for content_hashes_set in self._content_indexes[algorithm][hash_]: - hashes = dict(content_hashes_set) - colliding_content_hashes.append(hashes) - # Add the new colliding content - colliding_content_hashes.append(content.hashes()) - raise HashCollision(algorithm, hash_, colliding_content_hashes) - for algorithm in DEFAULT_ALGORITHMS: - hash_ = content.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(key) - self._objects[content.sha1_git].append(("content", content.sha1)) - self._contents[key] = content - self._sorted_sha1s.add(content.sha1) - self._contents[key] = attr.evolve(self._contents[key], data=None) - content_add += 1 + def get_row_hash(self, row: Row) -> int: + return hash(tuple(getattr(row, key) for key in self.primary_key)) - summary = { - "content:add": content_add, - } - if with_data: - summary["content:add:bytes"] = content_add_bytes - - return summary - - def content_add(self, content: Iterable[Content]) -> Dict: - content = [attr.evolve(c, ctime=now()) for c in content] - return self._content_add(content, with_data=True) + def _get_row_primary_key(self, row: Row) -> Tuple: + return (self.get_row_hash(row),) + tuple( + getattr(row, key) for key in self.clustering_key + ) - def content_update(self, content, keys=[]): - self.journal_writer.content_update(content) + def add_one(self, row: Row): + self._list.add(row) - for cont_update in content: - cont_update = cont_update.copy() - sha1 = cont_update.pop("sha1") - for old_key in self._content_indexes["sha1"][sha1]: - old_cont = self._contents.pop(old_key) + def __iter__(self): + return iter(self._list) - for algorithm in DEFAULT_ALGORITHMS: - hash_ = old_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].remove(old_key) + def __len__(self): + return len(self._list) - new_cont = attr.evolve(old_cont, **cont_update) - new_key = self._content_key(new_cont) - self._contents[new_key] = new_cont +class ContentTable(Table[Content]): + primary_key = ("sha1", "sha1_git", "sha256", "blake2s256") + clustering_key = () - for algorithm in DEFAULT_ALGORITHMS: - hash_ = new_cont.get_hash(algorithm) - self._content_indexes[algorithm][hash_].add(new_key) - def content_add_metadata(self, content: Iterable[Content]) -> Dict: - return self._content_add(content, with_data=False) +class InMemRunner: + def __init__(self): + self._content = ContentTable() - def content_get(self, content): - # FIXME: Make this method support slicing the `data`. - if len(content) > BULK_BLOCK_CONTENT_LEN_MAX: - raise StorageArgumentException( - "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX - ) - yield from self.objstorage.content_get(content) + def _content_add_finalize(self, content: Content): + self._content.add_one(content) - def content_get_range(self, start, end, limit=1000): - if limit is None: - raise StorageArgumentException("limit should not be None") - sha1s = ( - (sha1, content_key) - for sha1 in self._sorted_sha1s.iter_from(start) - for content_key in self._content_indexes["sha1"][sha1] - ) - matched = [] - next_content = None - for sha1, key in sha1s: - if sha1 > end: - break - if len(matched) >= limit: - next_content = sha1 - break - matched.append(self._contents[key].to_dict()) - return { - "contents": matched, - "next": next_content, - } + def content_add_prepare(self, content: Content) -> Tuple[int, Callable[[], None]]: + token = self._content.get_row_hash(content) + finalizer = functools.partial(self._content_add_finalize, content) + return (token, finalizer) - def content_get_partition( - self, - partition_id: int, - nb_partitions: int, - limit: int = 1000, - page_token: str = None, - ): - if limit is None: - raise StorageArgumentException("limit should not be None") - (start, end) = get_partition_bounds_bytes( - partition_id, nb_partitions, SHA1_SIZE - ) - if page_token: - start = hash_to_bytes(page_token) - if end is None: - end = b"\xff" * SHA1_SIZE - result = self.content_get_range(start, end, limit) - result2 = { - "contents": result["contents"], - "next_page_token": None, - } - if result["next"]: - result2["next_page_token"] = hash_to_hex(result["next"]) - return result2 - - def content_get_metadata(self, contents: List[bytes]) -> Dict[bytes, List[Dict]]: - result: Dict = {sha1: [] for sha1 in contents} - for sha1 in contents: - if sha1 in self._content_indexes["sha1"]: - objs = self._content_indexes["sha1"][sha1] - # only 1 element as content_add_metadata would have raised a - # hash collision otherwise - for key in objs: - d = self._contents[key].to_dict() - del d["ctime"] - if "data" in d: - del d["data"] - result[sha1].append(d) - return result - - def content_find(self, content): - if not set(content).intersection(DEFAULT_ALGORITHMS): - raise StorageArgumentException( - "content keys must contain at least one of: %s" - % ", ".join(sorted(DEFAULT_ALGORITHMS)) - ) - found = [] - for algo in DEFAULT_ALGORITHMS: - hash = content.get(algo) - if hash and hash in self._content_indexes[algo]: - found.append(self._content_indexes[algo][hash]) - - if not found: - return [] - - keys = list(set.intersection(*found)) - return [self._contents[key].to_dict() for key in keys] - - def content_missing(self, content, key_hash="sha1"): - for cont in content: - for (algo, hash_) in cont.items(): - if algo not in DEFAULT_ALGORITHMS: - continue - if hash_ not in self._content_indexes.get(algo, []): - yield cont[key_hash] + def content_get_from_pk(self, content_hashes: Dict[str, bytes]) -> None: + for content in self._content: + for algo in HASH_ALGORITHMS: + if content_hashes[algo] != getattr(content, algo): break else: - for result in self.content_find(cont): - if result["status"] == "missing": - yield cont[key_hash] - - def content_missing_per_sha1(self, contents): - for content in contents: - if content not in self._content_indexes["sha1"]: - yield content - - def content_missing_per_sha1_git(self, contents): - for content in contents: - if content not in self._content_indexes["sha1_git"]: - yield content - - def content_get_random(self): - return random.choice(list(self._content_indexes["sha1_git"])) - - def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict: - self.journal_writer.skipped_content_add(contents) - - summary = {"skipped_content:add": 0} - - missing_contents = self.skipped_content_missing([c.hashes() for c in contents]) - missing = {self._content_key(c) for c in missing_contents} - contents = [c for c in contents if self._content_key(c) in missing] - for content in contents: - key = self._content_key(content) - for algo in DEFAULT_ALGORITHMS: - if content.get_hash(algo): - self._skipped_content_indexes[algo][content.get_hash(algo)].add(key) - self._skipped_contents[key] = content - summary["skipped_content:add"] += 1 - - return summary - - def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict: - content = [attr.evolve(c, ctime=now()) for c in content] - return self._skipped_content_add(content) - - def skipped_content_missing(self, contents): - for content in contents: - matches = list(self._skipped_contents.values()) - for (algorithm, key) in self._content_key(content): - if algorithm == "blake2s256": - continue - # Filter out skipped contents with the same hash - matches = [ - match for match in matches if match.get_hash(algorithm) == key - ] - # if none of the contents match - if not matches: - yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS} - - def directory_add(self, directories: Iterable[Directory]) -> Dict: - directories = [dir_ for dir_ in directories if dir_.id not in self._directories] - self.journal_writer.directory_add(directories) - - count = 0 - for directory in directories: - count += 1 - self._directories[directory.id] = directory - self._objects[directory.id].append(("directory", directory.id)) - - return {"directory:add": count} - - def directory_missing(self, directories): - for id in directories: - if id not in self._directories: - yield id - - def _join_dentry_to_content(self, dentry): - keys = ( - "status", - "sha1", - "sha1_git", - "sha256", - "length", - ) - ret = dict.fromkeys(keys) - ret.update(dentry) - if ret["type"] == "file": - # TODO: Make it able to handle more than one content - content = self.content_find({"sha1_git": ret["target"]}) - if content: - content = content[0] - for key in keys: - ret[key] = content[key] - return ret - - def _directory_ls(self, directory_id, recursive, prefix=b""): - if directory_id in self._directories: - for entry in self._directories[directory_id].entries: - ret = self._join_dentry_to_content(entry.to_dict()) - ret["name"] = prefix + ret["name"] - ret["dir_id"] = directory_id - yield ret - if recursive and ret["type"] == "dir": - yield from self._directory_ls( - ret["target"], True, prefix + ret["name"] + b"/" - ) - - def directory_ls(self, directory, recursive=False): - yield from self._directory_ls(directory, recursive) - - def directory_entry_get_by_path(self, directory, paths): - return self._directory_entry_get_by_path(directory, paths, b"") - - def directory_get_random(self): - if not self._directories: + return content + else: return None - return random.choice(list(self._directories)) - - def _directory_entry_get_by_path(self, directory, paths, prefix): - if not paths: - return - - contents = list(self.directory_ls(directory)) - - if not contents: - return - - def _get_entry(entries, name): - for entry in entries: - if entry["name"] == name: - entry = entry.copy() - entry["name"] = prefix + entry["name"] - return entry - - first_item = _get_entry(contents, paths[0]) - - if len(paths) == 1: - return first_item - - if not first_item or first_item["type"] != "dir": - return - - return self._directory_entry_get_by_path( - first_item["target"], paths[1:], prefix + paths[0] + b"/" - ) - - def revision_add(self, revisions: Iterable[Revision]) -> Dict: - revisions = [rev for rev in revisions if rev.id not in self._revisions] - self.journal_writer.revision_add(revisions) - - count = 0 - for revision in revisions: - revision = attr.evolve( - revision, - committer=self._person_add(revision.committer), - author=self._person_add(revision.author), - ) - self._revisions[revision.id] = revision - self._objects[revision.id].append(("revision", revision.id)) - count += 1 - - return {"revision:add": count} - - def revision_missing(self, revisions): - for id in revisions: - if id not in self._revisions: - yield id - - def revision_get(self, revisions): - for id in revisions: - if id in self._revisions: - yield self._revisions.get(id).to_dict() - else: - yield None - - def _get_parent_revs(self, rev_id, seen, limit): - if limit and len(seen) >= limit: - return - if rev_id in seen or rev_id not in self._revisions: - return - seen.add(rev_id) - yield self._revisions[rev_id].to_dict() - for parent in self._revisions[rev_id].parents: - yield from self._get_parent_revs(parent, seen, limit) - - def revision_log(self, revisions, limit=None): - seen = set() - for rev_id in revisions: - yield from self._get_parent_revs(rev_id, seen, limit) - - def revision_shortlog(self, revisions, limit=None): - yield from ( - (rev["id"], rev["parents"]) for rev in self.revision_log(revisions, limit) - ) - - def revision_get_random(self): - return random.choice(list(self._revisions)) - - def release_add(self, releases: Iterable[Release]) -> Dict: - releases = [rel for rel in releases if rel.id not in self._releases] - self.journal_writer.release_add(releases) - - count = 0 - for rel in releases: - if rel.author: - self._person_add(rel.author) - self._objects[rel.id].append(("release", rel.id)) - self._releases[rel.id] = rel - count += 1 - - return {"release:add": count} - - def release_missing(self, releases): - yield from (rel for rel in releases if rel not in self._releases) - - def release_get(self, releases): - for rel_id in releases: - if rel_id in self._releases: - yield self._releases[rel_id].to_dict() - else: - yield None - def release_get_random(self): - return random.choice(list(self._releases)) + def content_get_from_token(self, token) -> Iterator[Row]: + for content in self._content: + if self._content.get_row_hash(content) == token: + yield content - def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict: - count = 0 - snapshots = (snap for snap in snapshots if snap.id not in self._snapshots) - for snapshot in snapshots: - self.journal_writer.snapshot_add([snapshot]) - sorted_branch_names = sorted(snapshot.branches) - self._snapshots[snapshot.id] = (snapshot, sorted_branch_names) - self._objects[snapshot.id].append(("snapshot", snapshot.id)) - count += 1 + def content_get_random(self) -> Row: + return random.choice(list(self._content)) - return {"snapshot:add": count} + def content_get_token_range( + self, start: int, end: int, limit: int + ) -> Iterator[Row]: + for (i, content) in enumerate(self._content.iter_from((start,))): + if i >= limit or self._content.get_row_hash(content) >= end: + return + yield content - def snapshot_missing(self, snapshots): - for id in snapshots: - if id not in self._snapshots: - yield id + def content_missing_per_sha1_git(self, ids: List[bytes]) -> Iterable[bytes]: + missing = set(ids) + for content in self._content: + missing.remove(content.sha1_git) + return missing - def snapshot_get(self, snapshot_id): - return self.snapshot_get_branches(snapshot_id) + def content_index_add_one(self, algo: str, content: Content, token: int) -> None: + pass - def snapshot_get_by_origin_visit(self, origin, visit): - origin_url = self._get_origin_url(origin) - if not origin_url: - return + def content_get_tokens_from_single_hash( + self, algo: str, hash_: bytes + ) -> Iterable[int]: + assert algo in HASH_ALGORITHMS + for content in self._content: + if getattr(content, algo) == hash_: + yield self._content.get_row_hash(content) - if origin_url not in self._origins or visit > len( - self._origin_visits[origin_url] - ): - return None - - visit = self._origin_visit_get_updated(origin_url, visit) - snapshot_id = visit.snapshot - if snapshot_id: - return self.snapshot_get(snapshot_id) - else: - return None - def snapshot_get_latest(self, origin, allowed_statuses=None): - origin_url = self._get_origin_url(origin) - if not origin_url: - return +class InMemoryStorage(CassandraStorage): + def __init__(self, journal_writer=None): + self.reset() + self.journal_writer = JournalWriter(journal_writer) + self.objstorage = ObjStorage({"cls": "memory", "args": {}}) - visit = self.origin_visit_get_latest( - origin_url, allowed_statuses=allowed_statuses, require_snapshot=True - ) - if visit and visit["snapshot"]: - snapshot = self.snapshot_get(visit["snapshot"]) - if not snapshot: - raise StorageArgumentException( - "last origin visit references an unknown snapshot" - ) - return snapshot + def reset(self): + self._cql_runner = self._runner = InMemRunner() - def snapshot_count_branches(self, snapshot_id): - (snapshot, _) = self._snapshots[snapshot_id] - return collections.Counter( - branch.target_type.value if branch else None - for branch in snapshot.branches.values() + def content_get_range(self, start, end, limit=1000): + # TODO: remove this when swh-index stops using it + if limit is None: + raise StorageArgumentException("limit should not be None") + contents = sorted( + (cont for cont in self._runner._content if start <= cont.sha1 <= end), + key=lambda cont: cont.sha1, ) - - def snapshot_get_branches( - self, snapshot_id, branches_from=b"", branches_count=1000, target_types=None - ): - res = self._snapshots.get(snapshot_id) - if res is None: - return None - (snapshot, sorted_branch_names) = res - from_index = bisect.bisect_left(sorted_branch_names, branches_from) - if target_types: - next_branch = None - branches = {} - for branch_name in sorted_branch_names[from_index:]: - branch = snapshot.branches[branch_name] - if branch and branch.target_type.value in target_types: - if len(branches) < branches_count: - branches[branch_name] = branch - else: - next_branch = branch_name - break + if len(contents) <= limit: + next_content = None else: - # As there is no 'target_types', we can do that much faster - to_index = from_index + branches_count - returned_branch_names = sorted_branch_names[from_index:to_index] - branches = { - branch_name: snapshot.branches[branch_name] - for branch_name in returned_branch_names - } - if to_index >= len(sorted_branch_names): - next_branch = None - else: - next_branch = sorted_branch_names[to_index] - - branches = { - name: branch.to_dict() if branch else None - for (name, branch) in branches.items() - } - + next_content = contents[limit] + contents = contents[:limit] return { - "id": snapshot_id, - "branches": branches, - "next_branch": next_branch, - } - - def snapshot_get_random(self): - return random.choice(list(self._snapshots)) - - def object_find_by_sha1_git(self, ids): - ret = {} - for id_ in ids: - objs = self._objects.get(id_, []) - ret[id_] = [{"sha1_git": id_, "type": obj[0],} for obj in objs] - return ret - - def _convert_origin(self, t): - if t is None: - return None - - return t.to_dict() - - def origin_get(self, origins): - if isinstance(origins, dict): - # Old API - return_single = True - origins = [origins] - else: - return_single = False - - # Sanity check to be error-compatible with the pgsql backend - if any("id" in origin for origin in origins) and not all( - "id" in origin for origin in origins - ): - raise StorageArgumentException( - 'Either all origins or none at all should have an "id".' - ) - if any("url" in origin for origin in origins) and not all( - "url" in origin for origin in origins - ): - raise StorageArgumentException( - "Either all origins or none at all should have " 'an "url" key.' - ) - - results = [] - for origin in origins: - result = None - if "url" in origin: - if origin["url"] in self._origins: - result = self._origins[origin["url"]] - else: - raise StorageArgumentException("Origin must have an url.") - results.append(self._convert_origin(result)) - - if return_single: - assert len(results) == 1 - return results[0] - else: - return results - - def origin_get_by_sha1(self, sha1s): - return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s] - - def origin_get_range(self, origin_from=1, origin_count=100): - origin_from = max(origin_from, 1) - if origin_from <= len(self._origins_by_id): - max_idx = origin_from + origin_count - 1 - if max_idx > len(self._origins_by_id): - max_idx = len(self._origins_by_id) - for idx in range(origin_from - 1, max_idx): - origin = self._convert_origin(self._origins[self._origins_by_id[idx]]) - yield {"id": idx + 1, **origin} - - def origin_list(self, page_token: Optional[str] = None, limit: int = 100) -> dict: - origin_urls = sorted(self._origins) - if page_token: - from_ = bisect.bisect_left(origin_urls, page_token) - else: - from_ = 0 - - result = { - "origins": [ - {"url": origin_url} for origin_url in origin_urls[from_ : from_ + limit] - ] + "contents": matched, + "next": next_content, } - if from_ + limit < len(origin_urls): - result["next_page_token"] = origin_urls[from_ + limit] - - return result - - def origin_search( - self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False - ): - origins = map(self._convert_origin, self._origins.values()) - if regexp: - pat = re.compile(url_pattern) - origins = [orig for orig in origins if pat.search(orig["url"])] - else: - origins = [orig for orig in origins if url_pattern in orig["url"]] - if with_visit: - filtered_origins = [] - for orig in origins: - visits = ( - self._origin_visit_get_updated(ov.origin, ov.visit) - for ov in self._origin_visits[orig["url"]] - ) - for ov in visits: - if ov.snapshot and ov.snapshot in self._snapshots: - filtered_origins.append(orig) - break - else: - filtered_origins = origins - - return filtered_origins[offset : offset + limit] - - def origin_count(self, url_pattern, regexp=False, with_visit=False): - return len( - self.origin_search( - url_pattern, - regexp=regexp, - with_visit=with_visit, - limit=len(self._origins), - ) - ) - - def origin_add(self, origins: Iterable[Origin]) -> List[Dict]: - origins = copy.deepcopy(list(origins)) - for origin in origins: - self.origin_add_one(origin) - return [origin.to_dict() for origin in origins] - - def origin_add_one(self, origin: Origin) -> str: - if origin.url not in self._origins: - self.journal_writer.origin_add([origin]) - # generate an origin_id because it is needed by origin_get_range. - # TODO: remove this when we remove origin_get_range - origin_id = len(self._origins) + 1 - self._origins_by_id.append(origin.url) - assert len(self._origins_by_id) == origin_id - - self._origins[origin.url] = origin - self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin - self._origin_visits[origin.url] = [] - self._objects[origin.url].append(("origin", origin.url)) - - return origin.url - - def origin_visit_add( - self, origin_url: str, date: Union[str, datetime.datetime], type: str - ) -> OriginVisit: - if isinstance(date, str): - # FIXME: Converge on iso8601 at some point - date = dateutil.parser.parse(date) - elif not isinstance(date, datetime.datetime): - raise StorageArgumentException("Date must be a datetime or a string") - - origin = self.origin_get({"url": origin_url}) - if not origin: # Cannot add a visit without an origin - raise StorageArgumentException("Unknown origin %s", origin_url) - - if origin_url in self._origins: - origin = self._origins[origin_url] - # visit ids are in the range [1, +inf[ - visit_id = len(self._origin_visits[origin_url]) + 1 - status = "ongoing" - with convert_validation_exceptions(): - visit = OriginVisit( - origin=origin_url, - date=date, - type=type, - # TODO: Remove when we remove those fields from the model - status=status, - snapshot=None, - metadata=None, - visit=visit_id, - ) - self._origin_visits[origin_url].append(visit) - assert visit.visit is not None - visit_key = (origin_url, visit.visit) - - with convert_validation_exceptions(): - visit_update = OriginVisitStatus( - origin=origin_url, - visit=visit_id, - date=date, - status=status, - snapshot=None, - metadata=None, - ) - self._origin_visit_statuses[visit_key] = [visit_update] - - self._objects[visit_key].append(("origin_visit", None)) - - self.journal_writer.origin_visit_add([visit]) - - # return last visit - return visit - - def origin_visit_update( - self, - origin: str, - visit_id: int, - status: str, - metadata: Optional[Dict] = None, - snapshot: Optional[bytes] = None, - date: Optional[datetime.datetime] = None, - ): - origin_url = self._get_origin_url(origin) - if origin_url is None: - raise StorageArgumentException("Unknown origin.") - - try: - visit = self._origin_visits[origin_url][visit_id - 1] - except IndexError: - raise StorageArgumentException("Unknown visit_id for this origin") from None - - # Retrieve the previous visit status - assert visit.visit is not None - visit_key = (origin_url, visit.visit) - - last_visit_update = max( - self._origin_visit_statuses[visit_key], key=lambda v: v.date - ) - - with convert_validation_exceptions(): - visit_update = OriginVisitStatus( - origin=origin_url, - visit=visit_id, - date=date or now(), - status=status, - snapshot=snapshot or last_visit_update.snapshot, - metadata=metadata or last_visit_update.metadata, - ) - self._origin_visit_statuses[visit_key].append(visit_update) - - self.journal_writer.origin_visit_update( - [self._origin_visit_get_updated(origin_url, visit_id)] - ) - - self._origin_visits[origin_url][visit_id - 1] = visit - - def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None: - for visit in visits: - if visit.visit is None: - raise StorageArgumentException(f"Missing visit id for visit {visit}") - - self.journal_writer.origin_visit_upsert(visits) - - date = now() - - for visit in visits: - assert visit.visit is not None - origin_url = visit.origin - origin = self.origin_get({"url": origin_url}) - - if not origin: # Cannot add a visit without an origin - raise StorageArgumentException("Unknown origin %s", origin_url) - - if origin_url in self._origins: - origin = self._origins[origin_url] - # visit ids are in the range [1, +inf[ - assert visit.visit is not None - visit_key = (origin_url, visit.visit) - - with convert_validation_exceptions(): - visit_update = OriginVisitStatus( - origin=origin_url, - visit=visit.visit, - date=date, - status=visit.status, - snapshot=visit.snapshot, - metadata=visit.metadata, - ) - - self._origin_visit_statuses.setdefault(visit_key, []) - while len(self._origin_visits[origin_url]) <= visit.visit: - self._origin_visits[origin_url].append(None) - - self._origin_visits[origin_url][visit.visit - 1] = visit - self._origin_visit_statuses[visit_key].append(visit_update) - - self._objects[visit_key].append(("origin_visit", None)) - - def _origin_visit_get_updated(self, origin: str, visit_id: int) -> OriginVisit: - """Merge origin visit and latest origin visit status - - """ - assert visit_id >= 1 - visit = self._origin_visits[origin][visit_id - 1] - assert visit is not None - visit_key = (origin, visit_id) - - visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date) - - return OriginVisit.from_dict( - { - # default to the values in visit - **visit.to_dict(), - # override with the last update - **visit_update.to_dict(), - # but keep the date of the creation of the origin visit - "date": visit.date, - } - ) - - def origin_visit_get( - self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None - ) -> Iterable[Dict[str, Any]]: - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits: - visits = self._origin_visits[origin_url] - if last_visit is not None: - visits = visits[last_visit:] - if limit is not None: - visits = visits[:limit] - for visit in visits: - if not visit: - continue - visit_id = visit.visit - - visit_update = self._origin_visit_get_updated(origin_url, visit_id) - assert visit_update is not None - yield visit_update.to_dict() - - def origin_visit_find_by_date( - self, origin: str, visit_date: datetime.datetime - ) -> Optional[Dict[str, Any]]: - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits: - visits = self._origin_visits[origin_url] - visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit)) - visit_update = self._origin_visit_get_updated(origin, visit.visit) - assert visit_update is not None - return visit_update.to_dict() - return None - - def origin_visit_get_by(self, origin: str, visit: int) -> Optional[Dict[str, Any]]: - origin_url = self._get_origin_url(origin) - if origin_url in self._origin_visits and visit <= len( - self._origin_visits[origin_url] - ): - visit_update = self._origin_visit_get_updated(origin_url, visit) - assert visit_update is not None - return visit_update.to_dict() - return None - - def origin_visit_get_latest( - self, - origin: str, - allowed_statuses: Optional[List[str]] = None, - require_snapshot: bool = False, - ) -> Optional[Dict[str, Any]]: - ori = self._origins.get(origin) - if not ori: - return None - visits = self._origin_visits[ori.url] - visits = [ - self._origin_visit_get_updated(visit.origin, visit.visit) - for visit in visits - if visit is not None - ] - - if allowed_statuses is not None: - visits = [visit for visit in visits if visit.status in allowed_statuses] - if require_snapshot: - visits = [visit for visit in visits if visit.snapshot] - - visit = max(visits, key=lambda v: (v.date, v.visit), default=None) - if visit is None: - return None - return visit.to_dict() - - def _select_random_origin_visit_by_type(self, type: str) -> str: - while True: - url = random.choice(list(self._origin_visits.keys())) - random_origin_visits = self._origin_visits[url] - if random_origin_visits[0].type == type: - return url - - def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]: - url = self._select_random_origin_visit_by_type(type) - random_origin_visits = copy.deepcopy(self._origin_visits[url]) - random_origin_visits.reverse() - back_in_the_day = now() - timedelta(weeks=12) # 3 months back - # This should be enough for tests - for visit in random_origin_visits: - updated_visit = self._origin_visit_get_updated(url, visit.visit) - assert updated_visit is not None - if updated_visit.date > back_in_the_day and updated_visit.status == "full": - return updated_visit.to_dict() - else: - return None - def stat_counters(self): - keys = ( - "content", - "directory", - "origin", - "origin_visit", - "person", - "release", - "revision", - "skipped_content", - "snapshot", - ) - stats = {key: 0 for key in keys} - stats.update( - collections.Counter( - obj_type - for (obj_type, obj_id) in itertools.chain(*self._objects.values()) - ) - ) - return stats - - def refresh_stat_counters(self): - pass - - def origin_metadata_add( - self, - origin_url: str, - discovery_date: datetime.datetime, - authority: Dict[str, Any], - fetcher: Dict[str, Any], - format: str, - metadata: bytes, - ) -> None: - if not isinstance(origin_url, str): - raise StorageArgumentException( - "origin_id must be str, not %r" % (origin_url,) - ) - if not isinstance(metadata, bytes): - raise StorageArgumentException( - "metadata must be bytes, not %r" % (metadata,) - ) - authority_key = self._metadata_authority_key(authority) - if authority_key not in self._metadata_authorities: - raise StorageArgumentException(f"Unknown authority {authority}") - fetcher_key = self._metadata_fetcher_key(fetcher) - if fetcher_key not in self._metadata_fetchers: - raise StorageArgumentException(f"Unknown fetcher {fetcher}") - - origin_metadata = { - "origin_url": origin_url, - "discovery_date": discovery_date, - "authority": authority_key, - "fetcher": fetcher_key, - "format": format, - "metadata": metadata, - } - self._origin_metadata[origin_url][authority_key].add(origin_metadata) - return None - - def origin_metadata_get( - self, - origin_url: str, - authority: Dict[str, str], - after: Optional[datetime.datetime] = None, - limit: Optional[int] = None, - ) -> List[Dict[str, Any]]: - if not isinstance(origin_url, str): - raise TypeError("origin_url must be str, not %r" % (origin_url,)) - - authority_key = self._metadata_authority_key(authority) - - if after is None: - entries = iter(self._origin_metadata[origin_url][authority_key]) - else: - entries = self._origin_metadata[origin_url][authority_key].iter_from(after) - if limit: - entries = itertools.islice(entries, 0, limit) - - results = [] - for entry in entries: - authority = self._metadata_authorities[entry["authority"]] - fetcher = self._metadata_fetchers[entry["fetcher"]] - results.append( - { - **entry, - "authority": {"type": authority["type"], "url": authority["url"],}, - "fetcher": { - "name": fetcher["name"], - "version": fetcher["version"], - }, - } - ) - return results - - def metadata_fetcher_add( - self, name: str, version: str, metadata: Dict[str, Any] - ) -> None: - fetcher = { - "name": name, - "version": version, - "metadata": metadata, - } - key = self._metadata_fetcher_key(fetcher) - if key not in self._metadata_fetchers: - self._metadata_fetchers[key] = fetcher - - def metadata_fetcher_get(self, name: str, version: str) -> Optional[Dict[str, Any]]: - return self._metadata_fetchers.get( - self._metadata_fetcher_key({"name": name, "version": version}) - ) - - def metadata_authority_add( - self, type: str, url: str, metadata: Dict[str, Any] - ) -> None: - authority = { - "type": type, - "url": url, - "metadata": metadata, + return { + "content": len(self._runner._content), } - key = self._metadata_authority_key(authority) - self._metadata_authorities[key] = authority - - def metadata_authority_get(self, type: str, url: str) -> Optional[Dict[str, Any]]: - return self._metadata_authorities.get( - self._metadata_authority_key({"type": type, "url": url}) - ) - - def _get_origin_url(self, origin): - if isinstance(origin, str): - return origin - else: - raise TypeError("origin must be a string.") - - def _person_add(self, person): - key = ("person", person.fullname) - if key not in self._objects: - person_id = len(self._persons) + 1 - self._persons.append(person) - self._objects[key].append(("person", person_id)) - else: - person_id = self._objects[key][0][1] - person = self._persons[person_id - 1] - return person - - @staticmethod - def _content_key(content): - """ A stable key and the algorithm for a content""" - if isinstance(content, BaseContent): - content = content.to_dict() - return tuple((key, content.get(key)) for key in sorted(DEFAULT_ALGORITHMS)) - - @staticmethod - def _metadata_fetcher_key(fetcher: Dict) -> Hashable: - return (fetcher["name"], fetcher["version"]) - - @staticmethod - def _metadata_authority_key(authority: Dict) -> Hashable: - return (authority["type"], authority["url"]) - - def diff_directories(self, from_dir, to_dir, track_renaming=False): - raise NotImplementedError("InMemoryStorage.diff_directories") - - def diff_revisions(self, from_rev, to_rev, track_renaming=False): - raise NotImplementedError("InMemoryStorage.diff_revisions") - - def diff_revision(self, revision, track_renaming=False): - raise NotImplementedError("InMemoryStorage.diff_revision") - - def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None: - """Do nothing - - """ - return None - - def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict: - return {}