Page MenuHomeSoftware Heritage

D3194.diff
No OneTemporary

D3194.diff

diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py
--- a/swh/storage/in_memory.py
+++ b/swh/storage/in_memory.py
@@ -3,57 +3,45 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-import re
import bisect
-import dateutil
import collections
-import copy
import datetime
+import functools
import itertools
+import json
+import logging
import random
-
-from collections import defaultdict
-from datetime import timedelta
from typing import (
Any,
Callable,
Dict,
- Generic,
- Hashable,
Iterable,
Iterator,
+ Generic,
List,
Optional,
Tuple,
TypeVar,
- Union,
)
-import attr
from swh.model.model import (
- BaseContent,
+ Sha1Git,
+ TimestampWithTimezone,
+ Timestamp,
+ Person,
Content,
SkippedContent,
- Directory,
- Revision,
- Release,
- Snapshot,
OriginVisit,
OriginVisitStatus,
Origin,
- SHA1_SIZE,
)
-from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex
-from swh.storage.objstorage import ObjStorage
-from swh.storage.validate import convert_validation_exceptions
-from swh.storage.utils import now
-from .exc import StorageArgumentException, HashCollision
-
-from .converters import origin_url_to_sha1
-from .utils import get_partition_bounds_bytes
-from .writer import JournalWriter
+from swh.storage.cassandra.schema import HASH_ALGORITHMS
+from swh.storage.cassandra.storage import CassandraStorage
+from swh.storage.exc import StorageArgumentException
+from swh.storage.objstorage import ObjStorage
+from swh.storage.writer import JournalWriter
# Max block size of contents to return
BULK_BLOCK_CONTENT_LEN_MAX = 10000
@@ -92,6 +80,9 @@
for (k, item) in self.data:
yield item
+ def __len__(self):
+ return len(self.data)
+
def iter_from(self, start_key: SortedListKey) -> Iterator[SortedListItem]:
"""Returns an iterator over all the elements whose key is greater
or equal to `start_key`.
@@ -103,1098 +94,129 @@
yield item
-class InMemoryStorage:
- def __init__(self, journal_writer=None):
+Row = TypeVar("Row")
- self.reset()
- self.journal_writer = JournalWriter(journal_writer)
-
- def reset(self):
- self._contents = {}
- self._content_indexes = defaultdict(lambda: defaultdict(set))
- self._skipped_contents = {}
- self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
- self._directories = {}
- self._revisions = {}
- self._releases = {}
- self._snapshots = {}
- self._origins = {}
- self._origins_by_id = []
- self._origins_by_sha1 = {}
- self._origin_visits = {}
- self._origin_visit_statuses: Dict[Tuple[str, int], List[OriginVisitStatus]] = {}
- self._persons = []
- # {origin_url: {authority: [metadata]}}
- self._origin_metadata: Dict[
- str, Dict[Hashable, SortedList[datetime.datetime, Dict[str, Any]]]
- ] = defaultdict(
- lambda: defaultdict(lambda: SortedList(key=lambda x: x["discovery_date"]))
- ) # noqa
-
- self._metadata_fetchers: Dict[Hashable, Dict[str, Any]] = {}
- self._metadata_authorities: Dict[Hashable, Dict[str, Any]] = {}
- self._objects = defaultdict(list)
- self._sorted_sha1s = SortedList[bytes, bytes]()
-
- self.objstorage = ObjStorage({"cls": "memory", "args": {}})
+class Table(Generic[Row]):
+ """A dict-like class, that provides range requests and order based on
+ the hash of keys."""
- def check_config(self, *, check_write):
- return True
+ primary_key: Tuple[str, ...]
+ clustering_key: Tuple[str, ...]
- def _content_add(self, contents: Iterable[Content], with_data: bool) -> Dict:
- self.journal_writer.content_add(contents)
+ def __init__(self):
+ if None in (self.primary_key, self.clustering_key):
+ raise TypeError(f"{self.__class__.__name__} is missing key definition.")
- content_add = 0
- if with_data:
- summary = self.objstorage.content_add(
- c for c in contents if c.status != "absent"
- )
- content_add_bytes = summary["content:add:bytes"]
+ self._list = SortedList[int, Row](key=self._get_row_primary_key)
- for content in contents:
- key = self._content_key(content)
- if key in self._contents:
- continue
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- if hash_ in self._content_indexes[algorithm] and (
- algorithm not in {"blake2s256", "sha256"}
- ):
- colliding_content_hashes = []
- # Add the already stored contents
- for content_hashes_set in self._content_indexes[algorithm][hash_]:
- hashes = dict(content_hashes_set)
- colliding_content_hashes.append(hashes)
- # Add the new colliding content
- colliding_content_hashes.append(content.hashes())
- raise HashCollision(algorithm, hash_, colliding_content_hashes)
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = content.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(key)
- self._objects[content.sha1_git].append(("content", content.sha1))
- self._contents[key] = content
- self._sorted_sha1s.add(content.sha1)
- self._contents[key] = attr.evolve(self._contents[key], data=None)
- content_add += 1
+ def get_row_hash(self, row: Row) -> int:
+ return hash(tuple(getattr(row, key) for key in self.primary_key))
- summary = {
- "content:add": content_add,
- }
- if with_data:
- summary["content:add:bytes"] = content_add_bytes
-
- return summary
-
- def content_add(self, content: Iterable[Content]) -> Dict:
- content = [attr.evolve(c, ctime=now()) for c in content]
- return self._content_add(content, with_data=True)
+ def _get_row_primary_key(self, row: Row) -> Tuple:
+ return (self.get_row_hash(row),) + tuple(
+ getattr(row, key) for key in self.clustering_key
+ )
- def content_update(self, content, keys=[]):
- self.journal_writer.content_update(content)
+ def add_one(self, row: Row):
+ self._list.add(row)
- for cont_update in content:
- cont_update = cont_update.copy()
- sha1 = cont_update.pop("sha1")
- for old_key in self._content_indexes["sha1"][sha1]:
- old_cont = self._contents.pop(old_key)
+ def __iter__(self):
+ return iter(self._list)
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = old_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].remove(old_key)
+ def __len__(self):
+ return len(self._list)
- new_cont = attr.evolve(old_cont, **cont_update)
- new_key = self._content_key(new_cont)
- self._contents[new_key] = new_cont
+class ContentTable(Table[Content]):
+ primary_key = ("sha1", "sha1_git", "sha256", "blake2s256")
+ clustering_key = ()
- for algorithm in DEFAULT_ALGORITHMS:
- hash_ = new_cont.get_hash(algorithm)
- self._content_indexes[algorithm][hash_].add(new_key)
- def content_add_metadata(self, content: Iterable[Content]) -> Dict:
- return self._content_add(content, with_data=False)
+class InMemRunner:
+ def __init__(self):
+ self._content = ContentTable()
- def content_get(self, content):
- # FIXME: Make this method support slicing the `data`.
- if len(content) > BULK_BLOCK_CONTENT_LEN_MAX:
- raise StorageArgumentException(
- "Sending at most %s contents." % BULK_BLOCK_CONTENT_LEN_MAX
- )
- yield from self.objstorage.content_get(content)
+ def _content_add_finalize(self, content: Content):
+ self._content.add_one(content)
- def content_get_range(self, start, end, limit=1000):
- if limit is None:
- raise StorageArgumentException("limit should not be None")
- sha1s = (
- (sha1, content_key)
- for sha1 in self._sorted_sha1s.iter_from(start)
- for content_key in self._content_indexes["sha1"][sha1]
- )
- matched = []
- next_content = None
- for sha1, key in sha1s:
- if sha1 > end:
- break
- if len(matched) >= limit:
- next_content = sha1
- break
- matched.append(self._contents[key].to_dict())
- return {
- "contents": matched,
- "next": next_content,
- }
+ def content_add_prepare(self, content: Content) -> Tuple[int, Callable[[], None]]:
+ token = self._content.get_row_hash(content)
+ finalizer = functools.partial(self._content_add_finalize, content)
+ return (token, finalizer)
- def content_get_partition(
- self,
- partition_id: int,
- nb_partitions: int,
- limit: int = 1000,
- page_token: str = None,
- ):
- if limit is None:
- raise StorageArgumentException("limit should not be None")
- (start, end) = get_partition_bounds_bytes(
- partition_id, nb_partitions, SHA1_SIZE
- )
- if page_token:
- start = hash_to_bytes(page_token)
- if end is None:
- end = b"\xff" * SHA1_SIZE
- result = self.content_get_range(start, end, limit)
- result2 = {
- "contents": result["contents"],
- "next_page_token": None,
- }
- if result["next"]:
- result2["next_page_token"] = hash_to_hex(result["next"])
- return result2
-
- def content_get_metadata(self, contents: List[bytes]) -> Dict[bytes, List[Dict]]:
- result: Dict = {sha1: [] for sha1 in contents}
- for sha1 in contents:
- if sha1 in self._content_indexes["sha1"]:
- objs = self._content_indexes["sha1"][sha1]
- # only 1 element as content_add_metadata would have raised a
- # hash collision otherwise
- for key in objs:
- d = self._contents[key].to_dict()
- del d["ctime"]
- if "data" in d:
- del d["data"]
- result[sha1].append(d)
- return result
-
- def content_find(self, content):
- if not set(content).intersection(DEFAULT_ALGORITHMS):
- raise StorageArgumentException(
- "content keys must contain at least one of: %s"
- % ", ".join(sorted(DEFAULT_ALGORITHMS))
- )
- found = []
- for algo in DEFAULT_ALGORITHMS:
- hash = content.get(algo)
- if hash and hash in self._content_indexes[algo]:
- found.append(self._content_indexes[algo][hash])
-
- if not found:
- return []
-
- keys = list(set.intersection(*found))
- return [self._contents[key].to_dict() for key in keys]
-
- def content_missing(self, content, key_hash="sha1"):
- for cont in content:
- for (algo, hash_) in cont.items():
- if algo not in DEFAULT_ALGORITHMS:
- continue
- if hash_ not in self._content_indexes.get(algo, []):
- yield cont[key_hash]
+ def content_get_from_pk(self, content_hashes: Dict[str, bytes]) -> None:
+ for content in self._content:
+ for algo in HASH_ALGORITHMS:
+ if content_hashes[algo] != getattr(content, algo):
break
else:
- for result in self.content_find(cont):
- if result["status"] == "missing":
- yield cont[key_hash]
-
- def content_missing_per_sha1(self, contents):
- for content in contents:
- if content not in self._content_indexes["sha1"]:
- yield content
-
- def content_missing_per_sha1_git(self, contents):
- for content in contents:
- if content not in self._content_indexes["sha1_git"]:
- yield content
-
- def content_get_random(self):
- return random.choice(list(self._content_indexes["sha1_git"]))
-
- def _skipped_content_add(self, contents: List[SkippedContent]) -> Dict:
- self.journal_writer.skipped_content_add(contents)
-
- summary = {"skipped_content:add": 0}
-
- missing_contents = self.skipped_content_missing([c.hashes() for c in contents])
- missing = {self._content_key(c) for c in missing_contents}
- contents = [c for c in contents if self._content_key(c) in missing]
- for content in contents:
- key = self._content_key(content)
- for algo in DEFAULT_ALGORITHMS:
- if content.get_hash(algo):
- self._skipped_content_indexes[algo][content.get_hash(algo)].add(key)
- self._skipped_contents[key] = content
- summary["skipped_content:add"] += 1
-
- return summary
-
- def skipped_content_add(self, content: Iterable[SkippedContent]) -> Dict:
- content = [attr.evolve(c, ctime=now()) for c in content]
- return self._skipped_content_add(content)
-
- def skipped_content_missing(self, contents):
- for content in contents:
- matches = list(self._skipped_contents.values())
- for (algorithm, key) in self._content_key(content):
- if algorithm == "blake2s256":
- continue
- # Filter out skipped contents with the same hash
- matches = [
- match for match in matches if match.get_hash(algorithm) == key
- ]
- # if none of the contents match
- if not matches:
- yield {algo: content[algo] for algo in DEFAULT_ALGORITHMS}
-
- def directory_add(self, directories: Iterable[Directory]) -> Dict:
- directories = [dir_ for dir_ in directories if dir_.id not in self._directories]
- self.journal_writer.directory_add(directories)
-
- count = 0
- for directory in directories:
- count += 1
- self._directories[directory.id] = directory
- self._objects[directory.id].append(("directory", directory.id))
-
- return {"directory:add": count}
-
- def directory_missing(self, directories):
- for id in directories:
- if id not in self._directories:
- yield id
-
- def _join_dentry_to_content(self, dentry):
- keys = (
- "status",
- "sha1",
- "sha1_git",
- "sha256",
- "length",
- )
- ret = dict.fromkeys(keys)
- ret.update(dentry)
- if ret["type"] == "file":
- # TODO: Make it able to handle more than one content
- content = self.content_find({"sha1_git": ret["target"]})
- if content:
- content = content[0]
- for key in keys:
- ret[key] = content[key]
- return ret
-
- def _directory_ls(self, directory_id, recursive, prefix=b""):
- if directory_id in self._directories:
- for entry in self._directories[directory_id].entries:
- ret = self._join_dentry_to_content(entry.to_dict())
- ret["name"] = prefix + ret["name"]
- ret["dir_id"] = directory_id
- yield ret
- if recursive and ret["type"] == "dir":
- yield from self._directory_ls(
- ret["target"], True, prefix + ret["name"] + b"/"
- )
-
- def directory_ls(self, directory, recursive=False):
- yield from self._directory_ls(directory, recursive)
-
- def directory_entry_get_by_path(self, directory, paths):
- return self._directory_entry_get_by_path(directory, paths, b"")
-
- def directory_get_random(self):
- if not self._directories:
+ return content
+ else:
return None
- return random.choice(list(self._directories))
-
- def _directory_entry_get_by_path(self, directory, paths, prefix):
- if not paths:
- return
-
- contents = list(self.directory_ls(directory))
-
- if not contents:
- return
-
- def _get_entry(entries, name):
- for entry in entries:
- if entry["name"] == name:
- entry = entry.copy()
- entry["name"] = prefix + entry["name"]
- return entry
-
- first_item = _get_entry(contents, paths[0])
-
- if len(paths) == 1:
- return first_item
-
- if not first_item or first_item["type"] != "dir":
- return
-
- return self._directory_entry_get_by_path(
- first_item["target"], paths[1:], prefix + paths[0] + b"/"
- )
-
- def revision_add(self, revisions: Iterable[Revision]) -> Dict:
- revisions = [rev for rev in revisions if rev.id not in self._revisions]
- self.journal_writer.revision_add(revisions)
-
- count = 0
- for revision in revisions:
- revision = attr.evolve(
- revision,
- committer=self._person_add(revision.committer),
- author=self._person_add(revision.author),
- )
- self._revisions[revision.id] = revision
- self._objects[revision.id].append(("revision", revision.id))
- count += 1
-
- return {"revision:add": count}
-
- def revision_missing(self, revisions):
- for id in revisions:
- if id not in self._revisions:
- yield id
-
- def revision_get(self, revisions):
- for id in revisions:
- if id in self._revisions:
- yield self._revisions.get(id).to_dict()
- else:
- yield None
-
- def _get_parent_revs(self, rev_id, seen, limit):
- if limit and len(seen) >= limit:
- return
- if rev_id in seen or rev_id not in self._revisions:
- return
- seen.add(rev_id)
- yield self._revisions[rev_id].to_dict()
- for parent in self._revisions[rev_id].parents:
- yield from self._get_parent_revs(parent, seen, limit)
-
- def revision_log(self, revisions, limit=None):
- seen = set()
- for rev_id in revisions:
- yield from self._get_parent_revs(rev_id, seen, limit)
-
- def revision_shortlog(self, revisions, limit=None):
- yield from (
- (rev["id"], rev["parents"]) for rev in self.revision_log(revisions, limit)
- )
-
- def revision_get_random(self):
- return random.choice(list(self._revisions))
-
- def release_add(self, releases: Iterable[Release]) -> Dict:
- releases = [rel for rel in releases if rel.id not in self._releases]
- self.journal_writer.release_add(releases)
-
- count = 0
- for rel in releases:
- if rel.author:
- self._person_add(rel.author)
- self._objects[rel.id].append(("release", rel.id))
- self._releases[rel.id] = rel
- count += 1
-
- return {"release:add": count}
-
- def release_missing(self, releases):
- yield from (rel for rel in releases if rel not in self._releases)
-
- def release_get(self, releases):
- for rel_id in releases:
- if rel_id in self._releases:
- yield self._releases[rel_id].to_dict()
- else:
- yield None
- def release_get_random(self):
- return random.choice(list(self._releases))
+ def content_get_from_token(self, token) -> Iterator[Row]:
+ for content in self._content:
+ if self._content.get_row_hash(content) == token:
+ yield content
- def snapshot_add(self, snapshots: Iterable[Snapshot]) -> Dict:
- count = 0
- snapshots = (snap for snap in snapshots if snap.id not in self._snapshots)
- for snapshot in snapshots:
- self.journal_writer.snapshot_add([snapshot])
- sorted_branch_names = sorted(snapshot.branches)
- self._snapshots[snapshot.id] = (snapshot, sorted_branch_names)
- self._objects[snapshot.id].append(("snapshot", snapshot.id))
- count += 1
+ def content_get_random(self) -> Row:
+ return random.choice(list(self._content))
- return {"snapshot:add": count}
+ def content_get_token_range(
+ self, start: int, end: int, limit: int
+ ) -> Iterator[Row]:
+ for (i, content) in enumerate(self._content.iter_from((start,))):
+ if i >= limit or self._content.get_row_hash(content) >= end:
+ return
+ yield content
- def snapshot_missing(self, snapshots):
- for id in snapshots:
- if id not in self._snapshots:
- yield id
+ def content_missing_per_sha1_git(self, ids: List[bytes]) -> Iterable[bytes]:
+ missing = set(ids)
+ for content in self._content:
+ missing.remove(content.sha1_git)
+ return missing
- def snapshot_get(self, snapshot_id):
- return self.snapshot_get_branches(snapshot_id)
+ def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
+ pass
- def snapshot_get_by_origin_visit(self, origin, visit):
- origin_url = self._get_origin_url(origin)
- if not origin_url:
- return
+ def content_get_tokens_from_single_hash(
+ self, algo: str, hash_: bytes
+ ) -> Iterable[int]:
+ assert algo in HASH_ALGORITHMS
+ for content in self._content:
+ if getattr(content, algo) == hash_:
+ yield self._content.get_row_hash(content)
- if origin_url not in self._origins or visit > len(
- self._origin_visits[origin_url]
- ):
- return None
-
- visit = self._origin_visit_get_updated(origin_url, visit)
- snapshot_id = visit.snapshot
- if snapshot_id:
- return self.snapshot_get(snapshot_id)
- else:
- return None
- def snapshot_get_latest(self, origin, allowed_statuses=None):
- origin_url = self._get_origin_url(origin)
- if not origin_url:
- return
+class InMemoryStorage(CassandraStorage):
+ def __init__(self, journal_writer=None):
+ self.reset()
+ self.journal_writer = JournalWriter(journal_writer)
+ self.objstorage = ObjStorage({"cls": "memory", "args": {}})
- visit = self.origin_visit_get_latest(
- origin_url, allowed_statuses=allowed_statuses, require_snapshot=True
- )
- if visit and visit["snapshot"]:
- snapshot = self.snapshot_get(visit["snapshot"])
- if not snapshot:
- raise StorageArgumentException(
- "last origin visit references an unknown snapshot"
- )
- return snapshot
+ def reset(self):
+ self._cql_runner = self._runner = InMemRunner()
- def snapshot_count_branches(self, snapshot_id):
- (snapshot, _) = self._snapshots[snapshot_id]
- return collections.Counter(
- branch.target_type.value if branch else None
- for branch in snapshot.branches.values()
+ def content_get_range(self, start, end, limit=1000):
+ # TODO: remove this when swh-index stops using it
+ if limit is None:
+ raise StorageArgumentException("limit should not be None")
+ contents = sorted(
+ (cont for cont in self._runner._content if start <= cont.sha1 <= end),
+ key=lambda cont: cont.sha1,
)
-
- def snapshot_get_branches(
- self, snapshot_id, branches_from=b"", branches_count=1000, target_types=None
- ):
- res = self._snapshots.get(snapshot_id)
- if res is None:
- return None
- (snapshot, sorted_branch_names) = res
- from_index = bisect.bisect_left(sorted_branch_names, branches_from)
- if target_types:
- next_branch = None
- branches = {}
- for branch_name in sorted_branch_names[from_index:]:
- branch = snapshot.branches[branch_name]
- if branch and branch.target_type.value in target_types:
- if len(branches) < branches_count:
- branches[branch_name] = branch
- else:
- next_branch = branch_name
- break
+ if len(contents) <= limit:
+ next_content = None
else:
- # As there is no 'target_types', we can do that much faster
- to_index = from_index + branches_count
- returned_branch_names = sorted_branch_names[from_index:to_index]
- branches = {
- branch_name: snapshot.branches[branch_name]
- for branch_name in returned_branch_names
- }
- if to_index >= len(sorted_branch_names):
- next_branch = None
- else:
- next_branch = sorted_branch_names[to_index]
-
- branches = {
- name: branch.to_dict() if branch else None
- for (name, branch) in branches.items()
- }
-
+ next_content = contents[limit]
+ contents = contents[:limit]
return {
- "id": snapshot_id,
- "branches": branches,
- "next_branch": next_branch,
- }
-
- def snapshot_get_random(self):
- return random.choice(list(self._snapshots))
-
- def object_find_by_sha1_git(self, ids):
- ret = {}
- for id_ in ids:
- objs = self._objects.get(id_, [])
- ret[id_] = [{"sha1_git": id_, "type": obj[0],} for obj in objs]
- return ret
-
- def _convert_origin(self, t):
- if t is None:
- return None
-
- return t.to_dict()
-
- def origin_get(self, origins):
- if isinstance(origins, dict):
- # Old API
- return_single = True
- origins = [origins]
- else:
- return_single = False
-
- # Sanity check to be error-compatible with the pgsql backend
- if any("id" in origin for origin in origins) and not all(
- "id" in origin for origin in origins
- ):
- raise StorageArgumentException(
- 'Either all origins or none at all should have an "id".'
- )
- if any("url" in origin for origin in origins) and not all(
- "url" in origin for origin in origins
- ):
- raise StorageArgumentException(
- "Either all origins or none at all should have " 'an "url" key.'
- )
-
- results = []
- for origin in origins:
- result = None
- if "url" in origin:
- if origin["url"] in self._origins:
- result = self._origins[origin["url"]]
- else:
- raise StorageArgumentException("Origin must have an url.")
- results.append(self._convert_origin(result))
-
- if return_single:
- assert len(results) == 1
- return results[0]
- else:
- return results
-
- def origin_get_by_sha1(self, sha1s):
- return [self._convert_origin(self._origins_by_sha1.get(sha1)) for sha1 in sha1s]
-
- def origin_get_range(self, origin_from=1, origin_count=100):
- origin_from = max(origin_from, 1)
- if origin_from <= len(self._origins_by_id):
- max_idx = origin_from + origin_count - 1
- if max_idx > len(self._origins_by_id):
- max_idx = len(self._origins_by_id)
- for idx in range(origin_from - 1, max_idx):
- origin = self._convert_origin(self._origins[self._origins_by_id[idx]])
- yield {"id": idx + 1, **origin}
-
- def origin_list(self, page_token: Optional[str] = None, limit: int = 100) -> dict:
- origin_urls = sorted(self._origins)
- if page_token:
- from_ = bisect.bisect_left(origin_urls, page_token)
- else:
- from_ = 0
-
- result = {
- "origins": [
- {"url": origin_url} for origin_url in origin_urls[from_ : from_ + limit]
- ]
+ "contents": matched,
+ "next": next_content,
}
- if from_ + limit < len(origin_urls):
- result["next_page_token"] = origin_urls[from_ + limit]
-
- return result
-
- def origin_search(
- self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False
- ):
- origins = map(self._convert_origin, self._origins.values())
- if regexp:
- pat = re.compile(url_pattern)
- origins = [orig for orig in origins if pat.search(orig["url"])]
- else:
- origins = [orig for orig in origins if url_pattern in orig["url"]]
- if with_visit:
- filtered_origins = []
- for orig in origins:
- visits = (
- self._origin_visit_get_updated(ov.origin, ov.visit)
- for ov in self._origin_visits[orig["url"]]
- )
- for ov in visits:
- if ov.snapshot and ov.snapshot in self._snapshots:
- filtered_origins.append(orig)
- break
- else:
- filtered_origins = origins
-
- return filtered_origins[offset : offset + limit]
-
- def origin_count(self, url_pattern, regexp=False, with_visit=False):
- return len(
- self.origin_search(
- url_pattern,
- regexp=regexp,
- with_visit=with_visit,
- limit=len(self._origins),
- )
- )
-
- def origin_add(self, origins: Iterable[Origin]) -> List[Dict]:
- origins = copy.deepcopy(list(origins))
- for origin in origins:
- self.origin_add_one(origin)
- return [origin.to_dict() for origin in origins]
-
- def origin_add_one(self, origin: Origin) -> str:
- if origin.url not in self._origins:
- self.journal_writer.origin_add([origin])
- # generate an origin_id because it is needed by origin_get_range.
- # TODO: remove this when we remove origin_get_range
- origin_id = len(self._origins) + 1
- self._origins_by_id.append(origin.url)
- assert len(self._origins_by_id) == origin_id
-
- self._origins[origin.url] = origin
- self._origins_by_sha1[origin_url_to_sha1(origin.url)] = origin
- self._origin_visits[origin.url] = []
- self._objects[origin.url].append(("origin", origin.url))
-
- return origin.url
-
- def origin_visit_add(
- self, origin_url: str, date: Union[str, datetime.datetime], type: str
- ) -> OriginVisit:
- if isinstance(date, str):
- # FIXME: Converge on iso8601 at some point
- date = dateutil.parser.parse(date)
- elif not isinstance(date, datetime.datetime):
- raise StorageArgumentException("Date must be a datetime or a string")
-
- origin = self.origin_get({"url": origin_url})
- if not origin: # Cannot add a visit without an origin
- raise StorageArgumentException("Unknown origin %s", origin_url)
-
- if origin_url in self._origins:
- origin = self._origins[origin_url]
- # visit ids are in the range [1, +inf[
- visit_id = len(self._origin_visits[origin_url]) + 1
- status = "ongoing"
- with convert_validation_exceptions():
- visit = OriginVisit(
- origin=origin_url,
- date=date,
- type=type,
- # TODO: Remove when we remove those fields from the model
- status=status,
- snapshot=None,
- metadata=None,
- visit=visit_id,
- )
- self._origin_visits[origin_url].append(visit)
- assert visit.visit is not None
- visit_key = (origin_url, visit.visit)
-
- with convert_validation_exceptions():
- visit_update = OriginVisitStatus(
- origin=origin_url,
- visit=visit_id,
- date=date,
- status=status,
- snapshot=None,
- metadata=None,
- )
- self._origin_visit_statuses[visit_key] = [visit_update]
-
- self._objects[visit_key].append(("origin_visit", None))
-
- self.journal_writer.origin_visit_add([visit])
-
- # return last visit
- return visit
-
- def origin_visit_update(
- self,
- origin: str,
- visit_id: int,
- status: str,
- metadata: Optional[Dict] = None,
- snapshot: Optional[bytes] = None,
- date: Optional[datetime.datetime] = None,
- ):
- origin_url = self._get_origin_url(origin)
- if origin_url is None:
- raise StorageArgumentException("Unknown origin.")
-
- try:
- visit = self._origin_visits[origin_url][visit_id - 1]
- except IndexError:
- raise StorageArgumentException("Unknown visit_id for this origin") from None
-
- # Retrieve the previous visit status
- assert visit.visit is not None
- visit_key = (origin_url, visit.visit)
-
- last_visit_update = max(
- self._origin_visit_statuses[visit_key], key=lambda v: v.date
- )
-
- with convert_validation_exceptions():
- visit_update = OriginVisitStatus(
- origin=origin_url,
- visit=visit_id,
- date=date or now(),
- status=status,
- snapshot=snapshot or last_visit_update.snapshot,
- metadata=metadata or last_visit_update.metadata,
- )
- self._origin_visit_statuses[visit_key].append(visit_update)
-
- self.journal_writer.origin_visit_update(
- [self._origin_visit_get_updated(origin_url, visit_id)]
- )
-
- self._origin_visits[origin_url][visit_id - 1] = visit
-
- def origin_visit_upsert(self, visits: Iterable[OriginVisit]) -> None:
- for visit in visits:
- if visit.visit is None:
- raise StorageArgumentException(f"Missing visit id for visit {visit}")
-
- self.journal_writer.origin_visit_upsert(visits)
-
- date = now()
-
- for visit in visits:
- assert visit.visit is not None
- origin_url = visit.origin
- origin = self.origin_get({"url": origin_url})
-
- if not origin: # Cannot add a visit without an origin
- raise StorageArgumentException("Unknown origin %s", origin_url)
-
- if origin_url in self._origins:
- origin = self._origins[origin_url]
- # visit ids are in the range [1, +inf[
- assert visit.visit is not None
- visit_key = (origin_url, visit.visit)
-
- with convert_validation_exceptions():
- visit_update = OriginVisitStatus(
- origin=origin_url,
- visit=visit.visit,
- date=date,
- status=visit.status,
- snapshot=visit.snapshot,
- metadata=visit.metadata,
- )
-
- self._origin_visit_statuses.setdefault(visit_key, [])
- while len(self._origin_visits[origin_url]) <= visit.visit:
- self._origin_visits[origin_url].append(None)
-
- self._origin_visits[origin_url][visit.visit - 1] = visit
- self._origin_visit_statuses[visit_key].append(visit_update)
-
- self._objects[visit_key].append(("origin_visit", None))
-
- def _origin_visit_get_updated(self, origin: str, visit_id: int) -> OriginVisit:
- """Merge origin visit and latest origin visit status
-
- """
- assert visit_id >= 1
- visit = self._origin_visits[origin][visit_id - 1]
- assert visit is not None
- visit_key = (origin, visit_id)
-
- visit_update = max(self._origin_visit_statuses[visit_key], key=lambda v: v.date)
-
- return OriginVisit.from_dict(
- {
- # default to the values in visit
- **visit.to_dict(),
- # override with the last update
- **visit_update.to_dict(),
- # but keep the date of the creation of the origin visit
- "date": visit.date,
- }
- )
-
- def origin_visit_get(
- self, origin: str, last_visit: Optional[int] = None, limit: Optional[int] = None
- ) -> Iterable[Dict[str, Any]]:
- origin_url = self._get_origin_url(origin)
- if origin_url in self._origin_visits:
- visits = self._origin_visits[origin_url]
- if last_visit is not None:
- visits = visits[last_visit:]
- if limit is not None:
- visits = visits[:limit]
- for visit in visits:
- if not visit:
- continue
- visit_id = visit.visit
-
- visit_update = self._origin_visit_get_updated(origin_url, visit_id)
- assert visit_update is not None
- yield visit_update.to_dict()
-
- def origin_visit_find_by_date(
- self, origin: str, visit_date: datetime.datetime
- ) -> Optional[Dict[str, Any]]:
- origin_url = self._get_origin_url(origin)
- if origin_url in self._origin_visits:
- visits = self._origin_visits[origin_url]
- visit = min(visits, key=lambda v: (abs(v.date - visit_date), -v.visit))
- visit_update = self._origin_visit_get_updated(origin, visit.visit)
- assert visit_update is not None
- return visit_update.to_dict()
- return None
-
- def origin_visit_get_by(self, origin: str, visit: int) -> Optional[Dict[str, Any]]:
- origin_url = self._get_origin_url(origin)
- if origin_url in self._origin_visits and visit <= len(
- self._origin_visits[origin_url]
- ):
- visit_update = self._origin_visit_get_updated(origin_url, visit)
- assert visit_update is not None
- return visit_update.to_dict()
- return None
-
- def origin_visit_get_latest(
- self,
- origin: str,
- allowed_statuses: Optional[List[str]] = None,
- require_snapshot: bool = False,
- ) -> Optional[Dict[str, Any]]:
- ori = self._origins.get(origin)
- if not ori:
- return None
- visits = self._origin_visits[ori.url]
- visits = [
- self._origin_visit_get_updated(visit.origin, visit.visit)
- for visit in visits
- if visit is not None
- ]
-
- if allowed_statuses is not None:
- visits = [visit for visit in visits if visit.status in allowed_statuses]
- if require_snapshot:
- visits = [visit for visit in visits if visit.snapshot]
-
- visit = max(visits, key=lambda v: (v.date, v.visit), default=None)
- if visit is None:
- return None
- return visit.to_dict()
-
- def _select_random_origin_visit_by_type(self, type: str) -> str:
- while True:
- url = random.choice(list(self._origin_visits.keys()))
- random_origin_visits = self._origin_visits[url]
- if random_origin_visits[0].type == type:
- return url
-
- def origin_visit_get_random(self, type: str) -> Optional[Dict[str, Any]]:
- url = self._select_random_origin_visit_by_type(type)
- random_origin_visits = copy.deepcopy(self._origin_visits[url])
- random_origin_visits.reverse()
- back_in_the_day = now() - timedelta(weeks=12) # 3 months back
- # This should be enough for tests
- for visit in random_origin_visits:
- updated_visit = self._origin_visit_get_updated(url, visit.visit)
- assert updated_visit is not None
- if updated_visit.date > back_in_the_day and updated_visit.status == "full":
- return updated_visit.to_dict()
- else:
- return None
-
def stat_counters(self):
- keys = (
- "content",
- "directory",
- "origin",
- "origin_visit",
- "person",
- "release",
- "revision",
- "skipped_content",
- "snapshot",
- )
- stats = {key: 0 for key in keys}
- stats.update(
- collections.Counter(
- obj_type
- for (obj_type, obj_id) in itertools.chain(*self._objects.values())
- )
- )
- return stats
-
- def refresh_stat_counters(self):
- pass
-
- def origin_metadata_add(
- self,
- origin_url: str,
- discovery_date: datetime.datetime,
- authority: Dict[str, Any],
- fetcher: Dict[str, Any],
- format: str,
- metadata: bytes,
- ) -> None:
- if not isinstance(origin_url, str):
- raise StorageArgumentException(
- "origin_id must be str, not %r" % (origin_url,)
- )
- if not isinstance(metadata, bytes):
- raise StorageArgumentException(
- "metadata must be bytes, not %r" % (metadata,)
- )
- authority_key = self._metadata_authority_key(authority)
- if authority_key not in self._metadata_authorities:
- raise StorageArgumentException(f"Unknown authority {authority}")
- fetcher_key = self._metadata_fetcher_key(fetcher)
- if fetcher_key not in self._metadata_fetchers:
- raise StorageArgumentException(f"Unknown fetcher {fetcher}")
-
- origin_metadata = {
- "origin_url": origin_url,
- "discovery_date": discovery_date,
- "authority": authority_key,
- "fetcher": fetcher_key,
- "format": format,
- "metadata": metadata,
- }
- self._origin_metadata[origin_url][authority_key].add(origin_metadata)
- return None
-
- def origin_metadata_get(
- self,
- origin_url: str,
- authority: Dict[str, str],
- after: Optional[datetime.datetime] = None,
- limit: Optional[int] = None,
- ) -> List[Dict[str, Any]]:
- if not isinstance(origin_url, str):
- raise TypeError("origin_url must be str, not %r" % (origin_url,))
-
- authority_key = self._metadata_authority_key(authority)
-
- if after is None:
- entries = iter(self._origin_metadata[origin_url][authority_key])
- else:
- entries = self._origin_metadata[origin_url][authority_key].iter_from(after)
- if limit:
- entries = itertools.islice(entries, 0, limit)
-
- results = []
- for entry in entries:
- authority = self._metadata_authorities[entry["authority"]]
- fetcher = self._metadata_fetchers[entry["fetcher"]]
- results.append(
- {
- **entry,
- "authority": {"type": authority["type"], "url": authority["url"],},
- "fetcher": {
- "name": fetcher["name"],
- "version": fetcher["version"],
- },
- }
- )
- return results
-
- def metadata_fetcher_add(
- self, name: str, version: str, metadata: Dict[str, Any]
- ) -> None:
- fetcher = {
- "name": name,
- "version": version,
- "metadata": metadata,
- }
- key = self._metadata_fetcher_key(fetcher)
- if key not in self._metadata_fetchers:
- self._metadata_fetchers[key] = fetcher
-
- def metadata_fetcher_get(self, name: str, version: str) -> Optional[Dict[str, Any]]:
- return self._metadata_fetchers.get(
- self._metadata_fetcher_key({"name": name, "version": version})
- )
-
- def metadata_authority_add(
- self, type: str, url: str, metadata: Dict[str, Any]
- ) -> None:
- authority = {
- "type": type,
- "url": url,
- "metadata": metadata,
+ return {
+ "content": len(self._runner._content),
}
- key = self._metadata_authority_key(authority)
- self._metadata_authorities[key] = authority
-
- def metadata_authority_get(self, type: str, url: str) -> Optional[Dict[str, Any]]:
- return self._metadata_authorities.get(
- self._metadata_authority_key({"type": type, "url": url})
- )
-
- def _get_origin_url(self, origin):
- if isinstance(origin, str):
- return origin
- else:
- raise TypeError("origin must be a string.")
-
- def _person_add(self, person):
- key = ("person", person.fullname)
- if key not in self._objects:
- person_id = len(self._persons) + 1
- self._persons.append(person)
- self._objects[key].append(("person", person_id))
- else:
- person_id = self._objects[key][0][1]
- person = self._persons[person_id - 1]
- return person
-
- @staticmethod
- def _content_key(content):
- """ A stable key and the algorithm for a content"""
- if isinstance(content, BaseContent):
- content = content.to_dict()
- return tuple((key, content.get(key)) for key in sorted(DEFAULT_ALGORITHMS))
-
- @staticmethod
- def _metadata_fetcher_key(fetcher: Dict) -> Hashable:
- return (fetcher["name"], fetcher["version"])
-
- @staticmethod
- def _metadata_authority_key(authority: Dict) -> Hashable:
- return (authority["type"], authority["url"])
-
- def diff_directories(self, from_dir, to_dir, track_renaming=False):
- raise NotImplementedError("InMemoryStorage.diff_directories")
-
- def diff_revisions(self, from_rev, to_rev, track_renaming=False):
- raise NotImplementedError("InMemoryStorage.diff_revisions")
-
- def diff_revision(self, revision, track_renaming=False):
- raise NotImplementedError("InMemoryStorage.diff_revision")
-
- def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None:
- """Do nothing
-
- """
- return None
-
- def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict:
- return {}

File Metadata

Mime Type
text/plain
Expires
Dec 21 2024, 11:50 AM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3218424

Event Timeline