diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -3,15 +3,18 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import List, Optional + from swh.graphql import server +from swh.model import model as swh_model class Archive: def __init__(self): self.storage = server.get_storage() - def get_origin(self, url): - return self.storage.origin_get([url])[0] + def get_origins_with_urls(self, urls: List) -> List[Optional[swh_model.Origin]]: + return self.storage.origin_get(urls) def get_origins(self, after=None, first=50, url_pattern=None): # STORAGE-TODO diff --git a/swh/graphql/resolvers/origin.py b/swh/graphql/resolvers/origin.py --- a/swh/graphql/resolvers/origin.py +++ b/swh/graphql/resolvers/origin.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from swh.graphql.backends import archive +from swh.graphql.utils import utils from .base_connection import BaseConnection from .base_node import BaseSWHNode @@ -14,8 +15,13 @@ Node resolver for an origin requested directly with its URL """ + def _get_analogous_origins(self, url): + return utils.get_analogues_urls(url) + def _get_node_data(self): - return archive.Archive().get_origin(self.kwargs.get("url")) + urls = self._get_analogous_origins(self.kwargs.get("url")) + results = archive.Archive().get_origins_with_urls(urls) + return results[0] if results else None class OriginConnection(BaseConnection): diff --git a/swh/graphql/tests/unit/utils/test_utils.py b/swh/graphql/tests/unit/utils/test_utils.py --- a/swh/graphql/tests/unit/utils/test_utils.py +++ b/swh/graphql/tests/unit/utils/test_utils.py @@ -28,6 +28,19 @@ def test_get_decoded_cursor(self): assert utils.get_decoded_cursor("dGVzdGluZw==") == "testing" + def test_get_analogues_urls(self): + url = "http://example.com" + assert utils.get_analogues_urls(url) == [ + "http://example.com", + "http://example.com/", + ] + + url = "http://example.com/" + assert utils.get_analogues_urls(url) == [ + "http://example.com", + "http://example.com/", + ] + def test_str_to_sha1(self): assert ( utils.str_to_sha1("208f61cc7a5dbc9879ae6e5c2f95891e270f09ef") diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -30,6 +30,11 @@ return base64.b64decode(cursor).decode(ENCODING) +def get_analogues_urls(url: str) -> List[str]: + url = url.rstrip("/") + return [url, f"{url}/"] + + def str_to_sha1(sha1: str) -> bytearray: # FIXME, use core function return bytearray.fromhex(sha1)