diff --git a/swh/provenance/postgresql/archive.py b/swh/provenance/postgresql/archive.py --- a/swh/provenance/postgresql/archive.py +++ b/swh/provenance/postgresql/archive.py @@ -63,8 +63,7 @@ (id,), ) return [ - {"type": row[0], "target": row[1], "name": row[2]} - for row in cursor.fetchall() + {"type": row[0], "target": row[1], "name": row[2]} for row in cursor ] def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]: @@ -79,7 +78,7 @@ (id,), ) # There should be at most one row anyway - yield from (row[0] for row in cursor.fetchall()) + yield from (row[0] for row in cursor) def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: with self.conn.cursor() as cursor: @@ -114,4 +113,4 @@ """, (id,), ) - yield from (row[0] for row in cursor.fetchall()) + yield from (row[0] for row in cursor) diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager from datetime import datetime import itertools import logging @@ -31,22 +32,27 @@ self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False ) -> None: BaseDb.adapt_conn(conn) - conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) - conn.set_session(autocommit=True) self.conn = conn - self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) - # XXX: not sure this is the best place to do it! - sql = "SET timezone TO 'UTC'" - self.cursor.execute(sql) + with self.transaction() as cursor: + cursor.execute("SET timezone TO 'UTC'") self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit + @contextmanager + def transaction( + self, readonly: bool = False + ) -> Generator[psycopg2.extensions.cursor, None, None]: + self.conn.set_session(readonly=readonly) + with self.conn: + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + yield cur + @property def flavor(self) -> str: if self._flavor is None: - sql = "SELECT swh_get_dbflavor() AS flavor" - self.cursor.execute(sql) - self._flavor = self.cursor.fetchone()["flavor"] + with self.transaction(readonly=True) as cursor: + cursor.execute("SELECT swh_get_dbflavor() AS flavor") + self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @@ -56,16 +62,18 @@ def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" - self.cursor.execute(sql, (id,)) - row = self.cursor.fetchone() + with self.transaction(readonly=True) as cursor: + cursor.execute(query=sql, vars=(id,)) + row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" - self.cursor.execute(sql, (id, limit)) - yield from (ProvenanceResult(**row) for row in self.cursor.fetchall()) + with self.transaction(readonly=True) as cursor: + cursor.execute(query=sql, vars=(id, limit)) + yield from (ProvenanceResult(**row) for row in cursor) def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("content", dates) @@ -80,24 +88,26 @@ return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: - sql = f"SELECT sha1 FROM {entity.value}" - self.cursor.execute(sql) - return {row["sha1"] for row in self.cursor.fetchall()} + with self.transaction(readonly=True) as cursor: + cursor.execute(f"SELECT sha1 FROM {entity.value}") + return {row["sha1"] for row in cursor} def location_get(self) -> Set[bytes]: - sql = "SELECT location.path AS path FROM location" - self.cursor.execute(sql) - return {row["path"] for row in self.cursor.fetchall()} + with self.transaction(readonly=True) as cursor: + cursor.execute("SELECT location.path AS path FROM location") + return {row["path"] for row in cursor} def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: sql = """ - LOCK TABLE ONLY origin; INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ - psycopg2.extras.execute_values(self.cursor, sql, urls.items()) + with self.transaction() as cursor: + psycopg2.extras.execute_values( + cur=cursor, sql=sql, argslist=urls.items() + ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message @@ -117,8 +127,9 @@ FROM origin WHERE sha1 IN ({values}) """ - self.cursor.execute(sql, sha1s) - urls.update((row["sha1"], row["url"]) for row in self.cursor.fetchall()) + with self.transaction(readonly=True) as cursor: + cursor.execute(query=sql, vars=sha1s) + urls.update((row["sha1"], row["url"]) for row in cursor) return urls def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: @@ -128,7 +139,6 @@ try: if origins: sql = """ - LOCK TABLE ONLY revision; INSERT INTO revision(sha1, origin) (SELECT V.rev AS sha1, O.id AS origin FROM (VALUES %s) AS V(rev, org) @@ -136,7 +146,10 @@ ON CONFLICT (sha1) DO UPDATE SET origin=EXCLUDED.origin """ - psycopg2.extras.execute_values(self.cursor, sql, origins.items()) + with self.transaction() as cursor: + psycopg2.extras.execute_values( + cur=cursor, sql=sql, argslist=origins.items() + ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message @@ -157,11 +170,12 @@ LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) """ - self.cursor.execute(sql, sha1s) - result.update( - (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) - for row in self.cursor.fetchall() - ) + with self.transaction(readonly=True) as cursor: + cursor.execute(query=sql, vars=sha1s) + result.update( + (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) + for row in cursor + ) return result def relation_add( @@ -178,38 +192,38 @@ # non-null information srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) sql = f""" - LOCK TABLE ONLY {src_table}; INSERT INTO {src_table}(sha1) VALUES %s ON CONFLICT DO NOTHING """ - psycopg2.extras.execute_values(self.cursor, sql, srcs) + with self.transaction() as cursor: + psycopg2.extras.execute_values( + cur=cursor, sql=sql, argslist=srcs + ) if dst_table != "origin": # Origin entries should be inserted previously as they require extra # non-null information dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) sql = f""" - LOCK TABLE ONLY {dst_table}; INSERT INTO {dst_table}(sha1) VALUES %s ON CONFLICT DO NOTHING """ - psycopg2.extras.execute_values(self.cursor, sql, dsts) + with self.transaction() as cursor: + psycopg2.extras.execute_values( + cur=cursor, sql=sql, argslist=dsts + ) # Put the next three queries in a manual single transaction: # they use the same temp table - with self.conn: - with self.conn.cursor() as cur: - cur.execute("SELECT swh_mktemp_relation_add()") - psycopg2.extras.execute_values( - cur, - sql=( - "INSERT INTO tmp_relation_add (src, dst, path) " - "VALUES %s" - ), - argslist=rows, - ) - sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" - cur.execute(sql, (rel_table, src_table, dst_table)) + with self.transaction() as cursor: + cursor.execute("SELECT swh_mktemp_relation_add()") + psycopg2.extras.execute_values( + cur=cursor, + sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", + argslist=rows, + ) + sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" + cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message @@ -242,8 +256,9 @@ WHERE sha1 IN ({values}) AND date IS NOT NULL """ - self.cursor.execute(sql, sha1s) - dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall()) + with self.transaction(readonly=True) as cursor: + cursor.execute(query=sql, vars=sha1s) + dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( @@ -254,12 +269,12 @@ try: if data: sql = f""" - LOCK TABLE ONLY {entity}; INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ - psycopg2.extras.execute_values(self.cursor, sql, data.items()) + with self.transaction() as cursor: + psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message @@ -289,8 +304,11 @@ src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" - self.cursor.execute(sql, (rel_table, src_table, dst_table, filter, sha1s)) - result.update(RelationData(**row) for row in self.cursor.fetchall()) + with self.transaction(readonly=True) as cursor: + cursor.execute( + query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) + ) + result.update(RelationData(**row) for row in cursor) return result def with_path(self) -> bool: diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql --- a/swh/provenance/sql/40-funcs.sql +++ b/swh/provenance/sql/40-funcs.sql @@ -99,7 +99,6 @@ join_location text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - lock table only location; insert into location(path) select V.path from tmp_relation_add as V @@ -113,15 +112,14 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || join_location || ' on conflict do nothing', - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -254,14 +252,13 @@ as $$ begin execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, D.id from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) on conflict do nothing', - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -422,7 +419,6 @@ on_conflict text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - lock table only location; insert into location(path) select V.path from tmp_relation_add as V @@ -448,8 +444,7 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) @@ -457,7 +452,7 @@ ' || join_location || ' ' || group_entries || ' on conflict ' || on_conflict, - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$; @@ -641,15 +636,14 @@ end if; execute format( - 'lock table only %s; - insert into %s + 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || group_entries || ' on conflict ' || on_conflict, - rel_table, rel_table, src_table, dst_table + rel_table, src_table, dst_table ); end; $$;