diff --git a/swh/provenance/postgresql/provenancedb_with_path.py b/swh/provenance/postgresql/provenancedb_with_path.py --- a/swh/provenance/postgresql/provenancedb_with_path.py +++ b/swh/provenance/postgresql/provenancedb_with_path.py @@ -1,6 +1,4 @@ from datetime import datetime -import itertools -import operator import os from typing import Generator, Optional, Tuple @@ -139,50 +137,48 @@ ) def insert_location(self, src0_table, src1_table, dst_table): + """Insert location entries in `dst_table` from the insert_cache + + Also insert missing location entries in the 'location' table. + """ + # TODO: find a better way of doing this; might be doable in a coupls of + # SQL queries (one to insert missing entries in the location' table, + # one to insert entries in the dst_table) + # Resolve src0 ids - src0_values = dict().fromkeys( - map(operator.itemgetter(0), self.insert_cache[dst_table]) - ) - values = ", ".join(itertools.repeat("%s", len(src0_values))) + src0_sha1s = tuple(set(sha1 for (sha1, _, _) in self.insert_cache[dst_table])) self.cursor.execute( - f"""SELECT sha1, id FROM {src0_table} WHERE sha1 IN ({values})""", - tuple(src0_values), + f"""SELECT sha1, id FROM {src0_table} WHERE sha1 IN (%s)""", src0_sha1s, ) - src0_values = dict(self.cursor.fetchall()) + src0_values = {bytes(sha1): id for (sha1, id) in self.cursor.fetchall()} # Resolve src1 ids - src1_values = dict().fromkeys( - map(operator.itemgetter(1), self.insert_cache[dst_table]) - ) - values = ", ".join(itertools.repeat("%s", len(src1_values))) + src1_sha1s = tuple(set(sha1 for (_, sha1, _) in self.insert_cache[dst_table])) self.cursor.execute( - f"""SELECT sha1, id FROM {src1_table} WHERE sha1 IN ({values})""", - tuple(src1_values), + f"""SELECT sha1, id FROM {src1_table} WHERE sha1 IN (%s)""", src1_sha1s, ) - src1_values = dict(self.cursor.fetchall()) + src1_values = {bytes(sha1): id for (sha1, id) in self.cursor.fetchall()} - # Resolve location ids - location = dict().fromkeys( - map(operator.itemgetter(2), self.insert_cache[dst_table]) + # insert missing locations + locations = tuple(set((loc,) for (_, _, loc) in self.insert_cache[dst_table])) + self.cursor.execute( + """ + INSERT INTO location(path) VALUES %s + ON CONFLICT (path) DO NOTHING + """, + locations, ) - location = dict( - psycopg2.extras.execute_values( - self.cursor, - """LOCK TABLE ONLY location; - INSERT INTO location(path) VALUES %s - ON CONFLICT (path) DO - UPDATE SET path=EXCLUDED.path - RETURNING path, id""", - map(lambda path: (path,), location.keys()), - fetch=True, - ) + # fetch location ids + self.cursor.execute( + "SELECT path, id FROM location WHERE path IN (%s)", locations, ) + loc_ids = {bytes(path): id for (path, id) in self.cursor.fetchall()} # Insert values in dst_table - rows = map( - lambda row: (src0_values[row[0]], src1_values[row[1]], location[row[2]]), - self.insert_cache[dst_table], - ) + rows = [ + (src0_values[sha1_src], src1_values[sha1_dst], loc_ids[loc]) + for (sha1_src, sha1_dst, loc) in self.insert_cache[dst_table] + ] psycopg2.extras.execute_values( self.cursor, f"""INSERT INTO {dst_table} VALUES %s diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone import os from typing import Dict, Generator, List, Optional, Tuple @@ -9,6 +9,8 @@ from .origin import OriginEntry from .revision import RevisionEntry +UTCMIN = datetime.min.replace(tzinfo=timezone.utc) + # TODO: consider moving to path utils file together with normalize. def is_child(path: bytes, prefix: bytes) -> bool: @@ -277,11 +279,12 @@ # Recursively analyse directory nodes. stack.append(child) else: - maxdates = [] - for child in current.children: - assert child.maxdate is not None - maxdates.append(child.maxdate) - current.maxdate = max(maxdates) if maxdates else datetime.min + maxdates = [ + child.maxdate + for child in current.children + if child.maxdate is not None # mostly to please mypy + ] + current.maxdate = max(maxdates) if maxdates else UTCMIN else: # Directory node in the frontier, just use its known date. current.maxdate = current.date @@ -314,9 +317,7 @@ ) provenance.directory_add_to_revision(revision, current.entry, path) directory_process_content( - provenance, - directory=current.entry, - relative=current.entry, + provenance, directory=current.entry, relative=current.entry, ) else: # No point moving the frontier here. Either there are no files or they diff --git a/swh/provenance/revision.py b/swh/provenance/revision.py --- a/swh/provenance/revision.py +++ b/swh/provenance/revision.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from itertools import islice import threading from typing import Iterable, Iterator, Optional, Tuple @@ -20,6 +20,7 @@ self.archive = archive self.id = id self.date = date + assert self.date is None or self.date.tzinfo is not None self.parents = parents self.root = root @@ -78,11 +79,11 @@ def __next__(self): with self.mutex: id, date, root = next(self.revisions) + date = datetime.fromisoformat(date) + if date.tzinfo is None: + date = date.replace(tzinfo=timezone.utc) return RevisionEntry( - self.archive, - hash_to_bytes(id), - date=datetime.fromisoformat(date), - root=hash_to_bytes(root), + self.archive, hash_to_bytes(id), date=date, root=hash_to_bytes(root), ) diff --git a/swh/provenance/tests/test_provenance_db.py b/swh/provenance/tests/test_provenance_db.py new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/test_provenance_db.py @@ -0,0 +1,71 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import datetime + +from swh.model.model import TimestampWithTimezone +from swh.model.tests.swh_model_data import TEST_OBJECTS +from swh.provenance.origin import OriginEntry +from swh.provenance.postgresql.archive import ArchivePostgreSQL +from swh.provenance.provenance import origin_add, revision_add +from swh.provenance.revision import RevisionEntry + + +def ts2dt(ts: TimestampWithTimezone) -> datetime.datetime: + timestamp = datetime.datetime.fromtimestamp( + ts.timestamp.seconds, datetime.timezone.utc + ) + return timestamp.replace(microsecond=ts.timestamp.microseconds) + + +def test_provenance_origin_add(provenance, swh_storage_with_objects): + """Test the ProvenanceDB.origin_add() method""" + for origin in TEST_OBJECTS["origin"]: + entry = OriginEntry(url=origin.url, revisions=[]) + origin_add(provenance, entry) + # TODO: check some facts here + + +def test_provenance_revision_add(provenance, swh_storage_with_objects): + """Test the ProvenanceDB.revision_add() method""" + + archive = ArchivePostgreSQL(conn=swh_storage_with_objects.get_db().conn) + + for revision in TEST_OBJECTS["revision"]: + entry = RevisionEntry( + archive, + id=revision.id, + date=ts2dt(revision.date), + root=revision.directory, + parents=revision.parents, + ) + revision_add(provenance, archive, entry) + + # there should be only one 'location' for the empty path + provenance.cursor.execute("SELECT count(*) FROM location WHERE path=''") + assert provenance.cursor.fetchone()[0] == 1 + + # there should be as many entries in 'revision' as revisions from the test dataset + provenance.cursor.execute("SELECT count(*) FROM revision") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be as many entries in 'directory' as revisions from the test dataset + # WARNING: this results from the inserted revisions not to have subdirectories, + # thus can fail is the tests dataset is improved + provenance.cursor.execute("SELECT count(*) FROM directory") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be as many entries in 'directory_in_rev' as revs in the test dataset + # WARNING: same as above + provenance.cursor.execute("SELECT count(*) FROM directory_in_rev") + assert provenance.cursor.fetchone()[0] == len(TEST_OBJECTS["revision"]) + + # there should be no content (directory sha1s are fake) + provenance.cursor.execute("SELECT count(*) FROM content") + assert provenance.cursor.fetchone()[0] == 0 + provenance.cursor.execute("SELECT count(*) FROM content_in_dir") + assert provenance.cursor.fetchone()[0] == 0 + provenance.cursor.execute("SELECT count(*) FROM content_early_in_rev") + assert provenance.cursor.fetchone()[0] == 0