diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -58,7 +58,7 @@ from swh.model.hashutil import DEFAULT_ALGORITHMS, hash_to_bytes, hash_to_hex from swh.storage.cassandra import CassandraStorage -from swh.storage.cassandra.model import BaseRow +from swh.storage.cassandra.model import BaseRow, ObjectCountRow from swh.storage.interface import ( ListOrder, PagedResult, @@ -220,12 +220,27 @@ ) +class InMemoryCqlRunner: + def __init__(self): + self._stat_counters = defaultdict(int) + + def increment_counter(self, object_type: str, nb: int): + self._stat_counters[object_type] += nb + + def stat_counters(self) -> Iterable[ObjectCountRow]: + for (object_type, count) in self._stat_counters.items(): + yield ObjectCountRow(partition_key=0, object_type=object_type, count=count) + + class InMemoryStorage(CassandraStorage): + _cql_runner: InMemoryCqlRunner # type: ignore + def __init__(self, journal_writer=None): self.reset() self.journal_writer = JournalWriter(journal_writer) def reset(self): + self._cql_runner = InMemoryCqlRunner() self._contents = {} self._content_indexes = defaultdict(lambda: defaultdict(set)) self._skipped_contents = {} @@ -311,6 +326,8 @@ self._contents[key] = attr.evolve(self._contents[key], data=None) content_add += 1 + self._cql_runner.increment_counter("content", content_add) + summary = { "content:add": content_add, } @@ -467,6 +484,8 @@ self._skipped_contents[key] = content summary["skipped_content:add"] += 1 + self._cql_runner.increment_counter("skipped_content", len(contents)) + return summary def skipped_content_add(self, content: List[SkippedContent]) -> Dict: @@ -499,6 +518,8 @@ self._directories[directory.id] = directory self._objects[directory.id].append(("directory", directory.id)) + self._cql_runner.increment_counter("directory", len(directories)) + return {"directory:add": count} def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]: @@ -576,6 +597,8 @@ self._objects[revision.id].append(("revision", revision.id)) count += 1 + self._cql_runner.increment_counter("revision", len(revisions)) + return {"revision:add": count} def revision_missing(self, revisions: List[Sha1Git]) -> Iterable[Sha1Git]: @@ -635,6 +658,8 @@ self._objects[rel.id].append(("release", rel.id)) self._releases[rel.id] = rel + self._cql_runner.increment_counter("release", len(to_add)) + return {"release:add": len(to_add)} def release_missing(self, releases: List[Sha1Git]) -> Iterable[Sha1Git]: @@ -661,6 +686,8 @@ self._objects[snapshot.id].append(("snapshot", snapshot.id)) count += 1 + self._cql_runner.increment_counter("snapshot", len(snapshots)) + return {"snapshot:add": count} def snapshot_missing(self, snapshots: List[Sha1Git]) -> Iterable[Sha1Git]: @@ -856,6 +883,8 @@ self.origin_add_one(origin) added += 1 + self._cql_runner.increment_counter("origin", len(origins)) + return {"origin:add": added} def origin_add_one(self, origin: Origin) -> str: @@ -904,6 +933,8 @@ ) all_visits.append(visit) + self._cql_runner.increment_counter("origin_visit", len(all_visits)) + return all_visits def _origin_visit_status_add_one(self, visit_status: OriginVisitStatus) -> None: @@ -1146,30 +1177,6 @@ else: return None - def stat_counters(self): - keys = ( - "content", - "directory", - "origin", - "origin_visit", - "person", - "release", - "revision", - "skipped_content", - "snapshot", - ) - stats = {key: 0 for key in keys} - stats.update( - collections.Counter( - obj_type - for (obj_type, obj_id) in itertools.chain(*self._objects.values()) - ) - ) - return stats - - def refresh_stat_counters(self): - pass - def raw_extrinsic_metadata_add(self, metadata: List[RawExtrinsicMetadata],) -> None: self.journal_writer.raw_extrinsic_metadata_add(metadata) for metadata_entry in metadata: diff --git a/swh/storage/replay.py b/swh/storage/replay.py --- a/swh/storage/replay.py +++ b/swh/storage/replay.py @@ -136,6 +136,6 @@ method(model_objs) elif object_type in ("directory", "revision", "release", "snapshot", "origin",): method = getattr(storage, object_type + "_add") - method(object_converter_fn[object_type](o) for o in objects) + method([object_converter_fn[object_type](o) for o in objects]) else: logger.warning("Received a series of %s, this should not happen", object_type)