diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -938,7 +938,7 @@ return cur.fetchone()[0] - origin_metadata_get_cols = ['id', 'origin_id', 'discovery_date', + origin_metadata_get_cols = ['origin_id', 'discovery_date', 'tool_id', 'metadata', 'provider_id', 'provider_name', 'provider_type', 'provider_url'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/in_memory.py @@ -0,0 +1,1102 @@ +# Copyright (C) 2015-2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import re +import bisect +import dateutil +import collections +from collections import defaultdict +import copy +import datetime +import itertools +import random +import warnings + +from swh.model.hashutil import DEFAULT_ALGORITHMS +from swh.model.identifiers import normalize_timestamp + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +class HashCollision(Exception): + pass + + +OriginVisitKey = collections.namedtuple('OriginVisitKey', 'origin date') + + +class Storage: + def __init__(self): + self._contents = {} + self._contents_data = {} + self._content_indexes = defaultdict(lambda: defaultdict(set)) + + self._directories = {} + self._revisions = {} + self._releases = {} + self._snapshots = {} + self._origins = {} + self._origin_visits = {} + self._origin_metadata = defaultdict(list) + self._tools = {} + self._metadata_providers = {} + self._objects = defaultdict(list) + + @staticmethod + def _content_key(content): + """A stable key for a content""" + return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) + + @staticmethod + def _origin_key(origin): + return (origin['type'], origin['url']) + + @staticmethod + def _origin_visit_key(visit): + assert not isinstance(visit['origin'], dict), \ + "visit['origin'] must be an origin key." + return OriginVisitKey(visit['origin'], visit['date']) + + @staticmethod + def _origin_metadata_key(om): + return (om['origin_id'],) + + @staticmethod + def _tool_key(tool): + conf = tuple(sorted(tool['configuration'].items())) + return (tool['name'], tool['version'], conf) + + @staticmethod + def _metadata_provider_key(provider): + return (provider['name'], provider['url']) + + def check_config(self, *, check_write): + """Check that the storage is configured and ready to go.""" + return True + + def content_add(self, contents): + """Add content blobs to the storage + + Note: in case of DB errors, objects might have already been added to + the object storage and will not be removed. Since addition to the + object storage is idempotent, that should not be a problem. + + Args: + content (iterable): iterable of dictionaries representing + individual pieces of content to add. Each dictionary has the + following keys: + + - data (bytes): the actual content + - length (int): content length (default: -1) + - one key for each checksum algorithm in + :data:`swh.model.hashutil.DEFAULT_ALGORITHMS`, mapped to the + corresponding checksum + - status (str): one of visible, hidden, absent + - reason (str): if status = absent, the reason why + - origin (int): if status = absent, the origin we saw the + content in + + """ + for content in contents: + key = self._content_key(content) + if key in self._contents: + continue + for algorithm in DEFAULT_ALGORITHMS: + if content[algorithm] in self._content_indexes[algorithm]: + raise HashCollision(content[algorithm], key) + for algorithm in DEFAULT_ALGORITHMS: + self._content_indexes[algorithm][content[algorithm]].add(key) + self._objects[content['sha1_git']].append( + ('content', content['sha1'])) + self._contents[key] = copy.deepcopy(content) + self._contents[key]['ctime'] = now() + if self._contents[key]['status'] == 'visible': + self._contents_data[key] = self._contents[key].pop('data') + + def content_get_metadata(self, sha1s): + """Retrieve content metadata in bulk + + Args: + content: iterable of content identifiers (sha1) + + Returns: + an iterable with content metadata corresponding to the given ids + """ + # FIXME: the return value should be a mapping from search key to found + # content*s* + for sha1 in sha1s: + if sha1 in self._content_indexes['sha1']: + objs = self._content_indexes['sha1'][sha1] + key = random.sample(objs, 1)[0] + data = copy.deepcopy(self._contents[key]) + data.pop('ctime') + yield data + else: + # FIXME: should really be None + yield { + 'sha1': sha1, + 'sha1_git': None, + 'sha256': None, + 'blake2s256': None, + 'length': None, + 'status': None, + } + + def content_find(self, content): + if not set(content).intersection(DEFAULT_ALGORITHMS): + raise ValueError('content keys must contain at least one of: ' + '%s' % ', '.join(sorted(DEFAULT_ALGORITHMS))) + found = [] + for algo in DEFAULT_ALGORITHMS: + hash = content.get(algo) + if hash and hash in self._content_indexes[algo]: + found.append(self._content_indexes[algo][hash]) + if not found: + return + keys = list(set.intersection(*found)) + + # FIXME: should really be a list of all the objects found + return copy.deepcopy(self._contents[keys[0]]) + + def content_missing(self, contents, key_hash='sha1'): + """List content missing from storage + + Args: + content ([dict]): iterable of dictionaries containing one + key for each checksum algorithm in + :data:`swh.model.hashutil.ALGORITHMS`, + mapped to the corresponding checksum, + and a length key mapped to the content + length. + + key_hash (str): name of the column to use as hash id + result (default: 'sha1') + + Returns: + iterable ([bytes]): missing content ids (as per the + key_hash column) + """ + for content in contents: + if self._content_key(content) not in self._contents: + yield content[key_hash] + + def content_missing_per_sha1(self, contents): + """List content missing from storage based only on sha1. + + Args: + contents: Iterable of sha1 to check for absence. + + Returns: + iterable: missing ids + + Raises: + TODO: an exception when we get a hash collision. + + """ + for content in contents: + if content not in self._content_indexes['sha1']: + yield content + + def directory_add(self, directories): + """Add directories to the storage + + Args: + directories (iterable): iterable of dictionaries representing the + individual directories to add. Each dict has the following + keys: + + - id (sha1_git): the id of the directory to add + - entries (list): list of dicts for each entry in the + directory. Each dict has the following keys: + + - name (bytes) + - type (one of 'file', 'dir', 'rev'): type of the + directory entry (file, directory, revision) + - target (sha1_git): id of the object pointed at by the + directory entry + - perms (int): entry permissions + """ + for directory in directories: + if directory['id'] not in self._directories: + self._directories[directory['id']] = copy.deepcopy(directory) + self._objects[directory['id']].append( + ('directory', directory['id'])) + + def directory_missing(self, directory_ids): + """List directories missing from storage + + Args: + directories (iterable): an iterable of directory ids + + Yields: + missing directory ids + + """ + for id in directory_ids: + if id not in self._directories: + yield id + + def _join_dentry_to_content(self, dentry): + keys = [ + 'status', + 'sha1', + 'sha1_git', + 'sha256', + 'length', + ] + # FIXME: should really be None + ret = {key: None for key in keys} + ret.update(dentry) + if ret['type'] == 'file': + content = self.content_find({'sha1_git': ret['target']}) + if content: + for key in keys: + ret[key] = content[key] + return ret + + def directory_ls(self, directory_id): + """Get entries for one directory. + + Args: + - directory: the directory to list entries from. + - recursive: if flag on, this list recursively from this directory. + + Returns: + List of entries for such directory. + + """ + if directory_id in self._directories: + for entry in self._directories[directory_id]['entries']: + ret = self._join_dentry_to_content(entry) + ret['dir_id'] = directory_id + yield ret + + def directory_entry_get_by_path(self, directory, paths): + """Get the directory entry (either file or dir) from directory with path. + + Args: + - directory: sha1 of the top level directory + - paths: path to lookup from the top level directory. From left + (top) to right (bottom). + + Returns: + The corresponding directory entry if found, None otherwise. + + """ + if not paths: + return + + contents = list(self.directory_ls(directory)) + + if not contents: + return + + def _get_entry(entries, name): + for entry in entries: + if entry['name'] == name: + return entry + + first_item = _get_entry(contents, paths[0]) + + if len(paths) == 1: + return first_item + + if not first_item or first_item['type'] != 'dir': + return + + return self.directory_entry_get_by_path( + first_item['target'], paths[1:]) + + def revision_add(self, revisions): + """Add revisions to the storage + + Args: + revisions (iterable): iterable of dictionaries representing the + individual revisions to add. Each dict has the following keys: + + - id (sha1_git): id of the revision to add + - date (datetime.DateTime): date the revision was written + - date_offset (int): offset from UTC in minutes the revision + was written + - date_neg_utc_offset (boolean): whether a null date_offset + represents a negative UTC offset + - committer_date (datetime.DateTime): date the revision got + added to the origin + - committer_date_offset (int): offset from UTC in minutes the + revision was added to the origin + - committer_date_neg_utc_offset (boolean): whether a null + committer_date_offset represents a negative UTC offset + - type (one of 'git', 'tar'): type of the revision added + - directory (sha1_git): the directory the revision points at + - message (bytes): the message associated with the revision + - author_name (bytes): the name of the revision author + - author_email (bytes): the email of the revision author + - committer_name (bytes): the name of the revision committer + - committer_email (bytes): the email of the revision committer + - metadata (jsonb): extra information as dictionary + - synthetic (bool): revision's nature (tarball, directory + creates synthetic revision) + - parents (list of sha1_git): the parents of this revision + + """ + for revision in revisions: + if revision['id'] not in self._revisions: + self._revisions[revision['id']] = rev = copy.deepcopy(revision) + if rev['author']: + rev['author']['id'] = 42 + if rev['committer']: + rev['committer']['id'] = 42 + rev['date'] = normalize_timestamp(rev.get('date')) + rev['committer_date'] = normalize_timestamp( + rev.get('committer_date')) + self._objects[revision['id']].append( + ('revision', revision['id'])) + + def revision_missing(self, revision_ids): + """List revisions missing from storage + + Args: + revisions (iterable): revision ids + + Yields: + missing revision ids + + """ + for id in revision_ids: + if id not in self._revisions: + yield id + + def revision_get(self, revision_ids): + for id in revision_ids: + yield copy.deepcopy(self._revisions.get(id)) + + def _get_parent_revs(self, rev_id, seen, limit): + if limit and len(seen) >= limit: + return + if rev_id in seen: + return + seen.add(rev_id) + yield self._revisions[rev_id] + for parent in self._revisions[rev_id]['parents']: + yield from self._get_parent_revs(parent, seen, limit) + + def revision_log(self, revision_ids, limit=None): + """Fetch revision entry from the given root revisions. + + Args: + revisions: array of root revision to lookup + limit: limitation on the output result. Default to None. + + Yields: + List of revision log from such revisions root. + + """ + seen = set() + for rev_id in revision_ids: + yield from self._get_parent_revs(rev_id, seen, limit) + + def revision_shortlog(self, revisions, limit=None): + """Fetch the shortlog for the given revisions + + Args: + revisions: list of root revisions to lookup + limit: depth limitation for the output + + Yields: + a list of (id, parents) tuples. + + """ + yield from ((rev['id'], rev['parents']) + for rev in self.revision_log(revisions, limit)) + + def release_add(self, releases): + """Add releases to the storage + + Args: + releases (iterable): iterable of dictionaries representing the + individual releases to add. Each dict has the following keys: + + - id (sha1_git): id of the release to add + - revision (sha1_git): id of the revision the release points to + - date (datetime.DateTime): the date the release was made + - date_offset (int): offset from UTC in minutes the release was + made + - date_neg_utc_offset (boolean): whether a null date_offset + represents a negative UTC offset + - name (bytes): the name of the release + - comment (bytes): the comment associated with the release + - author_name (bytes): the name of the release author + - author_email (bytes): the email of the release author + + """ + for rel in releases: + rel['date'] = normalize_timestamp(rel['date']) + self._objects[rel['id']].append( + ('release', rel['id'])) + self._releases.update({rel['id']: rel for rel in releases}) + + def release_missing(self, releases): + """List releases missing from storage + + Args: + releases: an iterable of release ids + + Returns: + a list of missing release ids + + """ + yield from (rel for rel in releases if rel not in self._releases) + + def release_get(self, releases): + """Given a list of sha1, return the releases's information + + Args: + releases: list of sha1s + + Yields: + dicts with the same keys as those given to `release_add` + + Raises: + ValueError: if the keys does not match (url and type) nor id. + + """ + yield from map(self._releases.__getitem__, releases) + + def snapshot_add(self, origin, visit, snapshot): + """Add a snapshot for the given origin/visit couple + + Args: + origin (int): id of the origin + visit (int): id of the visit + snapshot (dict): the snapshot to add to the visit, containing the + following keys: + + - **id** (:class:`bytes`): id of the snapshot + - **branches** (:class:`dict`): branches the snapshot contains, + mapping the branch name (:class:`bytes`) to the branch target, + itself a :class:`dict` (or ``None`` if the branch points to an + unknown object) + + - **target_type** (:class:`str`): one of ``content``, + ``directory``, ``revision``, ``release``, + ``snapshot``, ``alias`` + - **target** (:class:`bytes`): identifier of the target + (currently a ``sha1_git`` for all object kinds, or the name + of the target branch for aliases) + """ + snapshot_id = snapshot['id'] + snapshot = { + 'origin': origin, + 'visit': visit, + 'id': snapshot_id, + 'branches': copy.deepcopy(snapshot['branches']), + '_sorted_branch_names': sorted(snapshot['branches']) + } + if snapshot_id not in self._snapshots: + self._snapshots[snapshot_id] = snapshot + self._origin_visits[visit]['snapshot'] = snapshot_id + + def snapshot_get(self, snapshot_id): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + snapshot_id (bytes): identifier of the snapshot + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + return self.snapshot_get_branches(snapshot_id) + + def snapshot_get_by_origin_visit(self, origin, visit): + """Get the content, possibly partial, of a snapshot for the given origin visit + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + visit (int): the visit identifier + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + + """ + if visit not in self._origin_visits: + return None + snapshot_id = self._origin_visits[visit]['snapshot'] + if snapshot_id: + return self.snapshot_get(snapshot_id) + else: + return None + + def snapshot_get_latest(self, origin, allowed_statuses=None): + """Get the content, possibly partial, of the latest snapshot for the + given origin, optionally only from visits that have one of the given + allowed_statuses + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + allowed_statuses (list of str): list of visit statuses considered + to find the latest snapshot for the visit. For instance, + ``allowed_statuses=['full']`` will only consider visits that + have successfully run to completion. + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + if allowed_statuses is None: + visits_dates = list(itertools.chain.from_iterable( + self._origins[origin]['visits_dates'].values())) + else: + last_visits = self._origins[origin]['visits_dates'] + visits_dates = [] + for status in allowed_statuses: + visits_dates.extend(last_visits[status]) + + for visit_date in sorted(visits_dates, reverse=True): + visit_id = self._origin_visit_key( + {'origin': origin, 'date': visit_date}) + snapshot_id = self._origin_visits[visit_id]['snapshot'] + snapshot = self.snapshot_get(snapshot_id) + if snapshot: + return snapshot + + return None + + def snapshot_count_branches(self, snapshot_id, db=None, cur=None): + """Count the number of branches in the snapshot with the given id + + Args: + snapshot_id (bytes): identifier of the snapshot + + Returns: + dict: A dict whose keys are the target types of branches and + values their corresponding amount + """ + branches = list(self._snapshots[snapshot_id]['branches'].values()) + return collections.Counter(branch['target_type'] if branch else None + for branch in branches) + + def snapshot_get_branches(self, snapshot_id, branches_from=b'', + branches_count=1000, target_types=None): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + Args: + snapshot_id (bytes): identifier of the snapshot + branches_from (bytes): optional parameter used to skip branches + whose name is lesser than it before returning them + branches_count (int): optional parameter used to restrain + the amount of returned branches + target_types (list): optional parameter used to filter the + target types of branch to return (possible values that can be + contained in that list are `'content', 'directory', + 'revision', 'release', 'snapshot', 'alias'`) + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than + `branches_count` branches after `branches_from` included. + """ + snapshot = self._snapshots.get(snapshot_id, None) + if snapshot is None: + return None + sorted_branch_names = snapshot['_sorted_branch_names'] + from_index = bisect.bisect_left( + sorted_branch_names, branches_from) + if target_types: + next_branch = None + branches = {} + for branch_name in sorted_branch_names[from_index:]: + branch = snapshot['branches'][branch_name] + if branch and branch['target_type'] in target_types: + if len(branches) < branches_count: + branches[branch_name] = branch + else: + next_branch = branch_name + break + else: + # As there is no 'target_types', we can do that much faster + to_index = from_index + branches_count + returned_branch_names = sorted_branch_names[from_index:to_index] + branches = {branch_name: snapshot['branches'][branch_name] + for branch_name in returned_branch_names} + if to_index >= len(sorted_branch_names): + next_branch = None + else: + next_branch = sorted_branch_names[to_index] + return { + 'id': snapshot_id, + 'branches': branches, + 'next_branch': next_branch, + } + + def object_find_by_sha1_git(self, ids, db=None, cur=None): + """Return the objects found with the given ids. + + Args: + ids: a generator of sha1_gits + + Returns: + dict: a mapping from id to the list of objects found. Each object + found is itself a dict with keys: + + - sha1_git: the input id + - type: the type of object found + - id: the id of the object found + - object_id: the numeric id of the object found. + + """ + ret = {} + for id_ in ids: + objs = self._objects.get(id_, []) + ret[id_] = [{ + 'sha1_git': id_, + 'type': obj[0], + 'id': obj[1], + 'object_id': id_, + } for obj in objs] + return ret + + def origin_get(self, origin): + """Return the origin either identified by its id or its tuple + (type, url). + + Args: + origin: dictionary representing the individual origin to find. + This dict has either the keys type and url: + + - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) + - url (bytes): the url the origin points to + + or the id: + + - id: the origin id + + Returns: + dict: the origin dictionary with the keys: + + - id: origin's id + - type: origin's type + - url: origin's url + + Raises: + ValueError: if the keys does not match (url and type) nor id. + + """ + if 'id' in origin: + key = origin['id'] + elif 'type' in origin and 'url' in origin: + key = self._origin_key(origin) + else: + raise ValueError('Origin must have either id or (type and url).') + if key not in self._origins: + return None + else: + origin = copy.deepcopy(self._origins[key]) + del origin['visits_dates'] + origin['id'] = self._origin_key(origin) + return origin + + def origin_search(self, url_pattern, offset=0, limit=50, + regexp=False, with_visit=False, db=None, cur=None): + """Search for origins whose urls contain a provided string pattern + or match a provided regular expression. + The search is performed in a case insensitive way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + offset (int): number of found origins to skip before returning + results + limit (int): the maximum number of found origins to return + regexp (bool): if True, consider the provided pattern as a regular + expression and return origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + + Returns: + An iterable of dict containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + origins = iter(self._origins.values()) + if regexp: + pat = re.compile(url_pattern) + origins = (orig for orig in origins if pat.match(orig['url'])) + else: + origins = (orig for orig in origins if url_pattern in orig['url']) + if with_visit: + origins = (orig for orig in origins if orig['visits_dates']) + origins = sorted(origins, key=self._origin_key) + origins = copy.deepcopy(origins[offset:offset+limit]) + for orig in origins: + del orig['visits_dates'] + orig['id'] = self._origin_key(orig) + return origins + + def origin_add(self, origins): + """Add origins to the storage + + Args: + origins: list of dictionaries representing the individual origins, + with the following keys: + + - type: the origin type ('git', 'svn', 'deb', ...) + - url (bytes): the url the origin points to + + Returns: + list: given origins as dict updated with their id + + """ + origins = copy.deepcopy(origins) + for origin in origins: + origin['id'] = self.origin_add_one(origin) + return origins + + def origin_add_one(self, origin): + """Add origin to the storage + + Args: + origin: dictionary representing the individual origin to add. This + dict has the following keys: + + - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) + - url (bytes): the url the origin points to + + Returns: + the id of the added origin, or of the identical one that already + exists. + + """ + origin = copy.deepcopy(origin) + assert 'id' not in origin + assert 'visits_dates' not in origin + key = self._origin_key(origin) + origin['visits_dates'] = defaultdict(set) + if key not in self._origins: + self._origins[key] = origin + return key + + def origin_visit_add(self, origin, date=None, *, ts=None): + """Add an origin_visit for the origin at date with status 'ongoing'. + + Args: + origin: Visited Origin id + date: timestamp of such visit + + Returns: + dict: dictionary with keys origin and visit where: + + - origin: origin identifier + - visit: the visit identifier for the new visit occurrence + + """ + if ts is None: + if date is None: + raise TypeError('origin_visit_add expected 2 arguments.') + else: + assert date is None + warnings.warn("argument 'ts' of origin_visit_add was renamed " + "to 'date' in v0.0.109.", + DeprecationWarning) + date = ts + + if isinstance(date, str): + date = dateutil.parser.parse(date) + + status = 'ongoing' + + visit = { + 'origin': origin, + 'date': date, + 'status': status, + 'snapshot': None, + 'metadata': None, + } + key = self._origin_visit_key(visit) + visit['visit'] = key + if key not in self._origin_visits: + self._origin_visits[key] = copy.deepcopy(visit) + self._origins[origin]['visits_dates'][status].add(date) + + return { + 'origin': visit['origin'], + 'visit': key, + } + + def origin_visit_update(self, origin, visit_id, status, metadata=None): + """Update an origin_visit's status. + + Args: + origin: Visited Origin id + visit_id: Visit's id + status: Visit's new status + metadata: Data associated to the visit + + Returns: + None + + """ + old_status = self._origin_visits[visit_id]['status'] + self._origins[origin]['visits_dates'][old_status] \ + .remove(visit_id.date) + self._origins[origin]['visits_dates'][status] \ + .add(visit_id.date) + self._origin_visits[visit_id].update({ + 'status': status, + 'metadata': metadata}) + + def origin_visit_get(self, origin, last_visit=None, limit=None): + """Retrieve all the origin's visit's information. + + Args: + origin (int): The occurrence's origin (identifier). + last_visit: Starting point from which listing the next visits + Default to None + limit (int): Number of results to return from the last visit. + Default to None + + Yields: + List of visits. + + """ + visits_dates = sorted(itertools.chain.from_iterable( + self._origins[origin]['visits_dates'].values())) + if last_visit is not None: + from_index = bisect.bisect_right(visits_dates, last_visit.date) + visits_dates = visits_dates[from_index:] + if limit is not None: + visits_dates = visits_dates[0:limit] + keys = (self._origin_visit_key({'origin': origin, 'date': date}) + for date in visits_dates) + yield from map(self._origin_visits.__getitem__, keys) + + def origin_visit_get_by(self, origin, visit): + """Retrieve origin visit's information. + + Args: + origin: The occurrence's origin (identifier). + + Returns: + The information on that particular (origin, visit) or None if + it does not exist + + """ + return self._origin_visits.get(visit, None) + + def stat_counters(self): + """compute statistics about the number of tuples in various tables + + Returns: + dict: a dictionary mapping textual labels (e.g., content) to + integer values (e.g., the number of tuples in table content) + + """ + keys = ['content', 'directory', 'directory_entry_dir', + 'origin', 'person', 'revision'] + stats = {key: 0 for key in keys} + stats.update(collections.Counter( + obj_type for (obj_type, obj_id) in self._objects.values())) + return stats + + def refresh_stat_counters(self): + """Recomputes the statistics for `stat_counters`.""" + pass + + def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, + db=None, cur=None): + """ Add an origin_metadata for the origin at ts with provenance and + metadata. + + Args: + origin_id: the origin's id for which the metadata is added + ts (datetime): timestamp of the found metadata + provider: id of the provider of metadata (ex:'hal') + tool: id of the tool used to extract metadata + metadata (jsonb): the metadata retrieved at the time and location + """ + if isinstance(ts, str): + ts = dateutil.parser.parse(ts) + + origin_metadata = { + 'origin_id': origin_id, + 'discovery_date': ts, + 'tool_id': tool, + 'metadata': metadata, + 'provider_id': provider, + } + key = self._origin_metadata_key(origin_metadata) + self._origin_metadata[key].append(origin_metadata) + return None + + def origin_metadata_get_by(self, origin_id, provider_type=None, db=None, + cur=None): + """Retrieve list of all origin_metadata entries for the origin_id + + Args: + origin_id (int): the unique origin identifier + provider_type (str): (optional) type of provider + + Returns: + list of dicts: the origin_metadata dictionary with the keys: + + - origin_id (int): origin's id + - discovery_date (datetime): timestamp of discovery + - tool_id (int): metadata's extracting tool + - metadata (jsonb) + - provider_id (int): metadata's provider + - provider_name (str) + - provider_type (str) + - provider_url (str) + + """ + metadata = [] + key = self._origin_metadata_key({'origin_id': origin_id}) + for item in self._origin_metadata[key]: + item = copy.deepcopy(item) + provider = self.metadata_provider_get(item['provider_id']) + for attr in ('name', 'type', 'url'): + item['provider_' + attr] = provider[attr] + metadata.append(item) + return metadata + + def tool_add(self, tools): + """Add new tools to the storage. + + Args: + tools (iterable of :class:`dict`): Tool information to add to + storage. Each tool is a :class:`dict` with the following keys: + + - name (:class:`str`): name of the tool + - version (:class:`str`): version of the tool + - configuration (:class:`dict`): configuration of the tool, + must be json-encodable + + Returns: + `iterable` of :class:`dict`: All the tools inserted in storage + (including the internal ``id``). The order of the list is not + guaranteed to match the order of the initial list. + + """ + inserted = [] + for tool in tools: + key = self._tool_key(tool) + assert 'id' not in tool + record = copy.deepcopy(tool) + record['id'] = key # TODO: remove this + if key not in self._tools: + self._tools[key] = record + inserted.append(copy.deepcopy(self._tools[key])) + + return inserted + + def tool_get(self, tool): + """Retrieve tool information. + + Args: + tool (dict): Tool information we want to retrieve from storage. + The dicts have the same keys as those used in :func:`tool_add`. + + Returns: + dict: The full tool information if it exists (``id`` included), + None otherwise. + + """ + return self._tools.get(self._tool_key(tool), None) + + def metadata_provider_add(self, provider_name, provider_type, provider_url, + metadata): + """Add a metadata provider. + + Args: + provider_name (str): Its name + provider_type (str): Its type + provider_url (str): Its URL + metadata: JSON-encodable object + + Returns: + dict: same as args, plus an 'id' key. + """ + provider = { + 'name': provider_name, + 'type': provider_type, + 'url': provider_url, + 'metadata': metadata, + } + key = self._metadata_provider_key(provider) + provider['id'] = key + self._metadata_providers[key] = provider + return provider.copy() + + def metadata_provider_get(self, provider_id, db=None, cur=None): + """Get a metadata provider + + Args: + provider_id: Its identifier, as given by `metadata_provider_add`. + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ + return self._metadata_providers.get(provider_id, None) + + def metadata_provider_get_by(self, provider, db=None, cur=None): + """Get a metadata provider + + Args: + provider_name: Its name + provider_url: Its URL + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ + key = self._metadata_provider_key({ + 'name': provider['provider_name'], + 'url': provider['provider_url']}) + return self._metadata_providers.get(key, None) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -669,11 +669,7 @@ releases: list of sha1s Yields: - releases: list of releases as dicts with the following keys: - - - id: origin's id - - revision: origin's type - - url: origin's url + dicts with the same keys as those given to `release_add` Raises: ValueError: if the keys does not match (url and type) nor id. @@ -768,7 +764,8 @@ origin (int): the origin identifier visit (int): the visit identifier Returns: - dict: a dict with three keys: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: * **id**: identifier of the snapshot * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. @@ -854,7 +851,8 @@ contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) Returns: - dict: a dict with three keys: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: * **id**: identifier of the snapshot * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. @@ -912,7 +910,6 @@ - origin: origin identifier - visit: the visit identifier for the new visit occurrence - - ts (datetime.DateTime): the visit date """ if ts is None: @@ -957,7 +954,7 @@ Args: origin (int): The occurrence's origin (identifier). - last_visit (int): Starting point from which listing the next visits + last_visit: Starting point from which listing the next visits Default to None limit (int): Number of results to return from the last visit. Default to None @@ -1200,6 +1197,15 @@ return {k: v for (k, v) in db.stat_counters()} @db_transaction() + def refresh_stat_counters(self, db=None, cur=None): + """Recomputes the statistics for `stat_counters`.""" + keys = ['content', 'directory', 'directory_entry_dir', + 'origin', 'person', 'revision'] + + for key in keys: + cur.execute('select * from swh_update_counter(%s)', (key,)) + + @db_transaction() def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, db=None, cur=None): """ Add an origin_metadata for the origin at ts with provenance and @@ -1233,7 +1239,6 @@ Returns: list of dicts: the origin_metadata dictionary with the keys: - - id (int): origin_metadata's id - origin_id (int): origin's id - discovery_date (datetime): timestamp of discovery - tool_id (int): metadata's extracting tool @@ -1260,8 +1265,8 @@ - configuration (:class:`dict`): configuration of the tool, must be json-encodable - Returns: - `iterable` of :class:`dict`: All the tools inserted in storage + Yields: + :class:`dict`: All the tools inserted in storage (including the internal ``id``). The order of the list is not guaranteed to match the order of the initial list. @@ -1302,11 +1307,30 @@ @db_transaction() def metadata_provider_add(self, provider_name, provider_type, provider_url, metadata, db=None, cur=None): + """Add a metadata provider. + + Args: + provider_name (str): Its name + provider_type (str): Its type + provider_url (str): Its URL + + Returns: + dict: same as args, plus an 'id' key. + """ return db.metadata_provider_add(provider_name, provider_type, provider_url, metadata, cur) @db_transaction() def metadata_provider_get(self, provider_id, db=None, cur=None): + """Get a metadata provider + + Args: + provider_id: Its identifier, as given by `metadata_provider_add`. + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ result = db.metadata_provider_get(provider_id) if not result: return None @@ -1314,6 +1338,17 @@ @db_transaction() def metadata_provider_get_by(self, provider, db=None, cur=None): + """Get a metadata provider + + Args: + provider (dict): A dictionary with keys: + * provider_name: Its name + * provider_url: Its URL + + Returns: + dict: same as `metadata_provider_add`; + or None if it does not exist. + """ result = db.metadata_provider_get_by(provider['provider_name'], provider['provider_url']) if not result: diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import pytest import shutil import tempfile import unittest @@ -10,11 +11,12 @@ from swh.core.tests.server_testing import ServerTestFixture from swh.storage.api.client import RemoteStorage from swh.storage.api.server import app -from swh.storage.tests.test_storage import CommonTestStorage +from swh.storage.tests.test_storage import CommonTestStorage, \ + StorageTestDbFixture class TestRemoteStorage(CommonTestStorage, ServerTestFixture, - unittest.TestCase): + StorageTestDbFixture, unittest.TestCase): """Test the remote storage API. This class doesn't define any tests as we want identical @@ -52,3 +54,7 @@ def tearDown(self): super().tearDown() shutil.rmtree(self.storage_base) + + @pytest.mark.skip('refresh_stat_counters not available in the remote api.') + def test_stat_counters(self): + pass diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_in_memory.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import unittest + +import pytest + +from swh.storage.in_memory import Storage + +from swh.storage.tests.test_storage import CommonTestStorage + + +class TestInMemoryStorage(CommonTestStorage, unittest.TestCase): + """Test the in-memory storage API + + This class doesn't define any tests as we want identical + functionality between local and remote storage. All the tests are + therefore defined in CommonTestStorage. + """ + def setUp(self): + super().setUp() + self.storage = Storage() + + @pytest.mark.skip('postgresql-specific test') + def test_content_add(self): + pass + + @pytest.mark.skip('postgresql-specific test') + def test_skipped_content_add(self): + pass diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -7,7 +7,6 @@ import datetime import unittest from collections import defaultdict -from operator import itemgetter from unittest.mock import Mock, patch import psycopg2 @@ -16,10 +15,11 @@ from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage.in_memory import HashCollision @pytest.mark.db -class BaseTestStorage(StorageTestFixture): +class StorageTestDbFixture(StorageTestFixture): def setUp(self): super().setUp() @@ -29,6 +29,15 @@ self.maxDiff = None + def tearDown(self): + self.reset_storage_tables() + super().tearDown() + + +class TestStorageData: + def setUp(self): + super().setUp() + self.cont = { 'data': b'42\n', 'length': 3, @@ -509,12 +518,8 @@ 'next_branch': None } - def tearDown(self): - self.reset_storage_tables() - super().tearDown() - -class CommonTestStorage(BaseTestStorage): +class CommonTestStorage(TestStorageData): """Base class for Storage testing. This class is used as-is to test local storage (see TestLocalStorage @@ -526,7 +531,6 @@ class twice. """ - @staticmethod def normalize_entity(entity): entity = copy.deepcopy(entity) @@ -565,7 +569,7 @@ sha256_array[0] += 1 cont1b['sha256'] = bytes(sha256_array) - with self.assertRaises(psycopg2.IntegrityError): + with self.assertRaises((psycopg2.IntegrityError, HashCollision)): self.storage.content_add([cont1, cont1b]) def test_skipped_content_add(self): @@ -677,7 +681,7 @@ stored_data = list(self.storage.directory_ls(self.dir['id'])) data_to_store = [] - for ent in sorted(self.dir['entries'], key=itemgetter('name')): + for ent in self.dir['entries']: data_to_store.append({ 'dir_id': self.dir['id'], 'type': ent['type'], @@ -691,7 +695,7 @@ 'length': None, }) - self.assertEqual(data_to_store, stored_data) + self.assertCountEqual(data_to_store, stored_data) after_missing = list(self.storage.directory_missing([self.dir['id']])) self.assertEqual([], after_missing) @@ -880,7 +884,8 @@ # then for actual_release in actual_releases: - del actual_release['author']['id'] # hack: ids are generated + if 'id' in actual_release['author']: + del actual_release['author']['id'] # hack: ids are generated self.assertEqual([self.normalize_entity(self.release), self.normalize_entity(self.release2)], @@ -1011,7 +1016,6 @@ # then self.assertEqual(origin_visit1['origin'], origin_id) self.assertIsNotNone(origin_visit1['visit']) - self.assertTrue(origin_visit1['visit'] > 0) actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) self.assertEqual(actual_origin_visits, @@ -1399,9 +1403,7 @@ expected_keys = ['content', 'directory', 'directory_entry_dir', 'origin', 'person', 'revision'] - for key in expected_keys: - self.cursor.execute('select * from swh_update_counter(%s)', (key,)) - self.conn.commit() + self.storage.refresh_stat_counters() counters = self.storage.stat_counters() @@ -1545,6 +1547,11 @@ for obj in val: del obj['object_id'] + import pprint + pprint.pprint(expected) + print('\n\n\n') + pprint.pprint(ret) + self.maxDiff = None self.assertEqual(expected, ret) def test_tool_add(self): @@ -1664,8 +1671,7 @@ origin_metadata0 = list(self.storage.origin_metadata_get_by(origin_id)) self.assertTrue(len(origin_metadata0) == 0) - tools = list(self.storage.tool_add([self.metadata_tool])) - tool = tools[0] + tool = list(self.storage.tool_add([self.metadata_tool]))[0] self.storage.metadata_provider_add( self.provider['name'], @@ -1676,10 +1682,9 @@ 'provider_name': self.provider['name'], 'provider_url': self.provider['url'] }) - tool = self.storage.tool_get(self.metadata_tool) # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider['id'], @@ -1687,7 +1692,6 @@ self.origin_metadata['metadata']) actual_om1 = list(self.storage.origin_metadata_get_by(origin_id)) # then - self.assertEqual(actual_om1[0]['id'], o_m1) self.assertEqual(len(actual_om1), 1) self.assertEqual(actual_om1[0]['origin_id'], origin_id) @@ -1704,21 +1708,21 @@ 'provider_name': self.provider['name'], 'provider_url': self.provider['url'] }) - tool = self.storage.tool_get(self.metadata_tool) + tool = list(self.storage.tool_add([self.metadata_tool]))[0] # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider['id'], tool['id'], self.origin_metadata['metadata']) - o_m2 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id2, self.origin_metadata2['discovery_date'], provider['id'], tool['id'], self.origin_metadata2['metadata']) - o_m3 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata2['discovery_date'], provider['id'], @@ -1730,15 +1734,12 @@ expected_results = [{ 'origin_id': origin_id, 'discovery_date': datetime.datetime( - 2017, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2017, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m3, 'provider_id': provider['id'], 'provider_name': 'hal', 'provider_type': 'deposit-client', @@ -1747,15 +1748,12 @@ }, { 'origin_id': origin_id, 'discovery_date': datetime.datetime( - 2015, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2015, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m1, 'provider_id': provider['id'], 'provider_name': 'hal', 'provider_type': 'deposit-client', @@ -1766,8 +1764,7 @@ # then self.assertEqual(len(all_metadatas), 2) self.assertEqual(len(metadatas_for_origin2), 1) - self.assertEqual(metadatas_for_origin2[0]['id'], o_m2) - self.assertEqual(all_metadatas, expected_results) + self.assertCountEqual(all_metadatas, expected_results) def test_origin_metadata_get_by_provider_type(self): # given @@ -1796,16 +1793,16 @@ # using the only tool now inserted in the data.sql, but for this # provider should be a crawler tool (not yet implemented) - tool = self.storage.tool_get(self.metadata_tool) + tool = list(self.storage.tool_add([self.metadata_tool]))[0] # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider1['id'], tool['id'], self.origin_metadata['metadata']) - o_m2 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id2, self.origin_metadata2['discovery_date'], provider2['id'], @@ -1816,18 +1813,18 @@ origin_metadata_get_by( origin_id2, provider_type)) + for item in m_by_provider: + if 'id' in item: + del item['id'] expected_results = [{ 'origin_id': origin_id2, 'discovery_date': datetime.datetime( - 2017, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2017, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m2, 'provider_id': provider2['id'], 'provider_name': 'swMATH', 'provider_type': provider_type, @@ -1838,11 +1835,10 @@ self.assertEqual(len(m_by_provider), 1) self.assertEqual(m_by_provider, expected_results) - self.assertEqual(m_by_provider[0]['id'], o_m2) - self.assertIsNotNone(o_m1) -class TestLocalStorage(CommonTestStorage, unittest.TestCase): +class TestLocalStorage(CommonTestStorage, StorageTestDbFixture, + unittest.TestCase): """Test the local storage""" # Can only be tested with local storage as you can't mock @@ -1921,7 +1917,8 @@ self.assertEqual(missing, [self.cont['sha1']]) -class AlteringSchemaTest(BaseTestStorage, unittest.TestCase): +class AlteringSchemaTest(TestStorageData, StorageTestDbFixture, + unittest.TestCase): """This class is dedicated for the rare case where the schema needs to be altered dynamically.