diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/in_memory.py @@ -0,0 +1,681 @@ +# Copyright (C) 2015-2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import bisect +import dateutil +from collections import defaultdict, namedtuple +import copy +import datetime +import itertools +import random + +from swh.model.hashutil import DEFAULT_ALGORITHMS +from swh.model.identifiers import normalize_timestamp + + +def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + +OriginVisitKey = namedtuple('OriginVisitKey', 'origin ts') + + +class Storage: + def __init__(self): + self._contents = {} + self._contents_data = {} + self._content_indexes = defaultdict(lambda: defaultdict(set)) + + self._directories = {} + self._revisions = {} + self._releases = {} + self._snapshots = {} + self._origins = {} + self._origin_visits = {} + self._tools = {} + + @staticmethod + def _content_key(content): + """A stable key for a content""" + return tuple(content.get(key) for key in sorted(DEFAULT_ALGORITHMS)) + + @staticmethod + def _origin_key(origin): + return (origin['type'], origin['url']) + + @staticmethod + def _origin_visit_key(visit): + assert not isinstance(visit['origin'], dict), \ + "visit['origin'] must be an origin key." + return OriginVisitKey(visit['origin'], visit['ts']) + + @staticmethod + def _tool_key(tool): + conf = tuple(sorted(tool['configuration'].items())) + return (tool['name'], tool['version'], conf) + + def check_config(self, *, check_write): + """Check that the storage is configured and ready to go.""" + return True + + def content_add(self, contents): + """Add content blobs to the storage + + Note: in case of DB errors, objects might have already been added to + the object storage and will not be removed. Since addition to the + object storage is idempotent, that should not be a problem. + + Args: + content (iterable): iterable of dictionaries representing + individual pieces of content to add. Each dictionary has the + following keys: + + - data (bytes): the actual content + - length (int): content length (default: -1) + - one key for each checksum algorithm in + :data:`swh.model.hashutil.DEFAULT_ALGORITHMS`, mapped to the + corresponding checksum + - status (str): one of visible, hidden, absent + - reason (str): if status = absent, the reason why + - origin (int): if status = absent, the origin we saw the + content in + + """ + for content in contents: + key = self._content_key(content) + if key in self._contents: + continue + self._contents[key] = copy.deepcopy(content) + self._contents[key]['ctime'] = now() + self._contents_data[key] = self._contents[key].pop('data') + for algorithm in DEFAULT_ALGORITHMS: + self._content_indexes[algorithm][content[algorithm]].add(key) + + def content_get_metadata(self, sha1s): + """Retrieve content metadata in bulk + + Args: + content: iterable of content identifiers (sha1) + + Returns: + an iterable with content metadata corresponding to the given ids + """ + # FIXME: the return value should be a mapping from search key to found + # content*s* + for sha1 in sha1s: + if sha1 in self._content_indexes['sha1']: + objs = self._content_indexes['sha1'][sha1] + key = random.sample(objs, 1)[0] + data = copy.deepcopy(self._contents[key]) + data.pop('ctime') + yield data + else: + # FIXME: should really be None + yield { + 'sha1': sha1, + 'sha1_git': None, + 'sha256': None, + 'blake2s256': None, + 'length': None, + 'status': None, + } + + def content_find(self, content): + if not set(content).intersection(DEFAULT_ALGORITHMS): + raise ValueError('content keys must contain at least one of: ' + '%s' % ', '.join(sorted(DEFAULT_ALGORITHMS))) + found = [] + for algo in DEFAULT_ALGORITHMS: + hash = content.get(algo) + if hash and hash in self._content_indexes[algo]: + found.append(self._content_indexes[algo][hash]) + if not found: + return + keys = list(set.intersection(*found)) + + # FIXME: should really be a list of all the objects found + return copy.deepcopy(self._contents[keys[0]]) + + def content_missing(self, contents, key_hash='sha1'): + """List content missing from storage + + Args: + content ([dict]): iterable of dictionaries containing one + key for each checksum algorithm in + :data:`swh.model.hashutil.ALGORITHMS`, + mapped to the corresponding checksum, + and a length key mapped to the content + length. + + key_hash (str): name of the column to use as hash id + result (default: 'sha1') + + Returns: + iterable ([bytes]): missing content ids (as per the + key_hash column) + """ + for content in contents: + if self._content_key(content) not in self._contents: + yield content[key_hash] + + def content_missing_per_sha1(self, contents): + """List content missing from storage based only on sha1. + + Args: + contents: Iterable of sha1 to check for absence. + + Returns: + iterable: missing ids + + Raises: + TODO: an exception when we get a hash collision. + + """ + for content in contents: + if content not in self._content_indexes['sha1']: + yield content + + def directory_add(self, directories): + """Add directories to the storage + + Args: + directories (iterable): iterable of dictionaries representing the + individual directories to add. Each dict has the following + keys: + + - id (sha1_git): the id of the directory to add + - entries (list): list of dicts for each entry in the + directory. Each dict has the following keys: + + - name (bytes) + - type (one of 'file', 'dir', 'rev'): type of the + directory entry (file, directory, revision) + - target (sha1_git): id of the object pointed at by the + directory entry + - perms (int): entry permissions + """ + for directory in directories: + if directory['id'] not in self._directories: + self._directories[directory['id']] = copy.deepcopy(directory) + + def directory_missing(self, directory_ids): + """List directories missing from storage + + Args: + directories (iterable): an iterable of directory ids + + Yields: + missing directory ids + + """ + for id in directory_ids: + if id not in self._directories: + yield id + + def _join_dentry_to_content(self, dentry): + keys = [ + 'status', + 'sha1', + 'sha1_git', + 'sha256', + 'length', + ] + # FIXME: should really be None + ret = {key: None for key in keys} + ret.update(dentry) + if ret['type'] == 'file': + content = self.content_find({'sha1_git': ret['target']}) + if content: + for key in keys: + ret[key] = content[key] + return ret + + def directory_ls(self, directory_id): + """Get entries for one directory. + + Args: + - directory: the directory to list entries from. + - recursive: if flag on, this list recursively from this directory. + + Returns: + List of entries for such directory. + + """ + if directory_id in self._directories: + for entry in self._directories[directory_id]['entries']: + ret = self._join_dentry_to_content(entry) + ret['dir_id'] = directory_id + yield ret + + def directory_entry_get_by_path(self, directory, paths): + """Get the directory entry (either file or dir) from directory with path. + + Args: + - directory: sha1 of the top level directory + - paths: path to lookup from the top level directory. From left + (top) to right (bottom). + + Returns: + The corresponding directory entry if found, None otherwise. + + """ + if not paths: + return + + contents = list(self.directory_ls(directory)) + + if not contents: + return + + def _get_entry(entries, name): + for entry in entries: + if entry['name'] == name: + return entry + + first_item = _get_entry(contents, paths[0]) + + if len(paths) == 1: + return first_item + + if not first_item or first_item['type'] != 'dir': + return + + return self.directory_entry_get_by_path( + first_item['target'], paths[1:]) + + def revision_add(self, revisions): + """Add revisions to the storage + + Args: + revisions (iterable): iterable of dictionaries representing the + individual revisions to add. Each dict has the following keys: + + - id (sha1_git): id of the revision to add + - date (datetime.DateTime): date the revision was written + - date_offset (int): offset from UTC in minutes the revision + was written + - date_neg_utc_offset (boolean): whether a null date_offset + represents a negative UTC offset + - committer_date (datetime.DateTime): date the revision got + added to the origin + - committer_date_offset (int): offset from UTC in minutes the + revision was added to the origin + - committer_date_neg_utc_offset (boolean): whether a null + committer_date_offset represents a negative UTC offset + - type (one of 'git', 'tar'): type of the revision added + - directory (sha1_git): the directory the revision points at + - message (bytes): the message associated with the revision + - author_name (bytes): the name of the revision author + - author_email (bytes): the email of the revision author + - committer_name (bytes): the name of the revision committer + - committer_email (bytes): the email of the revision committer + - metadata (jsonb): extra information as dictionary + - synthetic (bool): revision's nature (tarball, directory + creates synthetic revision) + - parents (list of sha1_git): the parents of this revision + + """ + for revision in revisions: + if revision['id'] not in self._revisions: + self._revisions[revision['id']] = rev = copy.deepcopy(revision) + if rev['author']: + rev['author']['id'] = 42 + if rev['committer']: + rev['committer']['id'] = 42 + rev['date'] = normalize_timestamp(rev.get('date')) + rev['committer_date'] = normalize_timestamp( + rev.get('committer_date')) + + def revision_missing(self, revision_ids): + """List revisions missing from storage + + Args: + revisions (iterable): revision ids + + Yields: + missing revision ids + + """ + for id in revision_ids: + if id not in self._revisions: + yield id + + def revision_get(self, revision_ids): + for id in revision_ids: + yield copy.deepcopy(self._revisions.get(id)) + + def revision_log(self, revision_ids, limit=None): + """Fetch revision entry from the given root revisions. + + Args: + revisions: array of root revision to lookup + limit: limitation on the output result. Default to None. + + Yields: + List of revision log from such revisions root. + + """ + ids = [id for id in revision_ids if id in self._revisions] + ids # TODO + + def snapshot_add(self, origin, visit, snapshot): + """Add a snapshot for the given origin/visit couple + + Args: + origin (int): id of the origin + visit (int): id of the visit + snapshot (dict): the snapshot to add to the visit, containing the + following keys: + + - **id** (:class:`bytes`): id of the snapshot + - **branches** (:class:`dict`): branches the snapshot contains, + mapping the branch name (:class:`bytes`) to the branch target, + itself a :class:`dict` (or ``None`` if the branch points to an + unknown object) + + - **target_type** (:class:`str`): one of ``content``, + ``directory``, ``revision``, ``release``, + ``snapshot``, ``alias`` + - **target** (:class:`bytes`): identifier of the target + (currently a ``sha1_git`` for all object kinds, or the name + of the target branch for aliases) + """ + snapshot_id = snapshot['id'] + snapshot = { + 'origin': origin, + 'visit': visit, + 'id': snapshot_id, + 'branches': copy.deepcopy(snapshot['branches']), + '_sorted_branch_names': sorted(snapshot['branches']) + } + if snapshot_id not in self._snapshots: + self._snapshots[snapshot_id] = snapshot + self._origin_visits[visit]['snapshot_id'] = snapshot_id + + def snapshot_get(self, snapshot_id): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + snapshot_id (bytes): identifier of the snapshot + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + return self.snapshot_get_branches(snapshot_id) + + def snapshot_get_by_origin_visit(self, origin, visit): + """Get the content, possibly partial, of a snapshot for the given origin visit + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + visit (int): the visit identifier + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + + """ + if visit not in self._origin_visits: + return None + snapshot_id = self._origin_visits[visit]['snapshot_id'] + if snapshot_id: + return self.snapshot_get(snapshot_id) + else: + return None + + def snapshot_get_latest(self, origin, allowed_statuses=None): + """Get the content, possibly partial, of the latest snapshot for the + given origin, optionally only from visits that have one of the given + allowed_statuses + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + .. warning:: At most 1000 branches contained in the snapshot will be + returned for performance reasons. In order to browse the whole + set of branches, the method :meth:`snapshot_get_branches` + should be used instead. + + Args: + origin (int): the origin identifier + allowed_statuses (list of str): list of visit statuses considered + to find the latest snapshot for the visit. For instance, + ``allowed_statuses=['full']`` will only consider visits that + have successfully run to completion. + Returns: + dict: a dict with three keys: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than 1000 + branches. + """ + if allowed_statuses is None: + visits_ts = list(itertools.chain.from_iterable( + self._origins[origin]['visits_ts'].values())) + else: + last_visits = self._origins[origin]['visits_ts'] + visits_ts = [] + for status in allowed_statuses: + visits_ts.extend(last_visits[status]) + + for visit_ts in sorted(visits_ts, reverse=True): + visit_id = self._origin_visit_key( + {'origin': origin, 'ts': visit_ts}) + snapshot_id = self._origin_visits[visit_id]['snapshot_id'] + snapshot = self.snapshot_get(snapshot_id) + if snapshot: + return snapshot + + return None + + def snapshot_get_branches(self, snapshot_id, branches_from=b'', + branches_count=1000, target_types=None): + """Get the content, possibly partial, of a snapshot with the given id + + The branches of the snapshot are iterated in the lexicographical + order of their names. + + Args: + snapshot_id (bytes): identifier of the snapshot + branches_from (bytes): optional parameter used to skip branches + whose name is lesser than it before returning them + branches_count (int): optional parameter used to restrain + the amount of returned branches + target_types (list): optional parameter used to filter the + target types of branch to return (possible values that can be + contained in that list are `'content', 'directory', + 'revision', 'release', 'snapshot', 'alias'`) + Returns: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: + * **id**: identifier of the snapshot + * **branches**: a dict of branches contained in the snapshot + whose keys are the branches' names. + * **next_branch**: the name of the first branch not returned + or :const:`None` if the snapshot has less than + `branches_count` branches after `branches_from` included. + """ + snapshot = self._snapshots.get(snapshot_id, None) + if snapshot is None: + return None + sorted_branch_names = snapshot['_sorted_branch_names'] + from_index = bisect.bisect_left( + sorted_branch_names, branches_from) + if target_types: + next_branch = None + branches = {} + for branch_name in sorted_branch_names[from_index:]: + branch = snapshot['branches'][branch_name] + if branch and branch['target_type'] in target_types: + if len(branches) < branches_count: + branches[branch_name] = branch + else: + next_branch = branch_name + break + else: + # As there is no 'target_types', we can do that much faster + to_index = from_index + branches_count + returned_branch_names = sorted_branch_names[from_index:to_index] + branches = {branch_name: snapshot['branches'][branch_name] + for branch_name in returned_branch_names} + if to_index >= len(sorted_branch_names): + next_branch = None + else: + next_branch = sorted_branch_names[to_index] + return { + 'id': snapshot_id, + 'branches': branches, + 'next_branch': next_branch, + } + + def origin_add_one(self, origin): + """Add origin to the storage + + Args: + origin: dictionary representing the individual origin to add. This + dict has the following keys: + + - type (FIXME: enum TBD): the origin type ('git', 'wget', ...) + - url (bytes): the url the origin points to + + Returns: + the id of the added origin, or of the identical one that already + exists. + + """ + assert 'id' not in origin + assert 'visits_ts' not in origin + key = self._origin_key(origin) + origin['visits_ts'] = defaultdict(set) + if key not in self._origins: + self._origins[key] = copy.deepcopy(origin) + return key + + def origin_visit_add(self, origin, ts): + """Add an origin_visit for the origin at ts with status 'ongoing'. + + Args: + origin: Visited Origin id + ts: timestamp of such visit + + Returns: + dict: dictionary with keys origin and visit where: + + - origin: origin identifier + - visit: the visit identifier for the new visit occurrence + + """ + if isinstance(ts, str): + ts = dateutil.parser.parse(ts) + + status = 'ongoing' + + visit = { + 'origin': origin, + 'ts': ts, + 'status': status, + 'snapshot_id': None, + } + key = self._origin_visit_key(visit) + if key not in self._origin_visits: + self._origin_visits[key] = copy.deepcopy(visit) + self._origins[origin]['visits_ts'][status].add(ts) + + return { + 'origin': visit['origin'], + 'visit': key, + } + + def origin_visit_update(self, origin, visit_id, status, metadata=None): + """Update an origin_visit's status. + + Args: + origin: Visited Origin id + visit_id: Visit's id + status: Visit's new status + metadata: Data associated to the visit + + Returns: + None + + """ + old_status = self._origin_visits[visit_id]['status'] + self._origins[origin]['visits_ts'][old_status].remove(visit_id.ts) + self._origins[origin]['visits_ts'][status].add(visit_id.ts) + self._origin_visits[visit_id].update({ + 'status': status, + 'metadata': metadata}) + + def tool_add(self, tools): + """Add new tools to the storage. + + Args: + tools (iterable of :class:`dict`): Tool information to add to + storage. Each tool is a :class:`dict` with the following keys: + + - name (:class:`str`): name of the tool + - version (:class:`str`): version of the tool + - configuration (:class:`dict`): configuration of the tool, + must be json-encodable + + Returns: + `iterable` of :class:`dict`: All the tools inserted in storage + (including the internal ``id``). The order of the list is not + guaranteed to match the order of the initial list. + + """ + inserted = [] + for tool in tools: + key = self._tool_key(tool) + assert 'id' not in tool + record = copy.deepcopy(tool) + record['id'] = key # TODO: remove this + if key not in self._tools: + self._tools[key] = record + inserted.append(copy.deepcopy(self._tools[key])) + + return inserted + + def tool_get(self, tool): + """Retrieve tool information. + + Args: + tool (dict): Tool information we want to retrieve from storage. + The dicts have the same keys as those used in :func:`tool_add`. + + Returns: + dict: The full tool information if it exists (``id`` included), + None otherwise. + + """ + return self._tools.get(self._tool_key(tool), None) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -767,7 +767,8 @@ origin (int): the origin identifier visit (int): the visit identifier Returns: - dict: a dict with three keys: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: * **id**: identifier of the snapshot * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. @@ -853,7 +854,8 @@ contained in that list are `'content', 'directory', 'revision', 'release', 'snapshot', 'alias'`) Returns: - dict: a dict with three keys: + dict: None if the snapshot does not exist; + a dict with three keys otherwise: * **id**: identifier of the snapshot * **branches**: a dict of branches contained in the snapshot whose keys are the branches' names. @@ -910,7 +912,6 @@ - origin: origin identifier - visit: the visit identifier for the new visit occurrence - - ts (datetime.DateTime): the visit date """ if isinstance(ts, str): @@ -1247,8 +1248,8 @@ - configuration (:class:`dict`): configuration of the tool, must be json-encodable - Returns: - `iterable` of :class:`dict`: All the tools inserted in storage + Yields: + :class:`dict`: All the tools inserted in storage (including the internal ``id``). The order of the list is not guaranteed to match the order of the initial list. diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -10,11 +10,12 @@ from swh.core.tests.server_testing import ServerTestFixture from swh.storage.api.client import RemoteStorage from swh.storage.api.server import app -from swh.storage.tests.test_storage import CommonTestStorage +from swh.storage.tests.test_storage import CommonTestStorage, \ + StorageTestDbFixture class TestRemoteStorage(CommonTestStorage, ServerTestFixture, - unittest.TestCase): + StorageTestDbFixture, unittest.TestCase): """Test the remote storage API. This class doesn't define any tests as we want identical diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_in_memory.py @@ -0,0 +1,24 @@ +# Copyright (C) 2018 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information +import unittest + +import pytest + +from swh.storage.in_memory import Storage + +from swh.storage.tests.test_storage import CommonTestStorage + + +@pytest.mark.xfail +class TestInMemoryStorage(CommonTestStorage, unittest.TestCase): + """Test the in-memory storage API + + This class doesn't define any tests as we want identical + functionality between local and remote storage. All the tests are + therefore defined in CommonTestStorage. + """ + def setUp(self): + super().setUp() + self.storage = Storage() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -7,7 +7,6 @@ import datetime import unittest from collections import defaultdict -from operator import itemgetter from unittest.mock import Mock, patch import psycopg2 @@ -19,7 +18,7 @@ @pytest.mark.db -class BaseTestStorage(StorageTestFixture): +class StorageTestDbFixture(StorageTestFixture): def setUp(self): super().setUp() @@ -29,6 +28,15 @@ self.maxDiff = None + def tearDown(self): + self.reset_storage_tables() + super().tearDown() + + +class TestStorageData: + def setUp(self): + super().setUp() + self.cont = { 'data': b'42\n', 'length': 3, @@ -509,12 +517,8 @@ 'next_branch': None } - def tearDown(self): - self.reset_storage_tables() - super().tearDown() - -class CommonTestStorage(BaseTestStorage): +class CommonTestStorage(TestStorageData): """Base class for Storage testing. This class is used as-is to test local storage (see TestLocalStorage @@ -526,7 +530,6 @@ class twice. """ - @staticmethod def normalize_entity(entity): entity = copy.deepcopy(entity) @@ -677,7 +680,7 @@ stored_data = list(self.storage.directory_ls(self.dir['id'])) data_to_store = [] - for ent in sorted(self.dir['entries'], key=itemgetter('name')): + for ent in self.dir['entries']: data_to_store.append({ 'dir_id': self.dir['id'], 'type': ent['type'], @@ -691,7 +694,7 @@ 'length': None, }) - self.assertEqual(data_to_store, stored_data) + self.assertCountEqual(data_to_store, stored_data) after_missing = list(self.storage.directory_missing([self.dir['id']])) self.assertEqual([], after_missing) @@ -1842,7 +1845,8 @@ self.assertIsNotNone(o_m1) -class TestLocalStorage(CommonTestStorage, unittest.TestCase): +class TestLocalStorage(CommonTestStorage, StorageTestDbFixture, + unittest.TestCase): """Test the local storage""" # Can only be tested with local storage as you can't mock @@ -1921,7 +1925,8 @@ self.assertEqual(missing, [self.cont['sha1']]) -class AlteringSchemaTest(BaseTestStorage, unittest.TestCase): +class AlteringSchemaTest(TestStorageData, StorageTestDbFixture, + unittest.TestCase): """This class is dedicated for the rare case where the schema needs to be altered dynamically.