Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9123613
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
134 KB
Subscribers
None
View Options
diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py
index 8928e59..907183c 100644
--- a/swh/provenance/provenance.py
+++ b/swh/provenance/provenance.py
@@ -1,517 +1,518 @@
# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime
+import hashlib
import logging
import os
from types import TracebackType
from typing import Dict, Generator, Iterable, Optional, Set, Tuple, Type
from typing_extensions import Literal, TypedDict
from swh.core.statsd import statsd
from swh.model.model import Sha1Git
from .interface import ProvenanceInterface
from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry
from .storage.interface import (
DirectoryData,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
)
from .util import path_normalize
LOGGER = logging.getLogger(__name__)
BACKEND_DURATION_METRIC = "swh_provenance_backend_duration_seconds"
BACKEND_OPERATIONS_METRIC = "swh_provenance_backend_operations_total"
class DatetimeCache(TypedDict):
data: Dict[Sha1Git, Optional[datetime]] # None means unknown
added: Set[Sha1Git]
class OriginCache(TypedDict):
data: Dict[Sha1Git, str]
added: Set[Sha1Git]
class RevisionCache(TypedDict):
data: Dict[Sha1Git, Sha1Git]
added: Set[Sha1Git]
class ProvenanceCache(TypedDict):
content: DatetimeCache
directory: DatetimeCache
directory_flatten: Dict[Sha1Git, Optional[bool]] # None means unknown
revision: DatetimeCache
# below are insertion caches only
content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]]
content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]]
directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]]
# these two are for the origin layer
origin: OriginCache
revision_origin: RevisionCache
revision_before_revision: Dict[Sha1Git, Set[Sha1Git]]
revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]]
def new_cache() -> ProvenanceCache:
return ProvenanceCache(
content=DatetimeCache(data={}, added=set()),
directory=DatetimeCache(data={}, added=set()),
directory_flatten={},
revision=DatetimeCache(data={}, added=set()),
content_in_revision=set(),
content_in_directory=set(),
directory_in_revision=set(),
origin=OriginCache(data={}, added=set()),
revision_origin=RevisionCache(data={}, added=set()),
revision_before_revision={},
revision_in_origin=set(),
)
class Provenance:
MAX_CACHE_ELEMENTS = 40000
def __init__(self, storage: ProvenanceStorageInterface) -> None:
self.storage = storage
self.cache = new_cache()
def __enter__(self) -> ProvenanceInterface:
self.open()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
def _flush_limit_reached(self) -> bool:
return sum(self._get_cache_stats().values()) > self.MAX_CACHE_ELEMENTS
def _get_cache_stats(self) -> Dict[str, int]:
return {
k: len(v["data"])
if (isinstance(v, dict) and v.get("data") is not None)
else len(v) # type: ignore
for (k, v) in self.cache.items()
}
def clear_caches(self) -> None:
self.cache = new_cache()
def close(self) -> None:
self.storage.close()
@statsd.timed(metric=BACKEND_DURATION_METRIC, tags={"method": "flush"})
def flush(self) -> None:
self.flush_revision_content_layer()
self.flush_origin_revision_layer()
self.clear_caches()
def flush_if_necessary(self) -> bool:
"""Flush if the number of cached information reached a limit."""
LOGGER.debug("Cache stats: %s", self._get_cache_stats())
if self._flush_limit_reached():
self.flush()
return True
else:
return False
@statsd.timed(
metric=BACKEND_DURATION_METRIC, tags={"method": "flush_origin_revision"}
)
def flush_origin_revision_layer(self) -> None:
# Origins and revisions should be inserted first so that internal ids'
# resolution works properly.
urls = {
sha1: url
for sha1, url in self.cache["origin"]["data"].items()
if sha1 in self.cache["origin"]["added"]
}
if urls:
while not self.storage.origin_add(urls):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_origin_revision_retry_origin"},
)
LOGGER.warning(
"Unable to write origins urls to the storage. Retrying..."
)
rev_orgs = {
# Destinations in this relation should match origins in the next one
**{
src: RevisionData(date=None, origin=None)
for src in self.cache["revision_before_revision"]
},
**{
# This relation comes second so that non-None origins take precedence
src: RevisionData(date=None, origin=org)
for src, org in self.cache["revision_in_origin"]
},
}
if rev_orgs:
while not self.storage.revision_add(rev_orgs):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_origin_revision_retry_revision"},
)
LOGGER.warning(
"Unable to write revision entities to the storage. Retrying..."
)
# Second, flat models for revisions' histories (ie. revision-before-revision).
if self.cache["revision_before_revision"]:
rev_before_rev = {
src: {RelationData(dst=dst, path=None) for dst in dsts}
for src, dsts in self.cache["revision_before_revision"].items()
}
while not self.storage.relation_add(
RelationType.REV_BEFORE_REV, rev_before_rev
):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={
"method": "flush_origin_revision_retry_revision_before_revision"
},
)
LOGGER.warning(
"Unable to write %s rows to the storage. Retrying...",
RelationType.REV_BEFORE_REV,
)
# Heads (ie. revision-in-origin entries) should be inserted once flat models for
# their histories were already added. This is to guarantee consistent results if
# something needs to be reprocessed due to a failure: already inserted heads
# won't get reprocessed in such a case.
if self.cache["revision_in_origin"]:
rev_in_org: Dict[Sha1Git, Set[RelationData]] = {}
for src, dst in self.cache["revision_in_origin"]:
rev_in_org.setdefault(src, set()).add(RelationData(dst=dst, path=None))
while not self.storage.relation_add(RelationType.REV_IN_ORG, rev_in_org):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_origin_revision_retry_revision_in_origin"},
)
LOGGER.warning(
"Unable to write %s rows to the storage. Retrying...",
RelationType.REV_IN_ORG,
)
@statsd.timed(
metric=BACKEND_DURATION_METRIC, tags={"method": "flush_revision_content"}
)
def flush_revision_content_layer(self) -> None:
# Register in the storage all entities, to ensure the coming relations can
# properly resolve any internal reference if needed. Content and directory
# entries may safely be registered with their associated dates. In contrast,
# revision entries should be registered without date, as it is used to
# acknowledge that the flushing was successful. Also, directories are
# registered with their flatten flag not set.
cnt_dates = {
sha1: date
for sha1, date in self.cache["content"]["data"].items()
if sha1 in self.cache["content"]["added"] and date is not None
}
if cnt_dates:
while not self.storage.content_add(cnt_dates):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_content_date"},
)
LOGGER.warning(
"Unable to write content dates to the storage. Retrying..."
)
dir_dates = {
sha1: DirectoryData(date=date, flat=False)
for sha1, date in self.cache["directory"]["data"].items()
if sha1 in self.cache["directory"]["added"] and date is not None
}
if dir_dates:
while not self.storage.directory_add(dir_dates):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_directory_date"},
)
LOGGER.warning(
"Unable to write directory dates to the storage. Retrying..."
)
revs = {
- sha1
+ sha1: RevisionData(date=None, origin=None)
for sha1, date in self.cache["revision"]["data"].items()
if sha1 in self.cache["revision"]["added"] and date is not None
}
if revs:
while not self.storage.revision_add(revs):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_revision_none"},
)
LOGGER.warning(
"Unable to write revision entities to the storage. Retrying..."
)
paths = {
- path
+ hashlib.sha1(path).digest(): path
for _, _, path in self.cache["content_in_revision"]
| self.cache["content_in_directory"]
| self.cache["directory_in_revision"]
}
if paths:
while not self.storage.location_add(paths):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_location"},
)
LOGGER.warning(
"Unable to write locations entities to the storage. Retrying..."
)
# For this layer, relations need to be inserted first so that, in case of
# failure, reprocessing the input does not generated an inconsistent database.
if self.cache["content_in_revision"]:
cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {}
for src, dst, path in self.cache["content_in_revision"]:
cnt_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path))
while not self.storage.relation_add(
RelationType.CNT_EARLY_IN_REV, cnt_in_rev
):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_content_in_revision"},
)
LOGGER.warning(
"Unable to write %s rows to the storage. Retrying...",
RelationType.CNT_EARLY_IN_REV,
)
if self.cache["content_in_directory"]:
cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {}
for src, dst, path in self.cache["content_in_directory"]:
cnt_in_dir.setdefault(src, set()).add(RelationData(dst=dst, path=path))
while not self.storage.relation_add(RelationType.CNT_IN_DIR, cnt_in_dir):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={
"method": "flush_revision_content_retry_content_in_directory"
},
)
LOGGER.warning(
"Unable to write %s rows to the storage. Retrying...",
RelationType.CNT_IN_DIR,
)
if self.cache["directory_in_revision"]:
dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {}
for src, dst, path in self.cache["directory_in_revision"]:
dir_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path))
while not self.storage.relation_add(RelationType.DIR_IN_REV, dir_in_rev):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={
"method": "flush_revision_content_retry_directory_in_revision"
},
)
LOGGER.warning(
"Unable to write %s rows to the storage. Retrying...",
RelationType.DIR_IN_REV,
)
# After relations, flatten flags for directories can be safely set (if
# applicable) acknowledging those directories that have already be flattened.
# Similarly, dates for the revisions are set to acknowledge that these revisions
# won't need to be reprocessed in case of failure.
dir_acks = {
sha1: DirectoryData(
date=date, flat=self.cache["directory_flatten"].get(sha1) or False
)
for sha1, date in self.cache["directory"]["data"].items()
if self.cache["directory_flatten"].get(sha1) and date is not None
}
if dir_acks:
while not self.storage.directory_add(dir_acks):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_directory_ack"},
)
LOGGER.warning(
"Unable to write directory dates to the storage. Retrying..."
)
rev_dates = {
sha1: RevisionData(date=date, origin=None)
for sha1, date in self.cache["revision"]["data"].items()
if sha1 in self.cache["revision"]["added"] and date is not None
}
if rev_dates:
while not self.storage.revision_add(rev_dates):
statsd.increment(
metric=BACKEND_OPERATIONS_METRIC,
tags={"method": "flush_revision_content_retry_revision_date"},
)
LOGGER.warning(
"Unable to write revision dates to the storage. Retrying..."
)
def content_add_to_directory(
self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes
) -> None:
self.cache["content_in_directory"].add(
(blob.id, directory.id, path_normalize(os.path.join(prefix, blob.name)))
)
def content_add_to_revision(
self, revision: RevisionEntry, blob: FileEntry, prefix: bytes
) -> None:
self.cache["content_in_revision"].add(
(blob.id, revision.id, path_normalize(os.path.join(prefix, blob.name)))
)
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
return self.storage.content_find_first(id)
def content_find_all(
self, id: Sha1Git, limit: Optional[int] = None
) -> Generator[ProvenanceResult, None, None]:
yield from self.storage.content_find_all(id, limit=limit)
def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]:
return self.get_dates("content", [blob.id]).get(blob.id)
def content_get_early_dates(
self, blobs: Iterable[FileEntry]
) -> Dict[Sha1Git, datetime]:
return self.get_dates("content", [blob.id for blob in blobs])
def content_set_early_date(self, blob: FileEntry, date: datetime) -> None:
self.cache["content"]["data"][blob.id] = date
self.cache["content"]["added"].add(blob.id)
def directory_add_to_revision(
self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes
) -> None:
self.cache["directory_in_revision"].add(
(directory.id, revision.id, path_normalize(path))
)
def directory_already_flattenned(self, directory: DirectoryEntry) -> Optional[bool]:
cache = self.cache["directory_flatten"]
if directory.id not in cache:
cache.setdefault(directory.id, None)
ret = self.storage.directory_get([directory.id])
if directory.id in ret:
dir = ret[directory.id]
cache[directory.id] = dir.flat
# date is kept to ensure we have it available when flushing
self.cache["directory"]["data"][directory.id] = dir.date
return cache.get(directory.id)
def directory_flag_as_flattenned(self, directory: DirectoryEntry) -> None:
self.cache["directory_flatten"][directory.id] = True
def directory_get_date_in_isochrone_frontier(
self, directory: DirectoryEntry
) -> Optional[datetime]:
return self.get_dates("directory", [directory.id]).get(directory.id)
def directory_get_dates_in_isochrone_frontier(
self, dirs: Iterable[DirectoryEntry]
) -> Dict[Sha1Git, datetime]:
return self.get_dates("directory", [directory.id for directory in dirs])
def directory_set_date_in_isochrone_frontier(
self, directory: DirectoryEntry, date: datetime
) -> None:
self.cache["directory"]["data"][directory.id] = date
self.cache["directory"]["added"].add(directory.id)
def get_dates(
self,
entity: Literal["content", "directory", "revision"],
ids: Iterable[Sha1Git],
) -> Dict[Sha1Git, datetime]:
cache = self.cache[entity]
missing_ids = set(id for id in ids if id not in cache)
if missing_ids:
if entity == "content":
cache["data"].update(self.storage.content_get(missing_ids))
elif entity == "directory":
cache["data"].update(
{
id: dir.date
for id, dir in self.storage.directory_get(missing_ids).items()
}
)
elif entity == "revision":
cache["data"].update(
{
id: rev.date
for id, rev in self.storage.revision_get(missing_ids).items()
}
)
dates: Dict[Sha1Git, datetime] = {}
for sha1 in ids:
date = cache["data"].setdefault(sha1, None)
if date is not None:
dates[sha1] = date
return dates
def open(self) -> None:
self.storage.open()
def origin_add(self, origin: OriginEntry) -> None:
self.cache["origin"]["data"][origin.id] = origin.url
self.cache["origin"]["added"].add(origin.id)
def revision_add(self, revision: RevisionEntry) -> None:
self.cache["revision"]["data"][revision.id] = revision.date
self.cache["revision"]["added"].add(revision.id)
def revision_add_before_revision(
self, head_id: Sha1Git, revision_id: Sha1Git
) -> None:
self.cache["revision_before_revision"].setdefault(revision_id, set()).add(
head_id
)
def revision_add_to_origin(
self, origin: OriginEntry, revision: RevisionEntry
) -> None:
self.cache["revision_in_origin"].add((revision.id, origin.id))
def revision_is_head(self, revision: RevisionEntry) -> bool:
return bool(self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id]))
def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]:
return self.get_dates("revision", [revision.id]).get(revision.id)
def revision_get_preferred_origin(self, revision_id: Sha1Git) -> Optional[Sha1Git]:
cache = self.cache["revision_origin"]["data"]
if revision_id not in cache:
ret = self.storage.revision_get([revision_id])
if revision_id in ret:
origin = ret[revision_id].origin
if origin is not None:
cache[revision_id] = origin
return cache.get(revision_id)
def revision_set_preferred_origin(
self, origin: OriginEntry, revision_id: Sha1Git
) -> None:
self.cache["revision_origin"]["data"][revision_id] = origin.id
self.cache["revision_origin"]["added"].add(revision_id)
diff --git a/swh/provenance/storage/interface.py b/swh/provenance/storage/interface.py
index a59cca1..c0ba7d1 100644
--- a/swh/provenance/storage/interface.py
+++ b/swh/provenance/storage/interface.py
@@ -1,231 +1,229 @@
# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
import enum
from types import TracebackType
-from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union
+from typing import Dict, Generator, Iterable, List, Optional, Set, Type
from typing_extensions import Protocol, runtime_checkable
from swh.core.api import remote_api_endpoint
from swh.model.model import Sha1Git
class EntityType(enum.Enum):
CONTENT = "content"
DIRECTORY = "directory"
REVISION = "revision"
ORIGIN = "origin"
class RelationType(enum.Enum):
CNT_EARLY_IN_REV = "content_in_revision"
CNT_IN_DIR = "content_in_directory"
DIR_IN_REV = "directory_in_revision"
REV_IN_ORG = "revision_in_origin"
REV_BEFORE_REV = "revision_before_revision"
@dataclass(eq=True, frozen=True)
class ProvenanceResult:
content: Sha1Git
revision: Sha1Git
date: datetime
origin: Optional[str]
path: bytes
@dataclass(eq=True, frozen=True)
class DirectoryData:
"""Object representing the data associated to a directory in the provenance model,
where `date` is the date of the directory in the isochrone frontier, and `flat` is a
flag acknowledging that a flat model for the elements outside the frontier has
already been created.
"""
date: datetime
flat: bool
@dataclass(eq=True, frozen=True)
class RevisionData:
"""Object representing the data associated to a revision in the provenance model,
where `date` is the optional date of the revision (specifying it acknowledges that
the revision was already processed by the revision-content algorithm); and `origin`
identifies the preferred origin for the revision, if any.
"""
date: Optional[datetime]
origin: Optional[Sha1Git]
@dataclass(eq=True, frozen=True)
class RelationData:
"""Object representing a relation entry in the provenance model, where `src` and
`dst` are the sha1 ids of the entities being related, and `path` is optional
depending on the relation being represented.
"""
dst: Sha1Git
path: Optional[bytes]
@runtime_checkable
class ProvenanceStorageInterface(Protocol):
def __enter__(self) -> ProvenanceStorageInterface:
...
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
...
@remote_api_endpoint("close")
def close(self) -> None:
"""Close connection to the storage and release resources."""
...
@remote_api_endpoint("content_add")
def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool:
"""Add blobs identified by sha1 ids, with an associated date (as paired in
`cnts`) to the provenance storage. Return a boolean stating whether the
information was successfully stored.
"""
...
@remote_api_endpoint("content_find_first")
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
"""Retrieve the first occurrence of the blob identified by `id`."""
...
@remote_api_endpoint("content_find_all")
def content_find_all(
self, id: Sha1Git, limit: Optional[int] = None
) -> Generator[ProvenanceResult, None, None]:
"""Retrieve all the occurrences of the blob identified by `id`."""
...
@remote_api_endpoint("content_get")
def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
"""Retrieve the associated date for each blob sha1 in `ids`."""
...
@remote_api_endpoint("directory_add")
def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool:
"""Add directories identified by sha1 ids, with associated date and (optional)
flatten flag (as paired in `dirs`) to the provenance storage. If the flatten
flag is set to None, the previous value present in the storage is preserved.
Return a boolean stating if the information was successfully stored.
"""
...
@remote_api_endpoint("directory_get")
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]:
"""Retrieve the associated date and (optional) flatten flag for each directory
sha1 in `ids`. If some directories has no associated date, it is not present in
the resulting dictionary.
"""
...
@remote_api_endpoint("directory_iter_not_flattenned")
def directory_iter_not_flattenned(
self, limit: int, start_id: Sha1Git
) -> List[Sha1Git]:
"""Retrieve the unflattenned directories after ``start_id`` up to ``limit`` entries."""
...
@remote_api_endpoint("entity_get_all")
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
"""Retrieve all sha1 ids for entities of type `entity` present in the provenance
model. This method is used only in tests.
"""
...
@remote_api_endpoint("location_add")
- def location_add(self, paths: Iterable[bytes]) -> bool:
+ def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool:
"""Register the given `paths` in the storage."""
...
@remote_api_endpoint("location_get_all")
- def location_get_all(self) -> Set[bytes]:
+ def location_get_all(self) -> Dict[Sha1Git, bytes]:
"""Retrieve all paths present in the provenance model.
This method is used only in tests."""
...
@remote_api_endpoint("open")
def open(self) -> None:
"""Open connection to the storage and allocate necessary resources."""
...
@remote_api_endpoint("origin_add")
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
"""Add origins identified by sha1 ids, with their corresponding url (as paired
in `orgs`) to the provenance storage. Return a boolean stating if the
information was successfully stored.
"""
...
@remote_api_endpoint("origin_get")
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
"""Retrieve the associated url for each origin sha1 in `ids`."""
...
@remote_api_endpoint("revision_add")
- def revision_add(
- self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]]
- ) -> bool:
+ def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool:
"""Add revisions identified by sha1 ids, with optional associated date or origin
(as paired in `revs`) to the provenance storage. Return a boolean stating if the
information was successfully stored.
"""
...
@remote_api_endpoint("revision_get")
def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]:
"""Retrieve the associated date and origin for each revision sha1 in `ids`. If
some revision has no associated date nor origin, it is not present in the
resulting dictionary.
"""
...
@remote_api_endpoint("relation_add")
def relation_add(
self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
) -> bool:
"""Add entries in the selected `relation`. This method assumes all entities
being related are already registered in the storage. See `content_add`,
`directory_add`, `origin_add`, and `revision_add`.
"""
...
@remote_api_endpoint("relation_get")
def relation_get(
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False
) -> Dict[Sha1Git, Set[RelationData]]:
"""Retrieve all entries in the selected `relation` whose source entities are
identified by some sha1 id in `ids`. If `reverse` is set, destination entities
are matched instead.
"""
...
@remote_api_endpoint("relation_get_all")
def relation_get_all(
self, relation: RelationType
) -> Dict[Sha1Git, Set[RelationData]]:
"""Retrieve all entries in the selected `relation` that are present in the
provenance model. This method is used only in tests.
"""
...
@remote_api_endpoint("with_path")
def with_path(self) -> bool:
...
diff --git a/swh/provenance/storage/postgresql.py b/swh/provenance/storage/postgresql.py
index cccc5d5..f65f092 100644
--- a/swh/provenance/storage/postgresql.py
+++ b/swh/provenance/storage/postgresql.py
@@ -1,402 +1,398 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
+from hashlib import sha1
import itertools
import logging
from types import TracebackType
-from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union
+from typing import Dict, Generator, Iterable, List, Optional, Set, Type
import psycopg2.extensions
import psycopg2.extras
from swh.core.db import BaseDb
from swh.core.statsd import statsd
from swh.model.model import Sha1Git
from swh.provenance.storage.interface import (
DirectoryData,
EntityType,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
)
LOGGER = logging.getLogger(__name__)
STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds"
def handle_raise_on_commit(f):
@wraps(f)
def handle(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except BaseException as ex:
# Unexpected error occurred, rollback all changes and log message
LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise ex
return False
return handle
class ProvenanceStoragePostgreSql:
current_version = 3
def __init__(
self, page_size: Optional[int] = None, raise_on_commit: bool = False, **kwargs
) -> None:
self.conn: Optional[psycopg2.extensions.connection] = None
self.conn_args = kwargs
self._flavor: Optional[str] = None
self.page_size = page_size
self.raise_on_commit = raise_on_commit
def __enter__(self) -> ProvenanceStorageInterface:
self.open()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
@contextmanager
def transaction(
self, readonly: bool = False
) -> Generator[psycopg2.extras.RealDictCursor, None, None]:
if self.conn is None:
raise RuntimeError(
"Tried to access ProvenanceStoragePostgreSQL transaction() without opening it"
)
self.conn.set_session(readonly=readonly)
with self.conn:
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
yield cur
@property
def flavor(self) -> str:
if self._flavor is None:
with self.transaction(readonly=True) as cursor:
cursor.execute("SELECT swh_get_dbflavor() AS flavor")
flavor = cursor.fetchone()
assert flavor # please mypy
self._flavor = flavor["flavor"]
assert self._flavor is not None
return self._flavor
@property
def denormalized(self) -> bool:
return "denormalized" in self.flavor
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"})
def close(self) -> None:
assert self.conn is not None
self.conn.close()
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"})
@handle_raise_on_commit
def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool:
if cnts:
sql = """
INSERT INTO content(sha1, date) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET date=LEAST(EXCLUDED.date,content.date)
"""
page_size = self.page_size or len(cnts)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cursor, sql, argslist=cnts.items(), page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"})
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
sql = "SELECT * FROM swh_provenance_content_find_first(%s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=(id,))
row = cursor.fetchone()
return ProvenanceResult(**row) if row is not None else None
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"})
def content_find_all(
self, id: Sha1Git, limit: Optional[int] = None
) -> Generator[ProvenanceResult, None, None]:
sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=(id, limit))
yield from (ProvenanceResult(**row) for row in cursor)
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"})
def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
dates: Dict[Sha1Git, datetime] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date
FROM content
WHERE sha1 IN ({values})
AND date IS NOT NULL
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
dates.update((row["sha1"], row["date"]) for row in cursor)
return dates
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"})
@handle_raise_on_commit
def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool:
data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()]
if data:
sql = """
INSERT INTO directory(sha1, date, flat) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET
date=LEAST(EXCLUDED.date, directory.date),
flat=(EXCLUDED.flat OR directory.flat)
"""
page_size = self.page_size or len(data)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor, sql=sql, argslist=data, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"})
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]:
result: Dict[Sha1Git, DirectoryData] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date, flat
FROM directory
WHERE sha1 IN ({values})
AND date IS NOT NULL
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
result.update(
(row["sha1"], DirectoryData(date=row["date"], flat=row["flat"]))
for row in cursor
)
return result
@statsd.timed(
metric=STORAGE_DURATION_METRIC, tags={"method": "directory_iter_not_flattenned"}
)
def directory_iter_not_flattenned(
self, limit: int, start_id: Sha1Git
) -> List[Sha1Git]:
sql = """
SELECT sha1 FROM directory
WHERE flat=false AND sha1>%s ORDER BY sha1 LIMIT %s
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=(start_id, limit))
return [row["sha1"] for row in cursor]
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"})
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
with self.transaction(readonly=True) as cursor:
cursor.execute(f"SELECT sha1 FROM {entity.value}")
return {row["sha1"] for row in cursor}
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"})
@handle_raise_on_commit
- def location_add(self, paths: Iterable[bytes]) -> bool:
+ def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool:
if self.with_path():
- values = [(path,) for path in paths]
+ values = [(path,) for path in paths.values()]
if values:
sql = """
INSERT INTO location(path) VALUES %s
ON CONFLICT DO NOTHING
"""
page_size = self.page_size or len(values)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cursor, sql, argslist=values, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"})
- def location_get_all(self) -> Set[bytes]:
+ def location_get_all(self) -> Dict[Sha1Git, bytes]:
with self.transaction(readonly=True) as cursor:
cursor.execute("SELECT location.path AS path FROM location")
- return {row["path"] for row in cursor}
+ return {sha1(row["path"]).digest(): row["path"] for row in cursor}
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"})
@handle_raise_on_commit
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
if orgs:
sql = """
INSERT INTO origin(sha1, url) VALUES %s
ON CONFLICT DO NOTHING
"""
page_size = self.page_size or len(orgs)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor,
sql=sql,
argslist=orgs.items(),
page_size=page_size,
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"})
def open(self) -> None:
self.conn = BaseDb.connect(**self.conn_args).conn
BaseDb.adapt_conn(self.conn)
with self.transaction() as cursor:
cursor.execute("SET timezone TO 'UTC'")
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"})
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
urls: Dict[Sha1Git, str] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, url
FROM origin
WHERE sha1 IN ({values})
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
urls.update((row["sha1"], row["url"]) for row in cursor)
return urls
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"})
@handle_raise_on_commit
- def revision_add(
- self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]]
- ) -> bool:
- if isinstance(revs, dict):
+ def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool:
+ if revs:
data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()]
- else:
- data = [(sha1, None, None) for sha1 in revs]
- if data:
sql = """
INSERT INTO revision(sha1, date, origin)
(SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin
FROM (VALUES %s) AS V(rev, date, org)
LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git))
ON CONFLICT (sha1) DO
UPDATE SET
date=LEAST(EXCLUDED.date, revision.date),
origin=COALESCE(EXCLUDED.origin, revision.origin)
"""
page_size = self.page_size or len(data)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor, sql=sql, argslist=data, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"})
def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]:
result: Dict[Sha1Git, RevisionData] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT R.sha1, R.date, O.sha1 AS origin
FROM revision AS R
LEFT JOIN origin AS O ON (O.id=R.origin)
WHERE R.sha1 IN ({values})
AND (R.date is not NULL OR O.sha1 is not NULL)
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
result.update(
(row["sha1"], RevisionData(date=row["date"], origin=row["origin"]))
for row in cursor
)
return result
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"})
@handle_raise_on_commit
def relation_add(
self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
) -> bool:
rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts]
if rows:
rel_table = relation.value
src_table, *_, dst_table = rel_table.split("_")
page_size = self.page_size or len(rows)
# Put the next three queries in a manual single transaction:
# they use the same temp table
with self.transaction() as cursor:
cursor.execute("SELECT swh_mktemp_relation_add()")
psycopg2.extras.execute_values(
cur=cursor,
sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s",
argslist=rows,
page_size=page_size,
)
sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)"
cursor.execute(query=sql, vars=(rel_table, src_table, dst_table))
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"})
def relation_get(
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False
) -> Dict[Sha1Git, Set[RelationData]]:
return self._relation_get(relation, ids, reverse)
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"})
def relation_get_all(
self, relation: RelationType
) -> Dict[Sha1Git, Set[RelationData]]:
return self._relation_get(relation, None)
def _relation_get(
self,
relation: RelationType,
ids: Optional[Iterable[Sha1Git]],
reverse: bool = False,
) -> Dict[Sha1Git, Set[RelationData]]:
result: Dict[Sha1Git, Set[RelationData]] = {}
sha1s: List[Sha1Git]
if ids is not None:
sha1s = list(ids)
filter = "filter-src" if not reverse else "filter-dst"
else:
sha1s = []
filter = "no-filter"
if filter == "no-filter" or sha1s:
rel_table = relation.value
src_table, *_, dst_table = rel_table.split("_")
sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(
query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s)
)
for row in cursor:
src = row.pop("src")
result.setdefault(src, set()).add(RelationData(**row))
return result
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"})
def with_path(self) -> bool:
return "with-path" in self.flavor
diff --git a/swh/provenance/storage/rabbitmq/client.py b/swh/provenance/storage/rabbitmq/client.py
index b5d993a..f45c690 100644
--- a/swh/provenance/storage/rabbitmq/client.py
+++ b/swh/provenance/storage/rabbitmq/client.py
@@ -1,507 +1,508 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
import functools
import inspect
import logging
import queue
import threading
import time
from types import TracebackType
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
import uuid
import pika
import pika.channel
import pika.connection
import pika.frame
import pika.spec
from swh.core.api.serializers import encode_data_client as encode_data
from swh.core.api.serializers import msgpack_loads as decode_data
from swh.core.statsd import statsd
from swh.provenance.storage import get_provenance_storage
from swh.provenance.storage.interface import (
ProvenanceStorageInterface,
RelationData,
RelationType,
)
from .serializers import DECODERS, ENCODERS
from .server import ProvenanceStorageRabbitMQServer
LOG_FORMAT = (
"%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
"-35s %(lineno) -5d: %(message)s"
)
LOGGER = logging.getLogger(__name__)
STORAGE_DURATION_METRIC = "swh_provenance_storage_rabbitmq_duration_seconds"
class ResponseTimeout(Exception):
pass
class TerminateSignal(Exception):
pass
def split_ranges(
data: Iterable[bytes], meth_name: str, relation: Optional[RelationType] = None
) -> Dict[str, Set[Tuple[Any, ...]]]:
ranges: Dict[str, Set[Tuple[Any, ...]]] = {}
if relation is not None:
assert isinstance(data, dict), "Relation data must be provided in a dictionary"
for src, dsts in data.items():
key = ProvenanceStorageRabbitMQServer.get_routing_key(
src, meth_name, relation
)
for rel in dsts:
assert isinstance(
rel, RelationData
), "Values in the dictionary must be RelationData structures"
ranges.setdefault(key, set()).add((src, rel.dst, rel.path))
else:
items: Union[Set[Tuple[bytes, Any]], Set[Tuple[bytes]]]
if isinstance(data, dict):
items = set(data.items())
else:
+ # TODO this is probably not used any more
items = {(item,) for item in data}
for id, *rest in items:
key = ProvenanceStorageRabbitMQServer.get_routing_key(id, meth_name)
ranges.setdefault(key, set()).add((id, *rest))
return ranges
class MetaRabbitMQClient(type):
def __new__(cls, name, bases, attributes):
# For each method wrapped with @remote_api_endpoint in an API backend
# (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new
# method in RemoteStorage, with the same documentation.
#
# Note that, despite the usage of decorator magic (eg. functools.wrap),
# this never actually calls an IndexerStorage method.
backend_class = attributes.get("backend_class", None)
for base in bases:
if backend_class is not None:
break
backend_class = getattr(base, "backend_class", None)
if backend_class:
for meth_name, meth in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
cls.__add_endpoint(meth_name, meth, attributes)
return super().__new__(cls, name, bases, attributes)
@staticmethod
def __add_endpoint(meth_name, meth, attributes):
wrapped_meth = inspect.unwrap(meth)
@functools.wraps(meth) # Copy signature and doc
def meth_(*args, **kwargs):
with statsd.timed(
metric=STORAGE_DURATION_METRIC, tags={"method": meth_name}
):
# Match arguments and parameters
data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
# Remove arguments that should not be passed
self = data.pop("self")
# Call storage method with remaining arguments
return getattr(self._storage, meth_name)(**data)
@functools.wraps(meth) # Copy signature and doc
def write_meth_(*args, **kwargs):
with statsd.timed(
metric=STORAGE_DURATION_METRIC, tags={"method": meth_name}
):
# Match arguments and parameters
post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs)
try:
# Remove arguments that should not be passed
self = post_data.pop("self")
relation = post_data.pop("relation", None)
assert len(post_data) == 1
data = next(iter(post_data.values()))
ranges = split_ranges(data, meth_name, relation)
acks_expected = sum(len(items) for items in ranges.values())
self._correlation_id = str(uuid.uuid4())
exchange = ProvenanceStorageRabbitMQServer.get_exchange(
meth_name, relation
)
try:
self._delay_close = True
for routing_key, items in ranges.items():
items_list = list(items)
batches = (
items_list[idx : idx + self._batch_size]
for idx in range(0, len(items_list), self._batch_size)
)
for batch in batches:
# FIXME: this is running in a different thread! Hence, if
# self._connection drops, there is no guarantee that the
# request can be sent for the current elements. This
# situation should be handled properly.
self._connection.ioloop.add_callback_threadsafe(
functools.partial(
ProvenanceStorageRabbitMQClient.request,
channel=self._channel,
reply_to=self._callback_queue,
exchange=exchange,
routing_key=routing_key,
correlation_id=self._correlation_id,
data=batch,
)
)
return self.wait_for_acks(meth_name, acks_expected)
finally:
self._delay_close = False
except BaseException as ex:
self.request_termination(str(ex))
return False
if meth_name not in attributes:
attributes[meth_name] = (
write_meth_
if ProvenanceStorageRabbitMQServer.is_write_method(meth_name)
else meth_
)
class ProvenanceStorageRabbitMQClient(threading.Thread, metaclass=MetaRabbitMQClient):
backend_class = ProvenanceStorageInterface
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
def __init__(
self,
url: str,
storage_config: Dict[str, Any],
batch_size: int = 100,
prefetch_count: int = 100,
wait_min: float = 60,
wait_per_batch: float = 10,
) -> None:
"""Setup the client object, passing in the URL we will use to connect to
RabbitMQ, and the connection information for the local storage object used
for read-only operations.
:param str url: The URL for connecting to RabbitMQ
:param dict storage_config: Configuration parameters for the underlying
``ProvenanceStorage`` object expected by
``swh.provenance.get_provenance_storage``
:param int batch_size: Max amount of elements per package (after range
splitting) for writing operations
:param int prefetch_count: Prefetch value for the RabbitMQ connection when
receiving ack packages
:param float wait_min: Min waiting time for response on a writing operation, in
seconds
:param float wait_per_batch: Waiting time for response per batch of items on a
writing operation, in seconds
"""
super().__init__()
self._connection = None
self._callback_queue: Optional[str] = None
self._channel = None
self._closing = False
self._consumer_tag = None
self._consuming = False
self._correlation_id = str(uuid.uuid4())
self._prefetch_count = prefetch_count
self._batch_size = batch_size
self._response_queue: queue.Queue = queue.Queue()
self._storage = get_provenance_storage(**storage_config)
self._url = url
self._wait_min = wait_min
self._wait_per_batch = wait_per_batch
self._delay_close = False
def __enter__(self) -> ProvenanceStorageInterface:
self.open()
assert isinstance(self, ProvenanceStorageInterface)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"})
def open(self) -> None:
self.start()
while self._callback_queue is None:
time.sleep(0.1)
self._storage.open()
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"})
def close(self) -> None:
assert self._connection is not None
self._connection.ioloop.add_callback_threadsafe(self.request_termination)
self.join()
self._storage.close()
def request_termination(self, reason: str = "Normal shutdown") -> None:
assert self._connection is not None
def termination_callback():
raise TerminateSignal(reason)
self._connection.ioloop.add_callback_threadsafe(termination_callback)
def connect(self) -> pika.SelectConnection:
LOGGER.info("Connecting to %s", self._url)
return pika.SelectConnection(
parameters=pika.URLParameters(self._url),
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_closed,
)
def close_connection(self) -> None:
assert self._connection is not None
self._consuming = False
if self._connection.is_closing or self._connection.is_closed:
LOGGER.info("Connection is closing or already closed")
else:
LOGGER.info("Closing connection")
self._connection.close()
def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
LOGGER.info("Connection opened")
self.open_channel()
def on_connection_open_error(
self, _unused_connection: pika.SelectConnection, err: Exception
) -> None:
LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err)
assert self._connection is not None
self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
assert self._connection is not None
self._channel = None
if self._closing:
self._connection.ioloop.stop()
else:
LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason)
self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
def open_channel(self) -> None:
LOGGER.debug("Creating a new channel")
assert self._connection is not None
self._connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel: pika.channel.Channel) -> None:
LOGGER.debug("Channel opened")
self._channel = channel
LOGGER.debug("Adding channel close callback")
assert self._channel is not None
self._channel.add_on_close_callback(callback=self.on_channel_closed)
self.setup_queue()
def on_channel_closed(
self, channel: pika.channel.Channel, reason: Exception
) -> None:
LOGGER.warning("Channel %i was closed: %s", channel, reason)
self.close_connection()
def setup_queue(self) -> None:
LOGGER.debug("Declaring callback queue")
assert self._channel is not None
self._channel.queue_declare(
queue="", exclusive=True, callback=self.on_queue_declare_ok
)
def on_queue_declare_ok(self, frame: pika.frame.Method) -> None:
LOGGER.debug("Binding queue to default exchanger")
assert self._channel is not None
self._callback_queue = frame.method.queue
self._channel.basic_qos(
prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
)
def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
LOGGER.debug("QOS set to: %d", self._prefetch_count)
self.start_consuming()
def start_consuming(self) -> None:
LOGGER.debug("Issuing consumer related RPC commands")
LOGGER.debug("Adding consumer cancellation callback")
assert self._channel is not None
self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
assert self._callback_queue is not None
self._consumer_tag = self._channel.basic_consume(
queue=self._callback_queue, on_message_callback=self.on_response
)
self._consuming = True
def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
LOGGER.debug("Consumer was cancelled remotely, shutting down: %r", method_frame)
if self._channel:
self._channel.close()
def on_response(
self,
channel: pika.channel.Channel,
deliver: pika.spec.Basic.Deliver,
properties: pika.spec.BasicProperties,
body: bytes,
) -> None:
LOGGER.debug(
"Received message # %s from %s: %s",
deliver.delivery_tag,
properties.app_id,
body,
)
self._response_queue.put(
(
properties.correlation_id,
decode_data(body, extra_decoders=self.extra_type_decoders),
)
)
LOGGER.debug("Acknowledging message %s", deliver.delivery_tag)
channel.basic_ack(delivery_tag=deliver.delivery_tag)
def stop_consuming(self) -> None:
if self._channel:
LOGGER.debug("Sending a Basic.Cancel RPC command to RabbitMQ")
self._channel.basic_cancel(self._consumer_tag, self.on_cancel_ok)
def on_cancel_ok(self, _unused_frame: pika.frame.Method) -> None:
self._consuming = False
LOGGER.debug(
"RabbitMQ acknowledged the cancellation of the consumer: %s",
self._consumer_tag,
)
LOGGER.debug("Closing the channel")
assert self._channel is not None
self._channel.close()
def run(self) -> None:
while not self._closing:
try:
self._connection = self.connect()
assert self._connection is not None
self._connection.ioloop.start()
except KeyboardInterrupt:
LOGGER.info("Connection closed by keyboard interruption, reopening")
if self._connection is not None:
self._connection.ioloop.stop()
except TerminateSignal as ex:
LOGGER.info("Termination requested: %s", ex)
self.stop()
if self._connection is not None and not self._connection.is_closed:
# Finish closing
self._connection.ioloop.start()
except BaseException as ex:
LOGGER.warning("Unexpected exception, terminating: %s", ex)
self.stop()
if self._connection is not None and not self._connection.is_closed:
# Finish closing
self._connection.ioloop.start()
LOGGER.info("Stopped")
def stop(self) -> None:
assert self._connection is not None
if not self._closing:
if self._delay_close:
LOGGER.info("Delaying termination: waiting for a pending request")
delay_start = time.monotonic()
wait = 1
while self._delay_close:
if wait >= 32:
LOGGER.warning(
"Still waiting for pending request (for %2f seconds)...",
time.monotonic() - delay_start,
)
time.sleep(wait)
wait = min(wait * 2, 60)
self._closing = True
LOGGER.info("Stopping")
if self._consuming:
self.stop_consuming()
self._connection.ioloop.start()
else:
self._connection.ioloop.stop()
LOGGER.info("Stopped")
@staticmethod
def request(
channel: pika.channel.Channel,
reply_to: str,
exchange: str,
routing_key: str,
correlation_id: str,
**kwargs,
) -> None:
channel.basic_publish(
exchange=exchange,
routing_key=routing_key,
properties=pika.BasicProperties(
content_type="application/msgpack",
correlation_id=correlation_id,
reply_to=reply_to,
),
body=encode_data(
kwargs,
extra_encoders=ProvenanceStorageRabbitMQClient.extra_type_encoders,
),
)
def wait_for_acks(self, meth_name: str, acks_expected: int) -> bool:
acks_received = 0
timeout = max(
(acks_expected / self._batch_size) * self._wait_per_batch,
self._wait_min,
)
start = time.monotonic()
end = start + timeout
while acks_received < acks_expected:
local_timeout = end - time.monotonic()
if local_timeout < 1.0:
local_timeout = 1.0
try:
acks_received += self.wait_for_response(timeout=local_timeout)
except ResponseTimeout:
LOGGER.warning(
"Timed out waiting for acks in %s, %s received, %s expected (in %ss)",
meth_name,
acks_received,
acks_expected,
time.monotonic() - start,
)
return False
return acks_received == acks_expected
def wait_for_response(self, timeout: float = 120.0) -> Any:
start = time.monotonic()
end = start + timeout
while True:
try:
local_timeout = end - time.monotonic()
if local_timeout < 1.0:
local_timeout = 1.0
correlation_id, response = self._response_queue.get(
timeout=local_timeout
)
if correlation_id == self._correlation_id:
return response
except queue.Empty:
raise ResponseTimeout
diff --git a/swh/provenance/storage/rabbitmq/server.py b/swh/provenance/storage/rabbitmq/server.py
index 1e2b072..2dc7052 100644
--- a/swh/provenance/storage/rabbitmq/server.py
+++ b/swh/provenance/storage/rabbitmq/server.py
@@ -1,738 +1,738 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import Counter
from datetime import datetime
from enum import Enum
import functools
import logging
import multiprocessing
import os
import queue
import threading
from typing import Any, Callable
from typing import Counter as TCounter
from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast
import pika
import pika.channel
import pika.connection
import pika.exceptions
from pika.exchange_type import ExchangeType
import pika.frame
import pika.spec
from swh.core import config
from swh.core.api.serializers import encode_data_client as encode_data
from swh.core.api.serializers import msgpack_loads as decode_data
from swh.model.hashutil import hash_to_hex
from swh.model.model import Sha1Git
from swh.provenance.storage.interface import (
DirectoryData,
EntityType,
RelationData,
RelationType,
RevisionData,
)
from swh.provenance.util import path_id
from .serializers import DECODERS, ENCODERS
LOG_FORMAT = (
"%(levelname) -10s %(asctime)s %(name) -30s %(funcName) "
"-35s %(lineno) -5d: %(message)s"
)
LOGGER = logging.getLogger(__name__)
TERMINATE = object()
class ServerCommand(Enum):
TERMINATE = "terminate"
CONSUMING = "consuming"
class TerminateSignal(BaseException):
pass
def resolve_dates(dates: Iterable[Tuple[Sha1Git, datetime]]) -> Dict[Sha1Git, datetime]:
result: Dict[Sha1Git, datetime] = {}
for sha1, date in dates:
known = result.setdefault(sha1, date)
if date < known:
result[sha1] = date
return result
def resolve_directory(
data: Iterable[Tuple[Sha1Git, DirectoryData]]
) -> Dict[Sha1Git, DirectoryData]:
result: Dict[Sha1Git, DirectoryData] = {}
for sha1, dir in data:
known = result.setdefault(sha1, dir)
value = known
if dir.date < known.date:
value = DirectoryData(date=dir.date, flat=value.flat)
if dir.flat:
value = DirectoryData(date=value.date, flat=dir.flat)
if value != known:
result[sha1] = value
return result
def resolve_revision(
data: Iterable[Union[Tuple[Sha1Git, RevisionData], Tuple[Sha1Git]]]
) -> Dict[Sha1Git, RevisionData]:
result: Dict[Sha1Git, RevisionData] = {}
for row in data:
sha1 = row[0]
rev = (
cast(Tuple[Sha1Git, RevisionData], row)[1]
if len(row) > 1
else RevisionData(date=None, origin=None)
)
known = result.setdefault(sha1, RevisionData(date=None, origin=None))
value = known
if rev.date is not None and (known.date is None or rev.date < known.date):
value = RevisionData(date=rev.date, origin=value.origin)
if rev.origin is not None:
value = RevisionData(date=value.date, origin=rev.origin)
if value != known:
result[sha1] = value
return result
def resolve_relation(
data: Iterable[Tuple[Sha1Git, Sha1Git, bytes]]
) -> Dict[Sha1Git, Set[RelationData]]:
result: Dict[Sha1Git, Set[RelationData]] = {}
for src, dst, path in data:
result.setdefault(src, set()).add(RelationData(dst=dst, path=path))
return result
class ProvenanceStorageRabbitMQWorker(multiprocessing.Process):
EXCHANGE_TYPE = ExchangeType.direct
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
def __init__(
self,
url: str,
exchange: str,
range: int,
storage_config: Dict[str, Any],
batch_size: int = 100,
prefetch_count: int = 100,
) -> None:
"""Setup the worker object, passing in the URL we will use to connect to
RabbitMQ, the exchange to use, the range id on which to operate, and the
connection information for the underlying local storage object.
:param str url: The URL for connecting to RabbitMQ
:param str exchange: The name of the RabbitMq exchange to use
:param str range: The ID range to operate on
:param dict storage_config: Configuration parameters for the underlying
``ProvenanceStorage`` object expected by
``swh.provenance.get_provenance_storage``
:param int batch_size: Max amount of elements call to the underlying storage
:param int prefetch_count: Prefetch value for the RabbitMQ connection when
receiving messaged
"""
super().__init__(name=f"{exchange}_{range:x}")
self._connection = None
self._channel = None
self._closing = False
self._consumer_tag: Dict[str, str] = {}
self._consuming: Dict[str, bool] = {}
self._prefetch_count = prefetch_count
self._url = url
self._exchange = exchange
self._binding_keys = list(
ProvenanceStorageRabbitMQServer.get_binding_keys(self._exchange, range)
)
self._queues: Dict[str, str] = {}
self._storage_config = storage_config
self._batch_size = batch_size
self.command: multiprocessing.Queue = multiprocessing.Queue()
self.signal: multiprocessing.Queue = multiprocessing.Queue()
def connect(self) -> pika.SelectConnection:
LOGGER.info("Connecting to %s", self._url)
return pika.SelectConnection(
parameters=pika.URLParameters(self._url),
on_open_callback=self.on_connection_open,
on_open_error_callback=self.on_connection_open_error,
on_close_callback=self.on_connection_closed,
)
def close_connection(self) -> None:
assert self._connection is not None
self._consuming = {binding_key: False for binding_key in self._binding_keys}
if self._connection.is_closing or self._connection.is_closed:
LOGGER.info("Connection is closing or already closed")
else:
LOGGER.info("Closing connection")
self._connection.close()
def on_connection_open(self, _unused_connection: pika.SelectConnection) -> None:
LOGGER.info("Connection opened")
self.open_channel()
def on_connection_open_error(
self, _unused_connection: pika.SelectConnection, err: Exception
) -> None:
LOGGER.error("Connection open failed, reopening in 5 seconds: %s", err)
assert self._connection is not None
self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
def on_connection_closed(self, _unused_connection: pika.SelectConnection, reason):
assert self._connection is not None
self._channel = None
if self._closing:
self._connection.ioloop.stop()
else:
LOGGER.warning("Connection closed, reopening in 5 seconds: %s", reason)
self._connection.ioloop.call_later(5, self._connection.ioloop.stop)
def open_channel(self) -> None:
LOGGER.info("Creating a new channel")
assert self._connection is not None
self._connection.channel(on_open_callback=self.on_channel_open)
def on_channel_open(self, channel: pika.channel.Channel) -> None:
LOGGER.info("Channel opened")
self._channel = channel
LOGGER.info("Adding channel close callback")
assert self._channel is not None
self._channel.add_on_close_callback(callback=self.on_channel_closed)
self.setup_exchange()
def on_channel_closed(
self, channel: pika.channel.Channel, reason: Exception
) -> None:
LOGGER.warning("Channel %i was closed: %s", channel, reason)
self.close_connection()
def setup_exchange(self) -> None:
LOGGER.info("Declaring exchange %s", self._exchange)
assert self._channel is not None
self._channel.exchange_declare(
exchange=self._exchange,
exchange_type=self.EXCHANGE_TYPE,
callback=self.on_exchange_declare_ok,
)
def on_exchange_declare_ok(self, _unused_frame: pika.frame.Method) -> None:
LOGGER.info("Exchange declared: %s", self._exchange)
self.setup_queues()
def setup_queues(self) -> None:
for binding_key in self._binding_keys:
LOGGER.info("Declaring queue %s", binding_key)
assert self._channel is not None
callback = functools.partial(
self.on_queue_declare_ok,
binding_key=binding_key,
)
self._channel.queue_declare(queue=binding_key, callback=callback)
def on_queue_declare_ok(self, frame: pika.frame.Method, binding_key: str) -> None:
LOGGER.info(
"Binding queue %s to exchange %s with routing key %s",
frame.method.queue,
self._exchange,
binding_key,
)
assert self._channel is not None
callback = functools.partial(self.on_bind_ok, queue_name=frame.method.queue)
self._queues[binding_key] = frame.method.queue
self._channel.queue_bind(
queue=frame.method.queue,
exchange=self._exchange,
routing_key=binding_key,
callback=callback,
)
def on_bind_ok(self, _unused_frame: pika.frame.Method, queue_name: str) -> None:
LOGGER.info("Queue bound: %s", queue_name)
assert self._channel is not None
self._channel.basic_qos(
prefetch_count=self._prefetch_count, callback=self.on_basic_qos_ok
)
def on_basic_qos_ok(self, _unused_frame: pika.frame.Method) -> None:
LOGGER.info("QOS set to: %d", self._prefetch_count)
self.start_consuming()
def start_consuming(self) -> None:
LOGGER.info("Issuing consumer related RPC commands")
LOGGER.info("Adding consumer cancellation callback")
assert self._channel is not None
self._channel.add_on_cancel_callback(callback=self.on_consumer_cancelled)
for binding_key in self._binding_keys:
self._consumer_tag[binding_key] = self._channel.basic_consume(
queue=self._queues[binding_key], on_message_callback=self.on_request
)
self._consuming[binding_key] = True
self.signal.put(ServerCommand.CONSUMING)
def on_consumer_cancelled(self, method_frame: pika.frame.Method) -> None:
LOGGER.info("Consumer was cancelled remotely, shutting down: %r", method_frame)
if self._channel:
self._channel.close()
def on_request(
self,
channel: pika.channel.Channel,
deliver: pika.spec.Basic.Deliver,
properties: pika.spec.BasicProperties,
body: bytes,
) -> None:
LOGGER.info(
"Received message # %s from %s: %s",
deliver.delivery_tag,
properties.app_id,
body,
)
# XXX: for some reason this function is returning lists instead of tuples
# (the client send tuples)
batch = decode_data(data=body, extra_decoders=self.extra_type_decoders)["data"]
for item in batch:
self._request_queues[deliver.routing_key].put(
(tuple(item), (properties.correlation_id, properties.reply_to))
)
LOGGER.info("Acknowledging message %s", deliver.delivery_tag)
channel.basic_ack(delivery_tag=deliver.delivery_tag)
def stop_consuming(self) -> None:
if self._channel:
LOGGER.info("Sending a Basic.Cancel RPC command to RabbitMQ")
for binding_key in self._binding_keys:
callback = functools.partial(self.on_cancel_ok, binding_key=binding_key)
self._channel.basic_cancel(
self._consumer_tag[binding_key], callback=callback
)
def on_cancel_ok(self, _unused_frame: pika.frame.Method, binding_key: str) -> None:
self._consuming[binding_key] = False
LOGGER.info(
"RabbitMQ acknowledged the cancellation of the consumer: %s",
self._consuming[binding_key],
)
LOGGER.info("Closing the channel")
assert self._channel is not None
self._channel.close()
def run(self) -> None:
self._command_thread = threading.Thread(target=self.run_command_thread)
self._command_thread.start()
self._request_queues: Dict[str, queue.Queue] = {}
self._request_threads: Dict[str, threading.Thread] = {}
for binding_key in self._binding_keys:
meth_name, relation = ProvenanceStorageRabbitMQServer.get_meth_name(
binding_key
)
self._request_queues[binding_key] = queue.Queue()
self._request_threads[binding_key] = threading.Thread(
target=self.run_request_thread,
args=(binding_key, meth_name, relation),
)
self._request_threads[binding_key].start()
while not self._closing:
try:
self._connection = self.connect()
assert self._connection is not None
self._connection.ioloop.start()
except KeyboardInterrupt:
LOGGER.info("Connection closed by keyboard interruption, reopening")
if self._connection is not None:
self._connection.ioloop.stop()
except TerminateSignal as ex:
LOGGER.info("Termination requested: %s", ex)
self.stop()
if self._connection is not None and not self._connection.is_closed:
# Finish closing
self._connection.ioloop.start()
except BaseException as ex:
LOGGER.warning("Unexpected exception, terminating: %s", ex)
self.stop()
if self._connection is not None and not self._connection.is_closed:
# Finish closing
self._connection.ioloop.start()
for binding_key in self._binding_keys:
self._request_queues[binding_key].put(TERMINATE)
for binding_key in self._binding_keys:
self._request_threads[binding_key].join()
self._command_thread.join()
LOGGER.info("Stopped")
def run_command_thread(self) -> None:
while True:
try:
command = self.command.get()
if command == ServerCommand.TERMINATE:
self.request_termination()
break
except queue.Empty:
pass
except BaseException as ex:
self.request_termination(str(ex))
break
def request_termination(self, reason: str = "Normal shutdown") -> None:
assert self._connection is not None
def termination_callback():
raise TerminateSignal(reason)
self._connection.ioloop.add_callback_threadsafe(termination_callback)
def run_request_thread(
self, binding_key: str, meth_name: str, relation: Optional[RelationType]
) -> None:
from swh.provenance import get_provenance_storage
with get_provenance_storage(**self._storage_config) as storage:
request_queue = self._request_queues[binding_key]
merge_items = ProvenanceStorageRabbitMQWorker.get_conflicts_func(meth_name)
while True:
terminate = False
elements = []
while True:
try:
# TODO: consider reducing this timeout or removing it
elem = request_queue.get(timeout=0.1)
if elem is TERMINATE:
terminate = True
break
elements.append(elem)
except queue.Empty:
break
if len(elements) >= self._batch_size:
break
if terminate:
break
if not elements:
continue
try:
items, props = zip(*elements)
acks_count: TCounter[Tuple[str, str]] = Counter(props)
data = merge_items(items)
args = (relation, data) if relation is not None else (data,)
if getattr(storage, meth_name)(*args):
for (correlation_id, reply_to), count in acks_count.items():
# FIXME: this is running in a different thread! Hence, if
# self._connection drops, there is no guarantee that the
# response can be sent for the current elements. This
# situation should be handled properly.
assert self._connection is not None
self._connection.ioloop.add_callback_threadsafe(
functools.partial(
ProvenanceStorageRabbitMQWorker.respond,
channel=self._channel,
correlation_id=correlation_id,
reply_to=reply_to,
response=count,
)
)
else:
LOGGER.warning(
"Unable to process elements for queue %s", binding_key
)
for elem in elements:
request_queue.put(elem)
except BaseException as ex:
self.request_termination(str(ex))
break
def stop(self) -> None:
assert self._connection is not None
if not self._closing:
self._closing = True
LOGGER.info("Stopping")
if any(self._consuming):
self.stop_consuming()
self._connection.ioloop.start()
else:
self._connection.ioloop.stop()
LOGGER.info("Stopped")
@staticmethod
def get_conflicts_func(meth_name: str) -> Callable[[Iterable[Any]], Any]:
if meth_name == "content_add":
return resolve_dates
elif meth_name == "directory_add":
return resolve_directory
elif meth_name == "location_add":
- return lambda data: set(data) # just remove duplicates
+ return lambda data: dict(data)
elif meth_name == "origin_add":
return lambda data: dict(data) # last processed value is good enough
elif meth_name == "revision_add":
return resolve_revision
elif meth_name == "relation_add":
return resolve_relation
else:
LOGGER.warning(
"Unexpected conflict resolution function request for method %s",
meth_name,
)
return lambda x: x
@staticmethod
def respond(
channel: pika.channel.Channel,
correlation_id: str,
reply_to: str,
response: Any,
):
channel.basic_publish(
exchange="",
routing_key=reply_to,
properties=pika.BasicProperties(
content_type="application/msgpack",
correlation_id=correlation_id,
),
body=encode_data(
response,
extra_encoders=ProvenanceStorageRabbitMQServer.extra_type_encoders,
),
)
class ProvenanceStorageRabbitMQServer:
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
queue_count = 16
def __init__(
self,
url: str,
storage_config: Dict[str, Any],
batch_size: int = 100,
prefetch_count: int = 100,
) -> None:
"""Setup the server object, passing in the URL we will use to connect to
RabbitMQ, and the connection information for the underlying local storage
object.
:param str url: The URL for connecting to RabbitMQ
:param dict storage_config: Configuration parameters for the underlying
``ProvenanceStorage`` object expected by
``swh.provenance.get_provenance_storage``
:param int batch_size: Max amount of elements call to the underlying storage
:param int prefetch_count: Prefetch value for the RabbitMQ connection when
receiving messaged
"""
self._workers: List[ProvenanceStorageRabbitMQWorker] = []
for exchange in ProvenanceStorageRabbitMQServer.get_exchanges():
for range in ProvenanceStorageRabbitMQServer.get_ranges(exchange):
worker = ProvenanceStorageRabbitMQWorker(
url=url,
exchange=exchange,
range=range,
storage_config=storage_config,
batch_size=batch_size,
prefetch_count=prefetch_count,
)
self._workers.append(worker)
self._running = False
def start(self) -> None:
if not self._running:
self._running = True
for worker in self._workers:
worker.start()
for worker in self._workers:
try:
signal = worker.signal.get(timeout=60)
assert signal == ServerCommand.CONSUMING
except queue.Empty:
LOGGER.error(
"Could not initialize worker %s. Leaving...", worker.name
)
self.stop()
return
LOGGER.info("Start serving")
def stop(self) -> None:
if self._running:
for worker in self._workers:
worker.command.put(ServerCommand.TERMINATE)
for worker in self._workers:
worker.join()
LOGGER.info("Stop serving")
self._running = False
@staticmethod
def get_binding_keys(exchange: str, range: int) -> Iterator[str]:
for meth_name, relation in ProvenanceStorageRabbitMQServer.get_meth_names(
exchange
):
if relation is None:
assert (
meth_name != "relation_add"
), "'relation_add' requires 'relation' to be provided"
yield f"{meth_name}.unknown.{range:x}".lower()
else:
assert (
meth_name == "relation_add"
), f"'{meth_name}' requires 'relation' to be None"
yield f"{meth_name}.{relation.value}.{range:x}".lower()
@staticmethod
def get_exchange(meth_name: str, relation: Optional[RelationType] = None) -> str:
if meth_name == "relation_add":
assert (
relation is not None
), "'relation_add' requires 'relation' to be provided"
split = relation.value
else:
assert relation is None, f"'{meth_name}' requires 'relation' to be None"
split = meth_name
exchange, *_ = split.split("_")
return exchange
@staticmethod
def get_exchanges() -> Iterator[str]:
yield from [entity.value for entity in EntityType] + ["location"]
@staticmethod
def get_meth_name(
binding_key: str,
) -> Tuple[str, Optional[RelationType]]:
meth_name, relation, *_ = binding_key.split(".")
return meth_name, (RelationType(relation) if relation != "unknown" else None)
@staticmethod
def get_meth_names(
exchange: str,
) -> Iterator[Tuple[str, Optional[RelationType]]]:
if exchange == EntityType.CONTENT.value:
yield from [
("content_add", None),
("relation_add", RelationType.CNT_EARLY_IN_REV),
("relation_add", RelationType.CNT_IN_DIR),
]
elif exchange == EntityType.DIRECTORY.value:
yield from [
("directory_add", None),
("relation_add", RelationType.DIR_IN_REV),
]
elif exchange == EntityType.ORIGIN.value:
yield from [("origin_add", None)]
elif exchange == EntityType.REVISION.value:
yield from [
("revision_add", None),
("relation_add", RelationType.REV_BEFORE_REV),
("relation_add", RelationType.REV_IN_ORG),
]
elif exchange == "location":
yield "location_add", None
@staticmethod
def get_ranges(unused_exchange: str) -> Iterator[int]:
# XXX: we might want to have a different range per exchange
yield from range(ProvenanceStorageRabbitMQServer.queue_count)
@staticmethod
def get_routing_key(
item: bytes, meth_name: str, relation: Optional[RelationType] = None
) -> str:
hashid = (
path_id(item).hex()
if meth_name.startswith("location")
else hash_to_hex(item)
)
idx = int(hashid[0], 16) % ProvenanceStorageRabbitMQServer.queue_count
if relation is None:
assert (
meth_name != "relation_add"
), "'relation_add' requires 'relation' to be provided"
return f"{meth_name}.unknown.{idx:x}".lower()
else:
assert (
meth_name == "relation_add"
), f"'{meth_name}' requires 'relation' to be None"
return f"{meth_name}.{relation.value}.{idx:x}".lower()
@staticmethod
def is_write_method(meth_name: str) -> bool:
return "_add" in meth_name
def load_and_check_config(
config_path: Optional[str], type: str = "local"
) -> Dict[str, Any]:
"""Check the minimal configuration is set to run the api or raise an
error explanation.
Args:
config_path (str): Path to the configuration file to load
type (str): configuration type. For 'local' type, more
checks are done.
Raises:
Error if the setup is not as expected
Returns:
configuration as a dict
"""
if config_path is None:
raise EnvironmentError("Configuration file must be defined")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file {config_path} does not exist")
cfg = config.read(config_path)
pcfg: Optional[Dict[str, Any]] = cfg.get("provenance")
if pcfg is None:
raise KeyError("Missing 'provenance' configuration")
rcfg: Optional[Dict[str, Any]] = pcfg.get("rabbitmq")
if rcfg is None:
raise KeyError("Missing 'provenance.rabbitmq' configuration")
scfg: Optional[Dict[str, Any]] = rcfg.get("storage_config")
if scfg is None:
raise KeyError("Missing 'provenance.rabbitmq.storage_config' configuration")
if type == "local":
cls = scfg.get("cls")
if cls != "postgresql":
raise ValueError(
"The provenance backend can only be started with a 'postgresql' "
"configuration"
)
db = scfg.get("db")
if not db:
raise KeyError("Invalid configuration; missing 'db' config entry")
return cfg
def make_server_from_configfile() -> ProvenanceStorageRabbitMQServer:
config_path = os.environ.get("SWH_CONFIG_FILENAME")
server_cfg = load_and_check_config(config_path)
return ProvenanceStorageRabbitMQServer(**server_cfg["provenance"]["rabbitmq"])
diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py
index 60c5126..a76f922 100644
--- a/swh/provenance/tests/test_provenance_storage.py
+++ b/swh/provenance/tests/test_provenance_storage.py
@@ -1,521 +1,531 @@
# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime, timezone
+import hashlib
import inspect
import os
from typing import Any, Dict, Iterable, Optional, Set, Tuple
import pytest
from swh.model.hashutil import hash_to_bytes
from swh.model.model import Origin, Sha1Git
from swh.provenance.algos.origin import origin_add
from swh.provenance.algos.revision import revision_add
from swh.provenance.archive import ArchiveInterface
from swh.provenance.interface import ProvenanceInterface
from swh.provenance.model import OriginEntry, RevisionEntry
from swh.provenance.provenance import Provenance
from swh.provenance.storage.interface import (
DirectoryData,
EntityType,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
)
from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt
class TestProvenanceStorage:
def test_provenance_storage_content(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests content methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Add all content present in the current repo to the storage, just assigning their
# creation dates. Then check that the returned results when querying are the same.
cnt_dates = {
cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"])
}
assert provenance_storage.content_add(cnt_dates)
assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates
assert provenance_storage.entity_get_all(EntityType.CONTENT) == set(
cnt_dates.keys()
)
def test_provenance_storage_directory(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests directory methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Of all directories present in the current repo, only assign a date to those
# containing blobs (picking the max date among the available ones). Then check that
# the returned results when querying are the same.
def getmaxdate(
directory: Dict[str, Any], contents: Iterable[Dict[str, Any]]
) -> Optional[datetime]:
dates = [
content["ctime"]
for entry in directory["entries"]
for content in contents
if entry["type"] == "file" and entry["target"] == content["sha1_git"]
]
return max(dates) if dates else None
flat_values = (False, True)
dir_dates = {}
for idx, dir in enumerate(data["directory"]):
date = getmaxdate(dir, data["content"])
if date is not None:
dir_dates[dir["id"]] = DirectoryData(
date=date, flat=flat_values[idx % 2]
)
assert provenance_storage.directory_add(dir_dates)
assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates
assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == set(
dir_dates.keys()
)
def test_provenance_storage_location(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests location methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Add all names of entries present in the directories of the current repo as paths
# to the storage. Then check that the returned results when querying are the same.
- paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]}
+ paths = {
+ hashlib.sha1(entry["name"]).digest(): entry["name"]
+ for dir in data["directory"]
+ for entry in dir["entries"]
+ }
assert provenance_storage.location_add(paths)
if provenance_storage.with_path():
assert provenance_storage.location_get_all() == paths
else:
- assert provenance_storage.location_get_all() == set()
+ assert not provenance_storage.location_get_all()
@pytest.mark.origin_layer
def test_provenance_storage_origin(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests origin methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Test origin methods.
# Add all origins present in the current repo to the storage. Then check that the
# returned results when querying are the same.
orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]}
assert orgs
assert provenance_storage.origin_add(orgs)
assert provenance_storage.origin_get(set(orgs.keys())) == orgs
assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys())
def test_provenance_storage_revision(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests revision methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Test revision methods.
# Add all revisions present in the current repo to the storage, assigning their
# dates and an arbitrary origin to each one. Then check that the returned results
# when querying are the same.
origin = Origin(url=next(iter(data["origin"]))["url"])
# Origin must be inserted in advance.
assert provenance_storage.origin_add({origin.id: origin.url})
- revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0}
+ revs = {rev["id"] for idx, rev in enumerate(data["revision"])}
rev_data = {
rev["id"]: RevisionData(
date=ts2dt(rev["date"]) if idx % 2 != 0 else None,
origin=origin.id if idx % 3 != 0 else None,
)
for idx, rev in enumerate(data["revision"])
- if idx % 6 != 0
}
assert revs
- assert provenance_storage.revision_add(revs)
assert provenance_storage.revision_add(rev_data)
- assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data
- assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set(
- rev_data.keys()
- )
+ assert provenance_storage.revision_get(set(rev_data.keys())) == {
+ k: v
+ for (k, v) in rev_data.items()
+ if v.date is not None or v.origin is not None
+ }
+ assert provenance_storage.entity_get_all(EntityType.REVISION) == set(rev_data)
def test_provenance_storage_relation_revision_layer(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests relation methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Test content-in-revision relation.
# Create flat models of every root directory for the revisions in the dataset.
cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {}
for rev in data["revision"]:
root = next(
subdir
for subdir in data["directory"]
if subdir["id"] == rev["directory"]
)
for cnt, rel in dircontent(data, rev["id"], root):
cnt_in_rev.setdefault(cnt, set()).add(rel)
relation_add_and_compare_result(
provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev
)
# Test content-in-directory relation.
# Create flat models for every directory in the dataset.
cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {}
for dir in data["directory"]:
for cnt, rel in dircontent(data, dir["id"], dir):
cnt_in_dir.setdefault(cnt, set()).add(rel)
relation_add_and_compare_result(
provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir
)
# Test content-in-directory relation.
# Add root directories to their correspondent revision in the dataset.
dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {}
for rev in data["revision"]:
dir_in_rev.setdefault(rev["directory"], set()).add(
RelationData(dst=rev["id"], path=b".")
)
relation_add_and_compare_result(
provenance_storage, RelationType.DIR_IN_REV, dir_in_rev
)
@pytest.mark.origin_layer
def test_provenance_storage_relation_orign_layer(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests relation methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Test revision-in-origin relation.
# Origins must be inserted in advance (cannot be done by `entity_add` inside
# `relation_add_and_compare_result`).
orgs = {Origin(url=org["url"]).id: org["url"] for org in data["origin"]}
assert provenance_storage.origin_add(orgs)
# Add all revisions that are head of some snapshot branch to the corresponding
# origin.
rev_in_org: Dict[Sha1Git, Set[RelationData]] = {}
for status in data["origin_visit_status"]:
if status["snapshot"] is not None:
for snapshot in data["snapshot"]:
if snapshot["id"] == status["snapshot"]:
for branch in snapshot["branches"].values():
if branch["target_type"] == "revision":
rev_in_org.setdefault(branch["target"], set()).add(
RelationData(
dst=Origin(url=status["origin"]).id,
path=None,
)
)
relation_add_and_compare_result(
provenance_storage, RelationType.REV_IN_ORG, rev_in_org
)
# Test revision-before-revision relation.
# For each revision in the data set add an entry for each parent to the relation.
rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {}
for rev in data["revision"]:
for parent in rev["parents"]:
rev_before_rev.setdefault(parent, set()).add(
RelationData(dst=rev["id"], path=None)
)
relation_add_and_compare_result(
provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev
)
def test_provenance_storage_find_revision_layer(
self,
provenance: ProvenanceInterface,
provenance_storage: ProvenanceStorageInterface,
archive: ArchiveInterface,
) -> None:
"""Tests `content_find_first` and `content_find_all` methods for every
`ProvenanceStorageInterface` implementation.
"""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
fill_storage(archive.storage, data)
# Test content_find_first and content_find_all, first only executing the
# revision-content algorithm, then adding the origin-revision layer.
def adapt_result(
result: Optional[ProvenanceResult], with_path: bool
) -> Optional[ProvenanceResult]:
if result is not None:
return ProvenanceResult(
result.content,
result.revision,
result.date,
result.origin,
result.path if with_path else b"",
)
return result
# Execute the revision-content algorithm on both storages.
revisions = [
RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"])
for rev in data["revision"]
]
revision_add(provenance, archive, revisions)
revision_add(Provenance(provenance_storage), archive, revisions)
assert adapt_result(
ProvenanceResult(
content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"),
revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"),
date=datetime.fromtimestamp(1000000000.0, timezone.utc),
origin=None,
path=b"A/B/C/a",
),
provenance_storage.with_path(),
) == provenance_storage.content_find_first(
hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494")
)
for cnt in {cnt["sha1_git"] for cnt in data["content"]}:
assert adapt_result(
provenance.storage.content_find_first(cnt),
provenance_storage.with_path(),
) == provenance_storage.content_find_first(cnt)
assert {
adapt_result(occur, provenance_storage.with_path())
for occur in provenance.storage.content_find_all(cnt)
} == set(provenance_storage.content_find_all(cnt))
@pytest.mark.origin_layer
def test_provenance_storage_find_origin_layer(
self,
provenance: ProvenanceInterface,
provenance_storage: ProvenanceStorageInterface,
archive: ArchiveInterface,
) -> None:
"""Tests `content_find_first` and `content_find_all` methods for every
`ProvenanceStorageInterface` implementation.
"""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
fill_storage(archive.storage, data)
# Execute the revision-content algorithm on both storages.
revisions = [
RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"])
for rev in data["revision"]
]
revision_add(provenance, archive, revisions)
revision_add(Provenance(provenance_storage), archive, revisions)
# Test content_find_first and content_find_all, first only executing the
# revision-content algorithm, then adding the origin-revision layer.
def adapt_result(
result: Optional[ProvenanceResult], with_path: bool
) -> Optional[ProvenanceResult]:
if result is not None:
return ProvenanceResult(
result.content,
result.revision,
result.date,
result.origin,
result.path if with_path else b"",
)
return result
# Execute the origin-revision algorithm on both storages.
origins = [
OriginEntry(url=sta["origin"], snapshot=sta["snapshot"])
for sta in data["origin_visit_status"]
if sta["snapshot"] is not None
]
origin_add(provenance, archive, origins)
origin_add(Provenance(provenance_storage), archive, origins)
assert adapt_result(
ProvenanceResult(
content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"),
revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"),
date=datetime.fromtimestamp(1000000000.0, timezone.utc),
origin="https://cmdbts2",
path=b"A/B/C/a",
),
provenance_storage.with_path(),
) == provenance_storage.content_find_first(
hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494")
)
for cnt in {cnt["sha1_git"] for cnt in data["content"]}:
assert adapt_result(
provenance.storage.content_find_first(cnt),
provenance_storage.with_path(),
) == provenance_storage.content_find_first(cnt)
assert {
adapt_result(occur, provenance_storage.with_path())
for occur in provenance.storage.content_find_all(cnt)
} == set(provenance_storage.content_find_all(cnt))
def test_types(self, provenance_storage: ProvenanceStorageInterface) -> None:
"""Checks all methods of ProvenanceStorageInterface are implemented by this
backend, and that they have the same signature."""
# Create an instance of the protocol (which cannot be instantiated
# directly, so this creates a subclass, then instantiates it)
interface = type("_", (ProvenanceStorageInterface,), {})()
assert "content_find_first" in dir(interface)
missing_methods = []
for meth_name in dir(interface):
if meth_name.startswith("_"):
continue
interface_meth = getattr(interface, meth_name)
try:
concrete_meth = getattr(provenance_storage, meth_name)
except AttributeError:
if not getattr(interface_meth, "deprecated_endpoint", False):
# The backend is missing a (non-deprecated) endpoint
missing_methods.append(meth_name)
continue
expected_signature = inspect.signature(interface_meth)
actual_signature = inspect.signature(concrete_meth)
assert expected_signature == actual_signature, meth_name
assert missing_methods == []
# If all the assertions above succeed, then this one should too.
# But there's no harm in double-checking.
# And we could replace the assertions above by this one, but unlike
# the assertions above, it doesn't explain what is missing.
assert isinstance(provenance_storage, ProvenanceStorageInterface)
def dircontent(
data: Dict[str, Any],
ref: Sha1Git,
dir: Dict[str, Any],
prefix: bytes = b"",
) -> Iterable[Tuple[Sha1Git, RelationData]]:
content = {
(
entry["target"],
RelationData(dst=ref, path=os.path.join(prefix, entry["name"])),
)
for entry in dir["entries"]
if entry["type"] == "file"
}
for entry in dir["entries"]:
if entry["type"] == "dir":
child = next(
subdir
for subdir in data["directory"]
if subdir["id"] == entry["target"]
)
content.update(
dircontent(data, ref, child, os.path.join(prefix, entry["name"]))
)
return content
def entity_add(
storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git]
) -> bool:
now = datetime.now(tz=timezone.utc)
if entity == EntityType.CONTENT:
return storage.content_add({sha1: now for sha1 in ids})
elif entity == EntityType.DIRECTORY:
return storage.directory_add(
{sha1: DirectoryData(date=now, flat=False) for sha1 in ids}
)
else: # entity == EntityType.REVISION:
return storage.revision_add(
{sha1: RevisionData(date=None, origin=None) for sha1 in ids}
)
def relation_add_and_compare_result(
storage: ProvenanceStorageInterface,
relation: RelationType,
data: Dict[Sha1Git, Set[RelationData]],
) -> None:
# Source, destinations and locations must be added in advance.
src, *_, dst = relation.value.split("_")
srcs = {sha1 for sha1 in data}
if src != "origin":
assert entity_add(storage, EntityType(src), srcs)
dsts = {rel.dst for rels in data.values() for rel in rels}
if dst != "origin":
assert entity_add(storage, EntityType(dst), dsts)
if storage.with_path():
assert storage.location_add(
- {rel.path for rels in data.values() for rel in rels if rel.path is not None}
+ {
+ hashlib.sha1(rel.path).digest(): rel.path
+ for rels in data.values()
+ for rel in rels
+ if rel.path is not None
+ }
)
assert data
assert storage.relation_add(relation, data)
for src_sha1 in srcs:
relation_compare_result(
storage.relation_get(relation, [src_sha1]),
{src_sha1: data[src_sha1]},
storage.with_path(),
)
for dst_sha1 in dsts:
relation_compare_result(
storage.relation_get(relation, [dst_sha1], reverse=True),
{
src_sha1: {
RelationData(dst=dst_sha1, path=rel.path)
for rel in rels
if dst_sha1 == rel.dst
}
for src_sha1, rels in data.items()
if dst_sha1 in {rel.dst for rel in rels}
},
storage.with_path(),
)
relation_compare_result(
storage.relation_get_all(relation), data, storage.with_path()
)
def relation_compare_result(
computed: Dict[Sha1Git, Set[RelationData]],
expected: Dict[Sha1Git, Set[RelationData]],
with_path: bool,
) -> None:
assert {
src_sha1: {
RelationData(dst=rel.dst, path=rel.path if with_path else None)
for rel in rels
}
for src_sha1, rels in expected.items()
} == computed
diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py
index abdf3b5..a3592d0 100644
--- a/swh/provenance/tests/test_revision_content_layer.py
+++ b/swh/provenance/tests/test_revision_content_layer.py
@@ -1,482 +1,482 @@
# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple
import pytest
from typing_extensions import TypedDict
from swh.model.hashutil import hash_to_bytes
from swh.model.model import Sha1Git
from swh.provenance.algos.directory import directory_add
from swh.provenance.algos.revision import revision_add
from swh.provenance.archive import ArchiveInterface
from swh.provenance.interface import ProvenanceInterface
from swh.provenance.model import DirectoryEntry, RevisionEntry
from swh.provenance.storage.interface import EntityType, RelationType
from swh.provenance.tests.conftest import (
fill_storage,
get_datafile,
load_repo_data,
ts2dt,
)
class SynthRelation(TypedDict):
prefix: Optional[str]
path: str
src: Sha1Git
dst: Sha1Git
rel_ts: float
class SynthRevision(TypedDict):
sha1: Sha1Git
date: float
msg: str
R_C: List[SynthRelation]
R_D: List[SynthRelation]
D_C: List[SynthRelation]
def synthetic_revision_content_result(filename: str) -> Iterator[SynthRevision]:
"""Generates dict representations of synthetic revisions found in the synthetic
file (from the data/ directory) given as argument of the generator.
Generated SynthRevision (typed dict) with the following elements:
"sha1": (Sha1Git) sha1 of the revision,
"date": (float) timestamp of the revision,
"msg": (str) commit message of the revision,
"R_C": (list) new R---C relations added by this revision
"R_D": (list) new R-D relations added by this revision
"D_C": (list) new D-C relations added by this revision
Each relation above is a SynthRelation typed dict with:
"path": (str) location
"src": (Sha1Git) sha1 of the source of the relation
"dst": (Sha1Git) sha1 of the destination of the relation
"rel_ts": (float) timestamp of the target of the relation
(related to the timestamp of the revision)
"""
with open(get_datafile(filename), "r") as fobj:
yield from _parse_synthetic_revision_content_file(fobj)
def _parse_synthetic_revision_content_file(
fobj: Iterable[str],
) -> Iterator[SynthRevision]:
"""Read a 'synthetic' file and generate a dict representation of the synthetic
revision for each revision listed in the synthetic file.
"""
regs = [
"(?P<revname>R[0-9]{2,4})?",
"(?P<reltype>[^| ]*)",
"([+] )?(?P<path>[^| +]*?)[/]?",
"(?P<type>[RDC]) (?P<sha1>[0-9a-f]{40})",
"(?P<ts>-?[0-9]+(.[0-9]+)?)",
]
regex = re.compile("^ *" + r" *[|] *".join(regs) + r" *(#.*)?$")
current_rev: List[dict] = []
for m in (regex.match(line) for line in fobj):
if m:
d = m.groupdict()
if d["revname"]:
if current_rev:
yield _mk_synth_rev(current_rev)
current_rev.clear()
current_rev.append(d)
if current_rev:
yield _mk_synth_rev(current_rev)
def _mk_synth_rev(synth_rev: List[Dict[str, str]]) -> SynthRevision:
assert synth_rev[0]["type"] == "R"
rev = SynthRevision(
sha1=hash_to_bytes(synth_rev[0]["sha1"]),
date=float(synth_rev[0]["ts"]),
msg=synth_rev[0]["revname"],
R_C=[],
R_D=[],
D_C=[],
)
current_path = None
# path of the last R-D relation we parsed, used a prefix for next D-C
# relations
for row in synth_rev[1:]:
if row["reltype"] == "R---C":
assert row["type"] == "C"
rev["R_C"].append(
SynthRelation(
prefix=None,
path=row["path"],
src=rev["sha1"],
dst=hash_to_bytes(row["sha1"]),
rel_ts=float(row["ts"]),
)
)
current_path = None
elif row["reltype"] == "R-D":
assert row["type"] == "D"
rev["R_D"].append(
SynthRelation(
prefix=None,
path=row["path"],
src=rev["sha1"],
dst=hash_to_bytes(row["sha1"]),
rel_ts=float(row["ts"]),
)
)
current_path = row["path"]
elif row["reltype"] == "D-C":
assert row["type"] == "C"
rev["D_C"].append(
SynthRelation(
prefix=current_path,
path=row["path"],
src=rev["R_D"][-1]["dst"],
dst=hash_to_bytes(row["sha1"]),
rel_ts=float(row["ts"]),
)
)
return rev
@pytest.mark.parametrize(
"repo, lower, mindepth, flatten",
(
("cmdbts2", True, 1, True),
("cmdbts2", True, 1, False),
("cmdbts2", False, 1, True),
("cmdbts2", False, 1, False),
("cmdbts2", True, 2, True),
("cmdbts2", True, 2, False),
("cmdbts2", False, 2, True),
("cmdbts2", False, 2, False),
("out-of-order", True, 1, True),
("out-of-order", True, 1, False),
),
)
def test_revision_content_result(
provenance: ProvenanceInterface,
archive: ArchiveInterface,
repo: str,
lower: bool,
mindepth: int,
flatten: bool,
) -> None:
# read data/README.md for more details on how these datasets are generated
data = load_repo_data(repo)
fill_storage(archive.storage, data)
syntheticfile = get_datafile(
f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt"
)
revisions = {rev["id"]: rev for rev in data["revision"]}
rows: Dict[str, Set[Any]] = {
"content": set(),
"content_in_directory": set(),
"content_in_revision": set(),
"directory": set(),
"directory_in_revision": set(),
"location": set(),
"revision": set(),
}
def maybe_path(path: str) -> Optional[bytes]:
if provenance.storage.with_path():
return path.encode("utf-8")
return None
for synth_rev in synthetic_revision_content_result(syntheticfile):
revision = revisions[synth_rev["sha1"]]
entry = RevisionEntry(
id=revision["id"],
date=ts2dt(revision["date"]),
root=revision["directory"],
)
if flatten:
revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth)
else:
prev_directories = provenance.storage.entity_get_all(EntityType.DIRECTORY)
revision_add(
provenance,
archive,
[entry],
lower=lower,
mindepth=mindepth,
flatten=False,
)
directories = [
DirectoryEntry(id=sha1)
for sha1 in provenance.storage.entity_get_all(
EntityType.DIRECTORY
).difference(prev_directories)
]
for directory in directories:
assert not provenance.directory_already_flattenned(directory)
directory_add(provenance, archive, directories)
# each "entry" in the synth file is one new revision
rows["revision"].add(synth_rev["sha1"])
assert rows["revision"] == provenance.storage.entity_get_all(
EntityType.REVISION
), synth_rev["msg"]
# check the timestamp of the revision
rev_ts = synth_rev["date"]
rev_data = provenance.storage.revision_get([synth_rev["sha1"]])[
synth_rev["sha1"]
]
assert (
rev_data.date is not None and rev_ts == rev_data.date.timestamp()
), synth_rev["msg"]
# this revision might have added new content objects
rows["content"] |= set(x["dst"] for x in synth_rev["R_C"])
rows["content"] |= set(x["dst"] for x in synth_rev["D_C"])
assert rows["content"] == provenance.storage.entity_get_all(
EntityType.CONTENT
), synth_rev["msg"]
# check for R-C (direct) entries
# these are added directly in the content_early_in_rev table
rows["content_in_revision"] |= set(
(x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"]
)
assert rows["content_in_revision"] == {
(src, rel.dst, rel.path)
for src, rels in provenance.storage.relation_get_all(
RelationType.CNT_EARLY_IN_REV
).items()
for rel in rels
}, synth_rev["msg"]
# check timestamps
for rc in synth_rev["R_C"]:
assert (
rev_ts + rc["rel_ts"]
== provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp()
), synth_rev["msg"]
# check directories
# each directory stored in the provenance index is an entry
# in the "directory" table...
rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"])
assert rows["directory"] == provenance.storage.entity_get_all(
EntityType.DIRECTORY
), synth_rev["msg"]
# ... + a number of rows in the "directory_in_rev" table...
# check for R-D entries
rows["directory_in_revision"] |= set(
(x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"]
)
assert rows["directory_in_revision"] == {
(src, rel.dst, rel.path)
for src, rels in provenance.storage.relation_get_all(
RelationType.DIR_IN_REV
).items()
for rel in rels
}, synth_rev["msg"]
# check timestamps
for rd in synth_rev["R_D"]:
dir_data = provenance.storage.directory_get([rd["dst"]])[rd["dst"]]
assert rev_ts + rd["rel_ts"] == dir_data.date.timestamp(), synth_rev["msg"]
assert dir_data.flat, synth_rev["msg"]
# ... + a number of rows in the "content_in_dir" table
# for content of the directory.
# check for D-C entries
rows["content_in_directory"] |= set(
(x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"]
)
assert rows["content_in_directory"] == {
(src, rel.dst, rel.path)
for src, rels in provenance.storage.relation_get_all(
RelationType.CNT_IN_DIR
).items()
for rel in rels
}, synth_rev["msg"]
# check timestamps
for dc in synth_rev["D_C"]:
assert (
rev_ts + dc["rel_ts"]
== provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp()
), synth_rev["msg"]
if provenance.storage.with_path():
# check for location entries
rows["location"] |= set(x["path"].encode() for x in synth_rev["R_C"])
rows["location"] |= set(x["path"].encode() for x in synth_rev["D_C"])
rows["location"] |= set(x["path"].encode() for x in synth_rev["R_D"])
- assert rows["location"] == provenance.storage.location_get_all(), synth_rev[
- "msg"
- ]
+ assert rows["location"] == set(
+ provenance.storage.location_get_all().values()
+ ), synth_rev["msg"]
@pytest.mark.parametrize(
"repo, lower, mindepth",
(
("cmdbts2", True, 1),
("cmdbts2", False, 1),
("cmdbts2", True, 2),
("cmdbts2", False, 2),
("out-of-order", True, 1),
),
)
@pytest.mark.parametrize("batch", (True, False))
def test_provenance_heuristics_content_find_all(
provenance: ProvenanceInterface,
archive: ArchiveInterface,
repo: str,
lower: bool,
mindepth: int,
batch: bool,
) -> None:
# read data/README.md for more details on how these datasets are generated
data = load_repo_data(repo)
fill_storage(archive.storage, data)
revisions = [
RevisionEntry(
id=revision["id"],
date=ts2dt(revision["date"]),
root=revision["directory"],
)
for revision in data["revision"]
]
def maybe_path(path: str) -> str:
if provenance.storage.with_path():
return path
return ""
if batch:
revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth)
else:
for revision in revisions:
revision_add(
provenance, archive, [revision], lower=lower, mindepth=mindepth
)
syntheticfile = get_datafile(
f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt"
)
expected_occurrences: Dict[str, List[Tuple[str, float, Optional[str], str]]] = {}
for synth_rev in synthetic_revision_content_result(syntheticfile):
rev_id = synth_rev["sha1"].hex()
rev_ts = synth_rev["date"]
for rc in synth_rev["R_C"]:
expected_occurrences.setdefault(rc["dst"].hex(), []).append(
(rev_id, rev_ts, None, maybe_path(rc["path"]))
)
for dc in synth_rev["D_C"]:
assert dc["prefix"] is not None # to please mypy
expected_occurrences.setdefault(dc["dst"].hex(), []).append(
(rev_id, rev_ts, None, maybe_path(dc["prefix"] + "/" + dc["path"]))
)
for content_id, results in expected_occurrences.items():
expected = [(content_id, *result) for result in results]
db_occurrences = [
(
occur.content.hex(),
occur.revision.hex(),
occur.date.timestamp(),
occur.origin,
occur.path.decode(),
)
for occur in provenance.content_find_all(hash_to_bytes(content_id))
]
if provenance.storage.with_path():
# this is not true if the db stores no path, because a same content
# that appears several times in a given revision may be reported
# only once by content_find_all()
assert len(db_occurrences) == len(expected)
assert set(db_occurrences) == set(expected)
@pytest.mark.parametrize(
"repo, lower, mindepth",
(
("cmdbts2", True, 1),
("cmdbts2", False, 1),
("cmdbts2", True, 2),
("cmdbts2", False, 2),
("out-of-order", True, 1),
),
)
@pytest.mark.parametrize("batch", (True, False))
def test_provenance_heuristics_content_find_first(
provenance: ProvenanceInterface,
archive: ArchiveInterface,
repo: str,
lower: bool,
mindepth: int,
batch: bool,
) -> None:
# read data/README.md for more details on how these datasets are generated
data = load_repo_data(repo)
fill_storage(archive.storage, data)
revisions = [
RevisionEntry(
id=revision["id"],
date=ts2dt(revision["date"]),
root=revision["directory"],
)
for revision in data["revision"]
]
if batch:
revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth)
else:
for revision in revisions:
revision_add(
provenance, archive, [revision], lower=lower, mindepth=mindepth
)
syntheticfile = get_datafile(
f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt"
)
expected_first: Dict[str, Tuple[str, float, List[str]]] = {}
# dict of tuples (blob_id, rev_id, [path, ...]) the third element for path
# is a list because a content can be added at several places in a single
# revision, in which case the result of content_find_first() is one of
# those path, but we have no guarantee which one it will return.
for synth_rev in synthetic_revision_content_result(syntheticfile):
rev_id = synth_rev["sha1"].hex()
rev_ts = synth_rev["date"]
for rc in synth_rev["R_C"]:
sha1 = rc["dst"].hex()
if sha1 not in expected_first:
assert rc["rel_ts"] == 0
expected_first[sha1] = (rev_id, rev_ts, [rc["path"]])
else:
if rev_ts == expected_first[sha1][1]:
expected_first[sha1][2].append(rc["path"])
elif rev_ts < expected_first[sha1][1]:
expected_first[sha1] = (rev_id, rev_ts, [rc["path"]])
for dc in synth_rev["D_C"]:
sha1 = rc["dst"].hex()
assert sha1 in expected_first
# nothing to do there, this content cannot be a "first seen file"
for content_id, (rev_id, ts, paths) in expected_first.items():
occur = provenance.content_find_first(hash_to_bytes(content_id))
assert occur is not None
assert occur.content.hex() == content_id
assert occur.revision.hex() == rev_id
assert occur.date.timestamp() == ts
assert occur.origin is None
if provenance.storage.with_path():
assert occur.path.decode() in paths
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Sat, Jun 21, 5:47 PM (1 w, 5 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3245920
Attached To
rDPROV Provenance database
Event Timeline
Log In to Comment