Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/postgresql/provenancedb_base.py
from datetime import datetime | from datetime import datetime | ||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from typing import Any, Dict, Generator, List, Optional, Set, Tuple | from typing import Any, Dict, Generator, List, Optional, Set, Tuple | ||||
import psycopg2 | import psycopg2 | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from swh.model.model import Sha1Git | |||||
class ProvenanceDBBase: | class ProvenanceDBBase: | ||||
def __init__(self, conn: psycopg2.extensions.connection): | def __init__(self, conn: psycopg2.extensions.connection): | ||||
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) | ||||
conn.set_session(autocommit=True) | conn.set_session(autocommit=True) | ||||
self.conn = conn | self.conn = conn | ||||
self.cursor = self.conn.cursor() | self.cursor = self.conn.cursor() | ||||
# XXX: not sure this is the best place to do it! | # XXX: not sure this is the best place to do it! | ||||
Show All 29 Lines | def commit(self, data: Dict[str, Any], raise_on_commit: bool = False) -> bool: | ||||
# Relations should come after ids for entities were resolved | # Relations should come after ids for entities were resolved | ||||
for relation in ( | for relation in ( | ||||
"content_in_revision", | "content_in_revision", | ||||
"content_in_directory", | "content_in_directory", | ||||
"directory_in_revision", | "directory_in_revision", | ||||
): | ): | ||||
self.insert_relation(relation, data[relation]) | self.insert_relation(relation, data[relation]) | ||||
# Insert origins | |||||
self.insert_origin( | |||||
{ | |||||
sha1: data["origin"]["data"][sha1] | |||||
for sha1 in data["origin"]["added"] | |||||
}, | |||||
) | |||||
data["origin"]["data"].clear() | |||||
data["origin"]["added"].clear() | |||||
# Insert relations from the origin-revision layer | # Insert relations from the origin-revision layer | ||||
self.insert_origin_head(data["revision_in_origin"]) | self.insert_origin_head(data["revision_in_origin"]) | ||||
self.insert_revision_history(data["revision_before_revision"]) | self.insert_revision_history(data["revision_before_revision"]) | ||||
# Update preferred origins | # Update preferred origins | ||||
self.update_preferred_origin( | self.update_preferred_origin( | ||||
{ | { | ||||
sha1: data["revision_preferred_origin"]["data"][sha1] | sha1: data["revision_origin"]["data"][sha1] | ||||
for sha1 in data["revision_preferred_origin"]["added"] | for sha1 in data["revision_origin"]["added"] | ||||
} | } | ||||
) | ) | ||||
data["revision_preferred_origin"]["data"].clear() | data["revision_origin"]["data"].clear() | ||||
data["revision_preferred_origin"]["added"].clear() | data["revision_origin"]["added"].clear() | ||||
return True | return True | ||||
except: # noqa: E722 | except: # noqa: E722 | ||||
# Unexpected error occurred, rollback all changes and log message | # Unexpected error occurred, rollback all changes and log message | ||||
logging.exception("Unexpected error") | logging.exception("Unexpected error") | ||||
if raise_on_commit: | if raise_on_commit: | ||||
raise | raise | ||||
return False | return False | ||||
def content_find_first( | def content_find_first( | ||||
self, blob: bytes | self, id: Sha1Git | ||||
) -> Optional[Tuple[bytes, bytes, datetime, bytes]]: | ) -> Optional[Tuple[Sha1Git, Sha1Git, datetime, bytes]]: | ||||
... | ... | ||||
def content_find_all( | def content_find_all( | ||||
self, blob: bytes, limit: Optional[int] = None | self, id: Sha1Git, limit: Optional[int] = None | ||||
) -> Generator[Tuple[bytes, bytes, datetime, bytes], None, None]: | ) -> Generator[Tuple[Sha1Git, Sha1Git, datetime, bytes], None, None]: | ||||
... | ... | ||||
def get_dates(self, entity: str, ids: List[bytes]) -> Dict[bytes, datetime]: | def get_dates(self, entity: str, ids: List[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
dates = {} | dates = {} | ||||
if ids: | if ids: | ||||
values = ", ".join(itertools.repeat("%s", len(ids))) | values = ", ".join(itertools.repeat("%s", len(ids))) | ||||
self.cursor.execute( | self.cursor.execute( | ||||
f"""SELECT sha1, date FROM {entity} WHERE sha1 IN ({values})""", | f"""SELECT sha1, date FROM {entity} WHERE sha1 IN ({values})""", | ||||
tuple(ids), | tuple(ids), | ||||
) | ) | ||||
dates.update(self.cursor.fetchall()) | dates.update(self.cursor.fetchall()) | ||||
return dates | return dates | ||||
def insert_entity(self, entity: str, data: Dict[bytes, datetime]): | def insert_entity(self, entity: str, data: Dict[Sha1Git, datetime]): | ||||
if data: | if data: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
self.cursor, | self.cursor, | ||||
f""" | f""" | ||||
LOCK TABLE ONLY {entity}; | LOCK TABLE ONLY {entity}; | ||||
INSERT INTO {entity}(sha1, date) VALUES %s | INSERT INTO {entity}(sha1, date) VALUES %s | ||||
ON CONFLICT (sha1) DO | ON CONFLICT (sha1) DO | ||||
UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) | UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) | ||||
""", | """, | ||||
data.items(), | data.items(), | ||||
) | ) | ||||
# XXX: not sure if Python takes a reference or a copy. | # XXX: not sure if Python takes a reference or a copy. | ||||
# This might be useless! | # This might be useless! | ||||
data.clear() | data.clear() | ||||
def insert_origin_head(self, data: Set[Tuple[bytes, str]]): | def insert_origin(self, data: Dict[Sha1Git, str]): | ||||
if data: | |||||
psycopg2.extras.execute_values( | |||||
self.cursor, | |||||
""" | |||||
LOCK TABLE ONLY origin; | |||||
INSERT INTO origin(sha1, url) VALUES %s | |||||
ON CONFLICT DO NOTHING | |||||
""", | |||||
data.items(), | |||||
) | |||||
# XXX: not sure if Python takes a reference or a copy. | |||||
# This might be useless! | |||||
vlorentz: it's a reference. | |||||
Done Inline Actionsthanks for clarifying! I've removed this in a newer version I'm about to submit as diff anyway aeviso: thanks for clarifying! I've removed this in a newer version I'm about to submit as diff anyway | |||||
data.clear() | |||||
def insert_origin_head(self, data: Set[Tuple[Sha1Git, Sha1Git]]): | |||||
if data: | if data: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
self.cursor, | self.cursor, | ||||
# XXX: not clear how conflicts are handled here! | # XXX: not clear how conflicts are handled here! | ||||
""" | """ | ||||
LOCK TABLE ONLY revision_in_origin; | LOCK TABLE ONLY revision_in_origin; | ||||
INSERT INTO revision_in_origin | INSERT INTO revision_in_origin | ||||
SELECT R.id, O.id | SELECT R.id, O.id | ||||
FROM (VALUES %s) AS V(rev, org) | FROM (VALUES %s) AS V(rev, org) | ||||
INNER JOIN revision AS R on (R.sha1=V.rev) | INNER JOIN revision AS R on (R.sha1=V.rev) | ||||
INNER JOIN origin AS O on (O.url=V.org::unix_path) | INNER JOIN origin AS O on (O.sha1=V.org) | ||||
""", | """, | ||||
data, | data, | ||||
) | ) | ||||
data.clear() | data.clear() | ||||
def insert_relation(self, relation: str, data: Set[Tuple[bytes, bytes, bytes]]): | def insert_relation(self, relation: str, data: Set[Tuple[Sha1Git, Sha1Git, bytes]]): | ||||
... | ... | ||||
def insert_revision_history(self, data: Dict[bytes, bytes]): | def insert_revision_history(self, data: Dict[Sha1Git, Sha1Git]): | ||||
if data: | if data: | ||||
values = [[(prev, next) for next in data[prev]] for prev in data] | values = [[(prev, next) for next in data[prev]] for prev in data] | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
self.cursor, | self.cursor, | ||||
# XXX: not clear how conflicts are handled here! | # XXX: not clear how conflicts are handled here! | ||||
""" | """ | ||||
LOCK TABLE ONLY revision_before_revision; | LOCK TABLE ONLY revision_before_revision; | ||||
INSERT INTO revision_before_revision | INSERT INTO revision_before_revision | ||||
SELECT P.id, N.id | SELECT P.id, N.id | ||||
FROM (VALUES %s) AS V(prev, next) | FROM (VALUES %s) AS V(prev, next) | ||||
INNER JOIN revision AS P on (P.sha1=V.prev) | INNER JOIN revision AS P on (P.sha1=V.prev) | ||||
INNER JOIN revision AS N on (N.sha1=V.next) | INNER JOIN revision AS N on (N.sha1=V.next) | ||||
""", | """, | ||||
tuple(sum(values, [])), | tuple(sum(values, [])), | ||||
) | ) | ||||
data.clear() | data.clear() | ||||
def revision_get_preferred_origin(self, revision: bytes) -> Optional[str]: | def revision_get_preferred_origin(self, revision: Sha1Git) -> Optional[Sha1Git]: | ||||
self.cursor.execute( | self.cursor.execute( | ||||
""" | """ | ||||
SELECT O.url | SELECT O.sha1 | ||||
FROM revision AS R | FROM revision AS R | ||||
JOIN origin as O | JOIN origin as O | ||||
ON R.origin=O.id | ON R.origin=O.id | ||||
WHERE R.sha1=%s""", | WHERE R.sha1=%s""", | ||||
(revision,), | (revision,), | ||||
) | ) | ||||
row = self.cursor.fetchone() | row = self.cursor.fetchone() | ||||
return str(row[0], encoding="utf-8") if row is not None else None | return row[0] if row is not None else None | ||||
def revision_in_history(self, revision: bytes) -> bool: | def revision_in_history(self, revision: Sha1Git) -> bool: | ||||
self.cursor.execute( | self.cursor.execute( | ||||
""" | """ | ||||
SELECT 1 | SELECT 1 | ||||
FROM revision_before_revision | FROM revision_before_revision | ||||
JOIN revision | JOIN revision | ||||
ON revision.id=revision_before_revision.prev | ON revision.id=revision_before_revision.prev | ||||
WHERE revision.sha1=%s | WHERE revision.sha1=%s | ||||
""", | """, | ||||
(revision,), | (revision,), | ||||
) | ) | ||||
return self.cursor.fetchone() is not None | return self.cursor.fetchone() is not None | ||||
def revision_visited(self, revision: bytes) -> bool: | def revision_visited(self, revision: Sha1Git) -> bool: | ||||
self.cursor.execute( | self.cursor.execute( | ||||
""" | """ | ||||
SELECT 1 | SELECT 1 | ||||
FROM revision_in_origin | FROM revision_in_origin | ||||
JOIN revision | JOIN revision | ||||
ON revision.id=revision_in_origin.revision | ON revision.id=revision_in_origin.revision | ||||
WHERE revision.sha1=%s | WHERE revision.sha1=%s | ||||
""", | """, | ||||
(revision,), | (revision,), | ||||
) | ) | ||||
return self.cursor.fetchone() is not None | return self.cursor.fetchone() is not None | ||||
def update_preferred_origin(self, data: Dict[bytes, str]): | def update_preferred_origin(self, data: Dict[Sha1Git, Sha1Git]): | ||||
if data: | if data: | ||||
# XXX: this is assuming the revision already exists in the db! It should | # XXX: this is assuming the revision already exists in the db! It should | ||||
# be improved by allowing null dates in the revision table. | # be improved by allowing null dates in the revision table. | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
self.cursor, | self.cursor, | ||||
""" | """ | ||||
UPDATE revision | UPDATE revision R | ||||
SET origin=O.id | SET origin=O.id | ||||
FROM (VALUES %s) AS V(rev, org) | FROM (VALUES %s) AS V(rev, org) | ||||
INNER JOIN origin AS O on (O.url=V.org::unix_path) | INNER JOIN origin AS O on (O.sha1=V.org) | ||||
WHERE sha1=V.rev | WHERE R.sha1=V.rev | ||||
""", | """, | ||||
data.items(), | data.items(), | ||||
) | ) | ||||
data.clear() | data.clear() |
it's a reference.