diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -3,11 +3,11 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import copy import json import attr +from copy import deepcopy from typing import Dict from swh.model.model import ( @@ -19,51 +19,44 @@ from .common import Row -def revision_to_db(revision: Revision) -> Revision: +class CassObject(dict): + __getattr__ = dict.__getitem__ + + +def revision_to_db(revision: Revision) -> CassObject: + # we use a deepcopy of the dict because we do not want to recurse the + # Model->dict conversion (to keep Timestamp & al. entities), BUT we do not + # want to modify original metadata (embedded in the Model entity), so we + # non-recursively convert it as a dict but make a deep copy. + db_revision = CassObject(deepcopy(attr.asdict(revision, recurse=False))) metadata = revision.metadata if metadata and 'extra_headers' in metadata: - metadata = copy.deepcopy(metadata) - metadata['extra_headers'] = git_headers_to_db( + db_revision['metadata']['extra_headers'] = git_headers_to_db( metadata['extra_headers']) + db_revision['metadata'] = json.dumps(db_revision['metadata']) + db_revision['type'] = db_revision['type'].value + return db_revision - revision = attr.evolve( - revision, - type=revision.type.value, - metadata=json.dumps(metadata), - ) - - return revision - -def revision_from_db(revision) -> Revision: - metadata = json.loads(revision.metadata) +def revision_from_db(**kwargs) -> Revision: + kwargs['metadata'] = metadata = json.loads(kwargs['metadata']) if metadata and 'extra_headers' in metadata: extra_headers = db_to_git_headers( metadata['extra_headers']) metadata['extra_headers'] = extra_headers - rev = attr.evolve( - revision, - type=RevisionType(revision.type), - metadata=metadata, - ) - - return rev + kwargs['type'] = RevisionType(kwargs['type']) + return Revision(**kwargs) -def release_to_db(release: Release) -> Release: - release = attr.evolve( - release, - target_type=release.target_type.value, - ) - return release +def release_to_db(release: Release) -> CassObject: + db_release = CassObject(attr.asdict(release, recurse=False)) + db_release['target_type'] = release.target_type.value + return db_release -def release_from_db(release: Release) -> Release: - release = attr.evolve( - release, - target_type=ObjectType(release.target_type), - ) - return release +def release_from_db(**kwargs) -> Release: + kwargs['target_type'] = ObjectType(kwargs['target_type']) + return Release(**kwargs) def row_to_content_hashes(row: Row) -> Dict[str, bytes]: diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -414,22 +414,21 @@ # Filter-out revisions already in the database missing = self.revision_missing([rev.id for rev in revisions]) revisions = [rev for rev in revisions if rev.id in missing] - self.journal_writer.revision_add(revisions) for revision in revisions: - revision = revision_to_db(revision) - if revision: + revobject = revision_to_db(revision) + if revobject: # Add parents first - for (rank, parent) in enumerate(revision.parents): + for (rank, parent) in enumerate(revobject['parents']): self._cql_runner.revision_parent_add_one( - revision.id, rank, parent) + revobject['id'], rank, parent) # Then write the main revision row. # Writing this after all parents were written ensures that # read endpoints don't return a partial view while writing # the parents - self._cql_runner.revision_add_one(revision) + self._cql_runner.revision_add_one(revobject) return {'revision:add': len(revisions)} @@ -448,10 +447,7 @@ # parent_rank is the clustering key, so results are already # sorted by rank. parents = [row.parent_id for row in parent_rows] - - rev = Revision(**row._asdict(), parents=parents) - - rev = revision_from_db(rev) + rev = revision_from_db(**row._asdict(), parents=parents) revs[rev.id] = rev.to_dict() for rev_id in revisions: @@ -487,8 +483,8 @@ if short: yield (row.id, parents) else: - rev = revision_from_db(Revision( - **row._asdict(), parents=parents)) + rev = revision_from_db( + **row._asdict(), parents=parents) yield rev.to_dict() yield from self._get_parent_revs(parents, seen, limit, short) @@ -511,8 +507,7 @@ for release in releases: if release: - release = release_to_db(release) - self._cql_runner.release_add_one(release) + self._cql_runner.release_add_one(release_to_db(release)) return {'release:add': len(missing)} @@ -523,8 +518,7 @@ rows = self._cql_runner.release_get(releases) rels = {} for row in rows: - release = Release(**row._asdict()) - release = release_from_db(release) + release = release_from_db(**row._asdict()) rels[row.id] = release.to_dict() for rel_id in releases: