diff --git a/swh/provenance/backend.py b/swh/provenance/backend.py --- a/swh/provenance/backend.py +++ b/swh/provenance/backend.py @@ -8,7 +8,12 @@ from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry -from .provenance import ProvenanceResult, ProvenanceStorageInterface, RelationType +from .provenance import ( + ProvenanceResult, + ProvenanceStorageInterface, + RelationData, + RelationType, +) class DatetimeCache(TypedDict): @@ -71,7 +76,11 @@ # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. while not self.storage.relation_add( - RelationType.CNT_EARLY_IN_REV, self.cache["content_in_revision"] + RelationType.CNT_EARLY_IN_REV, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["content_in_revision"] + ), ): logging.warning( f"Unable to write {RelationType.CNT_EARLY_IN_REV} rows to the storage. " @@ -79,7 +88,11 @@ ) while not self.storage.relation_add( - RelationType.CNT_IN_DIR, self.cache["content_in_directory"] + RelationType.CNT_IN_DIR, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["content_in_directory"] + ), ): logging.warning( f"Unable to write {RelationType.CNT_IN_DIR} rows to the storage. " @@ -87,7 +100,11 @@ ) while not self.storage.relation_add( - RelationType.DIR_IN_REV, self.cache["directory_in_revision"] + RelationType.DIR_IN_REV, + ( + RelationData(src=src, dst=dst, path=path) + for src, dst, path in self.cache["directory_in_revision"] + ), ): logging.warning( f"Unable to write {RelationType.DIR_IN_REV} rows to the storage. " @@ -145,31 +162,34 @@ ) # Second, flat models for revisions' histories (ie. revision-before-revision). - rbr_data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = sum( + data: Iterable[RelationData] = sum( [ [ - (prev, next, None) + RelationData(src=prev, dst=next, path=None) for next in self.cache["revision_before_revision"][prev] ] for prev in self.cache["revision_before_revision"] ], [], ) - while not self.storage.relation_add(RelationType.REV_BEFORE_REV, rbr_data): + while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): logging.warning( f"Unable to write {RelationType.REV_BEFORE_REV} rows to the storage. " - f"Data: {rbr_data}. Retrying..." + f"Data: {data}. Retrying..." ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. - rio_data = [(rev, org, None) for rev, org in self.cache["revision_in_origin"]] - while not self.storage.relation_add(RelationType.REV_IN_ORG, rio_data): + data = ( + RelationData(src=rev, dst=org, path=None) + for rev, org in self.cache["revision_in_origin"] + ) + while not self.storage.relation_add(RelationType.REV_IN_ORG, data): logging.warning( f"Unable to write {RelationType.REV_IN_ORG} rows to the storage. " - f"Data: {rio_data}. Retrying..." + f"Data: {data}. Retrying..." ) # Finally, preferred origins for the visited revisions are set (this step can be @@ -254,9 +274,9 @@ if missing_ids: if entity == "revision": updated = { - id: date - for id, (date, _) in self.storage.revision_get(missing_ids).items() - if date is not None + id: rev.date + for id, rev in self.storage.revision_get(missing_ids).items() + if rev.date is not None } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) @@ -298,7 +318,7 @@ if revision.id not in cache: ret = self.storage.revision_get([revision.id]) if revision.id in ret: - origin = ret[revision.id][1] # TODO: make this not a tuple + origin = ret[revision.id].origin if origin is not None: cache[revision.id] = origin return cache.get(revision.id) diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -10,7 +10,13 @@ from swh.core.db import BaseDb from swh.model.model import Sha1Git -from ..provenance import ProvenanceResult, RelationType +from ..provenance import ( + EntityType, + ProvenanceResult, + RelationData, + RelationType, + RevisionData, +) class ProvenanceDBBase: @@ -60,6 +66,16 @@ def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + sql = f"SELECT sha1 FROM {entity.value}" + self.cursor.execute(sql) + return {row["sha1"] for row in self.cursor.fetchall()} + + def location_get(self) -> Set[bytes]: + sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location" + self.cursor.execute(sql) + return {row["path"] for row in self.cursor.fetchall()} + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: @@ -115,10 +131,8 @@ raise return False - def revision_get( - self, ids: Iterable[Sha1Git] - ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: - result: Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]] = {} + def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: + result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) @@ -129,25 +143,24 @@ """ self.cursor.execute(sql, sha1s) result.update( - (row["sha1"], (row["date"], row["origin"])) + (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in self.cursor.fetchall() ) return result def relation_add( - self, - relation: RelationType, - data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], + self, relation: RelationType, data: Iterable[RelationData] ) -> bool: try: - if data: + rows = tuple((rel.src, rel.dst, rel.path) for rel in data) + if rows: table = relation.value src, *_, dst = table.split("_") if src != "origin": # Origin entries should be inserted previously as they require extra # non-null information - srcs = tuple(set((sha1,) for (sha1, _, _) in data)) + srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) sql = f""" LOCK TABLE ONLY {src}; INSERT INTO {src}(sha1) VALUES %s @@ -157,7 +170,7 @@ if dst != "origin": # Origin entries should be inserted previously as they require extra # non-null information - dsts = tuple(set((sha1,) for (_, sha1, _) in data)) + dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) sql = f""" LOCK TABLE ONLY {dst}; INSERT INTO {dst}(sha1) VALUES %s @@ -171,7 +184,7 @@ selected = ["S.id", "D.id"] if self._relation_uses_location_table(relation): - locations = tuple(set((path,) for (_, _, path) in data)) + locations = tuple(set((path,) for (_, _, path) in rows)) sql = """ LOCK TABLE ONLY location; INSERT INTO location(path) VALUES %s @@ -190,7 +203,7 @@ '''.join(joins)}) ON CONFLICT DO NOTHING """ - psycopg2.extras.execute_values(self.cursor, sql, data) + psycopg2.extras.execute_values(self.cursor, sql, rows) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message @@ -201,45 +214,11 @@ def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: - result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() - sha1s = tuple(ids) - if sha1s: - table = relation.value - src, *_, dst = table.split("_") - - # TODO: improve this! - if src == "revision" and dst == "revision": - src_field = "prev" - dst_field = "next" - else: - src_field = src - dst_field = dst - - joins = [ - f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", - f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", - ] - selected = ["S.sha1 AS src", "D.sha1 AS dst"] - selector = "S.sha1" if not reverse else "D.sha1" - - if self._relation_uses_location_table(relation): - joins.append("INNER JOIN location AS L ON (L.id=R.location)") - selected.append("L.path AS path") - else: - selected.append("NULL AS path") + ) -> Set[RelationData]: + return self._relation_get(relation, ids, reverse) - sql = f""" - SELECT {", ".join(selected)} - FROM {table} AS R - {" ".join(joins)} - WHERE {selector} IN %s - """ - self.cursor.execute(sql, (sha1s,)) - result.update( - (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() - ) - return result + def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + return self._relation_get(relation, None) def _entity_get_date( self, @@ -281,5 +260,55 @@ raise return False + def _relation_get( + self, + relation: RelationType, + ids: Optional[Iterable[Sha1Git]], + reverse: bool = False, + ) -> Set[RelationData]: + result: Set[RelationData] = set() + + sha1s: Optional[Tuple[Tuple[Sha1Git, ...]]] + if ids is not None: + sha1s = (tuple(ids),) + where = f"WHERE {'S.sha1' if not reverse else 'D.sha1'} IN %s" + else: + sha1s = None + where = "" + + if sha1s is None or sha1s[0]: + table = relation.value + src, *_, dst = table.split("_") + + # TODO: improve this! + if src == "revision" and dst == "revision": + src_field = "prev" + dst_field = "next" + else: + src_field = src + dst_field = dst + + joins = [ + f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", + f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", + ] + selected = ["S.sha1 AS src", "D.sha1 AS dst"] + + if self._relation_uses_location_table(relation): + joins.append("INNER JOIN location AS L ON (L.id=R.location)") + selected.append("L.path AS path") + else: + selected.append("NULL AS path") + + sql = f""" + SELECT {", ".join(selected)} + FROM {table} AS R + {" ".join(joins)} + {where} + """ + self.cursor.execute(sql, sha1s) + result.update(RelationData(**row) for row in self.cursor.fetchall()) + return result + def _relation_uses_location_table(self, relation: RelationType) -> bool: ... diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,6 +1,7 @@ +from dataclasses import dataclass from datetime import datetime import enum -from typing import Dict, Generator, Iterable, Optional, Set, Tuple +from typing import Dict, Generator, Iterable, Optional, Set from typing_extensions import Protocol, runtime_checkable @@ -9,6 +10,13 @@ from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry +class EntityType(enum.Enum): + CONTENT = "content" + DIRECTORY = "directory" + REVISION = "revision" + ORIGIN = "origin" + + class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" @@ -17,20 +25,37 @@ REV_BEFORE_REV = "revision_before_revision" +@dataclass(eq=True, frozen=True) class ProvenanceResult: - def __init__( - self, - content: Sha1Git, - revision: Sha1Git, - date: datetime, - origin: Optional[str], - path: bytes, - ) -> None: - self.content = content - self.revision = revision - self.date = date - self.origin = origin - self.path = path + content: Sha1Git + revision: Sha1Git + date: datetime + origin: Optional[str] + path: bytes + + +@dataclass(eq=True, frozen=True) +class RevisionData: + """Object representing the data associated to a revision in the provenance model, + where `date` is the optional date of the revision (specifying it acknowledges that + the revision was already processed by the revision-content algorithm); and `origin` + identifies the preferred origin for the revision, if any. + """ + + date: Optional[datetime] + origin: Optional[Sha1Git] + + +@dataclass(eq=True, frozen=True) +class RelationData: + """Object representing a relation entry in the provenance model, where `src` and + `dst` are the sha1 ids of the entities being related, and `path` is optional + depending on the relation being represented. + """ + + src: Sha1Git + dst: Sha1Git + path: Optional[bytes] @runtime_checkable @@ -72,6 +97,16 @@ """ ... + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + """Retrieve all sha1 ids for entities of type `entity` present in the provenance + model. + """ + ... + + def location_get(self) -> Set[bytes]: + """Retrieve all paths present in the provenance model.""" + ... + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return a boolean stating whether the information was successfully stored. @@ -97,9 +132,7 @@ """ ... - def revision_get( - self, ids: Iterable[Sha1Git] - ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: + def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. @@ -107,25 +140,26 @@ ... def relation_add( - self, - relation: RelationType, - data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], + self, relation: RelationType, data: Iterable[RelationData] ) -> bool: - """Add entries in the selected `relation`. Each tuple in `data` is of the from - (`src`, `dst`, `path`), where `src` and `dst` are the sha1 ids of the entities - being related, and `path` is optional depending on the selected `relation`. - """ + """Add entries in the selected `relation`.""" ... def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: - """Retrieve all tuples in the selected `relation` whose source entities are + ) -> Set[RelationData]: + """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... + def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + """Retrieve all entries in the selected `relation` that are present in the + provenance model. + """ + ... + @runtime_checkable class ProvenanceInterface(Protocol): diff --git a/swh/provenance/tests/test_provenance_heuristics.py b/swh/provenance/tests/test_provenance_heuristics.py --- a/swh/provenance/tests/test_provenance_heuristics.py +++ b/swh/provenance/tests/test_provenance_heuristics.py @@ -3,18 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime from typing import Any, Dict, List, Optional, Set, Tuple -import psycopg2 import pytest from swh.model.hashutil import hash_to_bytes -from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.model import RevisionEntry from swh.provenance.postgresql.provenancedb_base import ProvenanceDBBase -from swh.provenance.provenance import ProvenanceInterface +from swh.provenance.provenance import EntityType, ProvenanceInterface, RelationType from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import ( fill_storage, @@ -26,77 +23,6 @@ from swh.storage.postgresql.storage import Storage -def sha1s(cur: psycopg2.extensions.cursor, table: str) -> Set[Sha1Git]: - """return the 'sha1' column from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT sha1 FROM {table}") - return set(row["sha1"].hex() for row in cur.fetchall()) - - -def locations(cur: psycopg2.extensions.cursor) -> Set[bytes]: - """return the 'path' column from the DB location table - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute("SELECT encode(location.path::bytea, 'escape') AS path FROM location") - return set(row["path"] for row in cur.fetchall()) - - -def relations( - cur: psycopg2.extensions.cursor, src: str, dst: str -) -> Set[Tuple[Sha1Git, Sha1Git, bytes]]: - """return the triplets ('sha1', 'sha1', 'path') from the DB - - for the relation between 'src' table and 'dst' table - (i.e. for C-R, C-D and D-R relations). - - 'cur' is a cursor to the provenance index DB. - """ - relation = f"{src}_in_{dst}" - cur.execute("SELECT swh_get_dbflavor() AS flavor") - with_path = cur.fetchone()["flavor"] == "with-path" - - # note that the columns have the same name as the relations they refer to, - # so we can write things like "rel.{dst}=src.id" in the query below - if with_path: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - encode(location.path::bytea, 'escape') AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (relation.{src} = src.id) - INNER JOIN {dst} AS dst ON (relation.{dst} = dst.id) - INNER JOIN location ON (relation.location = location.id) - """ - ) - else: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - '' AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (src.id = relation.{src}) - INNER JOIN {dst} AS dst ON (dst.id = relation.{dst}) - """ - ) - return set((row["src"], row["dst"], row["path"]) for row in cur.fetchall()) - - -def get_timestamp( - cur: psycopg2.extensions.cursor, table: str, sha1: Sha1Git -) -> List[datetime]: - """return the date for the 'sha1' from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT date FROM {table} WHERE sha1=%s", (sha1,)) - return [row["date"].timestamp() for row in cur.fetchall()] - - @pytest.mark.parametrize( "repo, lower, mindepth", ( @@ -133,14 +59,12 @@ "location": set(), "revision": set(), } - assert isinstance(provenance.storage, ProvenanceDBBase) - cursor = provenance.storage.cursor - def maybe_path(path: str) -> str: + def maybe_path(path: str) -> Optional[bytes]: assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: - return path - return "" + return path.encode("utf-8") + return None for synth_rev in synthetic_result(syntheticfile): revision = revisions[synth_rev["sha1"]] @@ -152,77 +76,94 @@ revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth) # each "entry" in the synth file is one new revision - rows["revision"].add(synth_rev["sha1"].hex()) - assert rows["revision"] == sha1s(cursor, "revision"), synth_rev["msg"] + rows["revision"].add(synth_rev["sha1"]) + assert rows["revision"] == provenance.storage.entity_get_all( + EntityType.REVISION + ), synth_rev["msg"] # check the timestamp of the revision rev_ts = synth_rev["date"] - assert get_timestamp(cursor, "revision", synth_rev["sha1"]) == [ - rev_ts - ], synth_rev["msg"] + rev_data = provenance.storage.revision_get([synth_rev["sha1"]])[ + synth_rev["sha1"] + ] + assert ( + rev_data.date is not None and rev_ts == rev_data.date.timestamp() + ), synth_rev["msg"] # this revision might have added new content objects - rows["content"] |= set(x["dst"].hex() for x in synth_rev["R_C"]) - rows["content"] |= set(x["dst"].hex() for x in synth_rev["D_C"]) - assert rows["content"] == sha1s(cursor, "content"), synth_rev["msg"] + rows["content"] |= set(x["dst"] for x in synth_rev["R_C"]) + rows["content"] |= set(x["dst"] for x in synth_rev["D_C"]) + assert rows["content"] == provenance.storage.entity_get_all( + EntityType.CONTENT + ), synth_rev["msg"] # check for R-C (direct) entries # these are added directly in the content_early_in_rev table rows["content_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) - assert rows["content_in_revision"] == relations( - cursor, "content", "revision" - ), synth_rev["msg"] + assert rows["content_in_revision"] == { + (rel.src, rel.dst, rel.path) + for rel in provenance.storage.relation_get_all( + RelationType.CNT_EARLY_IN_REV + ) + }, synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: - assert get_timestamp(cursor, "content", rc["dst"]) == [ + assert ( rev_ts + rc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp() + ), synth_rev["msg"] # check directories # each directory stored in the provenance index is an entry # in the "directory" table... - rows["directory"] |= set(x["dst"].hex() for x in synth_rev["R_D"]) - assert rows["directory"] == sha1s(cursor, "directory"), synth_rev["msg"] + rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"]) + assert rows["directory"] == provenance.storage.entity_get_all( + EntityType.DIRECTORY + ), synth_rev["msg"] # ... + a number of rows in the "directory_in_rev" table... # check for R-D entries rows["directory_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_D"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) - assert rows["directory_in_revision"] == relations( - cursor, "directory", "revision" - ), synth_rev["msg"] + assert rows["directory_in_revision"] == { + (rel.src, rel.dst, rel.path) + for rel in provenance.storage.relation_get_all(RelationType.DIR_IN_REV) + }, synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: - assert get_timestamp(cursor, "directory", rd["dst"]) == [ + assert ( rev_ts + rd["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp() + ), synth_rev["msg"] # ... + a number of rows in the "content_in_dir" table # for content of the directory. # check for D-C entries rows["content_in_directory"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["D_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) - assert rows["content_in_directory"] == relations( - cursor, "content", "directory" - ), synth_rev["msg"] + assert rows["content_in_directory"] == { + (rel.src, rel.dst, rel.path) + for rel in provenance.storage.relation_get_all(RelationType.CNT_IN_DIR) + }, synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: - assert get_timestamp(cursor, "content", dc["dst"]) == [ + assert ( rev_ts + dc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp() + ), synth_rev["msg"] + assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: # check for location entries rows["location"] |= set(x["path"] for x in synth_rev["R_C"]) rows["location"] |= set(x["path"] for x in synth_rev["D_C"]) rows["location"] |= set(x["path"] for x in synth_rev["R_D"]) - assert rows["location"] == locations(cursor), synth_rev["msg"] + assert rows["location"] == provenance.storage.location_get(), synth_rev[ + "msg" + ] @pytest.mark.parametrize( @@ -235,6 +176,7 @@ ("out-of-order", True, 1), ), ) +@pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_all( provenance: ProvenanceInterface, swh_storage: Storage, @@ -242,6 +184,7 @@ repo: str, lower: bool, mindepth: int, + batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) @@ -261,11 +204,13 @@ return path return "" - # XXX adding all revisions at once should be working just fine, but it does not... - # revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) - # ...so add revisions one at a time for now - for revision in revisions: - revision_add(provenance, archive, [revision], lower=lower, mindepth=mindepth) + if batch: + revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) + else: + for revision in revisions: + revision_add( + provenance, archive, [revision], lower=lower, mindepth=mindepth + ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" @@ -316,6 +261,7 @@ ("out-of-order", True, 1), ), ) +@pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_first( provenance: ProvenanceInterface, swh_storage: Storage, @@ -323,6 +269,7 @@ repo: str, lower: bool, mindepth: int, + batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) @@ -336,11 +283,13 @@ for revision in data["revision"] ] - # XXX adding all revisions at once should be working just fine, but it does not... - # revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) - # ...so add revisions one at a time for now - for revision in revisions: - revision_add(provenance, archive, [revision], lower=lower, mindepth=mindepth) + if batch: + revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) + else: + for revision in revisions: + revision_add( + provenance, archive, [revision], lower=lower, mindepth=mindepth + ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt"