diff --git a/swh/storage/cassandra/cql.py b/swh/storage/cassandra/cql.py --- a/swh/storage/cassandra/cql.py +++ b/swh/storage/cassandra/cql.py @@ -625,6 +625,17 @@ self._execute_with_retries(statement, [directory_ids]), ) + @_prepared_select_statement( + DirectoryEntryRow, "WHERE directory_id = ? AND name >= ? LIMIT ?" + ) + def directory_entry_get_from_name( + self, directory_id: Sha1Git, from_: bytes, limit: int, *, statement + ) -> Iterable[DirectoryEntryRow]: + return map( + DirectoryEntryRow.from_dict, + self._execute_with_retries(statement, [directory_id, from_, limit]), + ) + ########################## # 'snapshot' table ########################## diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -518,6 +518,30 @@ ) -> Iterable[Dict[str, Any]]: yield from self._directory_ls(directory, recursive) + def directory_get_entries( + self, + directory_id: Sha1Git, + page_token: Optional[bytes] = None, + limit: int = 1000, + ) -> Optional[PagedResult[DirectoryEntry]]: + if self.directory_missing([directory_id]): + return None + + entries_from: bytes = page_token or b"" + rows = self._cql_runner.directory_entry_get_from_name( + directory_id, entries_from, limit + 1 + ) + entries = [ + DirectoryEntry.from_dict(remove_keys(row.to_dict(), ("directory_id",))) + for row in rows + ] + if len(entries) > limit: + last_entry = entries.pop() + next_page_token = last_entry.name + else: + next_page_token = None + return PagedResult(results=entries, next_page_token=next_page_token) + def directory_get_random(self) -> Sha1Git: directory = self._cql_runner.directory_get_random() assert directory, "Could not find any directory" diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -6,6 +6,7 @@ from collections import defaultdict import datetime import functools +import itertools import random from typing import ( Any, @@ -301,6 +302,16 @@ for id_ in directory_ids: yield from self._directory_entries.get_from_partition_key((id_,)) + def directory_entry_get_from_name( + self, directory_id: Sha1Git, from_: bytes, limit: int + ) -> Iterable[DirectoryEntryRow]: + # Get all entries + entries = self._directory_entries.get_from_partition_key((directory_id,)) + # Filter out the ones before from_ + entries = itertools.dropwhile(lambda entry: entry.name < from_, entries) + # Apply limit + return itertools.islice(entries, limit) + ########################## # 'revision' table ########################## diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -15,6 +15,7 @@ from swh.model.model import ( Content, Directory, + DirectoryEntry, ExtID, MetadataAuthority, MetadataAuthorityType, @@ -426,6 +427,31 @@ """ ... + @remote_api_endpoint("directory/get_entries") + def directory_get_entries( + self, + directory_id: Sha1Git, + page_token: Optional[bytes] = None, + limit: int = 1000, + ) -> Optional[PagedResult[DirectoryEntry]]: + """Get the content, possibly partial, of a directory with the given id + + The entries of the directory are not guaranteed to be returned in any + particular order. + + The number of results is not guaranteed to be lower than the ``limit``. + + Args: + directory_id: dentifier of the directory + page_token: opaque string used to get the next results of a search + limit: Number of entries to return + + Returns: + None if the directory does not exist; a page of DirectoryEntry + objects otherwise. + """ + ... + @remote_api_endpoint("directory/get_random") def directory_get_random(self) -> Sha1Git: """Finds a random directory id. diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -14,7 +14,7 @@ from swh.core.db.db_utils import jsonize as _jsonize from swh.core.db.db_utils import stored_procedure from swh.model.identifiers import ObjectType -from swh.model.model import SHA1_SIZE, OriginVisit, OriginVisitStatus +from swh.model.model import SHA1_SIZE, OriginVisit, OriginVisitStatus, Sha1Git from swh.storage.interface import ListOrder logger = logging.getLogger(__name__) @@ -403,6 +403,15 @@ return None return data + directory_get_entries_cols = ["type", "target", "name", "perms"] + + def directory_get_entries(self, directory: Sha1Git, cur=None) -> List[Tuple]: + cur = self._cursor(cur) + cur.execute( + "SELECT * FROM swh_directory_get_entries(%s::sha1_git)", (directory,) + ) + return list(cur) + def directory_get_random(self, cur=None): return self._get_random_row_from_table("directory", ["id"], "id", cur) diff --git a/swh/storage/postgresql/storage.py b/swh/storage/postgresql/storage.py --- a/swh/storage/postgresql/storage.py +++ b/swh/storage/postgresql/storage.py @@ -24,6 +24,7 @@ SHA1_SIZE, Content, Directory, + DirectoryEntry, ExtID, MetadataAuthority, MetadataAuthorityType, @@ -549,6 +550,31 @@ def directory_get_random(self, db=None, cur=None) -> Sha1Git: return db.directory_get_random(cur) + @db_transaction() + def directory_get_entries( + self, + directory_id: Sha1Git, + page_token: Optional[bytes] = None, + limit: int = 1000, + db=None, + cur=None, + ) -> Optional[PagedResult[DirectoryEntry]]: + if list(self.directory_missing([directory_id], db=db, cur=cur)): + return None + + if page_token is not None: + raise StorageArgumentException("Unsupported page token") + + # TODO: actually paginate + rows = db.directory_get_entries(directory_id, cur=cur) + return PagedResult( + results=[ + DirectoryEntry(**dict(zip(db.directory_get_entries_cols, row))) + for row in rows + ], + next_page_token=None, + ) + @timed @process_metrics @db_transaction() diff --git a/swh/storage/sql/40-funcs.sql b/swh/storage/sql/40-funcs.sql --- a/swh/storage/sql/40-funcs.sql +++ b/swh/storage/sql/40-funcs.sql @@ -414,6 +414,34 @@ end $$; +-- Returns the entries in a directory, without joining with their target tables +create or replace function swh_directory_get_entries(dir_id sha1_git) + returns table ( + dir_id directory_entry_type, target sha1_git, name unix_path, perms file_perms + ) + language sql + stable +as $$ + with dir as ( + select id as dir_id, dir_entries, file_entries, rev_entries + from directory + where id = dir_id), + ls_d as (select dir_id, unnest(dir_entries) as entry_id from dir), + ls_f as (select dir_id, unnest(file_entries) as entry_id from dir), + ls_r as (select dir_id, unnest(rev_entries) as entry_id from dir) + (select 'dir'::directory_entry_type, e.target, e.name, e.perms + from ls_d + left join directory_entry_dir e on ls_d.entry_id = e.id) + union + (select 'file'::directory_entry_type, e.target, e.name, e.perms + from ls_f + left join directory_entry_file e on ls_f.entry_id = e.id) + union + (select 'rev'::directory_entry_type, e.target, e.name, e.perms + from ls_r + left join directory_entry_rev e on ls_r.entry_id = e.id) +$$; + -- List all revision IDs starting from a given revision, going back in time -- -- TODO ordering: should be breadth-first right now (what do we want?) diff --git a/swh/storage/tests/storage_tests.py b/swh/storage/tests/storage_tests.py --- a/swh/storage/tests/storage_tests.py +++ b/swh/storage/tests/storage_tests.py @@ -17,6 +17,7 @@ from hypothesis import HealthCheck, given, settings, strategies import pytest +from swh.core.api.classes import stream_results from swh.model import from_disk from swh.model.hashutil import hash_to_bytes from swh.model.hypothesis_strategies import objects @@ -909,6 +910,37 @@ ) assert actual_entry is None + def test_directory_get_entries_pagination(self, swh_storage, sample_data): + # Note: this test assumes entries are returned in lexicographic order, + # which is not actually guaranteed by the interface. + dir_ = sample_data.directory3 + entries = sorted(dir_.entries, key=lambda entry: entry.name) + swh_storage.directory_add(sample_data.directories) + + # No pagination needed + actual_data = swh_storage.directory_get_entries(dir_.id) + assert actual_data == PagedResult(results=entries, next_page_token=None) + + # A little pagination + actual_data = swh_storage.directory_get_entries(dir_.id, limit=2) + assert actual_data.results == entries[0:2] + assert actual_data.next_page_token is not None + + actual_data = swh_storage.directory_get_entries( + dir_.id, page_token=actual_data.next_page_token + ) + assert actual_data == PagedResult(results=entries[2:], next_page_token=None) + + @pytest.mark.parametrize("limit", [1, 2, 3, 4, 5]) + def test_directory_get_entries(self, swh_storage, sample_data, limit): + dir_ = sample_data.directory3 + swh_storage.directory_add(sample_data.directories) + + actual_data = list( + stream_results(swh_storage.directory_get_entries, dir_.id, limit=limit,) + ) + assert sorted(actual_data) == sorted(dir_.entries) + def test_directory_get_random(self, swh_storage, sample_data): dir1, dir2, dir3 = sample_data.directories[:3] swh_storage.directory_add([dir1, dir2, dir3]) diff --git a/swh/storage/tests/test_cassandra.py b/swh/storage/tests/test_cassandra.py --- a/swh/storage/tests/test_cassandra.py +++ b/swh/storage/tests/test_cassandra.py @@ -448,6 +448,7 @@ # be considered not written. assert swh_storage.directory_missing([directory.id]) == [directory.id] assert list(swh_storage.directory_ls(directory.id)) == [] + assert swh_storage.directory_get_entries(directory.id) is None def test_snapshot_add_atomic(self, swh_storage, sample_data, mocker): """Checks that a crash occurring after some snapshot branches were written diff --git a/swh/storage/tests/test_postgresql.py b/swh/storage/tests/test_postgresql.py --- a/swh/storage/tests/test_postgresql.py +++ b/swh/storage/tests/test_postgresql.py @@ -12,7 +12,7 @@ import pytest from swh.storage.postgresql.db import Db -from swh.storage.tests.storage_tests import TestStorage # noqa +from swh.storage.tests.storage_tests import TestStorage as _TestStorage from swh.storage.tests.storage_tests import TestStorageGeneratedData # noqa from swh.storage.utils import now @@ -24,6 +24,14 @@ yield db, cur +class TestStorage(_TestStorage): + @pytest.mark.skip( + "Directory pagination is not implemented in the postgresql backend yet." + ) + def test_directory_get_entries_pagination(self): + pass + + @pytest.mark.db class TestLocalStorage: """Test the local storage"""