diff --git a/swh/storage/cassandra/converters.py b/swh/storage/cassandra/converters.py --- a/swh/storage/cassandra/converters.py +++ b/swh/storage/cassandra/converters.py @@ -8,12 +8,15 @@ import attr +from typing import Dict + from swh.model.model import ( RevisionType, ObjectType, Revision, Release, ) - +from swh.model.hashutil import DEFAULT_ALGORITHMS from ..converters import git_headers_to_db, db_to_git_headers +from .common import Row def revision_to_db(revision: Revision) -> Revision: @@ -61,3 +64,13 @@ target_type=ObjectType(release.target_type), ) return release + + +def row_to_content_hashes(row: Row) -> Dict[str, bytes]: + """Convert cassandra row to a content hashes + + """ + hashes = {} + for algo in DEFAULT_ALGORITHMS: + hashes[algo] = getattr(row, algo) + return hashes diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -26,6 +26,7 @@ from .common import TOKEN_BEGIN, TOKEN_END from .converters import ( revision_to_db, revision_from_db, release_to_db, release_from_db, + row_to_content_hashes, ) from .cql import CqlRunner from .schema import HASH_ALGORITHMS @@ -93,7 +94,11 @@ algo, content.get_hash(algo)) if len(pks) > 1: # There are more than the one we just inserted. - raise HashCollision(algo, content.get_hash(algo), pks) + colliding_content_hashes = [ + row_to_content_hashes(pk) for pk in pks + ] + raise HashCollision( + algo, content.get_hash(algo), colliding_content_hashes) summary = { 'content:add': content_add, diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -92,7 +92,9 @@ hash_ = content.get_hash(algorithm) if hash_ in self._content_indexes[algorithm]\ and (algorithm not in {'blake2s256', 'sha256'}): - raise HashCollision(algorithm, hash_, key) + colliding_content_hashes = [content.hashes()] + raise HashCollision( + algorithm, hash_, colliding_content_hashes) for algorithm in DEFAULT_ALGORITHMS: hash_ = content.get_hash(algorithm) self._content_indexes[algorithm][hash_].add(key) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -33,7 +33,9 @@ from .exc import StorageArgumentException, StorageDBError from .algos import diff from .metrics import timed, send_metric, process_metrics -from .utils import get_partition_bounds_bytes +from .utils import ( + get_partition_bounds_bytes, extract_collision_hash +) from .writer import JournalWriter @@ -158,14 +160,27 @@ except psycopg2.IntegrityError as e: if e.diag.sqlstate == '23505' and \ e.diag.table_name == 'content': - constraint_to_hash_name = { - 'content_pkey': 'sha1', - 'content_sha1_git_idx': 'sha1_git', - 'content_sha256_idx': 'sha256', + message_detail = e.diag.message_detail + if message_detail: + hash_name, hash_id = extract_collision_hash(message_detail) + collision_contents_hashes = [ + c.hashes() for c in content + if c.get_hash(hash_name) == hash_id + ] + else: + constraint_to_hash_name = { + 'content_pkey': 'sha1', + 'content_sha1_git_idx': 'sha1_git', + 'content_sha256_idx': 'sha256', } - colliding_hash_name = constraint_to_hash_name \ - .get(e.diag.constraint_name) - raise HashCollision(colliding_hash_name) from None + hash_name = constraint_to_hash_name \ + .get(e.diag.constraint_name) + hash_id = None + collision_contents_hashes = None + + raise HashCollision( + hash_name, hash_id, collision_contents_hashes + ) from None else: raise diff --git a/swh/storage/tests/test_cassandra_converters.py b/swh/storage/tests/test_cassandra_converters.py new file mode 100644 --- /dev/null +++ b/swh/storage/tests/test_cassandra_converters.py @@ -0,0 +1,37 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from collections import namedtuple +from typing import List + +from swh.storage.cassandra import converters + +from swh.model.hashutil import DEFAULT_ALGORITHMS + + +# Test purposes +field_names: List[str] = list(DEFAULT_ALGORITHMS) +Row = namedtuple('Row', field_names) # type: ignore + + +def test_row_to_content_hashes(): + for row in [Row( + sha1=b'4\x972t\xcc\xefj\xb4\xdf\xaa\xf8e\x99y/\xa9\xc3\xfeF\x89', + sha1_git=b'\xd8\x1c\xc0q\x0e\xb6\xcf\x9e\xfd[\x92\n\x84S\xe1\xe0qW\xb6\xcd', # noqa + sha256=b'g6P\xf96\xcb;\n/\x93\xce\t\xd8\x1b\xe1\x07H\xb1\xb2\x03\xc1\x9e\x81v\xb4\xee\xfc\x19d\xa0\xcf:', # noqa + blake2s256=b"\xd5\xfe\x199We'\xe4,\xfdv\xa9EZ$2\xfe\x7fVf\x95dW}\xd9 bool: return n > 0 and n & (n-1) == 0 @@ -40,3 +44,25 @@ end = None if i == n-1 \ else (partition_size*(i+1)).to_bytes(nb_bytes, 'big') return (start, end) + + +def extract_collision_hash(error_message: str) -> Optional[Tuple[str, bytes]]: + """Utilities to extract the hash information from a hash collision error. + + Hash collision error message are of the form: + 'Key ()=([^)]+)\)=\(\\x(?P[a-f0-9]+)\) \w*' + result = re.match(pattern, error_message) + if result: + hash_type = result.group('type') + hash_id = result.group('id') + return hash_type, hash_to_bytes(hash_id) + return None