diff --git a/swh/storage/postgresql/db.py b/swh/storage/postgresql/db.py --- a/swh/storage/postgresql/db.py +++ b/swh/storage/postgresql/db.py @@ -24,6 +24,8 @@ """ + current_version = 161 + def mktemp_dir_entry(self, entry_type, cur=None): self._cursor(cur).execute( "SELECT swh_mktemp_dir_entry(%s)", (("directory_entry_%s" % entry_type),) @@ -1328,3 +1330,13 @@ row = cur.fetchone() if row: return row[0] + + dbversion_cols = ["version", "release", "description"] + + def dbversion(self): + with self.transaction() as cur: + cur.execute(f"SELECT {', '.join(self.dbversion_cols)} " f"FROM dbversion") + return dict(zip(self.dbversion_cols, cur.fetchone())) + + def check_dbversion(self): + return self.dbversion()["version"] == self.current_version diff --git a/swh/storage/pytest_plugin.py b/swh/storage/pytest_plugin.py --- a/swh/storage/pytest_plugin.py +++ b/swh/storage/pytest_plugin.py @@ -139,7 +139,7 @@ "WHERE table_schema = %s", ("public",), ) - tables = set(table for (table,) in cur.fetchall()) + tables = set(table for (table,) in cur.fetchall()) - {"dbversion"} for table in tables: cur.execute("truncate table %s cascade" % table) diff --git a/swh/storage/tests/test_cli.py b/swh/storage/tests/test_cli.py --- a/swh/storage/tests/test_cli.py +++ b/swh/storage/tests/test_cli.py @@ -47,11 +47,10 @@ monkeypatch.setattr(obj_in_objstorage.retry, "sleep", lambda x: None) -def invoke(*args, env=None, journal_config=None): +def invoke(*args, env=None, **extra_config): config = copy.deepcopy(CLI_CONFIG) - if journal_config: - config["journal_client"] = journal_config.copy() - config["journal_client"]["cls"] = "kafka" + if extra_config: + config.update(extra_config) runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: @@ -97,10 +96,11 @@ "replay", "--stop-after-objects", "1", - journal_config={ + journal_client={ "brokers": [kafka_server], "group_id": kafka_consumer_group, "prefix": kafka_prefix, + "cls": "kafka", }, ) diff --git a/swh/storage/tests/test_postgresql.py b/swh/storage/tests/test_postgresql.py --- a/swh/storage/tests/test_postgresql.py +++ b/swh/storage/tests/test_postgresql.py @@ -11,6 +11,7 @@ import attr import pytest +from swh.storage.postgresql.db import Db from swh.storage.tests.storage_tests import TestStorage # noqa from swh.storage.tests.storage_tests import TestStorageGeneratedData # noqa from swh.storage.utils import now @@ -254,3 +255,12 @@ """ assert swh_storage.flush() == {} + + def test_dbversion(self, swh_storage): + with swh_storage.db() as db: + assert db.check_dbversion() + + def test_dbversion_mismatch(self, swh_storage, monkeypatch): + monkeypatch.setattr(Db, "current_version", -1) + with swh_storage.db() as db: + assert db.check_dbversion() is False