diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -107,6 +107,16 @@ def origin_get(self, origin): return self.post('origin/get', {'origin': origin}) + def origin_search(self, url_pattern, offset=0, limit=50): + return self.post('origin/search', {'url_pattern': url_pattern, + 'offset': offset, + 'limit': limit}) + + def origin_regexp_search(self, url_regexp, offset=0, limit=50): + return self.post('origin/regexp_search', {'url_regexp': url_regexp, + 'offset': offset, + 'limit': limit}) + def origin_add(self, origins): return self.post('origin/add_multi', {'origins': origins}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -196,6 +196,17 @@ return encode_data(g.storage.origin_get(**decode_request(request))) +@app.route('/origin/search', methods=['POST']) +def origin_search(): + return encode_data(g.storage.origin_search(**decode_request(request))) + + +@app.route('/origin/regexp_search', methods=['POST']) +def origin_regexp_search(): + return encode_data( + g.storage.origin_regexp_search(**decode_request(request))) + + @app.route('/origin/add_multi', methods=['POST']) def origin_add(): return encode_data(g.storage.origin_add(**decode_request(request))) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -639,13 +639,16 @@ cur.execute(insert, (type, url)) return cur.fetchone()[0] + origin_cols = ['id', 'type', 'url', 'lister', 'project'] + def origin_get_with(self, type, url, cur=None): """Retrieve the origin id from its type and url if found.""" cur = self._cursor(cur) - query = """SELECT id, type, url, lister, project + query = """SELECT %s FROM origin - WHERE type=%s AND url=%s""" + WHERE type=%%s AND url=%%s + """ % ','.join(self.origin_cols) cur.execute(query, (type, url)) data = cur.fetchone() @@ -659,7 +662,9 @@ """ cur = self._cursor(cur) - query = "SELECT id, type, url, lister, project FROM origin WHERE id=%s" + query = """SELECT %s + FROM origin WHERE id=%%s + """ % ','.join(self.origin_cols) cur.execute(query, (id,)) data = cur.fetchone() @@ -667,6 +672,47 @@ return line_to_bytes(data) return None + def origin_search(self, url_pattern, offset=0, limit=50, cur=None): + """Search for origins whose urls contain the provided string pattern. + The search is performed in a case insensitive way. + + Args: + url_pattern: the string pattern to search for in origin urls + offset: number of found origins to skip before returning results + limit: the maximum number of found origins to return + + """ + cur = self._cursor(cur) + + query = """SELECT %s + FROM origin WHERE url ILIKE %%s + ORDER BY id + OFFSET %%s LIMIT %%s + """ % ','.join(self.origin_cols) + + cur.execute(query, ('%'+url_pattern+'%', offset, limit)) + yield from cursor_to_bytes(cur) + + def origin_regexp_search(self, url_regexp, offset=0, limit=50, cur=None): + """Search for origins whose urls match the provided regular expression. + + Args: + url_regexp: the regular expression to match in origin urls + offset: number of found origins to skip before returning results + limit: the maximum number of found origins to return + + """ + cur = self._cursor(cur) + + query = """SELECT %s + FROM origin WHERE url ~ %%s + ORDER BY id + OFFSET %%s LIMIT %%s + """ % ','.join(self.origin_cols) + + cur.execute(query, (url_regexp, offset, limit)) + yield from cursor_to_bytes(cur) + person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -983,6 +983,8 @@ return ret + origin_keys = ['id', 'type', 'url', 'lister', 'project'] + @db_transaction def origin_get(self, origin, cur=None): """Return the origin either identified by its id or its tuple @@ -1014,8 +1016,6 @@ """ db = self.db - keys = ['id', 'type', 'url', 'lister', 'project'] - origin_id = origin.get('id') if origin_id: # check lookup per id first ori = db.origin_get(origin_id, cur) @@ -1025,9 +1025,46 @@ raise ValueError('Origin must have either id or (type and url).') if ori: - return dict(zip(keys, ori)) + return dict(zip(self.origin_keys, ori)) return None + @db_transaction_generator + def origin_search(self, url_pattern, offset=0, limit=50, cur=None): + """Search for origins whose urls contain a provided string pattern. + The search is performed in a case insensitive way. + + Args: + url_pattern: the string pattern to search for in origin urls + offset: number of found origins to skip before returning results + limit: the maximum number of found origins to return + + Returns: + An iterable of dict containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + db = self.db + + for origin in db.origin_search(url_pattern, offset, limit, cur): + yield dict(zip(self.origin_keys, origin)) + + @db_transaction_generator + def origin_regexp_search(self, url_regexp, offset=0, limit=50, cur=None): + """Search for origins whose urls match a provided regular expression. + + Args: + url_regexp: the regular expression to match in origin urls + offset: number of found origins to skip before returning results + limit: the maximum number of found origins to return + + Returns: + An iterable of dict containing origin information as returned + by :meth:`swh.storage.storage.Storage.origin_get`. + """ + db = self.db + + for origin in db.origin_regexp_search(url_regexp, offset, limit, cur): + yield dict(zip(self.origin_keys, origin)) + @db_transaction def _person_add(self, person, cur=None): """Add a person in storage. diff --git a/swh/storage/tests/test_api_client.py b/swh/storage/tests/test_api_client.py --- a/swh/storage/tests/test_api_client.py +++ b/swh/storage/tests/test_api_client.py @@ -18,7 +18,7 @@ This class doesn't define any tests as we want identical functionality between local and remote storage. All the tests are - therefore defined in AbstractTestStorage. + therefore defined in CommonTestStorage. """ def setUp(self): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -611,7 +611,7 @@ class CommonTestStorage(BaseTestStorage): """Base class for Storage testing. - This class is used as-is to test local storage (see TestStorage + This class is used as-is to test local storage (see TestLocalStorage below) and remote storage (see TestRemoteStorage in test_remote_storage.py. @@ -1304,6 +1304,81 @@ 'project': None}) @istest + def origin_search(self): + found_origins = list(self.storage.origin_search(self.origin['url'])) + self.assertEqual(len(found_origins), 0) + + id = self.storage.origin_add_one(self.origin) + origin_data = {'id': id, + 'type': self.origin['type'], + 'url': self.origin['url'], + 'lister': None, + 'project': None} + found_origins = list(self.storage.origin_search(self.origin['url'])) + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin_data) + + id2 = self.storage.origin_add_one(self.origin2) + origin2_data = {'id': id2, + 'type': self.origin2['type'], + 'url': self.origin2['url'], + 'lister': None, + 'project': None} + found_origins = list(self.storage.origin_search(self.origin2['url'])) + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin2_data) + + found_origins = list(self.storage.origin_search('/')) + self.assertEqual(len(found_origins), 2) + + found_origins = list(self.storage.origin_search('/', offset=0, limit=1)) # noqa + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin_data) + + found_origins = list(self.storage.origin_search('/', offset=1, limit=1)) # noqa + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin2_data) + + @istest + def origin_regexp_search(self): + found_origins = list(self.storage.origin_regexp_search( + self.origin['url'])) + self.assertEqual(len(found_origins), 0) + + id = self.storage.origin_add_one(self.origin) + origin_data = {'id': id, + 'type': self.origin['type'], + 'url': self.origin['url'], + 'lister': None, + 'project': None} + found_origins = list(self.storage.origin_regexp_search( + '.' + self.origin['url'][1:-1] + '.')) + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin_data) + + id2 = self.storage.origin_add_one(self.origin2) + origin2_data = {'id': id2, + 'type': self.origin2['type'], + 'url': self.origin2['url'], + 'lister': None, + 'project': None} + found_origins = list(self.storage.origin_regexp_search( + '.' + self.origin2['url'][1:-1] + '.')) + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin2_data) + + found_origins = list(self.storage.origin_regexp_search('.*/.*')) + self.assertEqual(len(found_origins), 2) + + found_origins = list(self.storage.origin_regexp_search('.*/.*', offset=0, limit=1)) # noqa + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin_data) + + found_origins = list(self.storage.origin_regexp_search('.*/.*', offset=1, limit=1)) # noqa + self.assertEqual(len(found_origins), 1) + self.assertEqual(found_origins[0], origin2_data) + + @istest def origin_visit_add(self): # given self.assertIsNone(self.storage.origin_get(self.origin2))