Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/storage/postgresql.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from __future__ import annotations | from __future__ import annotations | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from datetime import datetime | from datetime import datetime | ||||
from functools import wraps | from functools import wraps | ||||
from hashlib import sha1 | |||||
import itertools | import itertools | ||||
import logging | import logging | ||||
from types import TracebackType | from types import TracebackType | ||||
from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union | from typing import Dict, Generator, Iterable, List, Optional, Set, Type | ||||
import psycopg2.extensions | import psycopg2.extensions | ||||
import psycopg2.extras | import psycopg2.extras | ||||
from swh.core.db import BaseDb | from swh.core.db import BaseDb | ||||
from swh.core.statsd import statsd | from swh.core.statsd import statsd | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from swh.provenance.storage.interface import ( | from swh.provenance.storage.interface import ( | ||||
▲ Show 20 Lines • Show All 191 Lines • ▼ Show 20 Lines | class ProvenanceStoragePostgreSql: | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) | ||||
def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(f"SELECT sha1 FROM {entity.value}") | cursor.execute(f"SELECT sha1 FROM {entity.value}") | ||||
return {row["sha1"] for row in cursor} | return {row["sha1"] for row in cursor} | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) | ||||
@handle_raise_on_commit | @handle_raise_on_commit | ||||
def location_add(self, paths: Iterable[bytes]) -> bool: | def location_add(self, paths: Dict[Sha1Git, bytes]) -> bool: | ||||
if self.with_path(): | if self.with_path(): | ||||
values = [(path,) for path in paths] | values = [(path,) for path in paths.values()] | ||||
if values: | if values: | ||||
sql = """ | sql = """ | ||||
INSERT INTO location(path) VALUES %s | INSERT INTO location(path) VALUES %s | ||||
ON CONFLICT DO NOTHING | ON CONFLICT DO NOTHING | ||||
""" | """ | ||||
page_size = self.page_size or len(values) | page_size = self.page_size or len(values) | ||||
with self.transaction() as cursor: | with self.transaction() as cursor: | ||||
psycopg2.extras.execute_values( | psycopg2.extras.execute_values( | ||||
cursor, sql, argslist=values, page_size=page_size | cursor, sql, argslist=values, page_size=page_size | ||||
) | ) | ||||
return True | return True | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) | ||||
def location_get_all(self) -> Set[bytes]: | def location_get_all(self) -> Dict[Sha1Git, bytes]: | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute("SELECT location.path AS path FROM location") | cursor.execute("SELECT location.path AS path FROM location") | ||||
return {row["path"] for row in cursor} | return {sha1(row["path"]).digest(): row["path"] for row in cursor} | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) | ||||
@handle_raise_on_commit | @handle_raise_on_commit | ||||
def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: | def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: | ||||
if orgs: | if orgs: | ||||
sql = """ | sql = """ | ||||
INSERT INTO origin(sha1, url) VALUES %s | INSERT INTO origin(sha1, url) VALUES %s | ||||
ON CONFLICT DO NOTHING | ON CONFLICT DO NOTHING | ||||
Show All 29 Lines | def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: | ||||
""" | """ | ||||
with self.transaction(readonly=True) as cursor: | with self.transaction(readonly=True) as cursor: | ||||
cursor.execute(query=sql, vars=sha1s) | cursor.execute(query=sql, vars=sha1s) | ||||
urls.update((row["sha1"], row["url"]) for row in cursor) | urls.update((row["sha1"], row["url"]) for row in cursor) | ||||
return urls | return urls | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) | ||||
@handle_raise_on_commit | @handle_raise_on_commit | ||||
def revision_add( | def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: | ||||
self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] | if revs: | ||||
) -> bool: | |||||
if isinstance(revs, dict): | |||||
data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] | data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] | ||||
else: | |||||
data = [(sha1, None, None) for sha1 in revs] | |||||
if data: | |||||
sql = """ | sql = """ | ||||
INSERT INTO revision(sha1, date, origin) | INSERT INTO revision(sha1, date, origin) | ||||
(SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin | (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin | ||||
FROM (VALUES %s) AS V(rev, date, org) | FROM (VALUES %s) AS V(rev, date, org) | ||||
LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) | LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) | ||||
ON CONFLICT (sha1) DO | ON CONFLICT (sha1) DO | ||||
UPDATE SET | UPDATE SET | ||||
date=LEAST(EXCLUDED.date, revision.date), | date=LEAST(EXCLUDED.date, revision.date), | ||||
▲ Show 20 Lines • Show All 100 Lines • Show Last 20 Lines |