diff --git a/swh/provenance/archive.py b/swh/provenance/archive.py --- a/swh/provenance/archive.py +++ b/swh/provenance/archive.py @@ -1,27 +1,45 @@ -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable from typing_extensions import Protocol, runtime_checkable +from swh.model.model import Revision, Sha1 + @runtime_checkable class ArchiveInterface(Protocol): - def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: - ... + def directory_ls(self, id: Sha1) -> Iterable[Dict[str, Any]]: + """List entries for one directory. - def iter_origins(self): - ... + Args: + id: sha1 id of the directory to list entries from. - def iter_origin_visits(self, origin: str): - ... + Yields: + directory entries for such directory. - def iter_origin_visit_statuses(self, origin: str, visit: int): + """ ... - def release_get(self, ids: Iterable[bytes]): - ... + def revision_get(self, ids: Iterable[Sha1]) -> Iterable[Revision]: + """Given a list of sha1, return the revisions' information - def revision_get(self, ids: Iterable[bytes]): + Args: + revisions: list of sha1s for the revisions to be retrieved + + Yields: + revisions matching the identifiers. If a revision does + not exist, the provided sha1 is simply ignored. + + """ ... - def snapshot_get_all_branches(self, snapshot: bytes): + def snapshot_get_heads(self, id: Sha1) -> Iterable[Sha1]: + """List all revisions pointed by one snapshot. + + Args: + snapshot: the snapshot's identifier + + Yields: + sha1 ids of found revisions. + + """ ... diff --git a/swh/provenance/model.py b/swh/provenance/model.py --- a/swh/provenance/model.py +++ b/swh/provenance/model.py @@ -4,10 +4,7 @@ # See top-level LICENSE file for more information from datetime import datetime -from typing import Iterable, Iterator, List, Optional, Set - -from swh.core.utils import grouper -from swh.model.model import ObjectType, TargetType +from typing import Iterable, Iterator, List, Optional from .archive import ArchiveInterface @@ -17,42 +14,17 @@ self, url: str, date: datetime, snapshot: bytes, id: Optional[int] = None ): self.url = url - self.date = date + # TODO: this is probably not needed and will be removed! + # self.date = date self.snapshot = snapshot self.id = id self._revisions: Optional[List[RevisionEntry]] = None def retrieve_revisions(self, archive: ArchiveInterface): if self._revisions is None: - snapshot = archive.snapshot_get_all_branches(self.snapshot) - assert snapshot is not None - targets_set = set() - releases_set = set() - if snapshot is not None: - for branch in snapshot.branches: - if snapshot.branches[branch].target_type == TargetType.REVISION: - targets_set.add(snapshot.branches[branch].target) - elif snapshot.branches[branch].target_type == TargetType.RELEASE: - releases_set.add(snapshot.branches[branch].target) - - batchsize = 100 - for releases in grouper(releases_set, batchsize): - targets_set.update( - release.target - for release in archive.revision_get(releases) - if release is not None - and release.target_type == ObjectType.REVISION - ) - - revisions: Set[RevisionEntry] = set() - for targets in grouper(targets_set, batchsize): - revisions.update( - RevisionEntry(revision.id) - for revision in archive.revision_get(targets) - if revision is not None - ) - - self._revisions = list(revisions) + self._revisions = [ + RevisionEntry(rev) for rev in archive.snapshot_get_heads(self.snapshot) + ] @property def revisions(self) -> Iterator["RevisionEntry"]: @@ -93,7 +65,7 @@ parents=rev.parents, ) for rev in archive.revision_get(self._parents) - if rev is not None and rev.date is not None + if rev.date is not None ] yield from self._nodes diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py --- a/swh/provenance/postgresql/archive.py +++ b/swh/provenance/postgresql/archive.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Set from methodtools import lru_cache import psycopg2 -from swh.model.model import Revision +from swh.model.model import ObjectType, Revision, Sha1, TargetType from swh.storage.postgresql.storage import Storage @@ -12,14 +12,14 @@ self.conn = conn self.storage = Storage(conn, objstorage={"cls": "memory"}) - def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: + def directory_ls(self, id: Sha1) -> Iterable[Dict[str, Any]]: # TODO: only call directory_ls_internal if the id is not being queried by # someone else. Otherwise wait until results get properly cached. entries = self.directory_ls_internal(id) - return entries + yield from entries @lru_cache(maxsize=100000) - def directory_ls_internal(self, id: bytes) -> List[Dict[str, Any]]: + def directory_ls_internal(self, id: Sha1) -> List[Dict[str, Any]]: # TODO: add file size filtering with self.conn.cursor() as cursor: cursor.execute( @@ -62,28 +62,7 @@ for row in cursor.fetchall() ] - def iter_origins(self): - from swh.storage.algos.origin import iter_origins - - yield from iter_origins(self.storage) - - def iter_origin_visits(self, origin: str): - from swh.storage.algos.origin import iter_origin_visits - - # TODO: filter unused fields - yield from iter_origin_visits(self.storage, origin) - - def iter_origin_visit_statuses(self, origin: str, visit: int): - from swh.storage.algos.origin import iter_origin_visit_statuses - - # TODO: filter unused fields - yield from iter_origin_visit_statuses(self.storage, origin, visit) - - def release_get(self, ids: Iterable[bytes]): - # TODO: filter unused fields - yield from self.storage.release_get(list(ids)) - - def revision_get(self, ids: Iterable[bytes]): + def revision_get(self, ids: Iterable[Sha1]) -> Iterable[Revision]: with self.conn.cursor() as cursor: psycopg2.extras.execute_values( cursor, @@ -117,8 +96,39 @@ } ) - def snapshot_get_all_branches(self, snapshot: bytes): + def snapshot_get_heads(self, id: Sha1) -> Iterable[Sha1]: + # TODO: this code is duplicated here (same as in swh.provenance.storage.archive) + # but it's just temporary. This method should actually perform a direct query to + # the SQL db of the archive. + from swh.core.utils import grouper from swh.storage.algos.snapshot import snapshot_get_all_branches - # TODO: filter unused fields - return snapshot_get_all_branches(self.storage, snapshot) + snapshot = snapshot_get_all_branches(self.storage, id) + assert snapshot is not None + + targets_set = set() + releases_set = set() + if snapshot is not None: + for branch in snapshot.branches: + if snapshot.branches[branch].target_type == TargetType.REVISION: + targets_set.add(snapshot.branches[branch].target) + elif snapshot.branches[branch].target_type == TargetType.RELEASE: + releases_set.add(snapshot.branches[branch].target) + + batchsize = 100 + for releases in grouper(releases_set, batchsize): + targets_set.update( + release.target + for release in self.storage.release_get(releases) + if release is not None and release.target_type == ObjectType.REVISION + ) + + revisions: Set[Sha1] = set() + for targets in grouper(targets_set, batchsize): + revisions.update( + revision.id + for revision in self.storage.revision_get(targets) + if revision is not None + ) + + yield from revisions diff --git a/swh/provenance/storage/archive.py b/swh/provenance/storage/archive.py --- a/swh/provenance/storage/archive.py +++ b/swh/provenance/storage/archive.py @@ -1,8 +1,6 @@ -from typing import Any, Dict, Iterable, List - -# from functools import lru_cache -from methodtools import lru_cache +from typing import Any, Dict, Iterable, Set +from swh.model.model import ObjectType, Revision, Sha1, TargetType from swh.storage.interface import StorageInterface @@ -10,38 +8,46 @@ def __init__(self, storage: StorageInterface): self.storage = storage - @lru_cache(maxsize=100000) - def directory_ls(self, id: bytes) -> List[Dict[str, Any]]: - # TODO: filter unused fields - return [entry for entry in self.storage.directory_ls(id)] - - def iter_origins(self): - from swh.storage.algos.origin import iter_origins - - yield from iter_origins(self.storage) - - def iter_origin_visits(self, origin: str): - from swh.storage.algos.origin import iter_origin_visits - + def directory_ls(self, id: Sha1) -> Iterable[Dict[str, Any]]: # TODO: filter unused fields - yield from iter_origin_visits(self.storage, origin) + yield from self.storage.directory_ls(id) - def iter_origin_visit_statuses(self, origin: str, visit: int): - from swh.storage.algos.origin import iter_origin_visit_statuses - - # TODO: filter unused fields - yield from iter_origin_visit_statuses(self.storage, origin, visit) - - def release_get(self, ids: Iterable[bytes]): + def revision_get(self, ids: Iterable[Sha1]) -> Iterable[Revision]: # TODO: filter unused fields - yield from self.storage.release_get(list(ids)) + yield from ( + rev for rev in self.storage.revision_get(list(ids)) if rev is not None + ) - def revision_get(self, ids: Iterable[bytes]): - # TODO: filter unused fields - yield from self.storage.revision_get(list(ids)) - - def snapshot_get_all_branches(self, snapshot: bytes): + def snapshot_get_heads(self, id: Sha1) -> Iterable[Sha1]: + from swh.core.utils import grouper from swh.storage.algos.snapshot import snapshot_get_all_branches - # TODO: filter unused fields - return snapshot_get_all_branches(self.storage, snapshot) + snapshot = snapshot_get_all_branches(self.storage, id) + assert snapshot is not None + + targets_set = set() + releases_set = set() + if snapshot is not None: + for branch in snapshot.branches: + if snapshot.branches[branch].target_type == TargetType.REVISION: + targets_set.add(snapshot.branches[branch].target) + elif snapshot.branches[branch].target_type == TargetType.RELEASE: + releases_set.add(snapshot.branches[branch].target) + + batchsize = 100 + for releases in grouper(releases_set, batchsize): + targets_set.update( + release.target + for release in self.storage.release_get(list(releases)) + if release is not None and release.target_type == ObjectType.REVISION + ) + + revisions: Set[Sha1] = set() + for targets in grouper(targets_set, batchsize): + revisions.update( + revision.id + for revision in self.storage.revision_get(list(targets)) + if revision is not None + ) + + yield from revisions