diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index c688eeb..d962cd7 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,518 +1,529 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import hashlib import logging import os from types import TracebackType from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type from typing_extensions import Literal, TypedDict from swh.core.statsd import statsd from swh.model.model import Sha1Git from .interface import ProvenanceInterface from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry from .storage.interface import ( DirectoryData, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from .util import path_normalize LOGGER = logging.getLogger(__name__) BACKEND_DURATION_METRIC = "swh_provenance_backend_duration_seconds" BACKEND_OPERATIONS_METRIC = "swh_provenance_backend_operations_total" class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] # None means unknown added: Set[Sha1Git] class OriginCache(TypedDict): data: Dict[Sha1Git, str] added: Set[Sha1Git] class RevisionCache(TypedDict): data: Dict[Sha1Git, Sha1Git] added: Set[Sha1Git] class ProvenanceCache(TypedDict): content: DatetimeCache directory: DatetimeCache directory_flatten: Dict[Sha1Git, Optional[bool]] # None means unknown revision: DatetimeCache # below are insertion caches only - content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] + content_in_revision: Set[Tuple[Sha1Git, Sha1Git, datetime, bytes]] content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]] directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] # these two are for the origin layer origin: OriginCache revision_origin: RevisionCache revision_before_revision: Dict[Sha1Git, Set[Sha1Git]] revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]] def new_cache() -> ProvenanceCache: return ProvenanceCache( content=DatetimeCache(data={}, added=set()), directory=DatetimeCache(data={}, added=set()), directory_flatten={}, revision=DatetimeCache(data={}, added=set()), content_in_revision=set(), content_in_directory=set(), directory_in_revision=set(), origin=OriginCache(data={}, added=set()), revision_origin=RevisionCache(data={}, added=set()), revision_before_revision={}, revision_in_origin=set(), ) class Provenance: MAX_CACHE_ELEMENTS = 40000 def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() def __enter__(self) -> ProvenanceInterface: self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() def _flush_limit_reached(self) -> bool: return sum(self._get_cache_stats().values()) > self.MAX_CACHE_ELEMENTS def _get_cache_stats(self) -> Dict[str, int]: return { k: len(v["data"]) if (isinstance(v, dict) and v.get("data") is not None) else len(v) # type: ignore for (k, v) in self.cache.items() } def clear_caches(self) -> None: self.cache = new_cache() def close(self) -> None: self.storage.close() @statsd.timed(metric=BACKEND_DURATION_METRIC, tags={"method": "flush"}) def flush(self) -> None: self.flush_revision_content_layer() self.flush_origin_revision_layer() self.clear_caches() def flush_if_necessary(self) -> bool: """Flush if the number of cached information reached a limit.""" LOGGER.debug("Cache stats: %s", self._get_cache_stats()) if self._flush_limit_reached(): self.flush() return True else: return False @statsd.timed( metric=BACKEND_DURATION_METRIC, tags={"method": "flush_origin_revision"} ) def flush_origin_revision_layer(self) -> None: # Origins and revisions should be inserted first so that internal ids' # resolution works properly. urls = { sha1: url for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } if urls: while not self.storage.origin_add(urls): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_origin_revision_retry_origin"}, ) LOGGER.warning( "Unable to write origins urls to the storage. Retrying..." ) rev_orgs = { # Destinations in this relation should match origins in the next one **{ src: RevisionData(date=None, origin=None) for src in self.cache["revision_before_revision"] }, **{ # This relation comes second so that non-None origins take precedence src: RevisionData(date=None, origin=org) for src, org in self.cache["revision_in_origin"] }, } if rev_orgs: while not self.storage.revision_add(rev_orgs): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_origin_revision_retry_revision"}, ) LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) # Second, flat models for revisions' histories (ie. revision-before-revision). if self.cache["revision_before_revision"]: rev_before_rev = { src: {RelationData(dst=dst, path=None) for dst in dsts} for src, dsts in self.cache["revision_before_revision"].items() } while not self.storage.relation_add( RelationType.REV_BEFORE_REV, rev_before_rev ): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={ "method": "flush_origin_revision_retry_revision_before_revision" }, ) LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_BEFORE_REV, ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. if self.cache["revision_in_origin"]: rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} for src, dst in self.cache["revision_in_origin"]: rev_in_org.setdefault(src, set()).add(RelationData(dst=dst, path=None)) while not self.storage.relation_add(RelationType.REV_IN_ORG, rev_in_org): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_origin_revision_retry_revision_in_origin"}, ) LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_IN_ORG, ) @statsd.timed( metric=BACKEND_DURATION_METRIC, tags={"method": "flush_revision_content"} ) def flush_revision_content_layer(self) -> None: # Register in the storage all entities, to ensure the coming relations can # properly resolve any internal reference if needed. Content and directory # entries may safely be registered with their associated dates. In contrast, # revision entries should be registered without date, as it is used to # acknowledge that the flushing was successful. Also, directories are # registered with their flatten flag not set. cnt_dates = { sha1: date for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } if cnt_dates: while not self.storage.content_add(cnt_dates): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_content_date"}, ) LOGGER.warning( "Unable to write content dates to the storage. Retrying..." ) dir_dates = { sha1: DirectoryData(date=date, flat=False) for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } if dir_dates: while not self.storage.directory_add(dir_dates): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_directory_date"}, ) LOGGER.warning( "Unable to write directory dates to the storage. Retrying..." ) revs = { sha1: RevisionData(date=None, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } if revs: while not self.storage.revision_add(revs): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_revision_none"}, ) LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) paths = { hashlib.sha1(path).digest(): path - for _, _, path in self.cache["content_in_revision"] - | self.cache["content_in_directory"] - | self.cache["directory_in_revision"] + for cache_table in ( + "content_in_revision", + "content_in_directory", + "directory_in_revision", + ) + for *_, path in self.cache[cache_table] # type: ignore } if paths: while not self.storage.location_add(paths): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_location"}, ) LOGGER.warning( "Unable to write locations entities to the storage. Retrying..." ) # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} - for src, dst, path in self.cache["content_in_revision"]: - cnt_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + for src, dst, dst_date, path in self.cache["content_in_revision"]: + cnt_in_rev.setdefault(src, set()).add( + RelationData(dst=dst, dst_date=dst_date, path=path) + ) while not self.storage.relation_add( RelationType.CNT_EARLY_IN_REV, cnt_in_rev ): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_content_in_revision"}, ) LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_EARLY_IN_REV, ) if self.cache["content_in_directory"]: cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for src, dst, path in self.cache["content_in_directory"]: cnt_in_dir.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add(RelationType.CNT_IN_DIR, cnt_in_dir): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={ "method": "flush_revision_content_retry_content_in_directory" }, ) LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_IN_DIR, ) if self.cache["directory_in_revision"]: dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for src, dst, path in self.cache["directory_in_revision"]: dir_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add(RelationType.DIR_IN_REV, dir_in_rev): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={ "method": "flush_revision_content_retry_directory_in_revision" }, ) LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.DIR_IN_REV, ) # After relations, flatten flags for directories can be safely set (if # applicable) acknowledging those directories that have already be flattened. # Similarly, dates for the revisions are set to acknowledge that these revisions # won't need to be reprocessed in case of failure. dir_acks = { sha1: DirectoryData( date=date, flat=self.cache["directory_flatten"].get(sha1) or False ) for sha1, date in self.cache["directory"]["data"].items() if self.cache["directory_flatten"].get(sha1) and date is not None } if dir_acks: while not self.storage.directory_add(dir_acks): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_directory_ack"}, ) LOGGER.warning( "Unable to write directory dates to the storage. Retrying..." ) rev_dates = { sha1: RevisionData(date=date, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } if rev_dates: while not self.storage.revision_add(rev_dates): statsd.increment( metric=BACKEND_OPERATIONS_METRIC, tags={"method": "flush_revision_content_retry_revision_date"}, ) LOGGER.warning( "Unable to write revision dates to the storage. Retrying..." ) def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_directory"].add( (blob.id, directory.id, path_normalize(os.path.join(prefix, blob.name))) ) def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: + assert revision.date is not None self.cache["content_in_revision"].add( - (blob.id, revision.id, path_normalize(os.path.join(prefix, blob.name))) + ( + blob.id, + revision.id, + revision.date, + path_normalize(os.path.join(prefix, blob.name)), + ) ) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: yield from self.storage.content_find_all(id, limit=limit) def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: return self.get_dates("content", [blob.id]).get(blob.id) def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("content", [blob.id for blob in blobs]) def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: self.cache["content"]["data"][blob.id] = date self.cache["content"]["added"].add(blob.id) def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: self.cache["directory_in_revision"].add( (directory.id, revision.id, path_normalize(path)) ) def directory_already_flattened(self, directory: DirectoryEntry) -> Optional[bool]: cache = self.cache["directory_flatten"] if directory.id not in cache: cache.setdefault(directory.id, None) ret = self.storage.directory_get([directory.id]) if directory.id in ret: dir = ret[directory.id] cache[directory.id] = dir.flat # date is kept to ensure we have it available when flushing self.cache["directory"]["data"][directory.id] = dir.date return cache.get(directory.id) def directory_flag_as_flattened(self, directory: DirectoryEntry) -> None: self.cache["directory_flatten"][directory.id] = True def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: return self.get_dates("directory", [directory.id]).get(directory.id) def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("directory", [directory.id for directory in dirs]) def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: self.cache["directory"]["data"][directory.id] = date self.cache["directory"]["added"].add(directory.id) def get_dates( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: cache = self.cache[entity] missing_ids = set(id for id in ids if id not in cache) if missing_ids: if entity == "content": cache["data"].update(self.storage.content_get(missing_ids)) elif entity == "directory": cache["data"].update( { id: dir.date for id, dir in self.storage.directory_get(missing_ids).items() } ) elif entity == "revision": cache["data"].update( { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() } ) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates def open(self) -> None: self.storage.open() def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) def revision_add(self, revision: RevisionEntry) -> None: self.cache["revision"]["data"][revision.id] = revision.date self.cache["revision"]["added"].add(revision.id) def revision_add_before_revision( self, head_id: Sha1Git, revision_id: Sha1Git ) -> None: self.cache["revision_before_revision"].setdefault(revision_id, set()).add( head_id ) def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_in_origin"].add((revision.id, origin.id)) def revision_is_head(self, revision: RevisionEntry) -> bool: return bool(self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id])) def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: return self.get_dates("revision", [revision.id]).get(revision.id) def revision_get_preferred_origin(self, revision_id: Sha1Git) -> Optional[Sha1Git]: cache = self.cache["revision_origin"]["data"] if revision_id not in cache: ret = self.storage.revision_get([revision_id]) if revision_id in ret: origin = ret[revision_id].origin if origin is not None: cache[revision_id] = origin return cache.get(revision_id) def revision_set_preferred_origin( self, origin: OriginEntry, revision_id: Sha1Git ) -> None: self.cache["revision_origin"]["data"][revision_id] = origin.id self.cache["revision_origin"]["added"].add(revision_id) diff --git a/swh/provenance/sql/30-schema.sql b/swh/provenance/sql/30-schema.sql index 949b6f6..8879bae 100644 --- a/swh/provenance/sql/30-schema.sql +++ b/swh/provenance/sql/30-schema.sql @@ -1,126 +1,128 @@ -- a Git object ID, i.e., a Git-style salted SHA1 checksum create domain sha1_git as bytea check (length(value) = 20); -- UNIX path (absolute, relative, individual path component, etc.) create domain unix_path as bytea; -- relation filter options for querying create type rel_flt as enum ( 'filter-src', 'filter-dst', 'no-filter' ); comment on type rel_flt is 'Relation get filter types'; -- entity tables create table content ( id bigserial primary key, -- internal identifier of the content blob sha1 sha1_git unique not null, -- intrinsic identifier of the content blob date timestamptz -- timestamp of the revision where the blob appears early ); comment on column content.id is 'Content internal identifier'; comment on column content.sha1 is 'Content intrinsic identifier'; comment on column content.date is 'Earliest timestamp for the content (first seen time)'; create table directory ( id bigserial primary key, -- internal identifier of the directory appearing in an isochrone inner frontier sha1 sha1_git unique not null, -- intrinsic identifier of the directory date timestamptz, -- max timestamp among those of the directory children's flat boolean not null default false -- flag acknowledging if the directory is flattenned in the model ); comment on column directory.id is 'Directory internal identifier'; comment on column directory.sha1 is 'Directory intrinsic identifier'; comment on column directory.date is 'Latest timestamp for the content in the directory'; create table revision ( id bigserial primary key, -- internal identifier of the revision sha1 sha1_git unique not null, -- intrinsic identifier of the revision date timestamptz, -- timestamp of the revision origin bigint -- id of the preferred origin -- foreign key (origin) references origin (id) ); comment on column revision.id is 'Revision internal identifier'; comment on column revision.sha1 is 'Revision intrinsic identifier'; comment on column revision.date is 'Revision timestamp'; comment on column revision.origin is 'preferred origin for the revision'; create table location ( id bigserial primary key, -- internal identifier of the location path unix_path -- path to the location ); comment on column location.id is 'Location internal identifier'; comment on column location.path is 'Path to the location'; create table origin ( id bigserial primary key, -- internal identifier of the origin sha1 sha1_git unique not null, -- intrinsic identifier of the origin url text -- url of the origin ); comment on column origin.id is 'Origin internal identifier'; comment on column origin.sha1 is 'Origin intrinsic identifier'; comment on column origin.url is 'URL of the origin'; -- relation tables create table content_in_revision ( content bigint not null, -- internal identifier of the content blob revision bigint not null, -- internal identifier of the revision where the blob appears for the first time - location bigint -- location of the content relative to the revision's root directory + location bigint, -- location of the content relative to the revision's root directory + revision_date timestamptz not null -- date of the revision where the blob appears for the first time -- foreign key (content) references content (id), -- foreign key (revision) references revision (id), -- foreign key (location) references location (id) ); comment on column content_in_revision.content is 'Content internal identifier'; comment on column content_in_revision.revision is 'Revision internal identifier'; comment on column content_in_revision.location is 'Location of content in revision'; +comment on column content_in_revision.revision_date is 'Date of the revision'; create table content_in_directory ( content bigint not null, -- internal identifier of the content blob directory bigint not null, -- internal identifier of the directory containing the blob location bigint -- location of the content relative to its parent directory in the isochrone frontier -- foreign key (content) references content (id), -- foreign key (directory) references directory (id), -- foreign key (location) references location (id) ); comment on column content_in_directory.content is 'Content internal identifier'; comment on column content_in_directory.directory is 'Directory internal identifier'; comment on column content_in_directory.location is 'Location of content in directory'; create table directory_in_revision ( directory bigint not null, -- internal identifier of the directory appearing in the revision revision bigint not null, -- internal identifier of the revision containing the directory location bigint -- location of the directory relative to the revision's root directory -- foreign key (directory) references directory (id), -- foreign key (revision) references revision (id), -- foreign key (location) references location (id) ); comment on column directory_in_revision.directory is 'Directory internal identifier'; comment on column directory_in_revision.revision is 'Revision internal identifier'; comment on column directory_in_revision.location is 'Location of content in revision'; create table revision_in_origin ( revision bigint not null, -- internal identifier of the revision poined by the origin origin bigint not null -- internal identifier of the origin that points to the revision -- foreign key (revision) references revision (id), -- foreign key (origin) references origin (id) ); comment on column revision_in_origin.revision is 'Revision internal identifier'; comment on column revision_in_origin.origin is 'Origin internal identifier'; create table revision_before_revision ( prev bigserial not null, -- internal identifier of the source revision next bigserial not null -- internal identifier of the destination revision -- foreign key (prev) references revision (id), -- foreign key (next) references revision (id) ); comment on column revision_before_revision.prev is 'Source revision internal identifier'; comment on column revision_before_revision.next is 'Destination revision internal identifier'; diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql index 30bd997..e85ffb6 100644 --- a/swh/provenance/sql/40-funcs.sql +++ b/swh/provenance/sql/40-funcs.sql @@ -1,191 +1,198 @@ create or replace function swh_mktemp_relation_add() returns void language sql as $$ create temp table tmp_relation_add ( src sha1_git not null, dst sha1_git not null, - path unix_path + path unix_path, + dst_date timestamptz ) on commit drop $$; create or replace function swh_provenance_content_find_first(content_id sha1_git) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ select C.sha1 as content, R.sha1 as revision, - R.date as date, + CR.revision_date as date, O.url as origin, L.path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join location as L on (L.id = CR.location) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id order by date, revision, origin, path asc limit 1 $$; create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ (select C.sha1 as content, R.sha1 as revision, - R.date as date, + CR.revision_date as date, O.url as origin, L.path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join location as L on (L.id = CR.location) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) - where C.sha1 = content_id) + where C.sha1 = content_id + order by date, revision, origin, path limit early_cut) union (select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, case DL.path when '' then CL.path when '.' then CL.path else (DL.path || '/' || CL.path)::unix_path end as path from content as C inner join content_in_directory as CD on (CD.content = C.id) inner join directory_in_revision as DR on (DR.directory = CD.directory) inner join revision as R on (R.id = DR.revision) inner join location as CL on (CL.id = CD.location) inner join location as DL on (DL.id = DR.location) left join origin as O on (O.id = R.origin) - where C.sha1 = content_id) + where C.sha1 = content_id + order by date, revision, origin, path limit early_cut) order by date, revision, origin, path limit early_cut $$; create or replace function swh_provenance_relation_add_from_temp( rel_table regclass, src_table regclass, dst_table regclass ) returns void language plpgsql volatile as $$ declare select_fields text; join_location text; begin - if src_table in ('content'::regclass, 'directory'::regclass) then + case + when src_table = 'content'::regclass and dst_table = 'revision'::regclass then + select_fields := 'D.id, L.id, dst_date as revision_date'; + join_location := 'inner join location as L on (digest(L.path,''sha1'') = digest(V.path,''sha1''))'; + when src_table in ('content'::regclass, 'directory'::regclass) then select_fields := 'D.id, L.id'; join_location := 'inner join location as L on (digest(L.path,''sha1'') = digest(V.path,''sha1''))'; - else + else select_fields := 'D.id'; join_location := ''; - end if; + end case; execute format( 'insert into %s (sha1) select distinct src from tmp_relation_add where not exists (select 1 from %s where %s.sha1=tmp_relation_add.src) on conflict do nothing', src_table, src_table, src_table); execute format( 'insert into %s (sha1) select distinct dst from tmp_relation_add where not exists (select 1 from %s where %s.sha1=tmp_relation_add.dst) on conflict do nothing', dst_table, dst_table, dst_table); if src_table in ('content'::regclass, 'directory'::regclass) then insert into location(path) select distinct path from tmp_relation_add where not exists (select 1 from location where digest(location.path,'sha1')=digest(tmp_relation_add.path,'sha1') ) on conflict do nothing; end if; execute format( 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || join_location || ' on conflict do nothing', rel_table, src_table, dst_table ); end; $$; create or replace function swh_provenance_relation_get( rel_table regclass, src_table regclass, dst_table regclass, filter rel_flt, sha1s sha1_git[] ) returns table ( src sha1_git, dst sha1_git, path unix_path ) language plpgsql stable as $$ declare src_field text; dst_field text; join_location text; proj_location text; filter_result text; begin if rel_table = 'revision_before_revision'::regclass then src_field := 'prev'; dst_field := 'next'; else src_field := src_table::text; dst_field := dst_table::text; end if; if src_table in ('content'::regclass, 'directory'::regclass) then join_location := 'inner join location as L on (L.id = R.location)'; proj_location := 'L.path'; else join_location := ''; proj_location := 'NULL::unix_path'; end if; case filter when 'filter-src'::rel_flt then filter_result := 'where S.sha1 = any($1)'; when 'filter-dst'::rel_flt then filter_result := 'where D.sha1 = any($1)'; else filter_result := ''; end case; return query execute format( 'select S.sha1 as src, D.sha1 as dst, ' || proj_location || ' as path from %s as R inner join %s as S on (S.id = R.' || src_field || ') inner join %s as D on (D.id = R.' || dst_field || ') ' || join_location || ' ' || filter_result, rel_table, src_table, dst_table ) using sha1s; end; $$; diff --git a/swh/provenance/sql/60-indexes.sql b/swh/provenance/sql/60-indexes.sql index ecbfa1a..f69df62 100644 --- a/swh/provenance/sql/60-indexes.sql +++ b/swh/provenance/sql/60-indexes.sql @@ -1,11 +1,11 @@ -- create unique indexes (instead of pkey) because location might be null for -- the without-path flavor -create unique index on content_in_revision(content, revision, location); +create unique index on content_in_revision(content, revision_date, revision, location); create unique index on directory_in_revision(directory, revision, location); create unique index on content_in_directory(content, directory, location); create unique index on location(digest(path, 'sha1')); create index on directory(sha1) where flat=false; alter table revision_in_origin add primary key (revision, origin); alter table revision_before_revision add primary key (prev, next); diff --git a/swh/provenance/sql/upgrades/006-step2.sql b/swh/provenance/sql/upgrades/006-step2.sql new file mode 100644 index 0000000..da1e692 --- /dev/null +++ b/swh/provenance/sql/upgrades/006-step2.sql @@ -0,0 +1,77 @@ +alter table content_in_revision alter column revision_date set not null; + +drop function if exists content_in_revision_add_date(bigint); +drop table if exists content_in_revision_add_date_progress; + + +create or replace function swh_provenance_content_find_first(content_id sha1_git) + returns table ( + content sha1_git, + revision sha1_git, + date timestamptz, + origin text, + path unix_path + ) + language sql + stable +as $$ + select C.sha1 as content, + R.sha1 as revision, + CR.revision_date as date, + O.url as origin, + L.path as path + from content as C + inner join content_in_revision as CR on (CR.content = C.id) + inner join location as L on (L.id = CR.location) + inner join revision as R on (R.id = CR.revision) + left join origin as O on (O.id = R.origin) + where C.sha1 = content_id + order by date, revision, origin, path asc limit 1 +$$; + +create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) + returns table ( + content sha1_git, + revision sha1_git, + date timestamptz, + origin text, + path unix_path + ) + language sql + stable +as $$ + (select C.sha1 as content, + R.sha1 as revision, + CR.revision_date as date, + O.url as origin, + L.path as path + from content as C + inner join content_in_revision as CR on (CR.content = C.id) + inner join location as L on (L.id = CR.location) + inner join revision as R on (R.id = CR.revision) + left join origin as O on (O.id = R.origin) + where C.sha1 = content_id + order by date, revision, origin, path limit early_cut) + union + (select C.sha1 as content, + R.sha1 as revision, + R.date as date, + O.url as origin, + case DL.path + when '' then CL.path + when '.' then CL.path + else (DL.path || '/' || CL.path)::unix_path + end as path + from content as C + inner join content_in_directory as CD on (CD.content = C.id) + inner join directory_in_revision as DR on (DR.directory = CD.directory) + inner join revision as R on (R.id = DR.revision) + inner join location as CL on (CL.id = CD.location) + inner join location as DL on (DL.id = DR.location) + left join origin as O on (O.id = R.origin) + where C.sha1 = content_id + order by date, revision, origin, path limit early_cut) + order by date, revision, origin, path limit early_cut +$$; + +drop index if exists content_in_revision_content_revision_date_revision_location_idx; diff --git a/swh/provenance/sql/upgrades/006.sql b/swh/provenance/sql/upgrades/006.sql new file mode 100644 index 0000000..6dc93ce --- /dev/null +++ b/swh/provenance/sql/upgrades/006.sql @@ -0,0 +1,119 @@ +alter table content_in_revision add column revision_date timestamptz; -- left this column null for now, needs a data migration + +comment on column content_in_revision.revision_date is 'Date of the revision'; + +create or replace function swh_mktemp_relation_add() returns void + language sql +as $$ + create temp table tmp_relation_add ( + src sha1_git not null, + dst sha1_git not null, + path unix_path, + dst_date timestamptz + ) on commit drop +$$; + +create or replace function swh_provenance_relation_add_from_temp( + rel_table regclass, src_table regclass, dst_table regclass +) + returns void + language plpgsql + volatile +as $$ + declare + select_fields text; + join_location text; + begin + case + when src_table = 'content'::regclass and dst_table = 'revision'::regclass then + select_fields := 'D.id, L.id, dst_date as revision_date'; + join_location := 'inner join location as L on (digest(L.path,''sha1'') = digest(V.path,''sha1''))'; + when src_table in ('content'::regclass, 'directory'::regclass) then + select_fields := 'D.id, L.id'; + join_location := 'inner join location as L on (digest(L.path,''sha1'') = digest(V.path,''sha1''))'; + else + select_fields := 'D.id'; + join_location := ''; + end case; + + execute format( + 'insert into %s (sha1) + select distinct src + from tmp_relation_add + where not exists (select 1 from %s where %s.sha1=tmp_relation_add.src) + on conflict do nothing', + src_table, src_table, src_table); + + execute format( + 'insert into %s (sha1) + select distinct dst + from tmp_relation_add + where not exists (select 1 from %s where %s.sha1=tmp_relation_add.dst) + on conflict do nothing', + dst_table, dst_table, dst_table); + + if src_table in ('content'::regclass, 'directory'::regclass) then + insert into location(path) + select distinct path + from tmp_relation_add + where not exists (select 1 from location + where digest(location.path,'sha1')=digest(tmp_relation_add.path,'sha1') + ) + on conflict do nothing; + end if; + + execute format( + 'insert into %s + select S.id, ' || select_fields || ' + from tmp_relation_add as V + inner join %s as S on (S.sha1 = V.src) + inner join %s as D on (D.sha1 = V.dst) + ' || join_location || ' + on conflict do nothing', + rel_table, src_table, dst_table + ); + end; +$$; + +create unique index on content_in_revision(content, revision_date, revision, location) where revision_date is not null; + +create table content_in_revision_add_date_progress ( + updated bigint not null, + last_content bigint not null, + date timestamptz not null, + single_row bool not null unique check (single_row != false) default true); + +create or replace function content_in_revision_add_date(num_contents bigint default 1000) +returns setof content_in_revision_add_date_progress +language sql +as $$ + with content_ids as ( + select distinct content + from content_in_revision + where content > coalesce((select last_content from content_in_revision_add_date_progress limit 1), 0) + order by content + limit content_in_revision_add_date.num_contents + ), + updated_rows as ( + update content_in_revision + set + revision_date = (select date from revision where id=content_in_revision.revision) + where + revision_date is null + and content > (select min(content) from content_ids) + and content <= (select max(content) from content_ids) + returning content + ), + updated_progress as ( + insert into content_in_revision_add_date_progress + (updated, last_content, date) + values + ((select count(*) from updated_rows), (select max(content) from updated_rows), now()) + on conflict (single_row) do update set + updated=content_in_revision_add_date_progress.updated+EXCLUDED.updated, + last_content=EXCLUDED.last_content, + date=EXCLUDED.date + returning * + ) + select * from updated_progress +$$; diff --git a/swh/provenance/storage/interface.py b/swh/provenance/storage/interface.py index ec1121d..32327a2 100644 --- a/swh/provenance/storage/interface.py +++ b/swh/provenance/storage/interface.py @@ -1,225 +1,228 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from dataclasses import dataclass from datetime import datetime import enum from types import TracebackType from typing import Dict, Generator, Iterable, List, Optional, Set, Type from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.model.model import Sha1Git class EntityType(enum.Enum): CONTENT = "content" DIRECTORY = "directory" REVISION = "revision" ORIGIN = "origin" class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" @dataclass(eq=True, frozen=True) class ProvenanceResult: content: Sha1Git revision: Sha1Git date: datetime origin: Optional[str] path: bytes @dataclass(eq=True, frozen=True) class DirectoryData: """Object representing the data associated to a directory in the provenance model, where `date` is the date of the directory in the isochrone frontier, and `flat` is a flag acknowledging that a flat model for the elements outside the frontier has already been created. """ date: Optional[datetime] flat: bool @dataclass(eq=True, frozen=True) class RevisionData: """Object representing the data associated to a revision in the provenance model, where `date` is the optional date of the revision (specifying it acknowledges that the revision was already processed by the revision-content algorithm); and `origin` identifies the preferred origin for the revision, if any. """ date: Optional[datetime] origin: Optional[Sha1Git] @dataclass(eq=True, frozen=True) class RelationData: - """Object representing a relation entry in the provenance model, where `src` and - `dst` are the sha1 ids of the entities being related, and `path` is optional - depending on the relation being represented. + """Object representing a relation entry in the provenance model, where `src` + and `dst` are the sha1 ids of the entities being related, and `path` is + optional depending on the relation being represented. `dst_date` is the + (denormalized) known date of the destination, if relevant (e.g. in + `content_in_revision`). """ dst: Sha1Git path: Optional[bytes] + dst_date: Optional[datetime] = None @runtime_checkable class ProvenanceStorageInterface(Protocol): def __enter__(self) -> ProvenanceStorageInterface: ... def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: ... @remote_api_endpoint("close") def close(self) -> None: """Close connection to the storage and release resources.""" ... @remote_api_endpoint("content_add") def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: """Add blobs identified by sha1 ids, with an associated date (as paired in `cnts`) to the provenance storage. Return a boolean stating whether the information was successfully stored. """ ... @remote_api_endpoint("content_find_first") def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... @remote_api_endpoint("content_find_all") def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... @remote_api_endpoint("content_get") def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`.""" ... @remote_api_endpoint("directory_add") def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: """Add directories identified by sha1 ids, with associated date and (optional) flatten flag (as paired in `dirs`) to the provenance storage. If the flatten flag is set to None, the previous value present in the storage is preserved. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("directory_get") def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: """Retrieve the associated date and (optional) flatten flag for each directory sha1 in `ids`. If some directories has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("directory_iter_not_flattened") def directory_iter_not_flattened( self, limit: int, start_id: Sha1Git ) -> List[Sha1Git]: """Retrieve the unflattened directories after ``start_id`` up to ``limit`` entries.""" ... @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance model. This method is used only in tests. """ ... @remote_api_endpoint("location_add") def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: """Register the given `paths` in the storage.""" ... @remote_api_endpoint("location_get_all") def location_get_all(self) -> Dict[Sha1Git, bytes]: """Retrieve all paths present in the provenance model. This method is used only in tests.""" ... @remote_api_endpoint("open") def open(self) -> None: """Open connection to the storage and allocate necessary resources.""" ... @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired in `orgs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("origin_get") def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: """Retrieve the associated url for each origin sha1 in `ids`.""" ... @remote_api_endpoint("revision_add") def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: """Add revisions identified by sha1 ids, with optional associated date or origin (as paired in `revs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("revision_get") def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("relation_add") def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: """Add entries in the selected `relation`. This method assumes all entities being related are already registered in the storage. See `content_add`, `directory_add`, `origin_add`, and `revision_add`. """ ... @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... @remote_api_endpoint("relation_get_all") def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` that are present in the provenance model. This method is used only in tests. """ ... diff --git a/swh/provenance/storage/journal.py b/swh/provenance/storage/journal.py index 4fb2545..9d55000 100644 --- a/swh/provenance/storage/journal.py +++ b/swh/provenance/storage/journal.py @@ -1,164 +1,169 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from datetime import datetime import hashlib from types import TracebackType from typing import Dict, Generator, Iterable, List, Optional, Set, Type from swh.model.model import Sha1Git from swh.provenance.storage.interface import ( DirectoryData, EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) class JournalMessage: def __init__(self, id, value, add_id=True): self.id = id self.value = value self.add_id = add_id def anonymize(self): return None def unique_key(self): return self.id def to_dict(self): if self.add_id: return { "id": self.id, "value": self.value, } else: return self.value class ProvenanceStorageJournal: def __init__(self, storage, journal): self.storage = storage self.journal = journal def __enter__(self) -> ProvenanceStorageInterface: self.storage.__enter__() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: return self.storage.__exit__(exc_type, exc_val, exc_tb) def open(self) -> None: self.storage.open() def close(self) -> None: self.storage.close() def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: self.journal.write_additions( "content", [JournalMessage(key, value) for (key, value) in cnts.items()] ) return self.storage.content_add(cnts) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: return self.storage.content_find_all(id, limit) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self.storage.content_get(ids) def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: self.journal.write_additions( "directory", [ JournalMessage(key, value.date) for (key, value) in dirs.items() if value.date is not None ], ) return self.storage.directory_add(dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: return self.storage.directory_get(ids) def directory_iter_not_flattened( self, limit: int, start_id: Sha1Git ) -> List[Sha1Git]: return self.storage.directory_iter_not_flattened(limit, start_id) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return self.storage.entity_get_all(entity) def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: return self.storage.location_add(paths) def location_get_all(self) -> Dict[Sha1Git, bytes]: return self.storage.location_get_all() def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: self.journal.write_additions( "origin", [JournalMessage(key, value) for (key, value) in orgs.items()] ) return self.storage.origin_add(orgs) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return self.storage.origin_get(ids) def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: self.journal.write_additions( "revision", [ JournalMessage(key, value.date) for (key, value) in revs.items() if value.date is not None ], ) return self.storage.revision_add(revs) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return self.storage.revision_get(ids) def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: messages = [] for src, relations in data.items(): for reldata in relations: key = hashlib.sha1(src + reldata.dst + (reldata.path or b"")).digest() messages.append( JournalMessage( key, - {"src": src, "dst": reldata.dst, "path": reldata.path}, + { + "src": src, + "dst": reldata.dst, + "path": reldata.path, + "dst_date": reldata.dst_date, + }, add_id=False, ) ) self.journal.write_additions(relation.value, messages) return self.storage.relation_add(relation, data) def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self.storage.relation_get(relation, ids, reverse) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self.storage.relation_get_all(relation) diff --git a/swh/provenance/storage/postgresql.py b/swh/provenance/storage/postgresql.py index 8448a72..c5622c4 100644 --- a/swh/provenance/storage/postgresql.py +++ b/swh/provenance/storage/postgresql.py @@ -1,392 +1,396 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from contextlib import contextmanager from datetime import datetime from functools import wraps from hashlib import sha1 import itertools import logging from types import TracebackType from typing import Dict, Generator, Iterable, List, Optional, Set, Type import psycopg2.extensions import psycopg2.extras from swh.core.db import BaseDb from swh.core.statsd import statsd from swh.model.model import Sha1Git from swh.provenance.storage.interface import ( DirectoryData, EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds" def handle_raise_on_commit(f): @wraps(f) def handle(self, *args, **kwargs): try: return f(self, *args, **kwargs) except BaseException as ex: # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise ex return False return handle class ProvenanceStoragePostgreSql: - current_version = 5 + current_version = 6 def __init__( self, page_size: Optional[int] = None, raise_on_commit: bool = False, db: str = "", ) -> None: self.conn: Optional[psycopg2.extensions.connection] = None self.dsn = db self._flavor: Optional[str] = None self.page_size = page_size self.raise_on_commit = raise_on_commit def __enter__(self) -> ProvenanceStorageInterface: self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extras.RealDictCursor, None, None]: if self.conn is None: raise RuntimeError( "Tried to access ProvenanceStoragePostgreSQL transaction() without opening it" ) self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") flavor = cursor.fetchone() assert flavor # please mypy self._flavor = flavor["flavor"] assert self._flavor is not None return self._flavor @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) def close(self) -> None: assert self.conn is not None self.conn.close() @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) @handle_raise_on_commit def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: if cnts: sql = """ INSERT INTO content(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,content.date) """ page_size = self.page_size or len(cnts) with self.transaction() as cursor: psycopg2.extras.execute_values( cursor, sql, argslist=cnts.items(), page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM content WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) @handle_raise_on_commit def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()] if data: sql = """ INSERT INTO directory(sha1, date, flat) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, directory.date), flat=(EXCLUDED.flat OR directory.flat) """ page_size = self.page_size or len(data) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=data, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: result: Dict[Sha1Git, DirectoryData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date, flat FROM directory WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], DirectoryData(date=row["date"], flat=row["flat"])) for row in cursor ) return result @statsd.timed( metric=STORAGE_DURATION_METRIC, tags={"method": "directory_iter_not_flattened"} ) def directory_iter_not_flattened( self, limit: int, start_id: Sha1Git ) -> List[Sha1Git]: sql = """ SELECT sha1 FROM directory WHERE flat=false AND sha1>%s ORDER BY sha1 LIMIT %s """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(start_id, limit)) return [row["sha1"] for row in cursor] @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) @handle_raise_on_commit def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: values = [(path,) for path in paths.values()] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ page_size = self.page_size or len(values) with self.transaction() as cursor: psycopg2.extras.execute_values( cursor, sql, argslist=values, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) def location_get_all(self) -> Dict[Sha1Git, bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {sha1(row["path"]).digest(): row["path"] for row in cursor} @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) @handle_raise_on_commit def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ page_size = self.page_size or len(orgs) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items(), page_size=page_size, ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: self.conn = BaseDb.connect(self.dsn).conn BaseDb.adapt_conn(self.conn) with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"}) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) @handle_raise_on_commit def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: if revs: data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ page_size = self.page_size or len(data) with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=data, page_size=page_size ) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) @handle_raise_on_commit def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: - rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] + rows = [ + (src, rel.dst, rel.path, rel.dst_date) + for src, dsts in data.items() + for rel in dsts + ] if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") page_size = self.page_size or len(rows) # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, - sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", + sql="INSERT INTO tmp_relation_add(src, dst, path, dst_date) VALUES %s", argslist=rows, page_size=page_size, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) for row in cursor: src = row.pop("src") result.setdefault(src, set()).add(RelationData(**row)) return result diff --git a/swh/provenance/storage/rabbitmq/client.py b/swh/provenance/storage/rabbitmq/client.py index 12485b2..361bccf 100644 --- a/swh/provenance/storage/rabbitmq/client.py +++ b/swh/provenance/storage/rabbitmq/client.py @@ -1,501 +1,503 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations import functools import inspect import logging import queue import threading import time from types import TracebackType from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union import uuid import pika import pika.channel import pika.connection import pika.frame import pika.spec from swh.core.api.serializers import encode_data_client as encode_data from swh.core.api.serializers import msgpack_loads as decode_data from swh.core.statsd import statsd from swh.provenance.storage import get_provenance_storage from swh.provenance.storage.interface import ( ProvenanceStorageInterface, RelationData, RelationType, ) from .serializers import DECODERS, ENCODERS from .server import ProvenanceStorageRabbitMQServer LOG_FORMAT = ( "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " "-35s %(lineno) -5d: %(message)s" ) LOGGER = logging.getLogger(__name__) STORAGE_DURATION_METRIC = "swh_provenance_storage_rabbitmq_duration_seconds" class ResponseTimeout(Exception): pass class TerminateSignal(Exception): pass def split_ranges( data: Iterable[bytes], meth_name: str, relation: Optional[RelationType] = None ) -> Dict[str, Set[Tuple[Any, ...]]]: ranges: Dict[str, Set[Tuple[Any, ...]]] = {} if relation is not None: assert isinstance(data, dict), "Relation data must be provided in a dictionary" for src, dsts in data.items(): key = ProvenanceStorageRabbitMQServer.get_routing_key( src, meth_name, relation ) for rel in dsts: assert isinstance( rel, RelationData ), "Values in the dictionary must be RelationData structures" - ranges.setdefault(key, set()).add((src, rel.dst, rel.path)) + ranges.setdefault(key, set()).add( + (src, rel.dst, rel.path, rel.dst_date) + ) else: items: Union[Set[Tuple[bytes, Any]], Set[Tuple[bytes]]] if isinstance(data, dict): items = set(data.items()) else: # TODO this is probably not used any more items = {(item,) for item in data} for id, *rest in items: key = ProvenanceStorageRabbitMQServer.get_routing_key(id, meth_name) ranges.setdefault(key, set()).add((id, *rest)) return ranges class MetaRabbitMQClient(type): def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break backend_class = getattr(base, "backend_class", None) if backend_class: for meth_name, meth in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): with statsd.timed( metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} ): # Match arguments and parameters data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = data.pop("self") # Call storage method with remaining arguments return getattr(self._storage, meth_name)(**data) @functools.wraps(meth) # Copy signature and doc def write_meth_(*args, **kwargs): with statsd.timed( metric=STORAGE_DURATION_METRIC, tags={"method": meth_name} ): # Match arguments and parameters post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) try: # Remove arguments that should not be passed self = post_data.pop("self") relation = post_data.pop("relation", None) assert len(post_data) == 1 data = next(iter(post_data.values())) ranges = split_ranges(data, meth_name, relation) acks_expected = sum(len(items) for items in ranges.values()) self._correlation_id = str(uuid.uuid4()) exchange = ProvenanceStorageRabbitMQServer.get_exchange( meth_name, relation ) try: self._delay_close = True for routing_key, items in ranges.items(): items_list = list(items) batches = ( items_list[idx : idx + self._batch_size] for idx in range(0, len(items_list), self._batch_size) ) for batch in batches: # FIXME: this is running in a different thread! Hence, if # self._connection drops, there is no guarantee that the # request can be sent for the current elements. This # situation should be handled properly. self._connection.ioloop.add_callback_threadsafe( functools.partial( ProvenanceStorageRabbitMQClient.request, channel=self._channel, reply_to=self._callback_queue, exchange=exchange, routing_key=routing_key, correlation_id=self._correlation_id, data=batch, ) ) return self.wait_for_acks(meth_name, acks_expected) finally: self._delay_close = False except BaseException as ex: self.request_termination(str(ex)) return False if meth_name not in attributes: attributes[meth_name] = ( write_meth_ if ProvenanceStorageRabbitMQServer.is_write_method(meth_name) else meth_ ) class ProvenanceStorageRabbitMQClient(threading.Thread, metaclass=MetaRabbitMQClient): backend_class = ProvenanceStorageInterface extra_type_decoders = DECODERS extra_type_encoders = ENCODERS def __init__( self, url: str, storage_config: Dict[str, Any], batch_size: int = 100, prefetch_count: int = 100, wait_min: float = 60, wait_per_batch: float = 10, ) -> None: """Setup the client object, passing in the URL we will use to connect to RabbitMQ, and the connection information for the local storage object used for read-only operations. :param str url: The URL for connecting to RabbitMQ :param dict storage_config: Configuration parameters for the underlying ``ProvenanceStorage`` object expected by ``swh.provenance.get_provenance_storage`` :param int batch_size: Max amount of elements per package (after range splitting) for writing operations :param int prefetch_count: Prefetch value for the RabbitMQ connection when receiving ack packages :param float wait_min: Min waiting time for response on a writing operation, in seconds :param float wait_per_batch: Waiting time for response per batch of items on a writing operation, in seconds """ super().__init__() self._connection = None self._callback_queue: Optional[str] = None self._channel = None self._closing = False self._consumer_tag = None self._consuming = False self._correlation_id = str(uuid.uuid4()) self._prefetch_count = prefetch_count self._batch_size = batch_size self._response_queue: queue.Queue = queue.Queue() self._storage = get_provenance_storage(**storage_config) self._url = url self._wait_min = wait_min self._wait_per_batch = wait_per_batch self._delay_close = False def __enter__(self) -> ProvenanceStorageInterface: self.open() assert isinstance(self, ProvenanceStorageInterface) return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: self.start() while self._callback_queue is None: time.sleep(0.1) self._storage.open() @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) def close(self) -> None: assert self._connection is not None self._connection.ioloop.add_callback_threadsafe(self.request_termination) self.join() self._storage.close() def request_termination(self, reason: str = "Normal shutdown") -> None: assert self._connection is not None def termination_callback(): raise TerminateSignal(reason) self._connection.ioloop.add_callback_threadsafe(termination_callback) def connect(self) -> pika.SelectConnection: LOGGER.info("Connecting to %s", self._url) return pika.SelectConnection( parameters=pika.URLParameters(self._url), on_open_callback=self.on_connection_open, on_open_error_callback=self.on_connection_open_error, on_close_callback=self.on_connection_closed, ) def close_connection(self) -> None: assert self._connection is not None self._consuming = False if self._connection.is_closing or self._connection.is_closed: LOGGER.info("Connection is closing or already closed") else: LOGGER.info("Closing connection") self._connection.close() def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: LOGGER.info("Connection opened") self.open_channel() def on_connection_open_error( self, _unused_connection: pika.SelectConnection, err: Exception ) -> None: LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) assert self._connection is not None self._connection.ioloop.call_later(5, self._connection.ioloop.stop) def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): assert self._connection is not None self._channel = None if self._closing: self._connection.ioloop.stop() else: LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) self._connection.ioloop.call_later(5, self._connection.ioloop.stop) def open_channel(self) -> None: LOGGER.debug("Creating a new channel") assert self._connection is not None self._connection.channel(on_open_callback=self.on_channel_open) def on_channel_open(self, channel: pika.channel.Channel) -> None: LOGGER.debug("Channel opened") self._channel = channel LOGGER.debug("Adding channel close callback") assert self._channel is not None self._channel.add_on_close_callback(callback=self.on_channel_closed) self.setup_queue() def on_channel_closed( self, channel: pika.channel.Channel, reason: Exception ) -> None: LOGGER.warning("Channel %i was closed: %s", channel, reason) self.close_connection() def setup_queue(self) -> None: LOGGER.debug("Declaring callback queue") assert self._channel is not None self._channel.queue_declare( queue="", exclusive=True, callback=self.on_queue_declare_ok ) def on_queue_declare_ok(self, frame: pika.frame.Method) -> None: LOGGER.debug("Binding queue to default exchanger") assert self._channel is not None self._callback_queue = frame.method.queue self._channel.basic_qos( prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok ) def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: LOGGER.debug("QOS set to: %d", self._prefetch_count) self.start_consuming() def start_consuming(self) -> None: LOGGER.debug("Issuing consumer related RPC commands") LOGGER.debug("Adding consumer cancellation callback") assert self._channel is not None self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) assert self._callback_queue is not None self._consumer_tag = self._channel.basic_consume( queue=self._callback_queue, on_message_callback=self.on_response ) self._consuming = True def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: LOGGER.debug("Consumer was cancelled remotely, shutting down: %r", method_frame) if self._channel: self._channel.close() def on_response( self, channel: pika.channel.Channel, deliver: pika.spec.Basic.Deliver, properties: pika.spec.BasicProperties, body: bytes, ) -> None: self._response_queue.put( ( properties.correlation_id, decode_data(body, extra_decoders=self.extra_type_decoders), ) ) channel.basic_ack(delivery_tag=deliver.delivery_tag) def stop_consuming(self) -> None: if self._channel: LOGGER.debug("Sending a Basic.Cancel RPC command to RabbitMQ") self._channel.basic_cancel(self._consumer_tag, self.on_cancel_ok) def on_cancel_ok(self, _unused_frame: pika.frame.Method) -> None: self._consuming = False LOGGER.debug( "RabbitMQ acknowledged the cancellation of the consumer: %s", self._consumer_tag, ) LOGGER.debug("Closing the channel") assert self._channel is not None self._channel.close() def run(self) -> None: while not self._closing: try: self._connection = self.connect() assert self._connection is not None self._connection.ioloop.start() except KeyboardInterrupt: LOGGER.info("Connection closed by keyboard interruption, reopening") if self._connection is not None: self._connection.ioloop.stop() except TerminateSignal as ex: LOGGER.info("Termination requested: %s", ex) self.stop() if self._connection is not None and not self._connection.is_closed: # Finish closing self._connection.ioloop.start() except BaseException as ex: LOGGER.warning("Unexpected exception, terminating: %s", ex) self.stop() if self._connection is not None and not self._connection.is_closed: # Finish closing self._connection.ioloop.start() LOGGER.info("Stopped") def stop(self) -> None: assert self._connection is not None if not self._closing: if self._delay_close: LOGGER.info("Delaying termination: waiting for a pending request") delay_start = time.monotonic() wait = 1 while self._delay_close: if wait >= 32: LOGGER.warning( "Still waiting for pending request (for %2f seconds)...", time.monotonic() - delay_start, ) time.sleep(wait) wait = min(wait * 2, 60) self._closing = True LOGGER.info("Stopping") if self._consuming: self.stop_consuming() self._connection.ioloop.start() else: self._connection.ioloop.stop() LOGGER.info("Stopped") @staticmethod def request( channel: pika.channel.Channel, reply_to: str, exchange: str, routing_key: str, correlation_id: str, **kwargs, ) -> None: channel.basic_publish( exchange=exchange, routing_key=routing_key, properties=pika.BasicProperties( content_type="application/msgpack", correlation_id=correlation_id, reply_to=reply_to, ), body=encode_data( kwargs, extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders, ), ) def wait_for_acks(self, meth_name: str, acks_expected: int) -> bool: acks_received = 0 timeout = max( (acks_expected / self._batch_size) * self._wait_per_batch, self._wait_min, ) start = time.monotonic() end = start + timeout while acks_received < acks_expected: local_timeout = end - time.monotonic() if local_timeout < 1.0: local_timeout = 1.0 try: acks_received += self.wait_for_response(timeout=local_timeout) except ResponseTimeout: LOGGER.warning( "Timed out waiting for acks in %s, %s received, %s expected (in %ss)", meth_name, acks_received, acks_expected, time.monotonic() - start, ) return False return acks_received == acks_expected def wait_for_response(self, timeout: float = 120.0) -> Any: start = time.monotonic() end = start + timeout while True: try: local_timeout = end - time.monotonic() if local_timeout < 1.0: local_timeout = 1.0 correlation_id, response = self._response_queue.get( timeout=local_timeout ) if correlation_id == self._correlation_id: return response except queue.Empty: raise ResponseTimeout diff --git a/swh/provenance/storage/rabbitmq/server.py b/swh/provenance/storage/rabbitmq/server.py index 4cfd49c..ff748fb 100644 --- a/swh/provenance/storage/rabbitmq/server.py +++ b/swh/provenance/storage/rabbitmq/server.py @@ -1,728 +1,728 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Counter from datetime import datetime from enum import Enum import functools import logging import multiprocessing import os import queue import threading from typing import Any, Callable from typing import Counter as TCounter from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast import pika import pika.channel import pika.connection import pika.exceptions from pika.exchange_type import ExchangeType import pika.frame import pika.spec from swh.core import config from swh.core.api.serializers import encode_data_client as encode_data from swh.core.api.serializers import msgpack_loads as decode_data from swh.model.hashutil import hash_to_hex from swh.model.model import Sha1Git from swh.provenance.storage.interface import ( DirectoryData, EntityType, RelationData, RelationType, RevisionData, ) from swh.provenance.util import path_id from .serializers import DECODERS, ENCODERS LOG_FORMAT = ( "%(levelname) -10s %(asctime)s %(name) -30s %(funcName) " "-35s %(lineno) -5d: %(message)s" ) LOGGER = logging.getLogger(__name__) TERMINATE = object() class ServerCommand(Enum): TERMINATE = "terminate" CONSUMING = "consuming" class TerminateSignal(BaseException): pass def resolve_dates(dates: Iterable[Tuple[Sha1Git, datetime]]) -> Dict[Sha1Git, datetime]: result: Dict[Sha1Git, datetime] = {} for sha1, date in dates: known = result.setdefault(sha1, date) if date < known: result[sha1] = date return result def resolve_directory( data: Iterable[Tuple[Sha1Git, DirectoryData]] ) -> Dict[Sha1Git, DirectoryData]: result: Dict[Sha1Git, DirectoryData] = {} for sha1, dir in data: known = result.setdefault(sha1, dir) value = known assert dir.date is not None assert known.date is not None if dir.date < known.date: value = DirectoryData(date=dir.date, flat=value.flat) if dir.flat: value = DirectoryData(date=value.date, flat=dir.flat) if value != known: result[sha1] = value return result def resolve_revision( data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] ) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} for row in data: sha1 = row[0] rev = ( cast(Tuple[Sha1Git, RevisionData], row)[1] if len(row) > 1 else RevisionData(date=None, origin=None) ) known = result.setdefault(sha1, RevisionData(date=None, origin=None)) value = known if rev.date is not None and (known.date is None or rev.date < known.date): value = RevisionData(date=rev.date, origin=value.origin) if rev.origin is not None: value = RevisionData(date=value.date, origin=rev.origin) if value != known: result[sha1] = value return result def resolve_relation( - data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]] + data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes], Optional[datetime]]] ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} - for src, dst, path in data: - result.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + for src, dst, path, dst_date in data: + result.setdefault(src, set()).add( + RelationData(dst=dst, path=path, dst_date=dst_date) + ) return result class ProvenanceStorageRabbitMQWorker(multiprocessing.Process): EXCHANGE_TYPE = ExchangeType.direct extra_type_decoders = DECODERS extra_type_encoders = ENCODERS def __init__( self, url: str, exchange: str, range: int, storage_config: Dict[str, Any], batch_size: int = 100, prefetch_count: int = 100, ) -> None: """Setup the worker object, passing in the URL we will use to connect to RabbitMQ, the exchange to use, the range id on which to operate, and the connection information for the underlying local storage object. :param str url: The URL for connecting to RabbitMQ :param str exchange: The name of the RabbitMq exchange to use :param str range: The ID range to operate on :param dict storage_config: Configuration parameters for the underlying ``ProvenanceStorage`` object expected by ``swh.provenance.get_provenance_storage`` :param int batch_size: Max amount of elements call to the underlying storage :param int prefetch_count: Prefetch value for the RabbitMQ connection when receiving messaged """ super().__init__(name=f"{exchange}_{range:x}") self._connection = None self._channel = None self._closing = False self._consumer_tag: Dict[str, str] = {} self._consuming: Dict[str, bool] = {} self._prefetch_count = prefetch_count self._url = url self._exchange = exchange self._binding_keys = list( ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range) ) self._queues: Dict[str, str] = {} self._storage_config = storage_config self._batch_size = batch_size self.command: multiprocessing.Queue = multiprocessing.Queue() self.signal: multiprocessing.Queue = multiprocessing.Queue() def connect(self) -> pika.SelectConnection: LOGGER.info("Connecting to %s", self._url) return pika.SelectConnection( parameters=pika.URLParameters(self._url), on_open_callback=self.on_connection_open, on_open_error_callback=self.on_connection_open_error, on_close_callback=self.on_connection_closed, ) def close_connection(self) -> None: assert self._connection is not None self._consuming = {binding_key: False for binding_key in self._binding_keys} if self._connection.is_closing or self._connection.is_closed: LOGGER.info("Connection is closing or already closed") else: LOGGER.info("Closing connection") self._connection.close() def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None: LOGGER.info("Connection opened") self.open_channel() def on_connection_open_error( self, _unused_connection: pika.SelectConnection, err: Exception ) -> None: LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err) assert self._connection is not None self._connection.ioloop.call_later(5, self._connection.ioloop.stop) def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason): assert self._connection is not None self._channel = None if self._closing: self._connection.ioloop.stop() else: LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason) self._connection.ioloop.call_later(5, self._connection.ioloop.stop) def open_channel(self) -> None: LOGGER.info("Creating a new channel") assert self._connection is not None self._connection.channel(on_open_callback=self.on_channel_open) def on_channel_open(self, channel: pika.channel.Channel) -> None: LOGGER.info("Channel opened") self._channel = channel LOGGER.info("Adding channel close callback") assert self._channel is not None self._channel.add_on_close_callback(callback=self.on_channel_closed) self.setup_exchange() def on_channel_closed( self, channel: pika.channel.Channel, reason: Exception ) -> None: LOGGER.warning("Channel %i was closed: %s", channel, reason) self.close_connection() def setup_exchange(self) -> None: LOGGER.info("Declaring exchange %s", self._exchange) assert self._channel is not None self._channel.exchange_declare( exchange=self._exchange, exchange_type=self.EXCHANGE_TYPE, callback=self.on_exchange_declare_ok, ) def on_exchange_declare_ok(self, _unused_frame: pika.frame.Method) -> None: LOGGER.info("Exchange declared: %s", self._exchange) self.setup_queues() def setup_queues(self) -> None: for binding_key in self._binding_keys: LOGGER.info("Declaring queue %s", binding_key) assert self._channel is not None callback = functools.partial( self.on_queue_declare_ok, binding_key=binding_key, ) self._channel.queue_declare(queue=binding_key, callback=callback) def on_queue_declare_ok(self, frame: pika.frame.Method, binding_key: str) -> None: LOGGER.info( "Binding queue %s to exchange %s with routing key %s", frame.method.queue, self._exchange, binding_key, ) assert self._channel is not None callback = functools.partial(self.on_bind_ok, queue_name=frame.method.queue) self._queues[binding_key] = frame.method.queue self._channel.queue_bind( queue=frame.method.queue, exchange=self._exchange, routing_key=binding_key, callback=callback, ) def on_bind_ok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None: LOGGER.info("Queue bound: %s", queue_name) assert self._channel is not None self._channel.basic_qos( prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok ) def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None: LOGGER.info("QOS set to: %d", self._prefetch_count) self.start_consuming() def start_consuming(self) -> None: LOGGER.info("Issuing consumer related RPC commands") LOGGER.info("Adding consumer cancellation callback") assert self._channel is not None self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled) for binding_key in self._binding_keys: self._consumer_tag[binding_key] = self._channel.basic_consume( queue=self._queues[binding_key], on_message_callback=self.on_request ) self._consuming[binding_key] = True self.signal.put(ServerCommand.CONSUMING) def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None: LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame) if self._channel: self._channel.close() def on_request( self, channel: pika.channel.Channel, deliver: pika.spec.Basic.Deliver, properties: pika.spec.BasicProperties, body: bytes, ) -> None: LOGGER.debug( "Received message # %s from %s: %s", deliver.delivery_tag, properties.app_id, body, ) - # XXX: for some reason this function is returning lists instead of tuples - # (the client send tuples) batch = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"] for item in batch: self._request_queues[deliver.routing_key].put( (tuple(item), (properties.correlation_id, properties.reply_to)) ) LOGGER.debug("Acknowledging message %s", deliver.delivery_tag) channel.basic_ack(delivery_tag=deliver.delivery_tag) def stop_consuming(self) -> None: if self._channel: LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ") for binding_key in self._binding_keys: callback = functools.partial(self.on_cancel_ok, binding_key=binding_key) self._channel.basic_cancel( self._consumer_tag[binding_key], callback=callback ) def on_cancel_ok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None: self._consuming[binding_key] = False LOGGER.info( "RabbitMQ acknowledged the cancellation of the consumer: %s", self._consuming[binding_key], ) LOGGER.info("Closing the channel") assert self._channel is not None self._channel.close() def run(self) -> None: self._command_thread = threading.Thread(target=self.run_command_thread) self._command_thread.start() self._request_queues: Dict[str, queue.Queue] = {} self._request_threads: Dict[str, threading.Thread] = {} for binding_key in self._binding_keys: meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name( binding_key ) self._request_queues[binding_key] = queue.Queue() self._request_threads[binding_key] = threading.Thread( target=self.run_request_thread, args=(binding_key, meth_name, relation), ) self._request_threads[binding_key].start() while not self._closing: try: self._connection = self.connect() assert self._connection is not None self._connection.ioloop.start() except KeyboardInterrupt: LOGGER.info("Connection closed by keyboard interruption, reopening") if self._connection is not None: self._connection.ioloop.stop() except TerminateSignal as ex: LOGGER.info("Termination requested: %s", ex) self.stop() if self._connection is not None and not self._connection.is_closed: # Finish closing self._connection.ioloop.start() except BaseException as ex: LOGGER.warning("Unexpected exception, terminating: %s", ex) self.stop() if self._connection is not None and not self._connection.is_closed: # Finish closing self._connection.ioloop.start() for binding_key in self._binding_keys: self._request_queues[binding_key].put(TERMINATE) for binding_key in self._binding_keys: self._request_threads[binding_key].join() self._command_thread.join() LOGGER.info("Stopped") def run_command_thread(self) -> None: while True: try: command = self.command.get() if command == ServerCommand.TERMINATE: self.request_termination() break except queue.Empty: pass except BaseException as ex: self.request_termination(str(ex)) break def request_termination(self, reason: str = "Normal shutdown") -> None: assert self._connection is not None def termination_callback(): raise TerminateSignal(reason) self._connection.ioloop.add_callback_threadsafe(termination_callback) def run_request_thread( self, binding_key: str, meth_name: str, relation: Optional[RelationType] ) -> None: from swh.provenance import get_provenance_storage with get_provenance_storage(**self._storage_config) as storage: request_queue = self._request_queues[binding_key] merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name) while True: terminate = False elements = [] while True: try: # TODO: consider reducing this timeout or removing it elem = request_queue.get(timeout=0.1) if elem is TERMINATE: terminate = True break elements.append(elem) except queue.Empty: break if len(elements) >= self._batch_size: break if terminate: break if not elements: continue try: items, props = zip(*elements) acks_count: TCounter[Tuple[str, str]] = Counter(props) data = merge_items(items) args = (relation, data) if relation is not None else (data,) if getattr(storage, meth_name)(*args): for (correlation_id, reply_to), count in acks_count.items(): # FIXME: this is running in a different thread! Hence, if # self._connection drops, there is no guarantee that the # response can be sent for the current elements. This # situation should be handled properly. assert self._connection is not None self._connection.ioloop.add_callback_threadsafe( functools.partial( ProvenanceStorageRabbitMQWorker.respond, channel=self._channel, correlation_id=correlation_id, reply_to=reply_to, response=count, ) ) else: LOGGER.warning( "Unable to process elements for queue %s", binding_key ) for elem in elements: request_queue.put(elem) except BaseException as ex: self.request_termination(str(ex)) break def stop(self) -> None: assert self._connection is not None if not self._closing: self._closing = True LOGGER.info("Stopping") if any(self._consuming): self.stop_consuming() self._connection.ioloop.start() else: self._connection.ioloop.stop() LOGGER.info("Stopped") @staticmethod def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]: if meth_name == "content_add": return resolve_dates elif meth_name == "directory_add": return resolve_directory elif meth_name == "location_add": return lambda data: dict(data) elif meth_name == "origin_add": return lambda data: dict(data) # last processed value is good enough elif meth_name == "revision_add": return resolve_revision elif meth_name == "relation_add": return resolve_relation else: LOGGER.warning( "Unexpected conflict resolution function request for method %s", meth_name, ) return lambda x: x @staticmethod def respond( channel: pika.channel.Channel, correlation_id: str, reply_to: str, response: Any, ): channel.basic_publish( exchange="", routing_key=reply_to, properties=pika.BasicProperties( content_type="application/msgpack", correlation_id=correlation_id, ), body=encode_data( response, extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders, ), ) class ProvenanceStorageRabbitMQServer: extra_type_decoders = DECODERS extra_type_encoders = ENCODERS queue_count = 16 def __init__( self, url: str, storage_config: Dict[str, Any], batch_size: int = 100, prefetch_count: int = 100, ) -> None: """Setup the server object, passing in the URL we will use to connect to RabbitMQ, and the connection information for the underlying local storage object. :param str url: The URL for connecting to RabbitMQ :param dict storage_config: Configuration parameters for the underlying ``ProvenanceStorage`` object expected by ``swh.provenance.get_provenance_storage`` :param int batch_size: Max amount of elements call to the underlying storage :param int prefetch_count: Prefetch value for the RabbitMQ connection when receiving messaged """ self._workers: List[ProvenanceStorageRabbitMQWorker] = [] for exchange in ProvenanceStorageRabbitMQServer.get_exchanges(): for range in ProvenanceStorageRabbitMQServer.get_ranges(exchange): worker = ProvenanceStorageRabbitMQWorker( url=url, exchange=exchange, range=range, storage_config=storage_config, batch_size=batch_size, prefetch_count=prefetch_count, ) self._workers.append(worker) self._running = False def start(self) -> None: if not self._running: self._running = True for worker in self._workers: worker.start() for worker in self._workers: try: signal = worker.signal.get(timeout=60) assert signal == ServerCommand.CONSUMING except queue.Empty: LOGGER.error( "Could not initialize worker %s. Leaving...", worker.name ) self.stop() return LOGGER.info("Start serving") def stop(self) -> None: if self._running: for worker in self._workers: worker.command.put(ServerCommand.TERMINATE) for worker in self._workers: worker.join() LOGGER.info("Stop serving") self._running = False @staticmethod def get_binding_keys(exchange: str, range: int) -> Iterator[str]: for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names( exchange ): if relation is None: assert ( meth_name != "relation_add" ), "'relation_add' requires 'relation' to be provided" yield f"{meth_name}.unknown.{range:x}".lower() else: assert ( meth_name == "relation_add" ), f"'{meth_name}' requires 'relation' to be None" yield f"{meth_name}.{relation.value}.{range:x}".lower() @staticmethod def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str: if meth_name == "relation_add": assert ( relation is not None ), "'relation_add' requires 'relation' to be provided" split = relation.value else: assert relation is None, f"'{meth_name}' requires 'relation' to be None" split = meth_name exchange, *_ = split.split("_") return exchange @staticmethod def get_exchanges() -> Iterator[str]: yield from [entity.value for entity in EntityType] + ["location"] @staticmethod def get_meth_name( binding_key: str, ) -> Tuple[str, Optional[RelationType]]: meth_name, relation, *_ = binding_key.split(".") return meth_name, (RelationType(relation) if relation != "unknown" else None) @staticmethod def get_meth_names( exchange: str, ) -> Iterator[Tuple[str, Optional[RelationType]]]: if exchange == EntityType.CONTENT.value: yield from [ ("content_add", None), ("relation_add", RelationType.CNT_EARLY_IN_REV), ("relation_add", RelationType.CNT_IN_DIR), ] elif exchange == EntityType.DIRECTORY.value: yield from [ ("directory_add", None), ("relation_add", RelationType.DIR_IN_REV), ] elif exchange == EntityType.ORIGIN.value: yield from [("origin_add", None)] elif exchange == EntityType.REVISION.value: yield from [ ("revision_add", None), ("relation_add", RelationType.REV_BEFORE_REV), ("relation_add", RelationType.REV_IN_ORG), ] elif exchange == "location": yield "location_add", None @staticmethod def get_ranges(unused_exchange: str) -> Iterator[int]: # XXX: we might want to have a different range per exchange yield from range(ProvenanceStorageRabbitMQServer.queue_count) @staticmethod def get_routing_key( item: bytes, meth_name: str, relation: Optional[RelationType] = None ) -> str: hashid = ( path_id(item).hex() if meth_name.startswith("location") else hash_to_hex(item) ) idx = int(hashid[0], 16) % ProvenanceStorageRabbitMQServer.queue_count if relation is None: assert ( meth_name != "relation_add" ), "'relation_add' requires 'relation' to be provided" return f"{meth_name}.unknown.{idx:x}".lower() else: assert ( meth_name == "relation_add" ), f"'{meth_name}' requires 'relation' to be None" return f"{meth_name}.{relation.value}.{idx:x}".lower() @staticmethod def is_write_method(meth_name: str) -> bool: return "_add" in meth_name def load_and_check_config( config_path: Optional[str], type: str = "local" ) -> Dict[str, Any]: """Check the minimal configuration is set to run the api or raise an error explanation. Args: config_path (str): Path to the configuration file to load type (str): configuration type. For 'local' type, more checks are done. Raises: Error if the setup is not as expected Returns: configuration as a dict """ if config_path is None: raise EnvironmentError("Configuration file must be defined") if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file {config_path} does not exist") cfg = config.read(config_path) pcfg: Optional[Dict[str, Any]] = cfg.get("provenance") if pcfg is None: raise KeyError("Missing 'provenance' configuration") rcfg: Optional[Dict[str, Any]] = pcfg.get("rabbitmq") if rcfg is None: raise KeyError("Missing 'provenance.rabbitmq' configuration") scfg: Optional[Dict[str, Any]] = rcfg.get("storage_config") if scfg is None: raise KeyError("Missing 'provenance.rabbitmq.storage_config' configuration") return cfg def make_server_from_configfile() -> ProvenanceStorageRabbitMQServer: config_path = os.environ.get("SWH_CONFIG_FILENAME") server_cfg = load_and_check_config(config_path) return ProvenanceStorageRabbitMQServer(**server_cfg["provenance"]["rabbitmq"]) diff --git a/swh/provenance/storage/replay.py b/swh/provenance/storage/replay.py index 0f19e7c..666a4f1 100644 --- a/swh/provenance/storage/replay.py +++ b/swh/provenance/storage/replay.py @@ -1,122 +1,125 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import defaultdict from datetime import datetime import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union try: from systemd.daemon import notify except ImportError: notify = None from swh.core.statsd import statsd from swh.journal.serializers import kafka_to_value from swh.provenance.storage.interface import ( DirectoryData, RelationData, RelationType, RevisionData, Sha1Git, ) from .interface import ProvenanceStorageInterface logger = logging.getLogger(__name__) REPLAY_OPERATIONS_METRIC = "swh_provenance_replayer_operations_total" REPLAY_DURATION_METRIC = "swh_provenance_replayer_duration_seconds" def cvrt_directory(msg_d): return (msg_d["id"], DirectoryData(date=msg_d["value"], flat=False)) def cvrt_revision(msg_d): return (msg_d["id"], RevisionData(date=msg_d["value"], origin=None)) def cvrt_default(msg_d): return (msg_d["id"], msg_d["value"]) def cvrt_relation(msg_d): - return (msg_d["src"], RelationData(dst=msg_d["dst"], path=msg_d["path"])) + return ( + msg_d["src"], + RelationData(dst=msg_d["dst"], path=msg_d["path"], dst_date=msg_d["dst_date"]), + ) OBJECT_CONVERTERS: Dict[str, Callable[[Dict], Tuple[bytes, Any]]] = { "directory": cvrt_directory, "revision": cvrt_revision, "content": cvrt_default, "content_in_revision": cvrt_relation, "content_in_directory": cvrt_relation, "directory_in_revision": cvrt_relation, } class ProvenanceObjectDeserializer: def __init__( self, raise_on_error: bool = False, reporter: Optional[Callable[[str, bytes], None]] = None, ): self.reporter = reporter self.raise_on_error = raise_on_error def convert(self, object_type: str, msg: bytes) -> Optional[Tuple[bytes, Any]]: dict_repr = kafka_to_value(msg) obj = OBJECT_CONVERTERS[object_type](dict_repr) return obj def report_failure(self, msg: bytes, obj: Dict): if self.reporter: self.reporter(obj["id"].hex(), msg) def process_replay_objects( all_objects: Dict[str, List[Tuple[bytes, Any]]], *, storage: ProvenanceStorageInterface, ) -> None: for object_type, objects in all_objects.items(): logger.debug("Inserting %s %s objects", len(objects), object_type) with statsd.timed(REPLAY_DURATION_METRIC, tags={"object_type": object_type}): _insert_objects(object_type, objects, storage) statsd.increment( REPLAY_OPERATIONS_METRIC, len(objects), tags={"object_type": object_type} ) if notify: notify("WATCHDOG=1") def _insert_objects( object_type: str, objects: List[Tuple[bytes, Any]], storage: ProvenanceStorageInterface, ) -> None: """Insert objects of type object_type in the storage.""" if object_type not in OBJECT_CONVERTERS: logger.warning("Received a series of %s, this should not happen", object_type) return if "_in_" in object_type: reldata = defaultdict(set) for k, v in objects: reldata[k].add(v) storage.relation_add(relation=RelationType(object_type), data=reldata) elif object_type in ("revision", "directory"): entitydata: Dict[Sha1Git, Union[RevisionData, DirectoryData]] = {} for k, v in objects: if k not in entitydata or entitydata[k].date > v.date: entitydata[k] = v getattr(storage, f"{object_type}_add")(entitydata) else: data: Dict[Sha1Git, datetime] = {} for k, v in objects: assert isinstance(v, datetime) if k not in data or data[k] > v: data[k] = v getattr(storage, f"{object_type}_add")(data) diff --git a/swh/provenance/tests/test_cli.py b/swh/provenance/tests/test_cli.py index 26ec93e..0be8712 100644 --- a/swh/provenance/tests/test_cli.py +++ b/swh/provenance/tests/test_cli.py @@ -1,191 +1,197 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import logging import re from typing import Dict, List from click.testing import CliRunner from confluent_kafka import Producer import psycopg2.extensions import pytest from swh.core.cli import swh as swhmain import swh.core.cli.db # noqa ; ensure cli is loaded from swh.core.db.db_utils import init_admin_extensions from swh.journal.serializers import key_to_kafka, value_to_kafka from swh.model.hashutil import MultiHash import swh.provenance.cli # noqa ; ensure cli is loaded from swh.provenance.storage.interface import EntityType, RelationType from swh.storage.interface import StorageInterface from .utils import fill_storage, get_datafile, invoke, load_repo_data logger = logging.getLogger(__name__) def test_cli_swh_db_help() -> None: # swhmain.add_command(provenance_cli) result = CliRunner().invoke(swhmain, ["provenance", "-h"]) assert result.exit_code == 0 assert "Commands:" in result.output commands = result.output.split("Commands:")[1] for command in ( "find-all", "find-first", "iter-frontiers", "iter-origins", "iter-revisions", ): assert f" {command} " in commands def test_cli_init_db_default_flavor(postgresql: psycopg2.extensions.connection) -> None: "Test that 'swh db init provenance' defaults to a normalized flavored DB" dbname = postgresql.dsn init_admin_extensions("swh.provenance", dbname) result = CliRunner().invoke(swhmain, ["db", "init", "-d", dbname, "provenance"]) assert result.exit_code == 0, result.output @pytest.mark.origin_layer @pytest.mark.parametrize( "subcommand", (["origin", "from-csv"], ["iter-origins"]), ) def test_cli_origin_from_csv( swh_storage: StorageInterface, subcommand: List[str], swh_storage_backend_config: Dict, provenance, tmp_path, ): repo = "cmdbts2" origin_url = f"https://{repo}" data = load_repo_data(repo) fill_storage(swh_storage, data) assert len(data["origin"]) >= 1 assert origin_url in [o["url"] for o in data["origin"]] cfg = { "provenance": { "archive": { "cls": "api", "storage": swh_storage_backend_config, }, "storage": { "cls": "postgresql", "db": provenance.storage.conn.dsn, }, }, } csv_filepath = get_datafile("origins.csv") subcommand = subcommand + [csv_filepath] result = invoke(subcommand, config=cfg) assert result.exit_code == 0, f"Unexpected result: {result.output}" origin_sha1 = MultiHash.from_data( origin_url.encode(), hash_names=["sha1"] ).digest()["sha1"] actual_result = provenance.storage.origin_get([origin_sha1]) assert actual_result == {origin_sha1: origin_url} @pytest.mark.kafka def test_replay( provenance_storage, provenance_postgresqldb: str, kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, ): kafka_prefix += ".swh.journal.provenance" producer = Producer( { "bootstrap.servers": kafka_server, "client.id": "test-producer", "acks": "all", } ) for i in range(10): date = datetime.fromtimestamp(i, tz=timezone.utc) cntkey = (b"cnt:" + bytes([i])).ljust(20, b"\x00") dirkey = (b"dir:" + bytes([i])).ljust(20, b"\x00") revkey = (b"rev:" + bytes([i])).ljust(20, b"\x00") loc = f"dir/{i}".encode() producer.produce( topic=kafka_prefix + ".content_in_revision", key=key_to_kafka(cntkey), - value=value_to_kafka({"src": cntkey, "dst": revkey, "path": loc}), + value=value_to_kafka( + {"src": cntkey, "dst": revkey, "path": loc, "dst_date": date} + ), ) producer.produce( topic=kafka_prefix + ".content_in_directory", key=key_to_kafka(cntkey), - value=value_to_kafka({"src": cntkey, "dst": dirkey, "path": loc}), + value=value_to_kafka( + {"src": cntkey, "dst": dirkey, "path": loc, "dst_date": None} + ), ) producer.produce( topic=kafka_prefix + ".directory_in_revision", key=key_to_kafka(dirkey), - value=value_to_kafka({"src": dirkey, "dst": revkey, "path": loc}), + value=value_to_kafka( + {"src": dirkey, "dst": revkey, "path": loc, "dst_date": None} + ), ) # now add dates to entities producer.produce( topic=kafka_prefix + ".content", key=key_to_kafka(cntkey), value=value_to_kafka({"id": cntkey, "value": date}), ) producer.produce( topic=kafka_prefix + ".directory", key=key_to_kafka(dirkey), value=value_to_kafka({"id": dirkey, "value": date}), ) producer.produce( topic=kafka_prefix + ".revision", key=key_to_kafka(revkey), value=value_to_kafka({"id": revkey, "value": date}), ) producer.flush() logger.debug("Flushed producer") config = { "provenance": { "storage": { "cls": "postgresql", "db": provenance_postgresqldb, }, "journal_client": { "cls": "kafka", "brokers": [kafka_server], "group_id": kafka_consumer_group, "prefix": kafka_prefix, "stop_on_eof": True, }, } } result = invoke(["replay"], config=config) expected = r"Done. processed 60 messages\n" assert result.exit_code == 0, result.output assert re.fullmatch(expected, result.output, re.MULTILINE), result.output assert len(provenance_storage.entity_get_all(EntityType.CONTENT)) == 10 assert len(provenance_storage.entity_get_all(EntityType.REVISION)) == 10 assert len(provenance_storage.entity_get_all(EntityType.DIRECTORY)) == 10 assert len(provenance_storage.location_get_all()) == 10 assert len(provenance_storage.relation_get_all(RelationType.CNT_EARLY_IN_REV)) == 10 assert len(provenance_storage.relation_get_all(RelationType.DIR_IN_REV)) == 10 assert len(provenance_storage.relation_get_all(RelationType.CNT_IN_DIR)) == 10 diff --git a/swh/provenance/tests/test_conflict_resolution.py b/swh/provenance/tests/test_conflict_resolution.py index 6c5e54b..d1b3944 100644 --- a/swh/provenance/tests/test_conflict_resolution.py +++ b/swh/provenance/tests/test_conflict_resolution.py @@ -1,169 +1,176 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union from swh.model.hashutil import hash_to_bytes from swh.model.model import Sha1Git from swh.provenance.storage.interface import DirectoryData, RelationData, RevisionData from swh.provenance.storage.rabbitmq.server import ( resolve_dates, resolve_directory, resolve_relation, resolve_revision, ) def test_resolve_dates() -> None: items: List[Tuple[Sha1Git, datetime]] = [ ( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), datetime.fromtimestamp(1000000001), ), ( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), datetime.fromtimestamp(1000000000), ), ] assert resolve_dates(items) == { hash_to_bytes( "20329687bb9c1231a7e05afe86160343ad49b494" ): datetime.fromtimestamp(1000000000) } def test_resolve_directory() -> None: items: List[Tuple[Sha1Git, DirectoryData]] = [ ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), DirectoryData(date=datetime.fromtimestamp(1000000002), flat=False), ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), DirectoryData(date=datetime.fromtimestamp(1000000001), flat=True), ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), DirectoryData(date=datetime.fromtimestamp(1000000000), flat=False), ), ] assert resolve_directory(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): DirectoryData( date=datetime.fromtimestamp(1000000000), flat=True ) } def test_resolve_revision_without_date() -> None: items: List[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] = [ (hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"),), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData( date=None, origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ), ), ] assert resolve_revision(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): RevisionData( date=None, origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ) } def test_resolve_revision_without_origin() -> None: items: List[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] = [ (hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"),), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData(date=datetime.fromtimestamp(1000000000), origin=None), ), ] assert resolve_revision(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): RevisionData( date=datetime.fromtimestamp(1000000000), origin=None, ) } def test_resolve_revision_merge() -> None: items: List[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] = [ ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData(date=datetime.fromtimestamp(1000000000), origin=None), ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData( date=None, origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ), ), ] assert resolve_revision(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): RevisionData( date=datetime.fromtimestamp(1000000000), origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ) } def test_resolve_revision_keep_min_date() -> None: items: List[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]] = [ ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData( date=datetime.fromtimestamp(1000000000), origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), ), ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), RevisionData( date=datetime.fromtimestamp(1000000001), origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ), ), ] assert resolve_revision(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): RevisionData( date=datetime.fromtimestamp(1000000000), origin=hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8175"), ) } def test_resolve_relation() -> None: - items: List[Tuple[Sha1Git, Sha1Git, bytes]] = [ + items: List[Tuple[Sha1Git, Sha1Git, Optional[bytes], Optional[datetime]]] = [ ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/1", + datetime.fromtimestamp(1000000001), ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/2", + None, ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/1", + datetime.fromtimestamp(1000000001), ), ] assert resolve_relation(items) == { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): { RelationData( - hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/1" + hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), + b"/path/1", + dst_date=datetime.fromtimestamp(1000000001), ), RelationData( - hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/2" + hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), + b"/path/2", + dst_date=None, ), } } diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index 571787e..1068bcf 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,488 +1,499 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import hashlib import inspect import os from typing import Any, Dict, Iterable, Optional, Set, Tuple import pytest from swh.model.hashutil import hash_to_bytes from swh.model.model import Origin, Sha1Git from swh.provenance.algos.origin import origin_add from swh.provenance.algos.revision import revision_add from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface from swh.provenance.model import OriginEntry, RevisionEntry from swh.provenance.provenance import Provenance from swh.provenance.storage.interface import ( DirectoryData, EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from .utils import fill_storage, load_repo_data, ts2dt class TestProvenanceStorage: def test_provenance_storage_content( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests content methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all content present in the current repo to the storage, just assigning their # creation dates. Then check that the returned results when querying are the same. cnt_dates = { cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"]) } assert provenance_storage.content_add(cnt_dates) assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates assert provenance_storage.entity_get_all(EntityType.CONTENT) == set( cnt_dates.keys() ) def test_provenance_storage_directory( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Of all directories present in the current repo, only assign a date to those # containing blobs (picking the max date among the available ones). Then check that # the returned results when querying are the same. def getmaxdate( directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] ) -> Optional[datetime]: dates = [ content["ctime"] for entry in directory["entries"] for content in contents if entry["type"] == "file" and entry["target"] == content["sha1_git"] ] return max(dates) if dates else None flat_values = (False, True) dir_dates = {} for idx, dir in enumerate(data["directory"]): date = getmaxdate(dir, data["content"]) if date is not None: dir_dates[dir["id"]] = DirectoryData( date=date, flat=flat_values[idx % 2] ) assert provenance_storage.directory_add(dir_dates) assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set( dir_dates.keys() ) def test_provenance_storage_location( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests location methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all names of entries present in the directories of the current repo as paths # to the storage. Then check that the returned results when querying are the same. paths = { hashlib.sha1(entry["name"]).digest(): entry["name"] for dir in data["directory"] for entry in dir["entries"] } assert provenance_storage.location_add(paths) assert provenance_storage.location_get_all() == paths @pytest.mark.origin_layer def test_provenance_storage_origin( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test origin methods. # Add all origins present in the current repo to the storage. Then check that the # returned results when querying are the same. orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} assert orgs assert provenance_storage.origin_add(orgs) assert provenance_storage.origin_get(set(orgs.keys())) == orgs assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) def test_provenance_storage_revision( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test revision methods. # Add all revisions present in the current repo to the storage, assigning their # dates and an arbitrary origin to each one. Then check that the returned results # when querying are the same. origin = Origin(url=next(iter(data["origin"]))["url"]) # Origin must be inserted in advance. assert provenance_storage.origin_add({origin.id: origin.url}) revs = {rev["id"] for idx, rev in enumerate(data["revision"])} rev_data = { rev["id"]: RevisionData( date=ts2dt(rev["date"]) if idx % 2 != 0 else None, origin=origin.id if idx % 3 != 0 else None, ) for idx, rev in enumerate(data["revision"]) } assert revs assert provenance_storage.revision_add(rev_data) assert provenance_storage.revision_get(set(rev_data.keys())) == { k: v for (k, v) in rev_data.items() if v.date is not None or v.origin is not None } assert provenance_storage.entity_get_all(EntityType.REVISION) == set(rev_data) def test_provenance_storage_relation_revision_layer( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) - for cnt, rel in dircontent(data, rev["id"], root): + for cnt, rel in dircontent( + data=data, ref=rev["id"], dir=root, ref_date=ts2dt(rev["date"]) + ): cnt_in_rev.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for dir in data["directory"]: - for cnt, rel in dircontent(data, dir["id"], dir): + for cnt, rel in dircontent(data=data, ref=dir["id"], dir=dir): cnt_in_dir.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: dir_in_rev.setdefault(rev["directory"], set()).add( RelationData(dst=rev["id"], path=b".") ) relation_add_and_compare_result( provenance_storage, RelationType.DIR_IN_REV, dir_in_rev ) @pytest.mark.origin_layer def test_provenance_storage_relation_origin_layer( self, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test revision-in-origin relation. # Origins must be inserted in advance (cannot be done by `entity_add` inside # `relation_add_and_compare_result`). orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]} assert provenance_storage.origin_add(orgs) # Add all revisions that are head of some snapshot branch to the corresponding # origin. rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} for status in data["origin_visit_status"]: if status["snapshot"] is not None: for snapshot in data["snapshot"]: if snapshot["id"] == status["snapshot"]: for branch in snapshot["branches"].values(): if branch["target_type"] == "revision": rev_in_org.setdefault(branch["target"], set()).add( RelationData( dst=Origin(url=status["origin"]).id, path=None, ) ) relation_add_and_compare_result( provenance_storage, RelationType.REV_IN_ORG, rev_in_org ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: for parent in rev["parents"]: rev_before_rev.setdefault(parent, set()).add( RelationData(dst=rev["id"], path=None) ) relation_add_and_compare_result( provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev ) def test_provenance_storage_find_revision_layer( self, provenance: ProvenanceInterface, provenance_storage: ProvenanceStorageInterface, archive: ArchiveInterface, ) -> None: """Tests `content_find_first` and `content_find_all` methods for every `ProvenanceStorageInterface` implementation. """ # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") fill_storage(archive.storage, data) # Test content_find_first and content_find_all, first only executing the # revision-content algorithm, then adding the origin-revision layer. # Execute the revision-content algorithm on both storages. revisions = [ RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) for rev in data["revision"] ] revision_add(provenance, archive, revisions) revision_add(Provenance(provenance_storage), archive, revisions) assert ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin=None, path=b"A/B/C/a", ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert provenance.storage.content_find_first( cnt ) == provenance_storage.content_find_first(cnt) assert set(provenance.storage.content_find_all(cnt)) == set( provenance_storage.content_find_all(cnt) ) @pytest.mark.origin_layer def test_provenance_storage_find_origin_layer( self, provenance: ProvenanceInterface, provenance_storage: ProvenanceStorageInterface, archive: ArchiveInterface, ) -> None: """Tests `content_find_first` and `content_find_all` methods for every `ProvenanceStorageInterface` implementation. """ # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") fill_storage(archive.storage, data) # Execute the revision-content algorithm on both storages. revisions = [ RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) for rev in data["revision"] ] revision_add(provenance, archive, revisions) revision_add(Provenance(provenance_storage), archive, revisions) # Test content_find_first and content_find_all, first only executing the # revision-content algorithm, then adding the origin-revision layer. # Execute the origin-revision algorithm on both storages. origins = [ OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) for sta in data["origin_visit_status"] if sta["snapshot"] is not None ] origin_add(provenance, archive, origins) origin_add(Provenance(provenance_storage), archive, origins) assert ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin="https://cmdbts2", path=b"A/B/C/a", ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert provenance.storage.content_find_first( cnt ) == provenance_storage.content_find_first(cnt) assert set(provenance.storage.content_find_all(cnt)) == set( provenance_storage.content_find_all(cnt) ) def test_types(self, provenance_storage: ProvenanceStorageInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (ProvenanceStorageInterface,), {})() assert "content_find_first" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: concrete_meth = getattr(provenance_storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] # If all the assertions above succeed, then this one should too. # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. assert isinstance(provenance_storage, ProvenanceStorageInterface) def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", + ref_date: Optional[datetime] = None, ) -> Iterable[Tuple[Sha1Git, RelationData]]: content = { ( entry["target"], - RelationData(dst=ref, path=os.path.join(prefix, entry["name"])), + RelationData( + dst=ref, path=os.path.join(prefix, entry["name"]), dst_date=ref_date + ), ) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( - dircontent(data, ref, child, os.path.join(prefix, entry["name"])) + dircontent( + data=data, + ref=ref, + dir=child, + prefix=os.path.join(prefix, entry["name"]), + ref_date=ref_date, + ) ) return content def entity_add( storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git] ) -> bool: now = datetime.now(tz=timezone.utc) if entity == EntityType.CONTENT: return storage.content_add({sha1: now for sha1 in ids}) elif entity == EntityType.DIRECTORY: return storage.directory_add( {sha1: DirectoryData(date=now, flat=False) for sha1 in ids} ) else: # entity == EntityType.REVISION: return storage.revision_add( {sha1: RevisionData(date=None, origin=None) for sha1 in ids} ) def relation_add_and_compare_result( storage: ProvenanceStorageInterface, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]], ) -> None: # Source, destinations and locations must be added in advance. src, *_, dst = relation.value.split("_") srcs = {sha1 for sha1 in data} if src != "origin": assert entity_add(storage, EntityType(src), srcs) dsts = {rel.dst for rels in data.values() for rel in rels} if dst != "origin": assert entity_add(storage, EntityType(dst), dsts) assert storage.location_add( { hashlib.sha1(rel.path).digest(): rel.path for rels in data.values() for rel in rels if rel.path is not None } ) assert data assert storage.relation_add(relation, data) for src_sha1 in srcs: relation_compare_result( storage.relation_get(relation, [src_sha1]), {src_sha1: data[src_sha1]}, ) for dst_sha1 in dsts: relation_compare_result( storage.relation_get(relation, [dst_sha1], reverse=True), { src_sha1: { RelationData(dst=dst_sha1, path=rel.path) for rel in rels if dst_sha1 == rel.dst } for src_sha1, rels in data.items() if dst_sha1 in {rel.dst for rel in rels} }, ) relation_compare_result( storage.relation_get_all(relation), data, ) def relation_compare_result( computed: Dict[Sha1Git, Set[RelationData]], expected: Dict[Sha1Git, Set[RelationData]], ) -> None: assert { src_sha1: {RelationData(dst=rel.dst, path=rel.path) for rel in rels} for src_sha1, rels in expected.items() } == computed diff --git a/swh/provenance/tests/test_split_ranges.py b/swh/provenance/tests/test_split_ranges.py index 9a7ab41..762751c 100644 --- a/swh/provenance/tests/test_split_ranges.py +++ b/swh/provenance/tests/test_split_ranges.py @@ -1,137 +1,141 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import pytest from swh.model.hashutil import hash_to_bytes from swh.provenance.storage.interface import RelationData, RelationType from swh.provenance.storage.rabbitmq.client import split_ranges def test_split_ranges_for_relation() -> None: data = { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): { RelationData( hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/1" ), RelationData( hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/2" ), }, hash_to_bytes("d0d8929936631ecbcf9147be6b8aa13b13b014e4"): { RelationData( hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/3" ), }, hash_to_bytes("c1d8929936631ecbcf9147be6b8aa13b13b014e4"): { RelationData( hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/4" ), }, } assert split_ranges(data, "relation_add", RelationType.CNT_EARLY_IN_REV) == { "relation_add.content_in_revision.c": { ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/1", + None, ), ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/2", + None, ), ( hash_to_bytes("c1d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/4", + None, ), }, "relation_add.content_in_revision.d": { ( hash_to_bytes("d0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/3", + None, ), }, } def test_split_ranges_error_for_relation() -> None: set_data = {hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4")} with pytest.raises(AssertionError) as ex: split_ranges(set_data, "relation_add", RelationType.CNT_EARLY_IN_REV) assert "Relation data must be provided in a dictionary" in str(ex.value) tuple_values = { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"): ( hash_to_bytes("3acef14580ea7fd42840ee905c5ce2b0ef9e8174"), b"/path/3", ) } with pytest.raises(AssertionError) as ex: split_ranges(tuple_values, "relation_add", RelationType.CNT_EARLY_IN_REV) assert "Values in the dictionary must be RelationData structures" in str(ex.value) @pytest.mark.parametrize( "entity", ("content", "directory", "origin", "revision"), ) def test_split_ranges_for_entity_without_data(entity: str) -> None: data = { hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("d0d8929936631ecbcf9147be6b8aa13b13b014e4"), hash_to_bytes("c1d8929936631ecbcf9147be6b8aa13b13b014e4"), } meth_name = f"{entity}_add" assert split_ranges(data, meth_name, None) == { f"{meth_name}.unknown.c": { (hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"),), (hash_to_bytes("c1d8929936631ecbcf9147be6b8aa13b13b014e4"),), }, f"{meth_name}.unknown.d": { (hash_to_bytes("d0d8929936631ecbcf9147be6b8aa13b13b014e4"),), }, } @pytest.mark.parametrize( "entity", ("content", "directory", "origin", "revision"), ) def test_split_ranges_for_entity_with_data(entity: str) -> None: data = { hash_to_bytes( "c0d8929936631ecbcf9147be6b8aa13b13b014e4" ): datetime.fromtimestamp(1000000000), hash_to_bytes( "d0d8929936631ecbcf9147be6b8aa13b13b014e4" ): datetime.fromtimestamp(1000000001), hash_to_bytes( "c1d8929936631ecbcf9147be6b8aa13b13b014e4" ): datetime.fromtimestamp(1000000002), } meth_name = f"{entity}_add" assert split_ranges(data, meth_name, None) == { f"{meth_name}.unknown.c": { ( hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), datetime.fromtimestamp(1000000000), ), ( hash_to_bytes("c1d8929936631ecbcf9147be6b8aa13b13b014e4"), datetime.fromtimestamp(1000000002), ), }, f"{meth_name}.unknown.d": { ( hash_to_bytes("d0d8929936631ecbcf9147be6b8aa13b13b014e4"), datetime.fromtimestamp(1000000001), ), }, }