diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py --- a/swh/provenance/postgresql/archive.py +++ b/swh/provenance/postgresql/archive.py @@ -17,46 +17,46 @@ @lru_cache(maxsize=1000000) def directory_ls_internal(self, id: bytes) -> List[Dict[str, Any]]: # TODO: add file size filtering - cursor = self.conn.cursor() - cursor.execute( - """WITH - dir AS (SELECT id AS dir_id, dir_entries, file_entries, rev_entries - FROM directory WHERE id=%s), - ls_d AS (SELECT dir_id, UNNEST(dir_entries) AS entry_id FROM dir), - ls_f AS (SELECT dir_id, UNNEST(file_entries) AS entry_id FROM dir), - ls_r AS (SELECT dir_id, UNNEST(rev_entries) AS entry_id FROM dir) - (SELECT 'dir'::directory_entry_type AS type, e.target, e.name, - NULL::sha1_git - FROM ls_d - LEFT JOIN directory_entry_dir e ON ls_d.entry_id=e.id) - UNION - (WITH known_contents AS - (SELECT 'file'::directory_entry_type AS type, e.target, e.name, - c.sha1_git - FROM ls_f - LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id - INNER JOIN content c ON e.target=c.sha1_git) - SELECT * FROM known_contents + with self.conn.cursor() as cursor: + cursor.execute( + """WITH + dir AS (SELECT id AS dir_id, dir_entries, file_entries, rev_entries + FROM directory WHERE id=%s), + ls_d AS (SELECT dir_id, UNNEST(dir_entries) AS entry_id FROM dir), + ls_f AS (SELECT dir_id, UNNEST(file_entries) AS entry_id FROM dir), + ls_r AS (SELECT dir_id, UNNEST(rev_entries) AS entry_id FROM dir) + (SELECT 'dir'::directory_entry_type AS type, e.target, e.name, + NULL::sha1_git + FROM ls_d + LEFT JOIN directory_entry_dir e ON ls_d.entry_id=e.id) UNION - (SELECT 'file'::directory_entry_type AS type, e.target, e.name, - c.sha1_git - FROM ls_f - LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id - LEFT JOIN skipped_content c ON e.target=c.sha1_git - WHERE NOT EXISTS ( - SELECT 1 FROM known_contents - WHERE known_contents.sha1_git=e.target + (WITH known_contents AS + (SELECT 'file'::directory_entry_type AS type, e.target, e.name, + c.sha1_git + FROM ls_f + LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id + INNER JOIN content c ON e.target=c.sha1_git) + SELECT * FROM known_contents + UNION + (SELECT 'file'::directory_entry_type AS type, e.target, e.name, + c.sha1_git + FROM ls_f + LEFT JOIN directory_entry_file e ON ls_f.entry_id=e.id + LEFT JOIN skipped_content c ON e.target=c.sha1_git + WHERE NOT EXISTS ( + SELECT 1 FROM known_contents + WHERE known_contents.sha1_git=e.target + ) ) ) + ORDER BY name + """, + (id,), ) - ORDER BY name - """, - (id,), - ) - return [ - {"type": row[0], "target": row[1], "name": row[2]} - for row in cursor.fetchall() - ] + return [ + {"type": row[0], "target": row[1], "name": row[2]} + for row in cursor.fetchall() + ] def iter_origins(self): raise NotImplementedError diff --git a/swh/provenance/postgresql/provenancedb_with_path.py b/swh/provenance/postgresql/provenancedb_with_path.py --- a/swh/provenance/postgresql/provenancedb_with_path.py +++ b/swh/provenance/postgresql/provenancedb_with_path.py @@ -1,6 +1,4 @@ from datetime import datetime -import itertools -import operator import os from typing import Generator, Optional, Tuple @@ -139,50 +137,48 @@ ) def insert_location(self, src0_table, src1_table, dst_table): + """Insert location entries in `dst_table` from the insert_cache + + Also insert missing location entries in the 'location' table. + """ + # TODO: find a better way of doing this; might be doable in a coupls of + # SQL queries (one to insert missing entries in the location' table, + # one to insert entries in the dst_table) + # Resolve src0 ids - src0_values = dict().fromkeys( - map(operator.itemgetter(0), self.insert_cache[dst_table]) - ) - values = ", ".join(itertools.repeat("%s", len(src0_values))) + src0_sha1s = tuple(set(sha1 for (sha1, _, _) in self.insert_cache[dst_table])) self.cursor.execute( - f"""SELECT sha1, id FROM {src0_table} WHERE sha1 IN ({values})""", - tuple(src0_values), + f"""SELECT sha1, id FROM {src0_table} WHERE sha1 IN (%s)""", src0_sha1s, ) - src0_values = dict(self.cursor.fetchall()) + src0_values = {bytes(sha1): id for (sha1, id) in self.cursor.fetchall()} # Resolve src1 ids - src1_values = dict().fromkeys( - map(operator.itemgetter(1), self.insert_cache[dst_table]) - ) - values = ", ".join(itertools.repeat("%s", len(src1_values))) + src1_sha1s = tuple(set(sha1 for (_, sha1, _) in self.insert_cache[dst_table])) self.cursor.execute( - f"""SELECT sha1, id FROM {src1_table} WHERE sha1 IN ({values})""", - tuple(src1_values), + f"""SELECT sha1, id FROM {src1_table} WHERE sha1 IN (%s)""", src1_sha1s, ) - src1_values = dict(self.cursor.fetchall()) + src1_values = {bytes(sha1): id for (sha1, id) in self.cursor.fetchall()} - # Resolve location ids - location = dict().fromkeys( - map(operator.itemgetter(2), self.insert_cache[dst_table]) + # insert missing locations + locations = tuple(set((loc,) for (_, _, loc) in self.insert_cache[dst_table])) + self.cursor.execute( + """ + INSERT INTO location(path) VALUES %s + ON CONFLICT (path) DO NOTHING + """, + locations, ) - location = dict( - psycopg2.extras.execute_values( - self.cursor, - """LOCK TABLE ONLY location; - INSERT INTO location(path) VALUES %s - ON CONFLICT (path) DO - UPDATE SET path=EXCLUDED.path - RETURNING path, id""", - map(lambda path: (path,), location.keys()), - fetch=True, - ) + # fetch location ids + self.cursor.execute( + "SELECT path, id FROM location WHERE path IN (%s)", locations, ) + loc_ids = {bytes(path): id for (path, id) in self.cursor.fetchall()} # Insert values in dst_table - rows = map( - lambda row: (src0_values[row[0]], src1_values[row[1]], location[row[2]]), - self.insert_cache[dst_table], - ) + rows = [ + (src0_values[sha1_src], src1_values[sha1_dst], loc_ids[loc]) + for (sha1_src, sha1_dst, loc) in self.insert_cache[dst_table] + ] psycopg2.extras.execute_values( self.cursor, f"""INSERT INTO {dst_table} VALUES %s diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone import os from typing import Dict, Generator, List, Optional, Tuple @@ -9,6 +9,8 @@ from .origin import OriginEntry from .revision import RevisionEntry +UTCMIN = datetime.min.replace(tzinfo=timezone.utc) + # TODO: consider moving to path utils file together with normalize. def is_child(path: bytes, prefix: bytes) -> bool: @@ -277,11 +279,12 @@ # Recursively analyse directory nodes. stack.append(child) else: - maxdates = [] - for child in current.children: - assert child.maxdate is not None - maxdates.append(child.maxdate) - current.maxdate = max(maxdates) if maxdates else datetime.min + maxdates = [ + child.maxdate + for child in current.children + if child.maxdate is not None # mostly to please mypy + ] + current.maxdate = max(maxdates) if maxdates else UTCMIN else: # Directory node in the frontier, just use its known date. current.maxdate = current.date @@ -314,9 +317,7 @@ ) provenance.directory_add_to_revision(revision, current.entry, path) directory_process_content( - provenance, - directory=current.entry, - relative=current.entry, + provenance, directory=current.entry, relative=current.entry, ) else: # No point moving the frontier here. Either there are no files or they diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from itertools import islice import threading from typing import Iterable, Iterator, Optional, Tuple @@ -20,6 +20,7 @@ self.archive = archive self.id = id self.date = date + assert self.date is None or self.date.tzinfo is not None self.parents = parents self.root = root @@ -78,11 +79,11 @@ def __next__(self): with self.mutex: id, date, root = next(self.revisions) + date = datetime.fromisoformat(date) + if date.tzinfo is None: + date = date.replace(tzinfo=timezone.utc) return RevisionEntry( - self.archive, - hash_to_bytes(id), - date=datetime.fromisoformat(date), - root=hash_to_bytes(root), + self.archive, hash_to_bytes(id), date=date, root=hash_to_bytes(root), ) diff --git a/swh/provenance/tests/test_provenance_db.py b/swh/provenance/tests/test_provenance_db.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_provenance_db.py @@ -0,0 +1,81 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime + +import pytest + +from swh.model.model import TimestampWithTimezone +from swh.model.tests.swh_model_data import TEST_OBJECTS +from swh.provenance.origin import OriginEntry +from swh.provenance.postgresql.archive import ArchivePostgreSQL +from swh.provenance.provenance import origin_add, revision_add +from swh.provenance.revision import RevisionEntry + + +def ts2dt(ts: TimestampWithTimezone) -> datetime.datetime: + timestamp = datetime.datetime.fromtimestamp( + ts.timestamp.seconds, datetime.timezone.utc + ) + return timestamp.replace(microsecond=ts.timestamp.microseconds) + + +def test_provenance_origin_add(provenance, swh_storage_with_objects): + """Test the ProvenanceDB.origin_add() method""" + for origin in TEST_OBJECTS["origin"]: + entry = OriginEntry(url=origin.url, revisions=[]) + origin_add(provenance, entry) + # TODO: check some facts here + + +@pytest.fixture +def archive_pg(swh_storage_with_objects): + # this is a workaround to prevent tests from hanging because of an unclosed + # transaction. + # TODO: refactor the ArchivePostgreSQL to properly deal with + # transactions and get rif of this fixture + archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn) + yield archive + archive.conn.rollback() + + +def test_provenance_revision_add(provenance, swh_storage_with_objects, archive_pg): + """Test the ProvenanceDB.revision_add() method""" + + for revision in TEST_OBJECTS["revision"]: + entry = RevisionEntry( + archive_pg, + id=revision.id, + date=ts2dt(revision.date), + root=revision.directory, + parents=revision.parents, + ) + revision_add(provenance, archive_pg, entry) + # there should be only one 'location' for the empty path + provenance.cursor.execute("SELECT count(*) FROM location WHERE path=''") + assert provenance.cursor.fetchone()[0] == 1 + + # there should be as many entries in 'revision' as revisions from the test dataset + provenance.cursor.execute("SELECT count(*) FROM revision") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be as many entries in 'directory' as revisions from the test dataset + # WARNING: this results from the inserted revisions not to have subdirectories, + # thus can fail is the tests dataset is improved + provenance.cursor.execute("SELECT count(*) FROM directory") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be as many entries in 'directory_in_rev' as revs in the test dataset + # WARNING: same as above + provenance.cursor.execute("SELECT count(*) FROM directory_in_rev") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be no content (directory sha1s are fake) + provenance.cursor.execute("SELECT count(*) FROM content") + assert provenance.cursor.fetchone()[0] == 0 + provenance.cursor.execute("SELECT count(*) FROM content_in_dir") + assert provenance.cursor.fetchone()[0] == 0 + provenance.cursor.execute("SELECT count(*) FROM content_early_in_rev") + assert provenance.cursor.fetchone()[0] == 0 diff --git a/swh/provenance/tests/test_revision_iterator.py b/swh/provenance/tests/test_revision_iterator.py --- a/swh/provenance/tests/test_revision_iterator.py +++ b/swh/provenance/tests/test_revision_iterator.py @@ -2,25 +2,16 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import datetime -from swh.model.model import TimestampWithTimezone from swh.model.tests.swh_model_data import TEST_OBJECTS from swh.provenance.revision import CSVRevisionIterator - - -def ts_to_dt(ts_with_tz: TimestampWithTimezone) -> datetime.datetime: - """converts a TimestampWithTimezone into a datetime""" - ts = ts_with_tz.timestamp - timestamp = datetime.datetime.fromtimestamp(ts.seconds, datetime.timezone.utc) - timestamp = timestamp.replace(microsecond=ts.microseconds) - return timestamp +from swh.provenance.tests.test_provenance_db import ts2dt def test_archive_direct_revision_iterator(swh_storage_with_objects, archive_direct): """Test FileOriginIterator""" revisions_csv = [ - (rev.id, ts_to_dt(rev.date).isoformat(), rev.directory) + (rev.id, ts2dt(rev.date).isoformat(), rev.directory) for rev in TEST_OBJECTS["revision"] ] revisions = list(CSVRevisionIterator(revisions_csv, archive_direct))