diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -8,6 +8,10 @@ Storage = storage.Storage +class HashCollision(Exception): + pass + + def get_storage(cls, args): """ Get a storage object of class `storage_class` with arguments diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -938,7 +938,7 @@ return cur.fetchone()[0] - origin_metadata_get_cols = ['id', 'origin_id', 'discovery_date', + origin_metadata_get_cols = ['origin_id', 'discovery_date', 'tool_id', 'metadata', 'provider_id', 'provider_name', 'provider_type', 'provider_url'] diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -157,7 +157,11 @@ db.content_get_metadata_keys, cur) # move metadata in place - db.content_add_from_temp(cur) + try: + db.content_add_from_temp(cur) + except psycopg2.IntegrityError: + from . import HashCollision + raise HashCollision() if missing_skipped: missing_filtered = ( @@ -1200,6 +1204,15 @@ return {k: v for (k, v) in db.stat_counters()} @db_transaction() + def refresh_stat_counters(self, db=None, cur=None): + """Recomputes the statistics for `stat_counters`.""" + keys = ['content', 'directory', 'directory_entry_dir', + 'origin', 'person', 'revision'] + + for key in keys: + cur.execute('select * from swh_update_counter(%s)', (key,)) + + @db_transaction() def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, db=None, cur=None): """ Add an origin_metadata for the origin at ts with provenance and diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import pytest import shutil import tempfile import unittest @@ -52,3 +53,7 @@ def tearDown(self): super().tearDown() shutil.rmtree(self.storage_base) + + @pytest.mark.skip('refresh_stat_counters not available in the remote api.') + def test_stat_counters(self): + pass diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -7,15 +7,14 @@ import datetime import unittest from collections import defaultdict -from operator import itemgetter from unittest.mock import Mock, patch -import psycopg2 import pytest from swh.model import from_disk, identifiers from swh.model.hashutil import hash_to_bytes from swh.storage.tests.storage_testing import StorageTestFixture +from swh.storage import HashCollision @pytest.mark.db @@ -526,7 +525,6 @@ class twice. """ - @staticmethod def normalize_entity(entity): entity = copy.deepcopy(entity) @@ -565,7 +563,7 @@ sha256_array[0] += 1 cont1b['sha256'] = bytes(sha256_array) - with self.assertRaises(psycopg2.IntegrityError): + with self.assertRaises(HashCollision): self.storage.content_add([cont1, cont1b]) def test_skipped_content_add(self): @@ -677,7 +675,7 @@ stored_data = list(self.storage.directory_ls(self.dir['id'])) data_to_store = [] - for ent in sorted(self.dir['entries'], key=itemgetter('name')): + for ent in self.dir['entries']: data_to_store.append({ 'dir_id': self.dir['id'], 'type': ent['type'], @@ -691,7 +689,7 @@ 'length': None, }) - self.assertEqual(data_to_store, stored_data) + self.assertCountEqual(data_to_store, stored_data) after_missing = list(self.storage.directory_missing([self.dir['id']])) self.assertEqual([], after_missing) @@ -880,7 +878,8 @@ # then for actual_release in actual_releases: - del actual_release['author']['id'] # hack: ids are generated + if 'id' in actual_release['author']: + del actual_release['author']['id'] # hack: ids are generated self.assertEqual([self.normalize_entity(self.release), self.normalize_entity(self.release2)], @@ -1011,7 +1010,6 @@ # then self.assertEqual(origin_visit1['origin'], origin_id) self.assertIsNotNone(origin_visit1['visit']) - self.assertTrue(origin_visit1['visit'] > 0) actual_origin_visits = list(self.storage.origin_visit_get(origin_id)) self.assertEqual(actual_origin_visits, @@ -1399,9 +1397,7 @@ expected_keys = ['content', 'directory', 'directory_entry_dir', 'origin', 'person', 'revision'] - for key in expected_keys: - self.cursor.execute('select * from swh_update_counter(%s)', (key,)) - self.conn.commit() + self.storage.refresh_stat_counters() counters = self.storage.stat_counters() @@ -1676,10 +1672,9 @@ 'provider_name': self.provider['name'], 'provider_url': self.provider['url'] }) - tool = self.storage.tool_get(self.metadata_tool) # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider['id'], @@ -1687,7 +1682,6 @@ self.origin_metadata['metadata']) actual_om1 = list(self.storage.origin_metadata_get_by(origin_id)) # then - self.assertEqual(actual_om1[0]['id'], o_m1) self.assertEqual(len(actual_om1), 1) self.assertEqual(actual_om1[0]['origin_id'], origin_id) @@ -1704,21 +1698,21 @@ 'provider_name': self.provider['name'], 'provider_url': self.provider['url'] }) - tool = self.storage.tool_get(self.metadata_tool) + tool = list(self.storage.tool_add([self.metadata_tool]))[0] # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider['id'], tool['id'], self.origin_metadata['metadata']) - o_m2 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id2, self.origin_metadata2['discovery_date'], provider['id'], tool['id'], self.origin_metadata2['metadata']) - o_m3 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata2['discovery_date'], provider['id'], @@ -1730,15 +1724,12 @@ expected_results = [{ 'origin_id': origin_id, 'discovery_date': datetime.datetime( - 2017, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2017, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m3, 'provider_id': provider['id'], 'provider_name': 'hal', 'provider_type': 'deposit-client', @@ -1747,15 +1738,12 @@ }, { 'origin_id': origin_id, 'discovery_date': datetime.datetime( - 2015, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2015, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m1, 'provider_id': provider['id'], 'provider_name': 'hal', 'provider_type': 'deposit-client', @@ -1766,8 +1754,7 @@ # then self.assertEqual(len(all_metadatas), 2) self.assertEqual(len(metadatas_for_origin2), 1) - self.assertEqual(metadatas_for_origin2[0]['id'], o_m2) - self.assertEqual(all_metadatas, expected_results) + self.assertCountEqual(all_metadatas, expected_results) def test_origin_metadata_get_by_provider_type(self): # given @@ -1796,16 +1783,16 @@ # using the only tool now inserted in the data.sql, but for this # provider should be a crawler tool (not yet implemented) - tool = self.storage.tool_get(self.metadata_tool) + tool = list(self.storage.tool_add([self.metadata_tool]))[0] # when adding for the same origin 2 metadatas - o_m1 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id, self.origin_metadata['discovery_date'], provider1['id'], tool['id'], self.origin_metadata['metadata']) - o_m2 = self.storage.origin_metadata_add( + self.storage.origin_metadata_add( origin_id2, self.origin_metadata2['discovery_date'], provider2['id'], @@ -1816,18 +1803,18 @@ origin_metadata_get_by( origin_id2, provider_type)) + for item in m_by_provider: + if 'id' in item: + del item['id'] expected_results = [{ 'origin_id': origin_id2, 'discovery_date': datetime.datetime( - 2017, 1, 2, 0, 0, - tzinfo=psycopg2.tz.FixedOffsetTimezone( - offset=60, - name=None)), + 2017, 1, 1, 23, 0, + tzinfo=datetime.timezone.utc), 'metadata': { 'name': 'test_origin_metadata', 'version': '0.0.1' }, - 'id': o_m2, 'provider_id': provider2['id'], 'provider_name': 'swMATH', 'provider_type': provider_type, @@ -1838,8 +1825,6 @@ self.assertEqual(len(m_by_provider), 1) self.assertEqual(m_by_provider, expected_results) - self.assertEqual(m_by_provider[0]['id'], o_m2) - self.assertIsNotNone(o_m1) class TestLocalStorage(CommonTestStorage, unittest.TestCase):