diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/in_memory.py @@ -0,0 +1,1105 @@ +# Copyright (C) 2015-2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import re +import bisect +import dateutil +import collections +from collections import defaultdict +import copy +import datetime +import itertools +import random +import warnings + +from swh.model.hashutil import DEFAULT_ALGORITHMS +from swh.model.identifiers import normalize_timestamp + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +OriginVisitKey = collections.namedtuple('OriginVisitKey', 'origin date') + + +class Storage: + def __init__(self): + self._contents = {} + self._contents_data = {} + self._content_indexes = defaultdict(lambda: defaultdict(set)) + + self._directories = {} + self._revisions = {} + self._releases = {} + self._snapshots = {} + self._origins = {} + self._origin_visits = {} + self._origin_metadata = defaultdict(list) + self._tools = {} + self._metadata_providers = {} + self._objects = defaultdict(list) + + def check_config(self, *, check_write): + """Check that the storage is configured and ready to go.""" + return True + + def content_add(self, contents): + """Add content blobs to the storage + + Note: in case of DB errors, objects might have already been added to + the object storage and will not be removed. Since addition to the + object storage is idempotent, that should not be a problem. + + Args: + content (iterable): iterable of dictionaries representing + individual pieces of content to add. Each dictionary has the + following keys: + + - data (bytes): the actual content + - length (int): content length (default: -1) + - one key for each checksum algorithm in + :data:`swh.model.hashutil.DEFAULT_ALGORITHMS`, mapped to the + corresponding checksum + - status (str): one of visible, hidden, absent + - reason (str): if status = absent, the reason why + - origin (int): if status = absent, the origin we saw the + content in + + """ + for content in contents: + key = self._content_key(content) + if key in self._contents: + continue + for algorithm in DEFAULT_ALGORITHMS: + if content[algorithm] in self._content_indexes[algorithm]: + from . import HashCollision + raise HashCollision(algorithm, content[algorithm], key) + for algorithm in DEFAULT_ALGORITHMS: + self._content_indexes[algorithm][content[algorithm]].add(key) + self._objects[content['sha1_git']].append( + ('content', content['sha1'])) + self._contents[key] = copy.deepcopy(content) + self._contents[key]['ctime'] = now() + if self._contents[key]['status'] == 'visible': + self._contents_data[key] = self._contents[key].pop('data') + + def content_get_metadata(self, sha1s): + """Retrieve content metadata in bulk + + Args: + content: iterable of content identifiers (sha1) + + Returns: + an iterable with content metadata corresponding to the given ids + """ + # FIXME: the return value should be a mapping from search key to found + # content*s* + for sha1 in sha1s: + if sha1 in self._content_indexes['sha1']: + objs = self._content_indexes['sha1'][sha1] + key = random.sample(objs, 1)[0] + data = copy.deepcopy(self._contents[key]) + data.pop('ctime') + yield data + else: + # FIXME: should really be None + yield { + 'sha1': sha1, + 'sha1_git': None, + 'sha256': None, + 'blake2s256': None, + 'length': None, + 'status': None, + } + + def content_find(self, content): + if not set(content).intersection(DEFAULT_ALGORITHMS): + raise ValueError('content keys must contain at least one of: ' + '%s' % ', '.join(sorted(DEFAULT_ALGORITHMS))) + found = [] + for algo in DEFAULT_ALGORITHMS: + hash = content.get(algo) + if hash and hash in self._content_indexes[algo]: + found.append(self._content_indexes[algo][hash]) + if not found: + return + keys = list(set.intersection(*found)) + + # FIXME: should really be a list of all the objects found + return copy.deepcopy(self._contents[keys[0]]) + + def content_missing(self, contents, key_hash='sha1'): + """List content missing from storage + + Args: + contents ([dict]): iterable of dictionaries containing one + key for each checksum algorithm in + :data:`swh.model.hashutil.ALGORITHMS`, + mapped to the corresponding checksum, + and a length key mapped to the content + length. + + key_hash (str): name of the column to use as hash id + result (default: 'sha1') + + Returns: + iterable ([bytes]): missing content ids (as per the + key_hash column) + """ + for content in contents: + if self._content_key(content) not in self._contents: + yield content[key_hash] + + def content_missing_per_sha1(self, contents): + """List content missing from storage based only on sha1. + + Args: + contents: Iterable of sha1 to check for absence. + + Returns: + iterable: missing ids + + Raises: + TODO: an exception when we get a hash collision. + + """ + for content in contents: + if content not in self._content_indexes['sha1']: + yield content + + def directory_add(self, directories): + """Add directories to the storage + + Args: + directories (iterable): iterable of dictionaries representing the + individual directories to add. Each dict has the following + keys: + + - id (sha1_git): the id of the directory to add + - entries (list): list of dicts for each entry in the + directory. Each dict has the following keys: + + - name (bytes) + - type (one of 'file', 'dir', 'rev'): type of the + directory entry (file, directory, revision) + - target (sha1_git): id of the object pointed at by the + directory entry + - perms (int): entry permissions + """ + for directory in directories: + if directory['id'] not in self._directories: + self._directories[directory['id']] = copy.deepcopy(directory) + self._objects[directory['id']].append( + ('directory', directory['id'])) + + def directory_missing(self, directory_ids): + """List directories missing from storage + + Args: + directories (iterable): an iterable of directory ids + + Yields: + missing directory ids + + """ + for id in directory_ids: + if id not in self._directories: + yield id + + def _join_dentry_to_content(self, dentry): + keys = ( + 'status', + 'sha1', + 'sha1_git', + 'sha256', + 'length', + ) + ret = dict.fromkeys(keys) + ret.update(dentry) + if ret['type'] == 'file': + content = self.content_find({'sha1_git': ret['target']}) + if content: + for key in keys: + ret[key] = content[key] + return ret + + def directory_ls(self, directory_id): + """Get entries for one directory. + + Args: + - directory: the directory to list entries from. + - recursive: if flag on, this list recursively from this directory. + + Returns: + List of entries for such directory. + + """ + if directory_id in self._directories: + for entry in self._directories[directory_id]['entries']: + ret = self._join_dentry_to_content(entry) + ret['dir_id'] = directory_id + yield ret + + def directory_entry_get_by_path(self, directory, paths): + """Get the directory entry (either file or dir) from directory with path. + + Args: + - directory: sha1 of the top level directory + - paths: path to lookup from the top level directory. From left + (top) to right (bottom). + + Returns: + The corresponding directory entry if found, None otherwise. + + """ + if not paths: + return + + contents = list(self.directory_ls(directory)) + + if not contents: + return + + def _get_entry(entries, name): + for entry in entries: + if entry['name'] == name: + return entry + + first_item = _get_entry(contents, paths[0]) + + if len(paths) == 1: + return first_item + + if not first_item or first_item['type'] != 'dir': + return + + return self.directory_entry_get_by_path( + first_item['target'], paths[1:]) + + def revision_add(self, revisions): + """Add revisions to the storage + + Args: + revisions (iterable): iterable of dictionaries representing the + individual revisions to add. Each dict has the following keys: + + - id (sha1_git): id of the revision to add + - date (datetime.DateTime): date the revision was written + - date_offset (int): offset from UTC in minutes the revision + was written + - date_neg_utc_offset (boolean): whether a null date_offset + represents a negative UTC offset + - committer_date (datetime.DateTime): date the revision got + added to the origin + - committer_date_offset (int): offset from UTC in minutes the + revision was added to the origin + - committer_date_neg_utc_offset (boolean): whether a null + committer_date_offset represents a negative UTC offset + - type (one of 'git', 'tar'): type of the revision added + - directory (sha1_git): the directory the revision points at + - message (bytes): the message associated with the revision + - author_name (bytes): the name of the revision author + - author_email (bytes): the email of the revision author + - committer_name (bytes): the name of the revision committer + - committer_email (bytes): the email of the revision committer + - metadata (jsonb): extra information as dictionary + - synthetic (bool): revision's nature (tarball, directory + creates synthetic revision) + - parents (list of sha1_git): the parents of this revision + + """ + for revision in revisions: + if revision['id'] not in self._revisions: + self._revisions[revision['id']] = rev = copy.deepcopy(revision) + rev['date'] = normalize_timestamp(rev.get('date')) + rev['committer_date'] = normalize_timestamp( + rev.get('committer_date')) + self._objects[revision['id']].append( + ('revision', revision['id'])) + + def revision_missing(self, revision_ids): + """List revisions missing from storage + + Args: + revisions (iterable): revision ids + + Yields: + missing revision ids + + """ + for id in revision_ids: + if id not in self._revisions: + yield id + + def revision_get(self, revision_ids): + for id in revision_ids: + yield copy.deepcopy(self._revisions.get(id)) + + def _get_parent_revs(self, rev_id, seen, limit): + if limit and len(seen) >= limit: + return + if rev_id in seen: + return + seen.add(rev_id) + yield self._revisions[rev_id] + for parent in self._revisions[rev_id]['parents']: + yield from self._get_parent_revs(parent, seen, limit) + + def revision_log(self, revision_ids, limit=None): + """Fetch revision entry from the given root revisions. + + Args: + revisions: array of root revision to lookup + limit: limitation on the output result. Default to None. + + Yields: + List of revision log from such revisions root. + + """ + seen = set() + for rev_id in revision_ids: + yield from self._get_parent_revs(rev_id, seen, limit) + + def revision_shortlog(self, revisions, limit=None): + """Fetch the shortlog for the given revisions + + Args: + revisions: list of root revisions to lookup + limit: depth limitation for the output + + Yields: + a list of (id, parents) tuples. + + """ + yield from ((rev['id'], rev['parents']) + for rev in self.revision_log(revisions, limit)) + + def release_add(self, releases): + """Add releases to the storage + + Args: + releases (iterable): iterable of dictionaries representing the + individual releases to add. Each dict has the following keys: + + - id (sha1_git): id of the release to add + - revision (sha1_git): id of the revision the release points to + - date (datetime.DateTime): the date the release was made + - date_offset (int): offset from UTC in minutes the release was + made + - date_neg_utc_offset (boolean): whether a null date_offset + represents a negative UTC offset + - name (bytes): the name of the release + - comment (bytes): the comment associated with the release + - author_name (bytes): the name of the release author + - author_email (bytes): the email of the release author + + """ + for rel in releases: + rel['date'] = normalize_timestamp(rel['date']) + self._objects[rel['id']].append( + ('release', rel['id'])) + self._releases.update((rel['id'], rel) for rel in releases) + + def release_missing(self, releases): + """List releases missing from storage + + Args: + releases: an iterable of release ids + + Returns: + a list of missing release ids + + """ + yield from (rel for rel in releases if rel not in self._releases) + + def release_get(self, releases): + """Given a list of sha1, return the releases's information + + Args: + releases: list of sha1s + + Yields: + dicts with the same keys as those given to `release_add` + + Raises: + ValueError: if the keys does not match (url and type) nor id. + + """ + yield from map(self._releases.__getitem__, releases) + + def snapshot_add(self, origin, visit, snapshot): + """Add a snapshot for the given origin/visit couple + + Args: + origin (int): id of the origin + visit (int): id of the visit + snapshot (dict): the snapshot to add to the visit, containing the + following keys: + + - **id** (:class:`bytes`): id of the snapshot + - **branches** (:class:`dict`): branches the snapshot contains, + mapping the branch name (:class:`bytes`) to the branch target, + itself a :class:`dict` (or ``None`` if the branch points to an + unknown object) + + - **target_type** (:class:`str`): one of ``content``, + ``directory``, ``revision``, ``release``, + ``snapshot``, ``alias`` + - **target** (:class:`bytes`): identifier of the target + (currently a ``sha1_git`` for all object kinds, or the name + of the target branch for aliases) + """ + snapshot_id = snapshot['id'] + if snapshot_id not in self._snapshots: + self._snapshots[snapshot_id] = { + 'origin': origin, + 'visit': visit, + 'id': snapshot_id, + 'branches': copy.deepcopy(snapshot['branches']), + '_sorted_branch_names': sorted(snapshot['branches']) + } + self._origin_visits[visit]['snapshot'] = snapshot_id + + def snapshot_get(self, snapshot_id): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + snapshot_id (bytes): identifier of the snapshot + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + return self.snapshot_get_branches(snapshot_id) + + def snapshot_get_by_origin_visit(self, origin, visit): + """Get the content, possibly partial, of a snapshot for the given origin visit + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + visit (int): the visit identifier + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + + """ + if visit not in self._origin_visits: + return None + snapshot_id = self._origin_visits[visit]['snapshot'] + if snapshot_id: + return self.snapshot_get(snapshot_id) + else: + return None + + def snapshot_get_latest(self, origin, allowed_statuses=None): + """Get the content, possibly partial, of the latest snapshot for the + given origin, optionally only from visits that have one of the given + allowed_statuses + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + allowed_statuses (list of str): list of visit statuses considered + to find the latest snapshot for the visit. For instance, + ``allowed_statuses=['full']`` will only consider visits that + have successfully run to completion. + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + if allowed_statuses is None: + visits_dates = list(itertools.chain( + *self._origins[origin]['visits_dates'].values())) + else: + last_visits = self._origins[origin]['visits_dates'] + visits_dates = list(itertools.chain( + *map(last_visits.__getitem__, allowed_statuses))) + + for visit_date in sorted(visits_dates, reverse=True): + visit_id = self._origin_visit_key( + {'origin': origin, 'date': visit_date}) + snapshot_id = self._origin_visits[visit_id]['snapshot'] + snapshot = self.snapshot_get(snapshot_id) + if snapshot: + return snapshot + + return None + + def snapshot_count_branches(self, snapshot_id, db=None, cur=None): + """Count the number of branches in the snapshot with the given id + + Args: + snapshot_id (bytes): identifier of the snapshot + + Returns: + dict: A dict whose keys are the target types of branches and + values their corresponding amount + """ + branches = list(self._snapshots[snapshot_id]['branches'].values()) + return collections.Counter(branch['target_type'] if branch else None + for branch in branches) + + def snapshot_get_branches(self, snapshot_id, branches_from=b'', + branches_count=1000, target_types=None): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + Args: + snapshot_id (bytes): identifier of the snapshot + branches_from (bytes): optional parameter used to skip branches + whose name is lesser than it before returning them + branches_count (int): optional parameter used to restrain + the amount of returned branches + target_types (list): optional parameter used to filter the + target types of branch to return (possible values that can be + contained in that list are `'content', 'directory', + 'revision', 'release', 'snapshot', 'alias'`) + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than + `branches_count` branches after `branches_from` included. + """ + snapshot = self._snapshots.get(snapshot_id, None) + if snapshot is None: + return None + sorted_branch_names = snapshot['_sorted_branch_names'] + from_index = bisect.bisect_left( + sorted_branch_names, branches_from) + if target_types: + next_branch = None + branches = {} + for branch_name in sorted_branch_names[from_index:]: + branch = snapshot['branches'][branch_name] + if branch and branch['target_type'] in target_types: + if len(branches) < branches_count: + branches[branch_name] = branch + else: + next_branch = branch_name + break + else: + # As there is no 'target_types', we can do that much faster + to_index = from_index + branches_count + returned_branch_names = sorted_branch_names[from_index:to_index] + branches = {branch_name: snapshot['branches'][branch_name] + for branch_name in returned_branch_names} + if to_index >= len(sorted_branch_names): + next_branch = None + else: + next_branch = sorted_branch_names[to_index] + return { + 'id': snapshot_id, + 'branches': branches, + 'next_branch': next_branch, + } + + def object_find_by_sha1_git(self, ids, db=None, cur=None): + """Return the objects found with the given ids. + + Args: + ids: a generator of sha1_gits + + Returns: + dict: a mapping from id to the list of objects found. Each object + found is itself a dict with keys: + + - sha1_git: the input id + - type: the type of object found + - id: the id of the object found + - object_id: the numeric id of the object found. + + """ + ret = {} + for id_ in ids: + objs = self._objects.get(id_, []) + ret[id_] = [{ + 'sha1_git': id_, + 'type': obj[0], + 'id': obj[1], + 'object_id': id_, + } for obj in objs] + return ret + + def origin_get(self, origin): + """Return the origin either identified by its id or its tuple + (type, url). + + Args: + origin: dictionary representing the individual origin to find. + This dict has either the keys type and url: + + - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) + - url (bytes): the url the origin points to + + or the id: + + - id: the origin id + + Returns: + dict: the origin dictionary with the keys: + + - id: origin's id + - type: origin's type + - url: origin's url + + Raises: + ValueError: if the keys does not match (url and type) nor id. + + """ + if 'id' in origin: + key = origin['id'] + elif 'type' in origin and 'url' in origin: + key = self._origin_key(origin) + else: + raise ValueError('Origin must have either id or (type and url).') + if key not in self._origins: + return None + else: + origin = copy.deepcopy(self._origins[key]) + del origin['visits_dates'] + origin['id'] = self._origin_key(origin) + return origin + + def origin_search(self, url_pattern, offset=0, limit=50, + regexp=False, with_visit=False, db=None, cur=None): + """Search for origins whose urls contain a provided string pattern + or match a provided regular expression. + The search is performed in a case insensitive way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + offset (int): number of found origins to skip before returning + results + limit (int): the maximum number of found origins to return + regexp (bool): if True, consider the provided pattern as a regular + expression and return origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + + Returns: + An iterable of dict containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + origins = iter(self._origins.values()) + if regexp: + pat = re.compile(url_pattern) + origins = (orig for orig in origins if pat.match(orig['url'])) + else: + origins = (orig for orig in origins if url_pattern in orig['url']) + if with_visit: + origins = (orig for orig in origins if orig['visits_dates']) + origins = sorted(origins, key=self._origin_key) + origins = copy.deepcopy(origins[offset:offset+limit]) + for orig in origins: + del orig['visits_dates'] + orig['id'] = self._origin_key(orig) + return origins + + def origin_add(self, origins): + """Add origins to the storage + + Args: + origins: list of dictionaries representing the individual origins, + with the following keys: + + - type: the origin type ('git', 'svn', 'deb', ...) + - url (bytes): the url the origin points to + + Returns: + list: given origins as dict updated with their id + + """ + origins = copy.deepcopy(origins) + for origin in origins: + origin['id'] = self.origin_add_one(origin) + return origins + + def origin_add_one(self, origin): + """Add origin to the storage + + Args: + origin: dictionary representing the individual origin to add. This + dict has the following keys: + + - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) + - url (bytes): the url the origin points to + + Returns: + the id of the added origin, or of the identical one that already + exists. + + """ + origin = copy.deepcopy(origin) + assert 'id' not in origin + assert 'visits_dates' not in origin + key = self._origin_key(origin) + origin['visits_dates'] = defaultdict(set) + if key not in self._origins: + self._origins[key] = origin + return key + + def origin_visit_add(self, origin, date=None, *, ts=None): + """Add an origin_visit for the origin at date with status 'ongoing'. + + Args: + origin: Visited Origin id + date: timestamp of such visit + + Returns: + dict: dictionary with keys origin and visit where: + + - origin: origin identifier + - visit: the visit identifier for the new visit occurrence + + """ + if ts is None: + if date is None: + raise TypeError('origin_visit_add expected 2 arguments.') + else: + assert date is None + warnings.warn("argument 'ts' of origin_visit_add was renamed " + "to 'date' in v0.0.109.", + DeprecationWarning) + date = ts + + if isinstance(date, str): + date = dateutil.parser.parse(date) + + status = 'ongoing' + + visit = { + 'origin': origin, + 'date': date, + 'status': status, + 'snapshot': None, + 'metadata': None, + } + key = self._origin_visit_key(visit) + visit['visit'] = key + if key not in self._origin_visits: + self._origin_visits[key] = copy.deepcopy(visit) + self._origins[origin]['visits_dates'][status].add(date) + + return { + 'origin': visit['origin'], + 'visit': key, + } + + def origin_visit_update(self, origin, visit_id, status, metadata=None): + """Update an origin_visit's status. + + Args: + origin: Visited Origin id + visit_id: Visit's id + status: Visit's new status + metadata: Data associated to the visit + + Returns: + None + + """ + old_status = self._origin_visits[visit_id]['status'] + self._origins[origin]['visits_dates'][old_status] \ + .remove(visit_id.date) + self._origins[origin]['visits_dates'][status] \ + .add(visit_id.date) + self._origin_visits[visit_id].update({ + 'status': status, + 'metadata': metadata}) + + def origin_visit_get(self, origin, last_visit=None, limit=None): + """Retrieve all the origin's visit's information. + + Args: + origin (int): The occurrence's origin (identifier). + last_visit: Starting point from which listing the next visits + Default to None + limit (int): Number of results to return from the last visit. + Default to None + + Yields: + List of visits. + + """ + visits_dates = sorted(itertools.chain.from_iterable( + self._origins[origin]['visits_dates'].values())) + if last_visit is not None: + from_index = bisect.bisect_right(visits_dates, last_visit.date) + visits_dates = visits_dates[from_index:] + if limit is not None: + visits_dates = visits_dates[0:limit] + keys = (self._origin_visit_key({'origin': origin, 'date': date}) + for date in visits_dates) + yield from map(self._origin_visits.__getitem__, keys) + + def origin_visit_get_by(self, origin, visit): + """Retrieve origin visit's information. + + Args: + origin: The occurrence's origin (identifier). + + Returns: + The information on that particular (origin, visit) or None if + it does not exist + + """ + return self._origin_visits.get(visit, None) + + def stat_counters(self): + """compute statistics about the number of tuples in various tables + + Returns: + dict: a dictionary mapping textual labels (e.g., content) to + integer values (e.g., the number of tuples in table content) + + """ + keys = ( + 'content', + 'directory', + 'directory_entry_dir', + 'directory_entry_file', + 'directory_entry_rev', + 'origin', + 'origin_visit', + 'person', + 'release', + 'revision', + 'revision_history', + 'skipped_content', + 'snapshot' + ) + stats = {key: 0 for key in keys} + stats.update(collections.Counter( + obj_type for (obj_type, obj_id) in self._objects.values())) + return stats + + def refresh_stat_counters(self): + """Recomputes the statistics for `stat_counters`.""" + pass + + def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, + db=None, cur=None): + """ Add an origin_metadata for the origin at ts with provenance and + metadata. + + Args: + origin_id: the origin's id for which the metadata is added + ts (datetime): timestamp of the found metadata + provider: id of the provider of metadata (ex:'hal') + tool: id of the tool used to extract metadata + metadata (jsonb): the metadata retrieved at the time and location + """ + if isinstance(ts, str): + ts = dateutil.parser.parse(ts) + + origin_metadata = { + 'origin_id': origin_id, + 'discovery_date': ts, + 'tool_id': tool, + 'metadata': metadata, + 'provider_id': provider, + } + key = self._origin_metadata_key(origin_metadata) + self._origin_metadata[key].append(origin_metadata) + return None + + def origin_metadata_get_by(self, origin_id, provider_type=None, db=None, + cur=None): + """Retrieve list of all origin_metadata entries for the origin_id + + Args: + origin_id (int): the unique origin identifier + provider_type (str): (optional) type of provider + + Returns: + list of dicts: the origin_metadata dictionary with the keys: + + - origin_id (int): origin's id + - discovery_date (datetime): timestamp of discovery + - tool_id (int): metadata's extracting tool + - metadata (jsonb) + - provider_id (int): metadata's provider + - provider_name (str) + - provider_type (str) + - provider_url (str) + + """ + metadata = [] + key = self._origin_metadata_key({'origin_id': origin_id}) + for item in self._origin_metadata[key]: + item = copy.deepcopy(item) + provider = self.metadata_provider_get(item['provider_id']) + for attr in ('name', 'type', 'url'): + item['provider_' + attr] = provider[attr] + metadata.append(item) + return metadata + + def tool_add(self, tools): + """Add new tools to the storage. + + Args: + tools (iterable of :class:`dict`): Tool information to add to + storage. Each tool is a :class:`dict` with the following keys: + + - name (:class:`str`): name of the tool + - version (:class:`str`): version of the tool + - configuration (:class:`dict`): configuration of the tool, + must be json-encodable + + Yields: + :class:`dict`: All the tools inserted in storage + (including the internal ``id``). The order of the list is not + guaranteed to match the order of the initial list. + + """ + inserted = [] + for tool in tools: + key = self._tool_key(tool) + assert 'id' not in tool + record = copy.deepcopy(tool) + record['id'] = key # TODO: remove this + if key not in self._tools: + self._tools[key] = record + inserted.append(copy.deepcopy(self._tools[key])) + + yield from inserted + + def tool_get(self, tool): + """Retrieve tool information. + + Args: + tool (dict): Tool information we want to retrieve from storage. + The dicts have the same keys as those used in :func:`tool_add`. + + Returns: + dict: The full tool information if it exists (``id`` included), + None otherwise. + + """ + return self._tools.get(self._tool_key(tool), None) + + def metadata_provider_add(self, provider_name, provider_type, provider_url, + metadata): + """Add a metadata provider. + + Args: + provider_name (str): Its name + provider_type (str): Its type + provider_url (str): Its URL + metadata: JSON-encodable object + + Returns: + dict: same as args, plus an 'id' key. + """ + provider = { + 'name': provider_name, + 'type': provider_type, + 'url': provider_url, + 'metadata': metadata, + } + key = self._metadata_provider_key(provider) + provider['id'] = key + self._metadata_providers[key] = provider + return provider.copy() + + def metadata_provider_get(self, provider_id, db=None, cur=None): + """Get a metadata provider + + Args: + provider_id: Its identifier, as given by `metadata_provider_add`. + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ + return self._metadata_providers.get(provider_id, None) + + def metadata_provider_get_by(self, provider, db=None, cur=None): + """Get a metadata provider + + Args: + provider_name: Its name + provider_url: Its URL + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ + key = self._metadata_provider_key({ + 'name': provider['provider_name'], + 'url': provider['provider_url']}) + return self._metadata_providers.get(key, None) + + @staticmethod + def _content_key(content): + """A stable key for a content""" + return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) + + @staticmethod + def _origin_key(origin): + return (origin['type'], origin['url']) + + @staticmethod + def _origin_visit_key(visit): + assert not isinstance(visit['origin'], dict), \ + "visit['origin'] must be an origin key." + return OriginVisitKey(visit['origin'], visit['date']) + + @staticmethod + def _origin_metadata_key(om): + return (om['origin_id'],) + + @staticmethod + def _tool_key(tool): + return (tool['name'], tool['version'], + tuple(sorted(tool['configuration'].items()))) + + @staticmethod + def _metadata_provider_key(provider): + return (provider['name'], provider['url']) diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_in_memory.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import unittest + +import pytest + +from swh.storage.in_memory import Storage + +from swh.storage.tests.test_storage import CommonTestStorage + + +class TestInMemoryStorage(CommonTestStorage, unittest.TestCase): + """Test the in-memory storage API + + This class doesn't define any tests as we want identical + functionality between local and remote storage. All the tests are + therefore defined in CommonTestStorage. + """ + def setUp(self): + super().setUp() + self.storage = Storage() + + @pytest.mark.skip('postgresql-specific test') + def test_content_add(self): + pass + + @pytest.mark.skip('postgresql-specific test') + def test_skipped_content_add(self): + pass diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -783,8 +783,10 @@ # hack: ids generated for actual_result in actual_results: - del actual_result['author']['id'] - del actual_result['committer']['id'] + if 'id' in actual_result['author']: + del actual_result['author']['id'] + if 'id' in actual_result['committer']: + del actual_result['committer']['id'] self.assertEqual(len(actual_results), 2) # rev4 -child-> rev3 self.assertEqual(actual_results[0], @@ -802,8 +804,10 @@ # hack: ids generated for actual_result in actual_results: - del actual_result['author']['id'] - del actual_result['committer']['id'] + if 'id' in actual_result['author']: + del actual_result['author']['id'] + if 'id' in actual_result['committer']: + del actual_result['committer']['id'] self.assertEqual(len(actual_results), 1) self.assertEqual(actual_results[0], self.revision4) @@ -847,8 +851,10 @@ [self.revision['id'], self.revision2['id']])) # when - del actual_revisions[0]['author']['id'] # hack: ids are generated - del actual_revisions[0]['committer']['id'] + if 'id' in actual_revisions[0]['author']: + del actual_revisions[0]['author']['id'] # hack: ids are generated + if 'id' in actual_revisions[0]['committer']: + del actual_revisions[0]['committer']['id'] self.assertEqual(len(actual_revisions), 2) self.assertEqual(actual_revisions[0],