diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -24,6 +24,8 @@ """ + current_version = 159 + def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( "SELECT swh_mktemp_dir_entry(%s)", (("directory_entry_%s" % entry_type),) @@ -1328,3 +1330,13 @@ row = cur.fetchone() if row: return row[0] + + dbversion_cols = ["version", "release", "description"] + + def dbversion(self): + with self.transaction() as cur: + cur.execute(f"SELECT {', '.join(self.dbversion_cols)} " f"FROM dbversion") + return dict(zip(self.dbversion_cols, cur.fetchone())) + + def check_dbversion(self): + return self.dbversion()["version"] == self.current_version diff --git a/swh/storage/pytest_plugin.py b/swh/storage/pytest_plugin.py --- a/swh/storage/pytest_plugin.py +++ b/swh/storage/pytest_plugin.py @@ -139,7 +139,7 @@ "WHERE table_schema = %s", ("public",), ) - tables = set(table for (table,) in cur.fetchall()) + tables = set(table for (table,) in cur.fetchall()) - {"dbversion"} for table in tables: cur.execute("truncate table %s cascade" % table) diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -27,6 +27,7 @@ import psycopg2.errors from swh.core.api.serializers import msgpack_loads, msgpack_dumps +from swh.core.db.common import db_transaction_generator, db_transaction from swh.model.identifiers import SWHID from swh.model.model import ( Content, @@ -60,7 +61,6 @@ from swh.storage.utils import now from . import converters -from .common import db_transaction_generator, db_transaction from .db import Db from .exc import StorageArgumentException, StorageDBError, HashCollision from .algos import diff diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -41,6 +41,7 @@ ) from swh.model.hypothesis_strategies import objects from swh.storage import get_storage +from swh.storage.db import Db from swh.storage.converters import origin_url_to_sha1 as sha1 from swh.storage.exc import HashCollision, StorageArgumentException from swh.storage.interface import ListOrder, PagedResult, StorageInterface @@ -4121,3 +4122,12 @@ """ assert swh_storage.flush() == {} + + def test_dbversion(self, swh_storage): + with swh_storage.db() as db: + assert db.check_dbversion() + + def test_dbversion_mismatch(self, swh_storage, monkeypatch): + monkeypatch.setattr(Db, "current_version", -1) + with swh_storage.db() as db: + assert db.check_dbversion() is False