diff --git a/swh/loader/git/loader.py b/swh/loader/git/loader.py --- a/swh/loader/git/loader.py +++ b/swh/loader/git/loader.py @@ -8,8 +8,10 @@ import logging import os import pickle +import signal import sys from tempfile import SpooledTemporaryFile +import tracemalloc from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Type import dulwich.client @@ -38,6 +40,25 @@ from .utils import HexBytes logger = logging.getLogger(__name__) +tracemalloc_logger = logging.getLogger(__name__ + ".tracemalloc") + + +def log_tracemalloc(msg: str, snapshot: tracemalloc.Snapshot): + top_stats = snapshot.statistics("lineno") + + tracemalloc_logger.debug("[ Top 10 memory users %s ]", msg) + for stat in top_stats[:10]: + tracemalloc_logger.debug(stat) + + +def log_tracemalloc_diff( + msg: str, snapshot1: tracemalloc.Snapshot, snapshot2: tracemalloc.Snapshot +): + top_stats = snapshot2.compare_to(snapshot1, "lineno") + + tracemalloc_logger.debug("[ Top 10 differences after %s ]", msg) + for stat in top_stats[:10]: + tracemalloc_logger.debug(stat) class RepoRepresentation: @@ -145,6 +166,20 @@ self.symbolic_refs: Dict[bytes, HexBytes] = {} self.ref_object_types: Dict[bytes, Optional[TargetType]] = {} + self.tracemalloc_snapshot = tracemalloc.take_snapshot() + signal.signal(signal.SIGUSR1, self.tracemalloc_handler) + + def do_tracemalloc(self, msg: str): + tracemalloc_snapshot = tracemalloc.take_snapshot() + log_tracemalloc(msg, tracemalloc_snapshot) + log_tracemalloc_diff( + msg, self.tracemalloc_snapshot, tracemalloc_snapshot, + ) + self.tracemalloc_snapshot = tracemalloc_snapshot + + def tracemalloc_handler(self, _signum, _frame): + self.do_tracemalloc("on_signal") + def fetch_pack_from_origin( self, origin_url: str, @@ -216,13 +251,17 @@ # not support it and do not fetch any refs self.dumb = transport_url.startswith("http") and client.dumb - return FetchPackReturn( + ret = FetchPackReturn( remote_refs=utils.filter_refs(remote_refs), symbolic_refs=utils.filter_refs(symbolic_refs), pack_buffer=pack_buffer, pack_size=pack_size, ) + self.do_tracemalloc("fetch_pack_from_origins") + + return ret + def prepare_origin_visit(self) -> None: self.visit_date = datetime.datetime.now(tz=datetime.timezone.utc) self.origin = Origin(url=self.origin_url) @@ -357,6 +396,8 @@ count += 1 logger.debug("packfile_read_count_%s=%s", object_type.decode(), count) + self.do_tracemalloc("iter_objects") + def get_contents(self) -> Iterable[BaseContent]: """Format the blobs from the git repository as swh contents""" for raw_obj in self.iter_objects(b"blob"):