Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/mypy.ini b/mypy.ini
index 17eee37..53b0ffb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,37 +1,39 @@
[mypy]
namespace_packages = True
warn_unused_ignores = True
exclude = swh/provenance/tools/
-
# 3rd party libraries without stubs (yet)
[mypy-bson.*]
ignore_missing_imports = True
+[mypy-confluent_kafka.*]
+ignore_missing_imports = True
+
[mypy-iso8601.*]
ignore_missing_imports = True
[mypy-methodtools.*]
ignore_missing_imports = True
[mypy-msgpack.*]
ignore_missing_imports = True
[mypy-pika.*]
ignore_missing_imports = True
[mypy-pkg_resources.*]
ignore_missing_imports = True
[mypy-pymongo.*]
ignore_missing_imports = True
[mypy-pytest.*]
ignore_missing_imports = True
[mypy-pytest_postgresql.*]
ignore_missing_imports = True
[mypy-psycopg2.*]
ignore_missing_imports = True
diff --git a/requirements-swh.txt b/requirements-swh.txt
index 8cec4fc..840225b 100644
--- a/requirements-swh.txt
+++ b/requirements-swh.txt
@@ -1,5 +1,6 @@
# Add here internal Software Heritage dependencies, one per line.
swh.core[db,http] >= 0.14
swh.model >= 2.6.1
swh.storage
swh.graph
+swh.journal
diff --git a/requirements-test.txt b/requirements-test.txt
index 934cbf1..9d2d915 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,7 +1,8 @@
pytest
pytest-mongodb
pytest-rabbitmq
swh.loader.git >= 0.8
swh.journal >= 0.8
swh.storage >= 0.40
swh.graph >= 0.3.2
+types-Deprecated
diff --git a/requirements.txt b/requirements.txt
index 669a745..6b70f34 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,15 @@
# Add here external Python modules dependencies, one per line. Module names
# should match https://pypi.python.org/pypi names. For the full spec or
# dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
click
+deprecated
iso8601
methodtools
mongomock
pika
pymongo
PyYAML
types-click
types-PyYAML
zmq
diff --git a/swh/provenance/cli.py b/swh/provenance/cli.py
index d4ec64e..8f51ac2 100644
--- a/swh/provenance/cli.py
+++ b/swh/provenance/cli.py
@@ -1,352 +1,427 @@
-# Copyright (C) 2021 The Software Heritage developers
+# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
# WARNING: do not import unnecessary things here to keep cli startup time under
# control
from datetime import datetime, timezone
+from functools import partial
import os
from typing import Any, Dict, Generator, Optional, Tuple
import click
+from deprecated import deprecated
import iso8601
import yaml
from swh.core import config
from swh.core.cli import CONTEXT_SETTINGS
from swh.core.cli import swh as swh_cli_group
from swh.model.hashutil import hash_to_bytes, hash_to_hex
from swh.model.model import Sha1Git
# All generic config code should reside in swh.core.config
CONFIG_ENVVAR = "SWH_CONFIG_FILENAME"
DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None)
DEFAULT_CONFIG: Dict[str, Any] = {
"provenance": {
"archive": {
# Storage API based Archive object
# "cls": "api",
# "storage": {
# "cls": "remote",
# "url": "http://uffizi.internal.softwareheritage.org:5002",
# }
# Direct access Archive object
"cls": "direct",
"db": {
"host": "belvedere.internal.softwareheritage.org",
"port": 5432,
"dbname": "softwareheritage",
"user": "guest",
},
},
"storage": {
# Local PostgreSQL Storage
# "cls": "postgresql",
# "db": {
# "host": "localhost",
# "user": "postgres",
# "password": "postgres",
# "dbname": "provenance",
# },
# Local MongoDB Storage
# "cls": "mongodb",
# "db": {
# "dbname": "provenance",
# },
# Remote RabbitMQ/PostgreSQL Storage
"cls": "rabbitmq",
"url": "amqp://localhost:5672/%2f",
"storage_config": {
"cls": "postgresql",
"db": {
"host": "localhost",
"user": "postgres",
"password": "postgres",
"dbname": "provenance",
},
},
"batch_size": 100,
"prefetch_count": 100,
},
}
}
CONFIG_FILE_HELP = f"""
\b Configuration can be loaded from a yaml file given either as --config-file
option or the {CONFIG_ENVVAR} environment variable. If no configuration file
is specified, use the following default configuration::
\b
{yaml.dump(DEFAULT_CONFIG)}"""
PROVENANCE_HELP = f"""Software Heritage provenance index database tools
{CONFIG_FILE_HELP}
"""
@swh_cli_group.group(
name="provenance", context_settings=CONTEXT_SETTINGS, help=PROVENANCE_HELP
)
@click.option(
"-C",
"--config-file",
default=None,
type=click.Path(exists=True, dir_okay=False, path_type=str),
help="""YAML configuration file.""",
)
@click.option(
"-P",
"--profile",
default=None,
type=click.Path(exists=False, dir_okay=False, path_type=str),
help="""Enable profiling to specified file.""",
)
@click.pass_context
def cli(ctx: click.core.Context, config_file: Optional[str], profile: str) -> None:
if (
config_file is None
and DEFAULT_PATH is not None
and config.config_exists(DEFAULT_PATH)
):
config_file = DEFAULT_PATH
if config_file is None:
conf = DEFAULT_CONFIG
else:
# read_raw_config do not fail on ENOENT
if not os.path.exists(config_file):
raise FileNotFoundError(config_file)
conf = yaml.safe_load(open(config_file, "rb"))
ctx.ensure_object(dict)
ctx.obj["config"] = conf
if profile:
import atexit
import cProfile
print("Profiling...")
pr = cProfile.Profile()
pr.enable()
def exit() -> None:
pr.disable()
pr.dump_stats(profile)
atexit.register(exit)
+@cli.group(name="origin")
+@click.pass_context
+def origin(ctx: click.core.Context):
+ from . import get_archive, get_provenance
+
+ archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
+ provenance = get_provenance(**ctx.obj["config"]["provenance"]["storage"])
+
+ ctx.obj["provenance"] = provenance
+ ctx.obj["archive"] = archive
+
+
+@origin.command(name="from-csv")
+@click.argument("filename", type=click.Path(exists=True))
+@click.option(
+ "-l",
+ "--limit",
+ type=int,
+ help="""Limit the amount of entries (origins) to read from the input file.""",
+)
+@click.pass_context
+def origin_from_csv(ctx: click.core.Context, filename: str, limit: Optional[int]):
+ from .origin import CSVOriginIterator, origin_add
+
+ provenance = ctx.obj["provenance"]
+ archive = ctx.obj["archive"]
+
+ origins_provider = generate_origin_tuples(filename)
+ origins = CSVOriginIterator(origins_provider, limit=limit)
+
+ with provenance:
+ for origin in origins:
+ origin_add(provenance, archive, [origin])
+
+
+@origin.command(name="from-journal")
+@click.pass_context
+def origin_from_journal(ctx: click.core.Context):
+ from swh.journal.client import get_journal_client
+
+ from .journal_client import process_journal_objects
+
+ provenance = ctx.obj["provenance"]
+ archive = ctx.obj["archive"]
+
+ journal_cfg = ctx.obj["config"].get("journal_client", {})
+
+ worker_fn = partial(
+ process_journal_objects,
+ archive=archive,
+ provenance=provenance,
+ )
+
+ cls = journal_cfg.pop("cls", None) or "kafka"
+ client = get_journal_client(
+ cls,
+ **{
+ **journal_cfg,
+ "object_types": ["origin_visit_status"],
+ },
+ )
+
+ try:
+ client.process(worker_fn)
+ except KeyboardInterrupt:
+ ctx.exit(0)
+ else:
+ print("Done.")
+ finally:
+ client.close()
+
+
@cli.command(name="iter-frontiers")
@click.argument("filename")
@click.option(
"-l",
"--limit",
type=int,
help="""Limit the amount of entries (directories) to read from the input file.""",
)
@click.option(
"-s",
"--min-size",
default=0,
type=int,
help="""Set the minimum size (in bytes) of files to be indexed. """
"""Any smaller file will be ignored.""",
)
@click.pass_context
def iter_frontiers(
ctx: click.core.Context,
filename: str,
limit: Optional[int],
min_size: int,
) -> None:
"""Process a provided list of directories in the isochrone frontier."""
from . import get_archive, get_provenance
from .directory import CSVDirectoryIterator, directory_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
directories_provider = generate_directory_ids(filename)
directories = CSVDirectoryIterator(directories_provider, limit=limit)
with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
for directory in directories:
directory_add(
provenance,
archive,
[directory],
minsize=min_size,
)
def generate_directory_ids(
filename: str,
) -> Generator[Sha1Git, None, None]:
for line in open(filename, "r"):
if line.strip():
yield hash_to_bytes(line.strip())
@cli.command(name="iter-revisions")
@click.argument("filename")
@click.option(
"-a",
"--track-all",
default=True,
type=bool,
help="""Index all occurrences of files in the development history.""",
)
@click.option(
"-f",
"--flatten",
default=True,
type=bool,
help="""Create flat models for directories in the isochrone frontier.""",
)
@click.option(
"-l",
"--limit",
type=int,
help="""Limit the amount of entries (revisions) to read from the input file.""",
)
@click.option(
"-m",
"--min-depth",
default=1,
type=int,
help="""Set minimum depth (in the directory tree) at which an isochrone """
"""frontier can be defined.""",
)
@click.option(
"-r",
"--reuse",
default=True,
type=bool,
help="""Prioritize the usage of previously defined isochrone frontiers """
"""whenever possible.""",
)
@click.option(
"-s",
"--min-size",
default=0,
type=int,
help="""Set the minimum size (in bytes) of files to be indexed. """
"""Any smaller file will be ignored.""",
)
@click.pass_context
def iter_revisions(
ctx: click.core.Context,
filename: str,
track_all: bool,
flatten: bool,
limit: Optional[int],
min_depth: int,
reuse: bool,
min_size: int,
) -> None:
"""Process a provided list of revisions."""
from . import get_archive, get_provenance
from .revision import CSVRevisionIterator, revision_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
revisions_provider = generate_revision_tuples(filename)
revisions = CSVRevisionIterator(revisions_provider, limit=limit)
with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
for revision in revisions:
revision_add(
provenance,
archive,
[revision],
trackall=track_all,
flatten=flatten,
lower=reuse,
mindepth=min_depth,
minsize=min_size,
)
def generate_revision_tuples(
filename: str,
) -> Generator[Tuple[Sha1Git, datetime, Sha1Git], None, None]:
for line in open(filename, "r"):
if line.strip():
revision, date, root = line.strip().split(",")
yield (
hash_to_bytes(revision),
iso8601.parse_date(date, default_timezone=timezone.utc),
hash_to_bytes(root),
)
@cli.command(name="iter-origins")
@click.argument("filename")
@click.option(
"-l",
"--limit",
type=int,
help="""Limit the amount of entries (origins) to read from the input file.""",
)
@click.pass_context
+@deprecated(version="0.0.1", reason="Use `swh provenance origin from-csv` instead")
def iter_origins(ctx: click.core.Context, filename: str, limit: Optional[int]) -> None:
"""Process a provided list of origins."""
from . import get_archive, get_provenance
from .origin import CSVOriginIterator, origin_add
archive = get_archive(**ctx.obj["config"]["provenance"]["archive"])
origins_provider = generate_origin_tuples(filename)
origins = CSVOriginIterator(origins_provider, limit=limit)
with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
for origin in origins:
origin_add(provenance, archive, [origin])
def generate_origin_tuples(filename: str) -> Generator[Tuple[str, bytes], None, None]:
for line in open(filename, "r"):
if line.strip():
url, snapshot = line.strip().split(",")
yield (url, hash_to_bytes(snapshot))
@cli.command(name="find-first")
@click.argument("swhid")
@click.pass_context
def find_first(ctx: click.core.Context, swhid: str) -> None:
"""Find first occurrence of the requested blob."""
from . import get_provenance
with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
occur = provenance.content_find_first(hash_to_bytes(swhid))
if occur is not None:
print(
f"swh:1:cnt:{hash_to_hex(occur.content)}, "
f"swh:1:rev:{hash_to_hex(occur.revision)}, "
f"{occur.date}, "
f"{occur.origin}, "
f"{os.fsdecode(occur.path)}"
)
else:
print(f"Cannot find a content with the id {swhid}")
@cli.command(name="find-all")
@click.argument("swhid")
@click.option(
"-l", "--limit", type=int, help="""Limit the amount results to be retrieved."""
)
@click.pass_context
def find_all(ctx: click.core.Context, swhid: str, limit: Optional[int]) -> None:
"""Find all occurrences of the requested blob."""
from . import get_provenance
with get_provenance(**ctx.obj["config"]["provenance"]["storage"]) as provenance:
for occur in provenance.content_find_all(hash_to_bytes(swhid), limit=limit):
print(
f"swh:1:cnt:{hash_to_hex(occur.content)}, "
f"swh:1:rev:{hash_to_hex(occur.revision)}, "
f"{occur.date}, "
f"{occur.origin}, "
f"{os.fsdecode(occur.path)}"
)
diff --git a/swh/provenance/journal_client.py b/swh/provenance/journal_client.py
new file mode 100644
index 0000000..8dda133
--- /dev/null
+++ b/swh/provenance/journal_client.py
@@ -0,0 +1,22 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from swh.provenance.interface import ProvenanceInterface
+from swh.provenance.model import OriginEntry
+from swh.provenance.origin import origin_add
+from swh.storage.interface import StorageInterface
+
+
+def process_journal_objects(
+ messages, *, provenance: ProvenanceInterface, archive: StorageInterface
+) -> None:
+ """Worker function for `JournalClient.process(worker_fn)`."""
+ assert set(messages) == {"origin_visit_status"}, set(messages)
+ origin_entries = [
+ OriginEntry(url=visit["origin"], snapshot=visit["snapshot"])
+ for visit in messages["origin_visit_status"]
+ if visit["snapshot"] is not None
+ ]
+ origin_add(provenance, archive, origin_entries)
diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py
index 17d313d..5e676b3 100644
--- a/swh/provenance/postgresql/provenance.py
+++ b/swh/provenance/postgresql/provenance.py
@@ -1,379 +1,384 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from functools import wraps
import itertools
import logging
from types import TracebackType
from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union
import psycopg2.extensions
import psycopg2.extras
from swh.core.db import BaseDb
from swh.core.statsd import statsd
from swh.model.model import Sha1Git
from ..interface import (
DirectoryData,
EntityType,
ProvenanceResult,
ProvenanceStorageInterface,
RelationData,
RelationType,
RevisionData,
)
LOGGER = logging.getLogger(__name__)
STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds"
def handle_raise_on_commit(f):
@wraps(f)
def handle(self, *args, **kwargs):
try:
return f(self, *args, **kwargs)
except BaseException as ex:
# Unexpected error occurred, rollback all changes and log message
LOGGER.exception("Unexpected error")
if self.raise_on_commit:
raise ex
return False
return handle
class ProvenanceStoragePostgreSql:
def __init__(
self, page_size: Optional[int] = None, raise_on_commit: bool = False, **kwargs
) -> None:
+ self.conn: Optional[psycopg2.extensions.connection] = None
self.conn_args = kwargs
self._flavor: Optional[str] = None
self.page_size = page_size
self.raise_on_commit = raise_on_commit
def __enter__(self) -> ProvenanceStorageInterface:
self.open()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.close()
@contextmanager
def transaction(
self, readonly: bool = False
) -> Generator[psycopg2.extras.RealDictCursor, None, None]:
+ if self.conn is None: # somehow, "implicit" __enter__ call did not happen
+ self.open()
+ assert self.conn is not None
self.conn.set_session(readonly=readonly)
with self.conn:
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
yield cur
@property
def flavor(self) -> str:
if self._flavor is None:
with self.transaction(readonly=True) as cursor:
cursor.execute("SELECT swh_get_dbflavor() AS flavor")
self._flavor = cursor.fetchone()["flavor"]
assert self._flavor is not None
return self._flavor
@property
def denormalized(self) -> bool:
return "denormalized" in self.flavor
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"})
def close(self) -> None:
+ assert self.conn is not None
self.conn.close()
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"})
@handle_raise_on_commit
def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool:
if cnts:
sql = """
INSERT INTO content(sha1, date) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET date=LEAST(EXCLUDED.date,content.date)
"""
page_size = self.page_size or len(cnts)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cursor, sql, argslist=cnts.items(), page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"})
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
sql = "SELECT * FROM swh_provenance_content_find_first(%s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=(id,))
row = cursor.fetchone()
return ProvenanceResult(**row) if row is not None else None
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"})
def content_find_all(
self, id: Sha1Git, limit: Optional[int] = None
) -> Generator[ProvenanceResult, None, None]:
sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=(id, limit))
yield from (ProvenanceResult(**row) for row in cursor)
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"})
def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
dates: Dict[Sha1Git, datetime] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date
FROM content
WHERE sha1 IN ({values})
AND date IS NOT NULL
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
dates.update((row["sha1"], row["date"]) for row in cursor)
return dates
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"})
@handle_raise_on_commit
def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool:
data = [(sha1, rev.date, rev.flat) for sha1, rev in dirs.items()]
if data:
sql = """
INSERT INTO directory(sha1, date, flat) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET
date=LEAST(EXCLUDED.date, directory.date),
flat=(EXCLUDED.flat OR directory.flat)
"""
page_size = self.page_size or len(data)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor, sql=sql, argslist=data, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"})
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]:
result: Dict[Sha1Git, DirectoryData] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, date, flat
FROM directory
WHERE sha1 IN ({values})
AND date IS NOT NULL
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
result.update(
(row["sha1"], DirectoryData(date=row["date"], flat=row["flat"]))
for row in cursor
)
return result
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"})
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
with self.transaction(readonly=True) as cursor:
cursor.execute(f"SELECT sha1 FROM {entity.value}")
return {row["sha1"] for row in cursor}
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"})
@handle_raise_on_commit
def location_add(self, paths: Iterable[bytes]) -> bool:
if self.with_path():
values = [(path,) for path in paths]
if values:
sql = """
INSERT INTO location(path) VALUES %s
ON CONFLICT DO NOTHING
"""
page_size = self.page_size or len(values)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cursor, sql, argslist=values, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"})
def location_get_all(self) -> Set[bytes]:
with self.transaction(readonly=True) as cursor:
cursor.execute("SELECT location.path AS path FROM location")
return {row["path"] for row in cursor}
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"})
@handle_raise_on_commit
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
if orgs:
sql = """
INSERT INTO origin(sha1, url) VALUES %s
ON CONFLICT DO NOTHING
"""
page_size = self.page_size or len(orgs)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor,
sql=sql,
argslist=orgs.items(),
page_size=page_size,
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"})
def open(self) -> None:
self.conn = BaseDb.connect(**self.conn_args).conn
BaseDb.adapt_conn(self.conn)
with self.transaction() as cursor:
cursor.execute("SET timezone TO 'UTC'")
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"})
def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
urls: Dict[Sha1Git, str] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT sha1, url
FROM origin
WHERE sha1 IN ({values})
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
urls.update((row["sha1"], row["url"]) for row in cursor)
return urls
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"})
@handle_raise_on_commit
def revision_add(
self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]]
) -> bool:
if isinstance(revs, dict):
data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()]
else:
data = [(sha1, None, None) for sha1 in revs]
if data:
sql = """
INSERT INTO revision(sha1, date, origin)
(SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin
FROM (VALUES %s) AS V(rev, date, org)
LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git))
ON CONFLICT (sha1) DO
UPDATE SET
date=LEAST(EXCLUDED.date, revision.date),
origin=COALESCE(EXCLUDED.origin, revision.origin)
"""
page_size = self.page_size or len(data)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cur=cursor, sql=sql, argslist=data, page_size=page_size
)
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"})
def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]:
result: Dict[Sha1Git, RevisionData] = {}
sha1s = tuple(ids)
if sha1s:
# TODO: consider splitting this query in several ones if sha1s is too big!
values = ", ".join(itertools.repeat("%s", len(sha1s)))
sql = f"""
SELECT R.sha1, R.date, O.sha1 AS origin
FROM revision AS R
LEFT JOIN origin AS O ON (O.id=R.origin)
WHERE R.sha1 IN ({values})
AND (R.date is not NULL OR O.sha1 is not NULL)
"""
with self.transaction(readonly=True) as cursor:
cursor.execute(query=sql, vars=sha1s)
result.update(
(row["sha1"], RevisionData(date=row["date"], origin=row["origin"]))
for row in cursor
)
return result
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"})
@handle_raise_on_commit
def relation_add(
self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
) -> bool:
rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts]
if rows:
rel_table = relation.value
src_table, *_, dst_table = rel_table.split("_")
page_size = self.page_size or len(rows)
# Put the next three queries in a manual single transaction:
# they use the same temp table
with self.transaction() as cursor:
cursor.execute("SELECT swh_mktemp_relation_add()")
psycopg2.extras.execute_values(
cur=cursor,
sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s",
argslist=rows,
page_size=page_size,
)
sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)"
cursor.execute(query=sql, vars=(rel_table, src_table, dst_table))
return True
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"})
def relation_get(
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False
) -> Dict[Sha1Git, Set[RelationData]]:
return self._relation_get(relation, ids, reverse)
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"})
def relation_get_all(
self, relation: RelationType
) -> Dict[Sha1Git, Set[RelationData]]:
return self._relation_get(relation, None)
def _relation_get(
self,
relation: RelationType,
ids: Optional[Iterable[Sha1Git]],
reverse: bool = False,
) -> Dict[Sha1Git, Set[RelationData]]:
result: Dict[Sha1Git, Set[RelationData]] = {}
sha1s: List[Sha1Git]
if ids is not None:
sha1s = list(ids)
filter = "filter-src" if not reverse else "filter-dst"
else:
sha1s = []
filter = "no-filter"
if filter == "no-filter" or sha1s:
rel_table = relation.value
src_table, *_, dst_table = rel_table.split("_")
sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)"
with self.transaction(readonly=True) as cursor:
cursor.execute(
query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s)
)
for row in cursor:
src = row.pop("src")
result.setdefault(src, set()).add(RelationData(**row))
return result
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"})
def with_path(self) -> bool:
return "with-path" in self.flavor
diff --git a/swh/provenance/tests/data/origins.csv b/swh/provenance/tests/data/origins.csv
new file mode 100644
index 0000000..e7e44bc
--- /dev/null
+++ b/swh/provenance/tests/data/origins.csv
@@ -0,0 +1 @@
+https://cmdbts2,5f577c4d4e5a1d0bca64f78facfb891933b17d94
diff --git a/swh/provenance/tests/test_cli.py b/swh/provenance/tests/test_cli.py
index 2efae68..18ded81 100644
--- a/swh/provenance/tests/test_cli.py
+++ b/swh/provenance/tests/test_cli.py
@@ -1,109 +1,164 @@
-# Copyright (C) 2021 The Software Heritage developers
+# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from typing import Set
+from typing import Dict, List, Set
from _pytest.monkeypatch import MonkeyPatch
from click.testing import CliRunner
import psycopg2.extensions
import pytest
from swh.core.cli import swh as swhmain
import swh.core.cli.db # noqa ; ensure cli is loaded
from swh.core.db import BaseDb
from swh.core.db.db_utils import init_admin_extensions
+from swh.model.hashutil import MultiHash
import swh.provenance.cli # noqa ; ensure cli is loaded
+from swh.provenance.tests.conftest import fill_storage, load_repo_data
+from swh.storage.interface import StorageInterface
+
+from .conftest import get_datafile
+from .test_utils import invoke, write_configuration_path
def test_cli_swh_db_help() -> None:
# swhmain.add_command(provenance_cli)
result = CliRunner().invoke(swhmain, ["provenance", "-h"])
assert result.exit_code == 0
assert "Commands:" in result.output
commands = result.output.split("Commands:")[1]
for command in (
"find-all",
"find-first",
"iter-frontiers",
"iter-origins",
"iter-revisions",
):
assert f" {command} " in commands
TABLES = {
"dbflavor",
"dbmodule",
"dbversion",
"content",
"content_in_revision",
"content_in_directory",
"directory",
"directory_in_revision",
"location",
"origin",
"revision",
"revision_before_revision",
"revision_in_origin",
}
@pytest.mark.parametrize(
"flavor, dbtables", (("with-path", TABLES), ("without-path", TABLES))
)
def test_cli_db_create_and_init_db_with_flavor(
monkeypatch: MonkeyPatch,
postgresql: psycopg2.extensions.connection,
flavor: str,
dbtables: Set[str],
) -> None:
"""Test that 'swh db init provenance' works with flavors
for both with-path and without-path flavors"""
dbname = f"{flavor}-db"
# DB creation using 'swh db create'
db_params = postgresql.get_dsn_parameters()
monkeypatch.setenv("PGHOST", db_params["host"])
monkeypatch.setenv("PGUSER", db_params["user"])
monkeypatch.setenv("PGPORT", db_params["port"])
result = CliRunner().invoke(swhmain, ["db", "create", "-d", dbname, "provenance"])
assert result.exit_code == 0, result.output
# DB init using 'swh db init'
result = CliRunner().invoke(
swhmain, ["db", "init", "-d", dbname, "--flavor", flavor, "provenance"]
)
assert result.exit_code == 0, result.output
assert f"(flavor {flavor})" in result.output
db_params["dbname"] = dbname
cnx = BaseDb.connect(**db_params).conn
# check the DB looks OK (check for db_flavor and expected tables)
with cnx.cursor() as cur:
cur.execute("select swh_get_dbflavor()")
assert cur.fetchone() == (flavor,)
cur.execute(
"select table_name from information_schema.tables "
"where table_schema = 'public' "
f"and table_catalog = '{dbname}'"
)
tables = set(x for (x,) in cur.fetchall())
assert tables == dbtables
def test_cli_init_db_default_flavor(postgresql: psycopg2.extensions.connection) -> None:
"Test that 'swh db init provenance' defaults to a with-path flavored DB"
dbname = postgresql.dsn
init_admin_extensions("swh.provenance", dbname)
result = CliRunner().invoke(swhmain, ["db", "init", "-d", dbname, "provenance"])
assert result.exit_code == 0, result.output
with postgresql.cursor() as cur:
cur.execute("select swh_get_dbflavor()")
assert cur.fetchone() == ("with-path",)
+
+
+@pytest.mark.parametrize(
+ "subcommand",
+ (["origin", "from-csv"], ["iter-origins"]),
+)
+def test_cli_origin_from_csv(
+ swh_storage: StorageInterface,
+ subcommand: List[str],
+ swh_storage_backend_config: Dict,
+ provenance,
+ tmp_path,
+):
+ repo = "cmdbts2"
+ origin_url = f"https://{repo}"
+ data = load_repo_data(repo)
+ fill_storage(swh_storage, data)
+
+ assert len(data["origin"]) == 1
+ assert {"url": origin_url} in data["origin"]
+
+ cfg = {
+ "provenance": {
+ "archive": {
+ "cls": "api",
+ "storage": swh_storage_backend_config,
+ },
+ "storage": {
+ "cls": "postgresql",
+ # "db": provenance.storage.conn.dsn,
+ "db": provenance.storage.conn.get_dsn_parameters(),
+ },
+ },
+ }
+
+ config_path = write_configuration_path(cfg, tmp_path)
+
+ csv_filepath = get_datafile("origins.csv")
+ subcommand = subcommand + [csv_filepath]
+
+ result = invoke(subcommand, config_path)
+ assert result.exit_code == 0, f"Unexpected result: {result.output}"
+
+ origin_sha1 = MultiHash.from_data(
+ origin_url.encode(), hash_names=["sha1"]
+ ).digest()["sha1"]
+ actual_result = provenance.storage.origin_get([origin_sha1])
+
+ assert actual_result == {origin_sha1: origin_url}
diff --git a/swh/provenance/tests/test_journal_client.py b/swh/provenance/tests/test_journal_client.py
new file mode 100644
index 0000000..efc237e
--- /dev/null
+++ b/swh/provenance/tests/test_journal_client.py
@@ -0,0 +1,81 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from typing import Dict
+
+from confluent_kafka import Consumer
+import pytest
+
+from swh.model.hashutil import MultiHash
+from swh.provenance.tests.conftest import fill_storage, load_repo_data
+from swh.storage.interface import StorageInterface
+
+from .test_utils import invoke, write_configuration_path
+
+
+@pytest.fixture
+def swh_storage_backend_config(swh_storage_backend_config, kafka_server, kafka_prefix):
+ writer_config = {
+ "cls": "kafka",
+ "brokers": [kafka_server],
+ "client_id": "kafka_writer",
+ "prefix": kafka_prefix,
+ "anonymize": False,
+ }
+ yield {**swh_storage_backend_config, "journal_writer": writer_config}
+
+
+def test_cli_origin_from_journal_client(
+ swh_storage: StorageInterface,
+ swh_storage_backend_config: Dict,
+ kafka_prefix: str,
+ kafka_server: str,
+ consumer: Consumer,
+ tmp_path: str,
+ provenance,
+) -> None:
+ """Test origin journal client cli"""
+
+ # Prepare storage data
+ data = load_repo_data("cmdbts2")
+ assert len(data["origin"]) == 1
+ origin_url = data["origin"][0]["url"]
+ fill_storage(swh_storage, data)
+
+ # Prepare configuration for cli call
+ swh_storage_backend_config.pop("journal_writer", None) # no need for that config
+ storage_config_dict = swh_storage_backend_config
+ cfg = {
+ "journal_client": {
+ "cls": "kafka",
+ "brokers": [kafka_server],
+ "group_id": "toto",
+ "prefix": kafka_prefix,
+ "object_types": ["origin_visit_status"],
+ "stop_on_eof": True,
+ },
+ "provenance": {
+ "archive": {
+ "cls": "api",
+ "storage": storage_config_dict,
+ },
+ "storage": {
+ "cls": "postgresql",
+ "db": provenance.storage.conn.get_dsn_parameters(),
+ },
+ },
+ }
+ config_path = write_configuration_path(cfg, tmp_path)
+
+ # call the cli 'swh provenance origin from-journal'
+ result = invoke(["origin", "from-journal"], config_path)
+ assert result.exit_code == 0, f"Unexpected result: {result.output}"
+
+ origin_sha1 = MultiHash.from_data(
+ origin_url.encode(), hash_names=["sha1"]
+ ).digest()["sha1"]
+ actual_result = provenance.storage.origin_get([origin_sha1])
+
+ assert actual_result == {origin_sha1: origin_url}
diff --git a/swh/provenance/tests/test_utils.py b/swh/provenance/tests/test_utils.py
new file mode 100644
index 0000000..9fe7ba2
--- /dev/null
+++ b/swh/provenance/tests/test_utils.py
@@ -0,0 +1,31 @@
+# Copyright (C) 2022 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+
+from os.path import join
+from typing import Dict, List
+
+from click.testing import CliRunner, Result
+from yaml import safe_dump
+
+from swh.provenance.cli import cli
+
+
+def invoke(args: List[str], config_path: str, catch_exceptions: bool = False) -> Result:
+ """Invoke swh journal subcommands"""
+ runner = CliRunner()
+ result = runner.invoke(cli, ["-C" + config_path] + args)
+ if not catch_exceptions and result.exception:
+ print(result.output)
+ raise result.exception
+ return result
+
+
+def write_configuration_path(config: Dict, tmp_path: str) -> str:
+ """Serialize yaml dict on disk given a configuration dict and and a temporary path."""
+ config_path = join(str(tmp_path), "config.yml")
+ with open(config_path, "w") as f:
+ f.write(safe_dump(config))
+ return config_path

File Metadata

Mime Type
text/x-diff
Expires
Jul 4 2025, 10:09 AM (5 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3269274

Event Timeline