diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -25,15 +25,121 @@ OriginVisitKey = collections.namedtuple('OriginVisitKey', 'origin date') +class DuplicateKey(Exception): + pass + + +_TABLE_GET_NO_DEFAULT = object() +"""Unique object, to indicate no default is supplied to get().""" + + +TYPES_WHITELIST = (int, float, str, bytes, tuple, frozenset, + datetime.datetime, None.__class__) + + +class _frozen_dict(tuple): + """Class to mark objects that should be unfrozen as dicts.""" + pass + + +def freeze(value): + if isinstance(value, list): + return tuple(map(freeze, value)) + elif isinstance(value, dict): + return _frozen_dict(map(freeze, value.items())) + elif isinstance(value, TYPES_WHITELIST): + return value + else: + raise TypeError('Type of {} is not allowed in tables.' + .format(value.__class__)) + + +def unfreeze(value): + if isinstance(value, _frozen_dict): + return dict(value) + elif isinstance(value, tuple): + return list(value) + elif isinstance(value, TYPES_WHITELIST): + return value + else: + assert False, '{!r} is not a valid frozen object.'.format(value) + + +def freeze_args(arguments): + if arguments is None: + return None + res = {} + for (key, value) in arguments.items(): + try: + res[key] = freeze(value) + except TypeError: + raise TypeError('Type of {} is not allowed in tables.' + .format(key)) + return res + + +def unfreeze_record(rec): + if rec is None: + return None + d = {} + for (key, value) in rec._asdict().items(): + d[key] = unfreeze(value) + return d + + +class Table: + def __init__(self, name, keys, columns): + if not set(keys).isdisjoint(set(columns)): + raise ValueError('%r is not disjoint from %r' % (keys, columns)) + if 'default' in keys: + raise ValueError('"default" is not a valid key name.') + self._keys = keys + self._columns = columns + self.Key = collections.namedtuple(name + 'Key', keys) + self.Record = collections.namedtuple(name + 'Record', keys + columns) + self._data = {} + + def _make_key_from_record(self, record): + return self.Key(*(getattr(record, col) for col in self._keys)) + + def add(self, **kwargs): + record = self.Record(**freeze_args(kwargs)) + key = self._make_key_from_record(record) + if key in self._data: + raise DuplicateKey(key) + self._data[key] = record + return key + + def contains(self, **kwargs): + freeze(kwargs) + return self.Key(**kwargs) in self._data + + def get(self, *, default=_TABLE_GET_NO_DEFAULT, **kwargs): + freeze(kwargs) + key = self.Key(**kwargs) + if default is _TABLE_GET_NO_DEFAULT: + rec = self._data[key] + else: + rec = self._data.get(key, freeze_args(default)) + return unfreeze_record(rec) + + class Storage: def __init__(self): self._contents = {} self._contents_data = {} self._content_indexes = defaultdict(lambda: defaultdict(set)) - self._directories = {} - self._revisions = {} - self._releases = {} + self._directories = Table('Directories', ['id'], ['entries']) + self._revisions = Table( + 'Revisions', ['id'], + ['message', 'author', 'date', 'committer', 'committer_date', + 'parents', 'type', 'directory', 'metadata', 'synthetic']) + self._releases = Table( + 'Releases', ['id'], + ['name', 'author', 'date', 'target', + 'target_type', 'message', 'synthetic']) + self._snapshots = {} self._origins = {} self._origin_visits = {} @@ -189,8 +295,8 @@ - perms (int): entry permissions """ for directory in directories: - if directory['id'] not in self._directories: - self._directories[directory['id']] = copy.deepcopy(directory) + if not self._directories.contains(id=directory['id']): + self._directories.add(**directory) self._objects[directory['id']].append( ('directory', directory['id'])) @@ -205,7 +311,7 @@ """ for id in directory_ids: - if id not in self._directories: + if not self._directories.contains(id=id): yield id def _join_dentry_to_content(self, dentry): @@ -236,8 +342,8 @@ List of entries for such directory. """ - if directory_id in self._directories: - for entry in self._directories[directory_id]['entries']: + if self._directories.contains(id=directory_id): + for entry in self._directories.get(id=directory_id)['entries']: ret = self._join_dentry_to_content(entry) ret['dir_id'] = directory_id yield ret @@ -311,11 +417,12 @@ """ for revision in revisions: - if revision['id'] not in self._revisions: - self._revisions[revision['id']] = rev = copy.deepcopy(revision) + if not self._revisions.contains(id=revision['id']): + rev = revision.copy() rev['date'] = normalize_timestamp(rev.get('date')) rev['committer_date'] = normalize_timestamp( rev.get('committer_date')) + self._revisions.add(**rev) self._objects[revision['id']].append( ('revision', revision['id'])) @@ -330,12 +437,12 @@ """ for id in revision_ids: - if id not in self._revisions: + if not self._revisions.contains(id=id): yield id def revision_get(self, revision_ids): for id in revision_ids: - yield copy.deepcopy(self._revisions.get(id)) + yield self._revisions.get(id=id, default=None) def _get_parent_revs(self, rev_id, seen, limit): if limit and len(seen) >= limit: @@ -343,8 +450,9 @@ if rev_id in seen: return seen.add(rev_id) - yield self._revisions[rev_id] - for parent in self._revisions[rev_id]['parents']: + rev = self._revisions.get(id=rev_id) + yield rev + for parent in rev['parents']: yield from self._get_parent_revs(parent, seen, limit) def revision_log(self, revision_ids, limit=None): @@ -400,7 +508,7 @@ rel['date'] = normalize_timestamp(rel['date']) self._objects[rel['id']].append( ('release', rel['id'])) - self._releases.update((rel['id'], rel) for rel in releases) + self._releases.add(**rel) def release_missing(self, releases): """List releases missing from storage @@ -412,7 +520,9 @@ a list of missing release ids """ - yield from (rel for rel in releases if rel not in self._releases) + for rel_id in releases: + if not self._releases.contains(id=rel_id): + yield rel_id def release_get(self, releases): """Given a list of sha1, return the releases's information @@ -427,7 +537,8 @@ ValueError: if the keys does not match (url and type) nor id. """ - yield from map(self._releases.__getitem__, releases) + for rel_id in releases: + yield self._releases.get(id=rel_id, default=None) def snapshot_add(self, origin, visit, snapshot): """Add a snapshot for the given origin/visit couple