diff --git a/requirements.txt b/requirements.txt
index 2201aff..4168e02 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,10 @@
# Add here external Python modules dependencies, one per line. Module names
# should match https://pypi.python.org/pypi names. For the full spec or
# dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
click
iso8601
methodtools
PyYAML
types-click
types-PyYAML
+types-Werkzeug
diff --git a/setup.py b/setup.py
index 23bfc8b..56127f2 100755
--- a/setup.py
+++ b/setup.py
@@ -1,74 +1,75 @@
#!/usr/bin/env python3
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from io import open
from os import path
+from typing import List, Optional
from setuptools import find_packages, setup
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
-def parse_requirements(name=None):
+def parse_requirements(name: Optional[str] = None) -> List[str]:
if name:
reqf = "requirements-%s.txt" % name
else:
reqf = "requirements.txt"
requirements = []
if not path.exists(reqf):
return requirements
with open(reqf) as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#"):
continue
requirements.append(line)
return requirements
# Edit this part to match your module.
# Full sample:
# https://forge.softwareheritage.org/diffusion/DCORE/browse/master/setup.py
setup(
name="swh.provenance",
description="Software Heritage code provenance",
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.7",
author="Software Heritage developers",
author_email="swh-devel@inria.fr",
url="https://forge.softwareheritage.org/diffusion/222/",
packages=find_packages(), # packages's modules
install_requires=parse_requirements() + parse_requirements("swh"),
tests_require=parse_requirements("test"),
setup_requires=["setuptools-scm"],
use_scm_version=True,
extras_require={"testing": parse_requirements("test")},
include_package_data=True,
entry_points="""
[swh.cli.subcommands]
provenance=swh.provenance.cli
""",
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
],
project_urls={
"Bug Reports": "https://forge.softwareheritage.org/maniphest",
"Funding": "https://www.softwareheritage.org/donate",
"Source": "https://forge.softwareheritage.org/source/swh-provenance",
"Documentation": "https://docs.softwareheritage.org/devel/swh-provenance/",
},
)
diff --git a/swh/provenance/api/serializers.py b/swh/provenance/api/serializers.py
index ef3d83d..3ca884c 100644
--- a/swh/provenance/api/serializers.py
+++ b/swh/provenance/api/serializers.py
@@ -1,48 +1,49 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from dataclasses import asdict
-from typing import Callable, Dict, List, Tuple
+from enum import Enum
+from typing import Any, Callable, Dict, List, Tuple
from .. import interface
-def _encode_dataclass(obj):
+def _encode_dataclass(obj: Any) -> Dict[str, Any]:
return {
**asdict(obj),
"__type__": type(obj).__name__,
}
-def _decode_dataclass(d):
+def _decode_dataclass(d: Dict[str, Any]) -> Any:
return getattr(interface, d.pop("__type__"))(**d)
-def _encode_enum(obj):
+def _encode_enum(obj: Enum) -> Dict[str, Any]:
return {
"value": obj.value,
"__type__": type(obj).__name__,
}
-def _decode_enum(d):
+def _decode_enum(d: Dict[str, Any]) -> Enum:
return getattr(interface, d.pop("__type__"))(d["value"])
ENCODERS: List[Tuple[type, str, Callable]] = [
(interface.ProvenanceResult, "dataclass", _encode_dataclass),
(interface.RelationData, "dataclass", _encode_dataclass),
(interface.RevisionData, "dataclass", _encode_dataclass),
(interface.EntityType, "enum", _encode_enum),
(interface.RelationType, "enum", _encode_enum),
(set, "set", list),
]
DECODERS: Dict[str, Callable] = {
"dataclass": _decode_dataclass,
"enum": _decode_enum,
"set": set,
}
diff --git a/swh/provenance/api/server.py b/swh/provenance/api/server.py
index 40a67c7..814b760 100644
--- a/swh/provenance/api/server.py
+++ b/swh/provenance/api/server.py
@@ -1,143 +1,148 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
import os
+from typing import Any, Dict, List, Optional
+
+from werkzeug.routing import Rule
from swh.core import config
from swh.core.api import JSONFormatter, MsgpackFormatter, RPCServerApp, negotiate
from swh.provenance import get_provenance_storage
from swh.provenance.interface import ProvenanceStorageInterface
from .serializers import DECODERS, ENCODERS
-storage = None
+storage: Optional[ProvenanceStorageInterface] = None
-def get_global_provenance_storage():
+def get_global_provenance_storage() -> ProvenanceStorageInterface:
global storage
- if not storage:
+ if storage is None:
storage = get_provenance_storage(**app.config["provenance"]["storage"])
return storage
class ProvenanceStorageServerApp(RPCServerApp):
extra_type_decoders = DECODERS
extra_type_encoders = ENCODERS
app = ProvenanceStorageServerApp(
__name__,
backend_class=ProvenanceStorageInterface,
backend_factory=get_global_provenance_storage,
)
-def has_no_empty_params(rule):
+def has_no_empty_params(rule: Rule) -> bool:
return len(rule.defaults or ()) >= len(rule.arguments or ())
@app.route("/")
-def index():
+def index() -> str:
return """
Software Heritage provenance storage RPC server
You have reached the
Software Heritage
provenance storage RPC server.
See its
documentation
and API for more information
"""
@app.route("/site-map")
@negotiate(MsgpackFormatter)
@negotiate(JSONFormatter)
-def site_map():
+def site_map() -> List[Dict[str, Any]]:
links = []
for rule in app.url_map.iter_rules():
if has_no_empty_params(rule) and hasattr(
ProvenanceStorageInterface, rule.endpoint
):
links.append(
dict(
rule=rule.rule,
description=getattr(
ProvenanceStorageInterface, rule.endpoint
).__doc__,
)
)
# links is now a list of url, endpoint tuples
return links
-def load_and_check_config(config_path, type="local"):
+def load_and_check_config(
+ config_path: Optional[str], type: str = "local"
+) -> Dict[str, Any]:
"""Check the minimal configuration is set to run the api or raise an
error explanation.
Args:
config_path (str): Path to the configuration file to load
type (str): configuration type. For 'local' type, more
checks are done.
Raises:
Error if the setup is not as expected
Returns:
configuration as a dict
"""
- if not config_path:
+ if config_path is None:
raise EnvironmentError("Configuration file must be defined")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file {config_path} does not exist")
cfg = config.read(config_path)
- pcfg = cfg.get("provenance")
- if not pcfg:
+ pcfg: Optional[Dict[str, Any]] = cfg.get("provenance")
+ if pcfg is None:
raise KeyError("Missing 'provenance' configuration")
- scfg = pcfg.get("storage")
- if not scfg:
+ scfg: Optional[Dict[str, Any]] = pcfg.get("storage")
+ if scfg is None:
raise KeyError("Missing 'provenance.storage' configuration")
if type == "local":
cls = scfg.get("cls")
if cls != "local":
raise ValueError(
"The provenance backend can only be started with a 'local' "
"configuration"
)
db = scfg.get("db")
if not db:
raise KeyError("Invalid configuration; missing 'db' config entry")
return cfg
-api_cfg = None
+api_cfg: Optional[Dict[str, Any]] = None
-def make_app_from_configfile():
+def make_app_from_configfile() -> ProvenanceStorageServerApp:
"""Run the WSGI app from the webserver, loading the configuration from
a configuration file.
SWH_CONFIG_FILENAME environment variable defines the
configuration path to load.
"""
global api_cfg
- if not api_cfg:
+ if api_cfg is None:
config_path = os.environ.get("SWH_CONFIG_FILENAME")
api_cfg = load_and_check_config(config_path)
app.config.update(api_cfg)
handler = logging.StreamHandler()
app.logger.addHandler(handler)
return app
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
index 39522a7..bc6fb6a 100644
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -1,227 +1,225 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
from datetime import datetime, timezone
import os
from typing import Any, Dict, Generator, Optional, Tuple
import click
import iso8601
import yaml
from swh.core import config
from swh.core.cli import CONTEXT_SETTINGS
from swh.core.cli import swh as swh_cli_group
from swh.model.hashutil import hash_to_bytes, hash_to_hex
from swh.model.model import Sha1Git
# All generic config code should reside in swh.core.config
CONFIG_ENVVAR = "SWH_CONFIG_FILENAME"
DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None)
DEFAULT_CONFIG: Dict[str, Any] = {
"provenance": {
"archive": {
# "cls": "api",
# "storage": {
# "cls": "remote",
# "url": "http://uffizi.internal.softwareheritage.org:5002",
# }
"cls": "direct",
"db": {
"host": "db.internal.softwareheritage.org",
"dbname": "softwareheritage",
"user": "guest",
},
},
"storage": {
"cls": "local",
"db": {"host": "localhost", "dbname": "provenance"},
},
}
}
CONFIG_FILE_HELP = f"""
\b Configuration can be loaded from a yaml file given either as --config-file
option or the {CONFIG_ENVVAR} environment variable. If no configuration file
is specified, use the following default configuration::
\b
{yaml.dump(DEFAULT_CONFIG)}"""
PROVENANCE_HELP = f"""Software Heritage provenance index database tools
{CONFIG_FILE_HELP}
"""
@swh_cli_group.group(
name="provenance", context_settings=CONTEXT_SETTINGS, help=PROVENANCE_HELP
)
@click.option(
"-C",
"--config-file",
default=None,
type=click.Path(exists=True, dir_okay=False, path_type=str),
help="""YAML configuration file.""",
)
@click.option(
"-P",
"--profile",
default=None,
type=click.Path(exists=False, dir_okay=False, path_type=str),
help="""Enable profiling to specified file.""",
)
@click.pass_context
def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None:
if (
config_file is None
and DEFAULT_PATH is not None
and config.config_exists(DEFAULT_PATH)
):
config_file = DEFAULT_PATH
if config_file is None:
conf = DEFAULT_CONFIG
else:
# read_raw_config do not fail on ENOENT
if not os.path.exists(config_file):
raise FileNotFoundError(config_file)
conf = yaml.safe_load(open(config_file, "rb"))
ctx.ensure_object(dict)
ctx.obj["config"] = conf
if profile:
import atexit
import cProfile
print("Profiling...")
pr = cProfile.Profile()
pr.enable()
def exit() -> None:
pr.disable()
pr.dump_stats(profile)
atexit.register(exit)
@cli.command(name="iter-revisions")
@click.argument("filename")
@click.option("-a", "--track-all", default=True, type=bool)
@click.option("-l", "--limit", type=int)
@click.option("-m", "--min-depth", default=1, type=int)
@click.option("-r", "--reuse", default=True, type=bool)
@click.pass_context
def iter_revisions(
ctx: click.core.Context,
filename: str,
track_all: bool,
limit: Optional[int],
min_depth: int,
reuse: bool,
) -> None:
# TODO: add file size filtering
"""Process a provided list of revisions."""
from . import get_archive, get_provenance
from .revision import CSVRevisionIterator, revision_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
revisions_provider = generate_revision_tuples(filename)
revisions = CSVRevisionIterator(revisions_provider, limit=limit)
for revision in revisions:
revision_add(
provenance,
archive,
[revision],
trackall=track_all,
lower=reuse,
mindepth=min_depth,
)
def generate_revision_tuples(
filename: str,
) -> Generator[Tuple[Sha1Git, datetime, Sha1Git], None, None]:
for line in open(filename, "r"):
if line.strip():
revision, date, root = line.strip().split(",")
yield (
hash_to_bytes(revision),
iso8601.parse_date(date, default_timezone=timezone.utc),
hash_to_bytes(root),
)
@cli.command(name="iter-origins")
@click.argument("filename")
@click.option("-l", "--limit", type=int)
@click.pass_context
def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None:
"""Process a provided list of origins."""
from . import get_archive, get_provenance
from .origin import CSVOriginIterator, origin_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
origins_provider = generate_origin_tuples(filename)
origins = CSVOriginIterator(origins_provider, limit=limit)
for origin in origins:
origin_add(provenance, archive, [origin])
def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]:
for line in open(filename, "r"):
if line.strip():
url, snapshot = line.strip().split(",")
yield (url, hash_to_bytes(snapshot))
@cli.command(name="find-first")
@click.argument("swhid")
@click.pass_context
def find_first(ctx: click.core.Context, swhid: str) -> None:
"""Find first occurrence of the requested blob."""
from . import get_provenance
provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- # TODO: return a dictionary with proper keys for each field
occur = provenance.content_find_first(hash_to_bytes(swhid))
if occur is not None:
print(
f"swh:1:cnt:{hash_to_hex(occur.content)}, "
f"swh:1:rev:{hash_to_hex(occur.revision)}, "
f"{occur.date}, "
f"{occur.origin}, "
f"{os.fsdecode(occur.path)}"
)
else:
print(f"Cannot find a content with the id {swhid}")
@cli.command(name="find-all")
@click.argument("swhid")
@click.option("-l", "--limit", type=int)
@click.pass_context
def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None:
"""Find all occurrences of the requested blob."""
from . import get_provenance
provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
- # TODO: return a dictionary with proper keys for each field
for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
print(
f"swh:1:cnt:{hash_to_hex(occur.content)}, "
f"swh:1:rev:{hash_to_hex(occur.revision)}, "
f"{occur.date}, "
f"{occur.origin}, "
f"{os.fsdecode(occur.path)}"
)
diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py
index 265e874..69f5012 100644
--- a/swh/provenance/postgresql/archive.py
+++ b/swh/provenance/postgresql/archive.py
@@ -1,115 +1,115 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from typing import Any, Dict, Iterable, List
from methodtools import lru_cache
-import psycopg2
+import psycopg2.extensions
from swh.model.model import Sha1Git
from swh.storage.postgresql.storage import Storage
class ArchivePostgreSQL:
def __init__(self, conn: psycopg2.extensions.connection) -> None:
self.conn = conn
self.storage = Storage(conn, objstorage={"cls": "memory"})
def directory_ls(self, id: Sha1Git) -> Iterable[Dict[str, Any]]:
entries = self._directory_ls(id)
yield from entries
@lru_cache(maxsize=100000)
def _directory_ls(self, id: Sha1Git) -> List[Dict[str, Any]]:
# TODO: add file size filtering
with self.conn.cursor() as cursor:
cursor.execute(
"""
WITH
dir AS (SELECT id AS dir_id, dir_entries, file_entries, rev_entries
FROM directory WHERE id=%s),
ls_d AS (SELECT dir_id, UNNEST(dir_entries) AS entry_id FROM dir),
ls_f AS (SELECT dir_id, UNNEST(file_entries) AS entry_id FROM dir),
ls_r AS (SELECT dir_id, UNNEST(rev_entries) AS entry_id FROM dir)
(SELECT 'dir'::directory_entry_type AS type, e.target, e.name,
NULL::sha1_git
FROM ls_d
LEFT JOIN directory_entry_dir e ON ls_d.entry_id=e.id)
UNION
(WITH known_contents AS
(SELECT 'file'::directory_entry_type AS type, e.target, e.name,
c.sha1_git
FROM ls_f
LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id
INNER JOIN content c ON e.target=c.sha1_git)
SELECT * FROM known_contents
UNION
(SELECT 'file'::directory_entry_type AS type, e.target, e.name,
c.sha1_git
FROM ls_f
LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id
LEFT JOIN skipped_content c ON e.target=c.sha1_git
WHERE NOT EXISTS (
SELECT 1 FROM known_contents
WHERE known_contents.sha1_git=e.target
)
)
)
""",
(id,),
)
return [
{"type": row[0], "target": row[1], "name": row[2]}
for row in cursor.fetchall()
]
def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]:
with self.conn.cursor() as cursor:
cursor.execute(
"""
SELECT RH.parent_id::bytea
FROM revision_history AS RH
WHERE RH.id=%s
ORDER BY RH.parent_rank
""",
(id,),
)
# There should be at most one row anyway
yield from (row[0] for row in cursor.fetchall())
def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]:
with self.conn.cursor() as cursor:
cursor.execute(
"""
WITH
snaps AS (SELECT object_id FROM snapshot WHERE snapshot.id=%s),
heads AS ((SELECT R.id, R.date
FROM snaps
JOIN snapshot_branches AS BS
ON (snaps.object_id=BS.snapshot_id)
JOIN snapshot_branch AS B
ON (BS.branch_id=B.object_id)
JOIN revision AS R
ON (B.target=R.id)
WHERE B.target_type='revision'::snapshot_target)
UNION
(SELECT RV.id, RV.date
FROM snaps
JOIN snapshot_branches AS BS
ON (snaps.object_id=BS.snapshot_id)
JOIN snapshot_branch AS B
ON (BS.branch_id=B.object_id)
JOIN release AS RL
ON (B.target=RL.id)
JOIN revision AS RV
ON (RL.target=RV.id)
WHERE B.target_type='release'::snapshot_target
AND RL.target_type='revision'::object_type)
ORDER BY date, id)
SELECT id FROM heads
""",
(id,),
)
yield from (row[0] for row in cursor.fetchall())
diff --git a/swh/provenance/postgresql/provenancedb.py b/swh/provenance/postgresql/provenancedb.py
index 8f52425..2319e2e 100644
--- a/swh/provenance/postgresql/provenancedb.py
+++ b/swh/provenance/postgresql/provenancedb.py
@@ -1,367 +1,367 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime
import itertools
import logging
from typing import Dict, Generator, Iterable, Optional, Set, Tuple
-import psycopg2
+import psycopg2.extensions
import psycopg2.extras
from typing_extensions import Literal
from swh.core.db import BaseDb
from swh.model.model import Sha1Git
from ..interface import (
EntityType,
ProvenanceResult,
RelationData,
RelationType,
RevisionData,
)
class ProvenanceDB:
def __init__(
self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False
- ):
+ ) -> None:
BaseDb.adapt_conn(conn)
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
conn.set_session(autocommit=True)
self.conn = conn
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# XXX: not sure this is the best place to do it!
sql = "SET timezone TO 'UTC'"
self.cursor.execute(sql)
self._flavor: Optional[str] = None
self.raise_on_commit = raise_on_commit
@property
def flavor(self) -> str:
if self._flavor is None:
sql = "SELECT swh_get_dbflavor() AS flavor"
self.cursor.execute(sql)
self._flavor = self.cursor.fetchone()["flavor"]
assert self._flavor is not None
return self._flavor
def with_path(self) -> bool:
return "with-path" in self.flavor
@property
def denormalized(self) -> bool:
return "denormalized" in self.flavor
def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool:
return self._entity_set_date("content", dates)
def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
return self._entity_get_date("content", ids)
def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool:
return self._entity_set_date("directory", dates)
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
return self._entity_get_date("directory", ids)
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
sql = f"SELECT sha1 FROM {entity.value}"
self.cursor.execute(sql)
return {row["sha1"] for row in self.cursor.fetchall()}
def location_get(self) -> Set[bytes]:
sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location"
self.cursor.execute(sql)
return {row["path"] for row in self.cursor.fetchall()}
def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool:
try:
if urls:
sql = """
LOCK TABLE ONLY origin;
INSERT INTO origin(sha1, url) VALUES %s
ON CONFLICT DO NOTHING
"""
psycopg2.extras.execute_values(self.cursor, sql, urls.items())
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
logging.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
urls: Dict[Sha1Git, str] = {}
sha1s = tuple(ids)
if sha1s:
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, url
FROM origin
WHERE sha1 IN ({values})
"""
self.cursor.execute(sql, sha1s)
urls.update(
(row["sha1"], row["url"].decode()) for row in self.cursor.fetchall()
)
return urls
def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool:
return self._entity_set_date("revision", dates)
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
sql = "SELECT * FROM swh_provenance_content_find_first(%s)"
self.cursor.execute(sql, (id,))
row = self.cursor.fetchone()
return ProvenanceResult(**row) if row is not None else None
def content_find_all(
self, id: Sha1Git, limit: Optional[int] = None
) -> Generator[ProvenanceResult, None, None]:
sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)"
self.cursor.execute(sql, (id, limit))
yield from (ProvenanceResult(**row) for row in self.cursor.fetchall())
def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool:
try:
if origins:
sql = """
LOCK TABLE ONLY revision;
INSERT INTO revision(sha1, origin)
(SELECT V.rev AS sha1, O.id AS origin
FROM (VALUES %s) AS V(rev, org)
JOIN origin AS O ON (O.sha1=V.org))
ON CONFLICT (sha1) DO
UPDATE SET origin=EXCLUDED.origin
"""
psycopg2.extras.execute_values(self.cursor, sql, origins.items())
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
logging.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]:
result: Dict[Sha1Git, RevisionData] = {}
sha1s = tuple(ids)
if sha1s:
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date, origin
FROM revision
WHERE sha1 IN ({values})
"""
self.cursor.execute(sql, sha1s)
result.update(
(row["sha1"], RevisionData(date=row["date"], origin=row["origin"]))
for row in self.cursor.fetchall()
)
return result
def relation_add(
self, relation: RelationType, data: Iterable[RelationData]
) -> bool:
try:
rows = tuple((rel.src, rel.dst, rel.path) for rel in data)
if rows:
table = relation.value
src, *_, dst = table.split("_")
if src != "origin":
# Origin entries should be inserted previously as they require extra
# non-null information
srcs = tuple(set((sha1,) for (sha1, _, _) in rows))
sql = f"""
LOCK TABLE ONLY {src};
INSERT INTO {src}(sha1) VALUES %s
ON CONFLICT DO NOTHING
"""
psycopg2.extras.execute_values(self.cursor, sql, srcs)
if dst != "origin":
# Origin entries should be inserted previously as they require extra
# non-null information
dsts = tuple(set((sha1,) for (_, sha1, _) in rows))
sql = f"""
LOCK TABLE ONLY {dst};
INSERT INTO {dst}(sha1) VALUES %s
ON CONFLICT DO NOTHING
"""
psycopg2.extras.execute_values(self.cursor, sql, dsts)
joins = [
f"INNER JOIN {src} AS S ON (S.sha1=V.src)",
f"INNER JOIN {dst} AS D ON (D.sha1=V.dst)",
]
nope = (RelationType.REV_BEFORE_REV, RelationType.REV_IN_ORG)
selected = ["S.id"]
if self.denormalized and relation not in nope:
selected.append("ARRAY_AGG(D.id)")
else:
selected.append("D.id")
if self._relation_uses_location_table(relation):
locations = tuple(set((path,) for (_, _, path) in rows))
sql = """
LOCK TABLE ONLY location;
INSERT INTO location(path) VALUES %s
ON CONFLICT (path) DO NOTHING
"""
psycopg2.extras.execute_values(self.cursor, sql, locations)
joins.append("INNER JOIN location AS L ON (L.path=V.path)")
if self.denormalized:
selected.append("ARRAY_AGG(L.id)")
else:
selected.append("L.id")
sql_l = [
f"INSERT INTO {table}",
f" SELECT {', '.join(selected)}",
" FROM (VALUES %s) AS V(src, dst, path)",
*joins,
]
if self.denormalized and relation not in nope:
sql_l.append("GROUP BY S.id")
sql_l.append(
f"""ON CONFLICT ({src}) DO UPDATE
SET {dst}=ARRAY(
SELECT UNNEST({table}.{dst} || excluded.{dst})),
location=ARRAY(
SELECT UNNEST({relation.value}.location || excluded.location))
"""
)
else:
sql_l.append("ON CONFLICT DO NOTHING")
sql = "\n".join(sql_l)
psycopg2.extras.execute_values(self.cursor, sql, rows)
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
logging.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
def relation_get(
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False
) -> Set[RelationData]:
return self._relation_get(relation, ids, reverse)
def relation_get_all(self, relation: RelationType) -> Set[RelationData]:
return self._relation_get(relation, None)
def _entity_get_date(
self,
entity: Literal["content", "directory", "revision"],
ids: Iterable[Sha1Git],
) -> Dict[Sha1Git, datetime]:
dates: Dict[Sha1Git, datetime] = {}
sha1s = tuple(ids)
if sha1s:
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date
FROM {entity}
WHERE sha1 IN ({values})
"""
self.cursor.execute(sql, sha1s)
dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall())
return dates
def _entity_set_date(
self,
entity: Literal["content", "directory", "revision"],
data: Dict[Sha1Git, datetime],
) -> bool:
try:
if data:
sql = f"""
LOCK TABLE ONLY {entity};
INSERT INTO {entity}(sha1, date) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date)
"""
psycopg2.extras.execute_values(self.cursor, sql, data.items())
return True
except: # noqa: E722
# Unexpected error occurred, rollback all changes and log message
logging.exception("Unexpected error")
if self.raise_on_commit:
raise
return False
def _relation_get(
self,
relation: RelationType,
ids: Optional[Iterable[Sha1Git]],
reverse: bool = False,
) -> Set[RelationData]:
result: Set[RelationData] = set()
sha1s: Optional[Tuple[Tuple[Sha1Git, ...]]]
if ids is not None:
sha1s = (tuple(ids),)
where = f"WHERE {'S' if not reverse else 'D'}.sha1 IN %s"
else:
sha1s = None
where = ""
aggreg_dst = self.denormalized and relation in (
RelationType.CNT_EARLY_IN_REV,
RelationType.CNT_IN_DIR,
RelationType.DIR_IN_REV,
)
if sha1s is None or sha1s[0]:
table = relation.value
src, *_, dst = table.split("_")
# TODO: improve this!
if src == "revision" and dst == "revision":
src_field = "prev"
dst_field = "next"
else:
src_field = src
dst_field = dst
if aggreg_dst:
revloc = f"UNNEST(R.{dst_field}) AS dst"
if self._relation_uses_location_table(relation):
revloc += ", UNNEST(R.location) AS path"
else:
revloc = f"R.{dst_field} AS dst"
if self._relation_uses_location_table(relation):
revloc += ", R.location AS path"
inner_sql = f"""
SELECT S.sha1 AS src, {revloc}
FROM {table} AS R
INNER JOIN {src} AS S ON (S.id=R.{src_field})
{where}
"""
if self._relation_uses_location_table(relation):
loc = "L.path AS path"
else:
loc = "NULL AS path"
sql = f"""
SELECT CL.src, D.sha1 AS dst, {loc}
FROM ({inner_sql}) AS CL
INNER JOIN {dst} AS D ON (D.id=CL.dst)
"""
if self._relation_uses_location_table(relation):
sql += "INNER JOIN location AS L ON (L.id=CL.path)"
self.cursor.execute(sql, sha1s)
result.update(RelationData(**row) for row in self.cursor.fetchall())
return result
def _relation_uses_location_table(self, relation: RelationType) -> bool:
if self.with_path():
src = relation.value.split("_")[0]
return src in ("content", "directory")
return False
diff --git a/swh/provenance/storage/archive.py b/swh/provenance/storage/archive.py
index 29b28c3..305cedf 100644
--- a/swh/provenance/storage/archive.py
+++ b/swh/provenance/storage/archive.py
@@ -1,63 +1,63 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime
from typing import Any, Dict, Iterable, Set, Tuple
from swh.model.model import ObjectType, Sha1Git, TargetType
from swh.storage.interface import StorageInterface
class ArchiveStorage:
- def __init__(self, storage: StorageInterface):
+ def __init__(self, storage: StorageInterface) -> None:
self.storage = storage
def directory_ls(self, id: Sha1Git) -> Iterable[Dict[str, Any]]:
# TODO: add file size filtering
for entry in self.storage.directory_ls(id):
yield {
"name": entry["name"],
"target": entry["target"],
"type": entry["type"],
}
def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]:
rev = self.storage.revision_get([id])[0]
if rev is not None:
yield from rev.parents
def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]:
from swh.core.utils import grouper
from swh.storage.algos.snapshot import snapshot_get_all_branches
snapshot = snapshot_get_all_branches(self.storage, id)
assert snapshot is not None
targets_set = set()
releases_set = set()
if snapshot is not None:
for branch in snapshot.branches:
if snapshot.branches[branch].target_type == TargetType.REVISION:
targets_set.add(snapshot.branches[branch].target)
elif snapshot.branches[branch].target_type == TargetType.RELEASE:
releases_set.add(snapshot.branches[branch].target)
batchsize = 100
for releases in grouper(releases_set, batchsize):
targets_set.update(
release.target
for release in self.storage.release_get(list(releases))
if release is not None and release.target_type == ObjectType.REVISION
)
revisions: Set[Tuple[datetime, Sha1Git]] = set()
for targets in grouper(targets_set, batchsize):
revisions.update(
(revision.date.to_datetime(), revision.id)
for revision in self.storage.revision_get(list(targets))
if revision is not None and revision.date is not None
)
yield from (head for _, head in sorted(revisions))
diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py
index 495b528..db16d0b 100644
--- a/swh/provenance/tests/conftest.py
+++ b/swh/provenance/tests/conftest.py
@@ -1,157 +1,153 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from os import path
from typing import Any, Dict, Iterable, Iterator
+from _pytest.fixtures import SubRequest
import msgpack
-import psycopg2
+import psycopg2.extensions
import pytest
from swh.journal.serializers import msgpack_ext_hook
from swh.model.tests.swh_model_data import TEST_OBJECTS
from swh.provenance import get_provenance, get_provenance_storage
from swh.provenance.api.client import RemoteProvenanceStorage
import swh.provenance.api.server as server
from swh.provenance.archive import ArchiveInterface
from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface
from swh.provenance.postgresql.archive import ArchivePostgreSQL
from swh.provenance.storage.archive import ArchiveStorage
from swh.storage.postgresql.storage import Storage
from swh.storage.replay import process_replay_objects
@pytest.fixture(
params=[
"with-path",
"without-path",
"with-path-denormalized",
"without-path-denormalized",
]
)
def populated_db(
- request, # TODO: add proper type annotation
+ request: SubRequest,
postgresql: psycopg2.extensions.connection,
) -> Dict[str, str]:
"""return a working and initialized provenance db"""
from swh.core.cli.db import populate_database_for_package
- # flavor = "with-path" if request.param == "client-server" else request.param
populate_database_for_package(
"swh.provenance", postgresql.dsn, flavor=request.param
)
- return {
- k: v
- for (k, v) in (item.split("=") for item in postgresql.dsn.split())
- if k != "options"
- }
+ return postgresql.get_dsn_parameters()
# the Flask app used as server in these tests
@pytest.fixture
-def app(populated_db: Dict[str, str]):
+def app(populated_db: Dict[str, str]) -> Iterator[server.ProvenanceStorageServerApp]:
assert hasattr(server, "storage")
server.storage = get_provenance_storage(cls="local", db=populated_db)
yield server.app
# the RPCClient class used as client used in these tests
@pytest.fixture
-def swh_rpc_client_class():
+def swh_rpc_client_class() -> type:
return RemoteProvenanceStorage
@pytest.fixture(params=["local", "remote"])
def provenance(
- request, # TODO: add proper type annotation
+ request: SubRequest,
populated_db: Dict[str, str],
swh_rpc_client: RemoteProvenanceStorage,
) -> ProvenanceInterface:
- """return a working and initialized provenance db"""
+ """Return a working and initialized ProvenanceInterface object"""
if request.param == "remote":
from swh.provenance.provenance import Provenance
assert isinstance(swh_rpc_client, ProvenanceStorageInterface)
return Provenance(swh_rpc_client)
else:
# in test sessions, we DO want to raise any exception occurring at commit time
prov = get_provenance(cls=request.param, db=populated_db, raise_on_commit=True)
return prov
@pytest.fixture
def swh_storage_with_objects(swh_storage: Storage) -> Storage:
"""return a Storage object (postgresql-based by default) with a few of each
object type in it
The inserted content comes from swh.model.tests.swh_model_data.
"""
for obj_type in (
"content",
"skipped_content",
"directory",
"revision",
"release",
"snapshot",
"origin",
"origin_visit",
"origin_visit_status",
):
getattr(swh_storage, f"{obj_type}_add")(TEST_OBJECTS[obj_type])
return swh_storage
@pytest.fixture
def archive_direct(swh_storage_with_objects: Storage) -> ArchiveInterface:
return ArchivePostgreSQL(swh_storage_with_objects.get_db().conn)
@pytest.fixture
def archive_api(swh_storage_with_objects: Storage) -> ArchiveInterface:
return ArchiveStorage(swh_storage_with_objects)
@pytest.fixture(params=["archive", "db"])
def archive(request, swh_storage_with_objects: Storage) -> Iterator[ArchiveInterface]:
"""Return a ArchivePostgreSQL based StorageInterface object"""
# this is a workaround to prevent tests from hanging because of an unclosed
# transaction.
# TODO: refactor the ArchivePostgreSQL to properly deal with
# transactions and get rid of this fixture
if request.param == "db":
archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn)
yield archive
archive.conn.rollback()
else:
yield ArchiveStorage(swh_storage_with_objects)
def get_datafile(fname: str) -> str:
return path.join(path.dirname(__file__), "data", fname)
def load_repo_data(repo: str) -> Dict[str, Any]:
data: Dict[str, Any] = {}
with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj:
unpacker = msgpack.Unpacker(
fobj,
raw=False,
ext_hook=msgpack_ext_hook,
strict_map_key=False,
timestamp=3, # convert Timestamp in datetime objects (tz UTC)
)
for objtype, objd in unpacker:
data.setdefault(objtype, []).append(objd)
return data
def filter_dict(d: Dict[Any, Any], keys: Iterable[Any]) -> Dict[Any, Any]:
return {k: v for (k, v) in d.items() if k in keys}
def fill_storage(storage: Storage, data: Dict[str, Any]) -> None:
process_replay_objects(data, storage=storage)
diff --git a/swh/provenance/tests/data/generate_repo.py b/swh/provenance/tests/data/generate_repo.py
index 734838b..55afc31 100644
--- a/swh/provenance/tests/data/generate_repo.py
+++ b/swh/provenance/tests/data/generate_repo.py
@@ -1,115 +1,116 @@
import os
import pathlib
import shutil
from subprocess import PIPE, check_call, check_output
+from typing import Any, Dict, List
import click
import yaml
-def clean_wd():
+def clean_wd() -> None:
_, dirnames, filenames = next(os.walk("."))
for d in dirnames:
if not d.startswith(".git"):
shutil.rmtree(d)
for f in filenames:
if not f.startswith(".git"):
os.unlink(f)
-def print_ids():
+def print_ids() -> None:
revid = check_output(["git", "rev-parse", "HEAD"]).decode().strip()
ts, msg = (
check_output(["git", "log", "-1", '--format="%at %s"'])
.decode()
.strip()[1:-1]
.split()
)
print(f"{ts}.0 {revid} {msg}")
print(f"{msg:<5} | {'':>5} | {'':>20} | R {revid} | {ts}.0")
for currentpath, dirnames, filenames in os.walk("."):
if currentpath == ".":
output = check_output(["git", "cat-file", "-p", "HEAD"]).decode()
dirhash = output.splitlines()[0].split()[1]
else:
currentpath = currentpath[2:]
output = check_output(["git", "ls-tree", "HEAD", currentpath]).decode()
dirhash = output.split()[2]
print(f"{'':>5} | {'':>5} | {currentpath:<20} | D {dirhash} | 0.0")
for fname in filenames:
fname = os.path.join(currentpath, fname)
output = check_output(["git", "ls-tree", "HEAD", fname]).decode()
fhash = output.split()[2]
print(f"{'':>5} | {'':>5} | {fname:<20} | C {fhash} | 0.0")
if ".git" in dirnames:
dirnames.remove(".git")
-def generate_repo(repo_desc, output_dir):
+def generate_repo(repo_desc: List[Dict[str, Any]], output_dir: str) -> None:
check_call(["git", "init", output_dir], stdout=PIPE, stderr=PIPE)
os.chdir(output_dir)
os.environ.update(
{
"GIT_AUTHOR_NAME": "SWH",
"GIT_AUTHOR_EMAIL": "contact@softwareheritage.org",
"GIT_COMMITTER_NAME": "SWH",
"GIT_COMMITTER_EMAIL": "contact@softwareheritage.org",
}
)
for rev_d in repo_desc:
parents = rev_d.get("parents")
if parents:
# move at the proper (first) parent position, if any
check_call(["git", "checkout", parents[0]], stdout=PIPE)
# give a branch name (the msg) to each commit to make it easier to
# navigate in history
check_call(["git", "checkout", "-b", rev_d["msg"]], stdout=PIPE)
if parents and len(parents) > 1:
# it's a merge
check_call(["git", "merge", "--no-commit", *parents[1:]], stdout=PIPE)
clean_wd()
for path, content in rev_d["content"].items():
p = pathlib.Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(content)
os.environ.update(
{
"GIT_AUTHOR_DATE": str(rev_d["date"]),
"GIT_COMMITTER_DATE": str(rev_d["date"]),
}
)
check_call(["git", "add", "."], stdout=PIPE)
check_call(
[
"git",
"commit",
"--all",
"--allow-empty",
"-m",
rev_d["msg"],
],
stdout=PIPE,
)
print_ids()
@click.command(name="generate-repo")
@click.argument("input-file")
@click.argument("output-dir")
@click.option("-C", "--clean-output/--no-clean-output", default=False)
-def main(input_file, output_dir, clean_output):
+def main(input_file: str, output_dir: str, clean_output: bool) -> None:
repo_desc = yaml.load(open(input_file))
if clean_output and os.path.exists(output_dir):
shutil.rmtree(output_dir)
generate_repo(repo_desc, output_dir)
if __name__ == "__main__":
main()
diff --git a/swh/provenance/tests/data/generate_storage_from_git.py b/swh/provenance/tests/data/generate_storage_from_git.py
index a9e5ddf..e31c54a 100644
--- a/swh/provenance/tests/data/generate_storage_from_git.py
+++ b/swh/provenance/tests/data/generate_storage_from_git.py
@@ -1,115 +1,119 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime, timezone
import os
from subprocess import check_output
+from typing import Dict, Optional
import click
import yaml
from swh.loader.git.from_disk import GitLoaderFromDisk
from swh.model.hashutil import hash_to_bytes
from swh.model.model import (
Origin,
OriginVisit,
OriginVisitStatus,
Snapshot,
SnapshotBranch,
TargetType,
)
from swh.storage import get_storage
+from swh.storage.interface import StorageInterface
-def load_git_repo(url, directory, storage):
+def load_git_repo(
+ url: str, directory: str, storage: StorageInterface
+) -> Dict[str, str]:
visit_date = datetime.now(tz=timezone.utc)
loader = GitLoaderFromDisk(
url=url,
directory=directory,
visit_date=visit_date,
storage=storage,
)
return loader.load()
-def pop_key(d, k):
- d.pop(k)
- return d
-
-
@click.command()
@click.option("-o", "--output", default=None, help="output file")
@click.option(
"-v",
"--visits",
type=click.File(mode="rb"),
default=None,
help="additional visits to generate.",
)
@click.argument("git-repo", type=click.Path(exists=True, file_okay=False))
-def main(output, visits, git_repo):
+def main(output: Optional[str], visits: bytes, git_repo: str) -> None:
"simple tool to generate the git_repo.msgpack dataset file used in some tests"
if output is None:
output = f"{git_repo}.msgpack"
with open(output, "wb") as outstream:
sto = get_storage(
cls="memory", journal_writer={"cls": "stream", "output_stream": outstream}
)
if git_repo.endswith("/"):
git_repo = git_repo[:-1]
reponame = os.path.basename(git_repo)
load_git_repo(f"https://{reponame}", git_repo, sto)
if visits:
# retrieve all branches from the actual git repo
all_branches = {
ref: sha1
for sha1, ref in (
line.strip().split()
for line in check_output(["git", "-C", git_repo, "show-ref"])
.decode()
.splitlines()
)
}
for visit in yaml.full_load(visits):
# add the origin (if it already exists, this is a noop)
sto.origin_add([Origin(url=visit["origin"])])
# add a new visit for this origin
- visit_id = sto.origin_visit_add(
- [
- OriginVisit(
- origin=visit["origin"],
- date=datetime.fromtimestamp(visit["date"], tz=timezone.utc),
- type="git",
- )
- ]
+ visit_id = list(
+ sto.origin_visit_add(
+ [
+ OriginVisit(
+ origin=visit["origin"],
+ date=datetime.fromtimestamp(
+ visit["date"], tz=timezone.utc
+ ),
+ type="git",
+ )
+ ]
+ )
)[0].visit
+ assert visit_id is not None
# add a snapshot with branches from the input file
branches = {
f"refs/heads/{name}".encode(): SnapshotBranch(
target=hash_to_bytes(all_branches[f"refs/heads/{name}"]),
target_type=TargetType.REVISION,
)
for name in visit["branches"]
}
snap = Snapshot(branches=branches)
sto.snapshot_add([snap])
# add a "closing" origin visit status update referencing the snapshot
status = OriginVisitStatus(
origin=visit["origin"],
visit=visit_id,
date=datetime.fromtimestamp(visit["date"], tz=timezone.utc),
status="full",
snapshot=snap.id,
)
sto.origin_visit_status_add([status])
click.echo(f"Serialized the storage made from {reponame} in {output}")
if __name__ == "__main__":
main()
diff --git a/swh/provenance/tests/test_cli.py b/swh/provenance/tests/test_cli.py
index eb8c1b3..cceb30a 100644
--- a/swh/provenance/tests/test_cli.py
+++ b/swh/provenance/tests/test_cli.py
@@ -1,103 +1,104 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from typing import Set
+from _pytest.monkeypatch import MonkeyPatch
from click.testing import CliRunner
-import psycopg2
+import psycopg2.extensions
import pytest
from swh.core.cli import swh as swhmain
import swh.core.cli.db # noqa ; ensure cli is loaded
from swh.core.db import BaseDb
import swh.provenance.cli # noqa ; ensure cli is loaded
def test_cli_swh_db_help() -> None:
# swhmain.add_command(provenance_cli)
result = CliRunner().invoke(swhmain, ["provenance", "-h"])
assert result.exit_code == 0
assert "Commands:" in result.output
commands = result.output.split("Commands:")[1]
for command in (
"find-all",
"find-first",
"iter-origins",
"iter-revisions",
):
assert f" {command} " in commands
TABLES = {
"dbflavor",
"dbversion",
"content",
"content_in_revision",
"content_in_directory",
"directory",
"directory_in_revision",
"location",
"origin",
"revision",
"revision_before_revision",
"revision_in_origin",
}
@pytest.mark.parametrize(
"flavor, dbtables", (("with-path", TABLES | {"location"}), ("without-path", TABLES))
)
def test_cli_db_create_and_init_db_with_flavor(
- monkeypatch, # TODO: add proper type annotation
+ monkeypatch: MonkeyPatch,
postgresql: psycopg2.extensions.connection,
flavor: str,
dbtables: Set[str],
) -> None:
"""Test that 'swh db init provenance' works with flavors
for both with-path and without-path flavors"""
dbname = f"{flavor}-db"
# DB creation using 'swh db create'
db_params = postgresql.get_dsn_parameters()
monkeypatch.setenv("PGHOST", db_params["host"])
monkeypatch.setenv("PGUSER", db_params["user"])
monkeypatch.setenv("PGPORT", db_params["port"])
result = CliRunner().invoke(swhmain, ["db", "create", "-d", dbname, "provenance"])
assert result.exit_code == 0, result.output
# DB init using 'swh db init'
result = CliRunner().invoke(
swhmain, ["db", "init", "-d", dbname, "--flavor", flavor, "provenance"]
)
assert result.exit_code == 0, result.output
assert f"(flavor {flavor})" in result.output
db_params["dbname"] = dbname
cnx = BaseDb.connect(**db_params).conn
# check the DB looks OK (check for db_flavor and expected tables)
with cnx.cursor() as cur:
cur.execute("select swh_get_dbflavor()")
assert cur.fetchone() == (flavor,)
cur.execute(
"select table_name from information_schema.tables "
"where table_schema = 'public' "
f"and table_catalog = '{dbname}'"
)
tables = set(x for (x,) in cur.fetchall())
assert tables == dbtables
def test_cli_init_db_default_flavor(postgresql: psycopg2.extensions.connection) -> None:
"Test that 'swh db init provenance' defaults to a with-path flavored DB"
dbname = postgresql.dsn
result = CliRunner().invoke(swhmain, ["db", "init", "-d", dbname, "provenance"])
assert result.exit_code == 0, result.output
with postgresql.cursor() as cur:
cur.execute("select swh_get_dbflavor()")
assert cur.fetchone() == ("with-path",)
diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py
index 76bc1e9..b8d83b3 100644
--- a/swh/provenance/tests/test_provenance_storage.py
+++ b/swh/provenance/tests/test_provenance_storage.py
@@ -1,46 +1,46 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import inspect
from ..interface import ProvenanceInterface, ProvenanceStorageInterface
-def test_types(provenance: ProvenanceInterface):
+def test_types(provenance: ProvenanceInterface) -> None:
"""Checks all methods of ProvenanceStorageInterface are implemented by this
backend, and that they have the same signature."""
# Create an instance of the protocol (which cannot be instantiated
# directly, so this creates a subclass, then instantiates it)
interface = type("_", (ProvenanceStorageInterface,), {})()
storage = provenance.storage
assert "content_find_first" in dir(interface)
missing_methods = []
for meth_name in dir(interface):
if meth_name.startswith("_"):
continue
interface_meth = getattr(interface, meth_name)
try:
concrete_meth = getattr(storage, meth_name)
except AttributeError:
if not getattr(interface_meth, "deprecated_endpoint", False):
# The backend is missing a (non-deprecated) endpoint
missing_methods.append(meth_name)
continue
expected_signature = inspect.signature(interface_meth)
actual_signature = inspect.signature(concrete_meth)
assert expected_signature == actual_signature, meth_name
assert missing_methods == []
# If all the assertions above succeed, then this one should too.
# But there's no harm in double-checking.
# And we could replace the assertions above by this one, but unlike
# the assertions above, it doesn't explain what is missing.
assert isinstance(storage, ProvenanceStorageInterface)