diff --git a/swh/storage/algos/revisions_walker.py b/swh/storage/algos/revisions_walker.py --- a/swh/storage/algos/revisions_walker.py +++ b/swh/storage/algos/revisions_walker.py @@ -1,11 +1,30 @@ -# Copyright (C) 2018-2021 The Software Heritage developers +# Copyright (C) 2018-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from __future__ import annotations + from abc import ABCMeta, abstractmethod from collections import deque +import dataclasses import heapq +from typing import TYPE_CHECKING, Dict, List, Optional, Set, TypeVar + +from swh.model.model import Sha1Git + +if TYPE_CHECKING: + from swh.storage.interface import StorageInterface + + +@dataclasses.dataclass +class State: + done: Set[Sha1Git] = dataclasses.field(default_factory=set) + revs_to_visit: List = dataclasses.field(default_factory=list) + last_rev: Optional[Dict] = None + num_revs: int = 0 + missing_revs: Set[Sha1Git] = dataclasses.field(default_factory=set) + _revs_walker_classes = {} @@ -18,6 +37,9 @@ return newclass +TWalker = TypeVar("TWalker", bound="RevisionsWalker") + + class RevisionsWalker(metaclass=_RevisionsWalkerMetaClass): """ Abstract base class encapsulating the logic to walk across @@ -48,32 +70,27 @@ :func:`get_revisions_walker`. Args: - storage (swh.storage.interface.StorageInterface): instance of swh storage - (either local or remote) - rev_start (bytes): a revision identifier - max_revs (Optional[int]): maximum number of revisions to return - state (Optional[dict]): previous state of that revisions walker + storage: instance of swh storage (either local or remote) + rev_start: a revision identifier + max_revs: maximum number of revisions to return + state: previous state of that revisions walker """ - def __init__(self, storage, rev_start, max_revs=None, state=None): - self._revs_to_visit = [] - self._done = set() - self._revs = {} - self._last_rev = None - self._num_revs = 0 + def __init__( + self, + storage: StorageInterface, + rev_start: Sha1Git, + max_revs: Optional[int] = None, + state: Optional[State] = None, + ): + self._revs: Dict[Sha1Git, Dict] = {} self._max_revs = max_revs - self._missing_revs = set() - if state: - self._revs_to_visit = state["revs_to_visit"] - self._done = state["done"] - self._last_rev = state["last_rev"] - self._num_revs = state["num_revs"] - self._missing_revs = state["missing_revs"] + self._state = state or State() self.storage = storage self.process_rev(rev_start) @abstractmethod - def process_rev(self, rev_id): + def process_rev(self, rev_id: Sha1Git) -> None: """ Abstract method whose purpose is to process a newly visited revision during the walk. @@ -82,25 +99,21 @@ through a dfs on the revisions DAG). Args: - rev_id (bytes): the newly visited revision identifier + rev_id: the newly visited revision identifier """ pass @abstractmethod - def get_next_rev_id(self): + def get_next_rev_id(self) -> Sha1Git: """ Abstract method whose purpose is to return the next revision during the iteration. Derived classes must implement it according to the desired method to walk across the revisions history. - - Returns: - dict: A dict describing a revision as returned by - :meth:`swh.storage.interface.StorageInterface.revision_get` """ pass - def process_parent_revs(self, rev): + def process_parent_revs(self, rev: Dict) -> None: """ Process the parents of a revision when it is iterated. The default implementation simply calls :meth:`process_rev` @@ -113,7 +126,7 @@ for parent_id in rev["parents"]: self.process_rev(parent_id) - def should_return(self, rev): + def should_return(self, rev: Dict) -> bool: """ Filter out a revision to return if needed. Default implementation returns all iterated revisions. @@ -127,7 +140,7 @@ """ return True - def is_finished(self): + def is_finished(self) -> bool: """ Determine if the iteration is finished. This method is called at the beginning of each iteration loop. @@ -135,13 +148,13 @@ Returns: bool: Whether the iteration is finished """ - if self._max_revs is not None and self._num_revs >= self._max_revs: + if self._max_revs is not None and self._state.num_revs >= self._max_revs: return True - if not self._revs_to_visit: + if not self._state.revs_to_visit: return True return False - def _get_rev(self, rev_id): + def _get_rev(self, rev_id: Sha1Git) -> Optional[Dict]: rev = self._revs.get(rev_id) if rev is None: # cache some revisions in advance to avoid sending too much @@ -153,7 +166,7 @@ self._revs[rev["id"]] = rev return self._revs.get(rev_id) - def missing_revisions(self): + def missing_revisions(self) -> Set[Sha1Git]: """ Return a set of revision identifiers whose associated data were found missing into the archive content while walking on the @@ -162,9 +175,9 @@ Returns: Set[bytes]: a set of revision identifiers """ - return self._missing_revs + return self._state.missing_revs - def is_history_truncated(self): + def is_history_truncated(self) -> bool: """ Return if the revision history generated so far has been truncated of not. A revision history might end up truncated if some revision @@ -175,7 +188,7 @@ """ return len(self.missing_revisions()) > 0 - def export_state(self): + def export_state(self) -> State: """ Export the internal state of that revision walker to a dict. Its purpose is to continue the iteration in a pagination context. @@ -183,35 +196,29 @@ Returns: dict: A dict containing the internal state of that revisions walker """ - return { - "revs_to_visit": self._revs_to_visit, - "done": self._done, - "last_rev": self._last_rev, - "num_revs": self._num_revs, - "missing_revs": self._missing_revs, - } + return self._state - def __next__(self): + def __next__(self) -> Dict: if self.is_finished(): raise StopIteration - while self._revs_to_visit: + while self._state.revs_to_visit: rev_id = self.get_next_rev_id() - if rev_id in self._done: + if rev_id in self._state.done: continue - self._done.add(rev_id) + self._state.done.add(rev_id) rev = self._get_rev(rev_id) # revision data is missing, returned history will be truncated if rev is None: - self._missing_revs.add(rev_id) + self._state.missing_revs.add(rev_id) continue self.process_parent_revs(rev) if self.should_return(rev): - self._num_revs += 1 - self._last_rev = rev + self._state.num_revs += 1 + self._state.last_rev = rev return rev raise StopIteration - def __iter__(self): + def __iter__(self: TWalker) -> TWalker: return self @@ -223,14 +230,14 @@ rw_type = "committer_date" - def process_rev(self, rev_id): + def process_rev(self, rev_id: Sha1Git) -> None: """ Add the revision to a priority queue according to the committer date. Args: rev_id (bytes): the newly visited revision identifier """ - if rev_id not in self._done: + if rev_id not in self._state.done: rev = self._get_rev(rev_id) if rev is not None: commit_time = ( @@ -238,13 +245,15 @@ if rev["committer_date"] # allows to avoid failure with a revision without commit date # and iterate on such revision before its parents - else len(self._revs_to_visit) + else len(self._state.revs_to_visit) ) - heapq.heappush(self._revs_to_visit, (-commit_time, rev_id)) + heapq.heappush( + self._state.revs_to_visit, (-commit_time, rev_id) + ) # type: ignore else: - self._missing_revs.add(rev_id) + self._state.missing_revs.add(rev_id) - def get_next_rev_id(self): + def get_next_rev_id(self) -> Sha1Git: """ Return the smallest revision from the priority queue, i.e. the one with highest committer date. @@ -253,7 +262,7 @@ dict: A dict describing a revision as returned by :meth:`swh.storage.interface.StorageInterface.revision_get` """ - _, rev_id = heapq.heappop(self._revs_to_visit) + _, rev_id = heapq.heappop(self._state.revs_to_visit) return rev_id @@ -268,19 +277,19 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._revs_to_visit = deque(self._revs_to_visit) + self._state.revs_to_visit = deque(self._state.revs_to_visit) # type: ignore - def process_rev(self, rev_id): + def process_rev(self, rev_id: Sha1Git) -> None: """ Append the revision to a queue. Args: rev_id (bytes): the newly visited revision identifier """ - if rev_id not in self._done: - self._revs_to_visit.append(rev_id) + if rev_id not in self._state.done: + self._state.revs_to_visit.append(rev_id) - def get_next_rev_id(self): + def get_next_rev_id(self) -> Sha1Git: """ Return the next revision from the queue. @@ -288,7 +297,7 @@ dict: A dict describing a revision as returned by :meth:`swh.storage.interface.StorageInterface.revision_get` """ - return self._revs_to_visit.popleft() + return self._state.revs_to_visit.popleft() # type: ignore # type: ignore class DFSPostRevisionsWalker(RevisionsWalker): @@ -302,17 +311,17 @@ rw_type = "dfs_post" - def process_rev(self, rev_id): + def process_rev(self, rev_id: Sha1Git) -> None: """ Append the revision to a stack. Args: rev_id (bytes): the newly visited revision identifier """ - if rev_id not in self._done: - self._revs_to_visit.append(rev_id) + if rev_id not in self._state.done: + self._state.revs_to_visit.append(rev_id) - def get_next_rev_id(self): + def get_next_rev_id(self) -> Sha1Git: """ Return the next revision from the stack. @@ -320,7 +329,7 @@ dict: A dict describing a revision as returned by :meth:`swh.storage.interface.StorageInterface.revision_get` """ - return self._revs_to_visit.pop() + return self._state.revs_to_visit.pop() class DFSRevisionsWalker(DFSPostRevisionsWalker): @@ -334,7 +343,7 @@ rw_type = "dfs" - def process_parent_revs(self, rev): + def process_parent_revs(self, rev: Dict) -> None: """ Process the parents of a revision when it is iterated in the reversed order they are declared.