diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,2 +1,2 @@ swh.core >= 0.0.77 -swh.scheduler >= 0.0.58 +swh.scheduler >= 0.3.0 diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py new file mode 100644 --- /dev/null +++ b/swh/lister/pattern.py @@ -0,0 +1,225 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from dataclasses import dataclass +from typing import ( + Any, + Dict, + Generic, + Iterable, + Iterator, + List, + Tuple, + TypeVar, +) + +from swh.core import config +from swh.scheduler import get_scheduler +from swh.scheduler import model + + +@dataclass +class ListerStats: + pages: int = 0 + origins: int = 0 + + def __add__(self, other: "ListerStats") -> "ListerStats": + return self.__class__(self.pages + other.pages, self.origins + other.origins) + + def __iadd__(self, other: "ListerStats"): + self.pages += other.pages + self.origins += other.origins + + def dict(self) -> Dict[str, int]: + return {"pages": self.pages, "origins": self.origins} + + +StateType = TypeVar("StateType") +PageType = TypeVar("PageType") + + +class Lister(Generic[StateType, PageType]): + """The base class for a Software Heritage lister. + + A lister scrapes a page by page list of origins from an upstream (a forge, the API + of a package manager, ...), and massages the results of that scrape into a list of + origins that are recorded by the scheduler backend. + + The main loop of the lister, :meth:`run`, basically revolves around the + :meth:`get_pages` iterator, which sets up the lister state, then yields the scrape + results page by page. The :func:`get_origins_from_page` method converts the pages + into a list of :class:`model.ListedOrigin`, sent to the scheduler at every page. + + The :func:`finalize_state` method is called at lister completion (successful or not) + to update the local :attr:`state` object before it's sent to the database. This + method must set the :attr:`state_updated` attribute if an updated state needs to be + sent to the scheduler backend. This method can call :func:`get_state_from_scheduler` + to refresh and merge the lister state from the scheduler before it's finalized (and + minimize the risk of race conditions between concurrent runs of the lister). + + The state of the lister is serialized and deserialized from the dict stored in the + scheduler backend, using the :meth:`state_from_dict` and :meth:`state_to_dict` + methods. + + Args: + url: a URL representing this lister, e.g. the API's base URL + instance: the instance name used, in conjunction with :attr:`LISTER_NAME`, to + uniquely identify this lister instance. + + Generic types: + - *StateType*: concrete lister type; should usually be a :class:`dataclass` for + stricter typing + - *PageType*: type of scrape results; can usually be a :class:`requests.Response`, + or a :class:`dict` + + """ + + LISTER_NAME: str = "" + + def __init__(self, url: str, instance: str): + if not self.LISTER_NAME: + raise ValueError("Must set the LISTER_NAME attribute on Lister classes") + + self.url = url + self.instance = instance + + self.config = self.load_config() + self.scheduler = get_scheduler(**self.config["scheduler"]) + + # store the initial state of the lister + self.state = self.get_state_from_scheduler() + self.state_updated = False + + def config_base_filename(self) -> str: + return "lister_%s" % self.LISTER_NAME + + def get_default_config(self) -> Dict[str, Tuple[str, Dict[str, Any]]]: + """Get the default config for this lister. + + The only mandatory key is the `scheduler` key (which is returned) + """ + return { + "scheduler": ( + "dict", + {"cls": "remote", "args": {"url": "http://localhost:5008/"}}, + ), + } + + def load_config(self) -> Dict[str, Any]: + """Load the configuration from the configured filename""" + return config.SWHConfig().parse_config_file( + base_filename=self.config_base_filename(), + additional_configs=[self.get_default_config()], + ) + + def get_credentials(self) -> List[Dict[str, str]]: + """Get the credentials for the current instance of the lister""" + return ( + self.config.get("credentials", {}) + .get(self.LISTER_NAME, {}) + .get(self.instance, []) + ) + + def run(self) -> ListerStats: + """Run the lister. + + Returns: + A counter with the number of pages and origins seen for this run + of the lister. + + """ + full_stats = ListerStats() + + try: + for page in self.get_pages(): + full_stats.pages += 1 + origins = self.get_origins_from_page(page) + full_stats.origins += self.send_origins(origins) + finally: + self.finalize_state() + if self.state_updated: + self.set_state_in_scheduler() + + return full_stats + + def get_state_from_scheduler(self) -> StateType: + """Update the state in the current instance from the state in the scheduler backend. + + This updates :attr:`lister_obj`, and returns its (deserialized) current state, + to allow for comparison with the local state. + """ + self.lister_obj = self.scheduler.get_or_create_lister( + name=self.LISTER_NAME, instance_name=self.instance + ) + return self.state_from_dict(self.lister_obj.current_state) + + def set_state_in_scheduler(self) -> None: + """Update the state in the scheduler backend from the state of the current instance. + + Raises: + :class:`swh.scheduler.exc.StaleData` in case of a race condition between + concurrent listers. + """ + if self.state is None: + raise ValueError("Current state unset!") + self.lister_obj.current_state = self.state_to_dict(self.state) + self.lister_obj = self.scheduler.update_lister(self.lister_obj) + + # State management to/from the scheduler + + def state_from_dict(self, d: Dict[str, Any]) -> StateType: + """Convert the state stored in the scheduler backend (as a dict), + to the concrete StateType for this lister.""" + raise NotImplementedError + + def state_to_dict(self, state: StateType) -> Dict[str, Any]: + """Convert the StateType for this lister to its serialization as dict for + storage in the scheduler. + + Values must be JSON-compatible as that's what the backend database expects. + """ + raise NotImplementedError + + def finalize_state(self) -> None: + """Custom hook to finalize the lister state before returning. + + This method must set :attr:`state_updated` if the updated state must be sent to + the scheduler backend. + + If relevant, this method can use :meth`get_state_from_scheduler` to merge the + current lister state with the one from the scheduler backend, reducing the risk + of race conditions if we're running concurrent listings. + + """ + pass + + # Actual listing logic + + def get_pages(self) -> Iterator[PageType]: + """Retrieve a list of pages of listed results. This is the main loop of the lister. + + Returns: + an iterator of raw pages fetched from the platform currently being listed. + """ + raise NotImplementedError + + def get_origins_from_page(self, page: PageType) -> Iterator[model.ListedOrigin]: + """Extract a list of :class:`model.ListedOrigin` from a raw page of results. + + Args: + page: a single page of results + Returns: + an iterator for the origins present on the given page of results + """ + raise NotImplementedError + + def send_origins(self, origins: Iterable[model.ListedOrigin]) -> int: + """Record a list of :class:`model.ListedOrigin` in the scheduler. + + Returns: + the number of listed origins recorded in the scheduler + """ + ret = self.scheduler.record_listed_origins(origins) + return len(ret) diff --git a/swh/lister/tests/conftest.py b/swh/lister/tests/conftest.py new file mode 100644 --- /dev/null +++ b/swh/lister/tests/conftest.py @@ -0,0 +1,7 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +# import the swh_scheduler fixture and its friends +from swh.scheduler.tests.conftest import * # noqa diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py new file mode 100644 --- /dev/null +++ b/swh/lister/tests/test_pattern.py @@ -0,0 +1,101 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Dict, Iterator, List + +import pytest + +from swh.scheduler.model import ListedOrigin +from swh.lister import pattern + + +@pytest.fixture +def mock_get_scheduler(monkeypatch, swh_scheduler): + def get_scheduler(cls, args): + return swh_scheduler + + monkeypatch.setattr(pattern, "get_scheduler", get_scheduler) + + +class InstantiatableLister(pattern.Lister[Dict[str, str], List[Dict[str, str]]]): + """A lister that can only be instantiated, not run.""" + + LISTER_NAME = "test-pattern-lister" + + def state_from_dict(self, d: Dict[str, str]) -> Dict[str, str]: + return d + + +def test_instantiation(mock_get_scheduler, swh_scheduler): + lister = InstantiatableLister(url="https://example.com", instance="example.com") + + # check the lister was registered in the scheduler backend + stored_lister = swh_scheduler.get_or_create_lister( + name="test-pattern-lister", instance_name="example.com" + ) + assert stored_lister == lister.lister_obj + + with pytest.raises(NotImplementedError): + lister.run() + + +class RunnableLister(InstantiatableLister): + """A lister that can be run.""" + + def state_to_dict(self, state: Dict[str, str]) -> Dict[str, str]: + return state + + def get_pages(self) -> Iterator[List[Dict[str, str]]]: + for pageno in range(2): + yield [ + {"url": f"https://example.com/{pageno:02d}{i:03d}"} for i in range(10) + ] + + def get_origins_from_page( + self, page: List[Dict[str, str]] + ) -> Iterator[ListedOrigin]: + for origin in page: + yield ListedOrigin( + lister_id=self.lister_obj.id, url=origin["url"], visit_type="git" + ) + + def finalize_state(self) -> None: + self.state["updated"] = "yes" + self.state_updated = True + + +def test_run(mock_get_scheduler, swh_scheduler): + lister = RunnableLister(url="https://example.com", instance="example.com") + + assert "updated" not in lister.state + + update_date = lister.lister_obj.updated + + run_result = lister.run() + + assert run_result.pages == 2 + assert run_result.origins == 20 + + stored_lister = swh_scheduler.get_or_create_lister( + name="test-pattern-lister", instance_name="example.com" + ) + + # Check that the finalize_state operation happened + assert stored_lister.updated > update_date + assert stored_lister.current_state["updated"] == "yes" + + # Gather the origins that are supposed to be listed + lister_urls = sorted( + sum([[o["url"] for o in page] for page in lister.get_pages()], []) + ) + + # And check the state of origins in the scheduler + ret = swh_scheduler.get_listed_origins() + assert ret.next_page_token is None + assert len(ret.origins) == len(lister_urls) + + for origin, expected_url in zip(ret.origins, lister_urls): + assert origin.url == expected_url + assert origin.lister_id == stored_lister.id