diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -57,6 +57,7 @@ ) from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex +from swh.storage.cassandra import CassandraStorage from swh.storage.cassandra.model import BaseRow from swh.storage.interface import ( ListOrder, @@ -219,7 +220,7 @@ ) -class InMemoryStorage: +class InMemoryStorage(CassandraStorage): def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) @@ -505,29 +506,10 @@ if id not in self._directories: yield id - def _join_dentry_to_content(self, dentry: Dict[str, Any]) -> Dict[str, Any]: - keys = ( - "status", - "sha1", - "sha1_git", - "sha256", - "length", - ) - ret = dict.fromkeys(keys) - ret.update(dentry) - if ret["type"] == "file": - # TODO: Make it able to handle more than one content - contents = self.content_find({"sha1_git": ret["target"]}) - if contents: - content = contents[0] - for key in keys: - ret[key] = getattr(content, key) - return ret - def _directory_ls(self, directory_id, recursive, prefix=b""): if directory_id in self._directories: for entry in self._directories[directory_id].entries: - ret = self._join_dentry_to_content(entry.to_dict()) + ret = self._join_dentry_to_content(entry) ret["name"] = prefix + ret["name"] ret["dir_id"] = directory_id yield ret @@ -610,7 +592,7 @@ else: yield None - def _get_parent_revs( + def __get_parent_revs( self, rev_id: Sha1Git, seen: Set[Sha1Git], limit: Optional[int] ) -> Iterable[Dict[str, Any]]: if limit and len(seen) >= limit: @@ -620,14 +602,14 @@ seen.add(rev_id) yield self._revisions[rev_id].to_dict() for parent in self._revisions[rev_id].parents: - yield from self._get_parent_revs(parent, seen, limit) + yield from self.__get_parent_revs(parent, seen, limit) def revision_log( self, revisions: List[Sha1Git], limit: Optional[int] = None ) -> Iterable[Optional[Dict[str, Any]]]: seen: Set[Sha1Git] = set() for rev_id in revisions: - yield from self._get_parent_revs(rev_id, seen, limit) + yield from self.__get_parent_revs(rev_id, seen, limit) def revision_shortlog( self, revisions: List[Sha1Git], limit: Optional[int] = None diff --git a/swh/storage/writer.py b/swh/storage/writer.py --- a/swh/storage/writer.py +++ b/swh/storage/writer.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Iterable +from typing import Any, Dict, Iterable from attr import evolve @@ -57,7 +57,7 @@ contents = [evolve(item, data=None) for item in contents] self.write_additions("content", contents) - def content_update(self, contents: Iterable[Content]) -> None: + def content_update(self, contents: Iterable[Dict[str, Any]]) -> None: if self.journal: raise NotImplementedError("content_update is not supported by the journal.")