diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -142,6 +142,11 @@ 'regexp': regexp, 'with_visit': with_visit}) + def origin_count(self, url_pattern, regexp=False, with_visit=False): + return self.post('origin/count', {'url_pattern': url_pattern, + 'regexp': regexp, + 'with_visit': with_visit}) + def origin_get_range(self, origin_from=1, origin_count=100): return self.post('origin/get_range', {'origin_from': origin_from, 'origin_count': origin_count}) diff --git a/swh/storage/api/server.py b/swh/storage/api/server.py --- a/swh/storage/api/server.py +++ b/swh/storage/api/server.py @@ -236,6 +236,11 @@ return encode_data(get_storage().origin_search(**decode_request(request))) +@app.route('/origin/count', methods=['POST']) +def origin_count(): + return encode_data(get_storage().origin_count(**decode_request(request))) + + @app.route('/origin/add_multi', methods=['POST']) def origin_add(): return encode_data(get_storage().origin_add(**decode_request(request))) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -621,24 +621,18 @@ yield from execute_values_generator( cur, query, ((id,) for id in ids)) - def origin_search(self, url_pattern, offset=0, limit=50, + def _origin_query(self, url_pattern, count=False, offset=0, limit=50, regexp=False, with_visit=False, cur=None): - """Search for origins whose urls contain a provided string pattern - or match a provided regular expression. - The search is performed in a case insensitive way. - - Args: - url_pattern (str): the string pattern to search for in origin urls - offset (int): number of found origins to skip before returning - results - limit (int): the maximum number of found origins to return - regexp (bool): if True, consider the provided pattern as a regular - expression and returns origins whose urls match it - with_visit (bool): if True, filter out origins with no visit - + """ + Method factorizing query creation for searching and counting origins. """ cur = self._cursor(cur) - origin_cols = ','.join(self.origin_cols) + + if count: + origin_cols = 'COUNT(*)' + else: + origin_cols = ','.join(self.origin_cols) + query = """SELECT %s FROM origin WHERE """ @@ -646,10 +640,9 @@ query += """ EXISTS (SELECT 1 from origin_visit WHERE origin=origin.id) AND """ - query += """ - url %s %%s - ORDER BY id - OFFSET %%s LIMIT %%s""" + query += 'url %s %%s ' + if not count: + query += 'ORDER BY id OFFSET %%s LIMIT %%s' if not regexp: query = query % (origin_cols, 'ILIKE') @@ -658,9 +651,48 @@ query = query % (origin_cols, '~*') query_params = (url_pattern, offset, limit) + if count: + query_params = (query_params[0],) + cur.execute(query, query_params) + + def origin_search(self, url_pattern, offset=0, limit=50, + regexp=False, with_visit=False, cur=None): + """Search for origins whose urls contain a provided string pattern + or match a provided regular expression. + The search is performed in a case insensitive way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + offset (int): number of found origins to skip before returning + results + limit (int): the maximum number of found origins to return + regexp (bool): if True, consider the provided pattern as a regular + expression and returns origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + + """ + self._origin_query(url_pattern, offset=offset, limit=limit, + regexp=regexp, with_visit=with_visit, cur=cur) yield from cur + def origin_count(self, url_pattern, regexp=False, + with_visit=False, cur=None): + """Count origins whose urls contain a provided string pattern + or match a provided regular expression. + The pattern search in origin urls is performed in a case insensitive + way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + regexp (bool): if True, consider the provided pattern as a regular + expression and returns origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + """ + self._origin_query(url_pattern, count=True, + regexp=regexp, with_visit=with_visit, cur=cur) + return cur.fetchone()[0] + person_cols = ['fullname', 'name', 'email'] person_get_cols = person_cols + ['id'] diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -896,9 +896,30 @@ if with_visit: origins = [orig for orig in origins if len(self._origin_visits[orig['id']-1]) > 0] + origins = copy.deepcopy(origins[offset:offset+limit]) return origins + def origin_count(self, url_pattern, regexp=False, with_visit=False, + db=None, cur=None): + """Count origins whose urls contain a provided string pattern + or match a provided regular expression. + The pattern search in origin urls is performed in a case insensitive + way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + regexp (bool): if True, consider the provided pattern as a regular + expression and return origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + + Returns: + int: The number of origins matching the search criterion. + """ + return len(self.origin_search(url_pattern, regexp=regexp, + with_visit=with_visit, + limit=len(self._origins))) + def origin_add(self, origins): """Add origins to the storage diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -1173,6 +1173,25 @@ regexp, with_visit, cur): yield dict(zip(self.origin_keys, origin)) + @db_transaction() + def origin_count(self, url_pattern, regexp=False, + with_visit=False, db=None, cur=None): + """Count origins whose urls contain a provided string pattern + or match a provided regular expression. + The pattern search in origin urls is performed in a case insensitive + way. + + Args: + url_pattern (str): the string pattern to search for in origin urls + regexp (bool): if True, consider the provided pattern as a regular + expression and return origins whose urls match it + with_visit (bool): if True, filter out origins with no visit + + Returns: + int: The number of origins matching the search criterion. + """ + return db.origin_count(url_pattern, regexp, with_visit, cur) + @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, db=None, cur=None): diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -2220,6 +2220,44 @@ origin_count=origin_count)) self.assertEqual(len(origins), 0) + def test_origin_count(self): + + new_origins = [ + { + 'type': 'git', + 'url': 'https://github.com/user1/repo1' + }, + { + 'type': 'git', + 'url': 'https://github.com/user2/repo1' + }, + { + 'type': 'git', + 'url': 'https://github.com/user3/repo1' + }, + { + 'type': 'git', + 'url': 'https://gitlab.com/user1/repo1' + }, + { + 'type': 'git', + 'url': 'https://gitlab.com/user2/repo1' + } + ] + + self.storage.origin_add(new_origins) + + self.assertEqual(self.storage.origin_count('github'), 3) + self.assertEqual(self.storage.origin_count('gitlab'), 2) + self.assertEqual( + self.storage.origin_count('.*user.*', regexp=True), 5) + self.assertEqual( + self.storage.origin_count('.*user.*', regexp=False), 0) + self.assertEqual( + self.storage.origin_count('.*user1.*', regexp=True), 2) + self.assertEqual( + self.storage.origin_count('.*user1.*', regexp=False), 0) + @pytest.mark.db class TestLocalStorage(CommonTestStorage, StorageTestDbFixture,