diff --git a/swh/graphql/app.py b/swh/graphql/app.py --- a/swh/graphql/app.py +++ b/swh/graphql/app.py @@ -33,8 +33,9 @@ resolvers.branch_target, resolvers.release_target, resolvers.directory_entry_target, + resolvers.binary_string, scalars.id_scalar, - scalars.string_scalar, + # scalars.string_scalar, scalars.datetime_scalar, scalars.swhid_scalar, ) diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -15,6 +15,7 @@ from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs +from swh.graphql.utils import utils from .resolver_factory import get_connection_resolver, get_node_resolver @@ -29,6 +30,8 @@ directory: ObjectType = ObjectType("Directory") directory_entry: ObjectType = ObjectType("DirectoryEntry") +binary_string: ObjectType = ObjectType("BinaryString") + branch_target: UnionType = UnionType("BranchTarget") release_target: UnionType = UnionType("ReleaseTarget") directory_entry_target: UnionType = UnionType("DirectoryEntryTarget") @@ -243,3 +246,13 @@ Generic resolver for all the union types """ return obj.is_type_of() + + +@binary_string.field("text") +def binary_string_text_resolver(obj, *args, **kw): + return obj.decode(utils.ENCODING, "ignore") + + +@binary_string.field("base64") +def binary_string_base64_resolver(obj, *args, **kw): + return utils.get_b64_string(obj) diff --git a/swh/graphql/resolvers/scalars.py b/swh/graphql/resolvers/scalars.py --- a/swh/graphql/resolvers/scalars.py +++ b/swh/graphql/resolvers/scalars.py @@ -14,7 +14,7 @@ datetime_scalar = ScalarType("DateTime") swhid_scalar = ScalarType("SWHID") id_scalar = ScalarType("ID") -string_scalar = ScalarType("String") +# string_scalar = ScalarType("String") @id_scalar.serializer @@ -24,11 +24,11 @@ return value -@string_scalar.serializer -def serialize_string(value): - if type(value) is bytes: - return value.decode("utf-8") - return value +# @string_scalar.serializer +# def serialize_string(value): +# if type(value) is bytes: +# return value.decode("utf-8") +# return value @datetime_scalar.serializer diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -14,7 +14,7 @@ @property def id(self): # FIXME, use a better id - return utils.b64encode(f"{self.origin}-{str(self.visit)}") + return utils.get_b64_string(f"{self.origin}-{str(self.visit)}") @property def visitId(self): # To support the schema naming convention diff --git a/swh/graphql/schema/schema.graphql b/swh/graphql/schema/schema.graphql --- a/swh/graphql/schema/schema.graphql +++ b/swh/graphql/schema/schema.graphql @@ -44,6 +44,22 @@ hasNextPage: Boolean! } +""" +Binary strings; different encodings +""" +type BinaryString { + """ + Utf-8 encoded value, any non Utf char will be ignored + """ + text: String + + """ + base64 encoded value + """ + base64: String +} + + """ Connection to origins """ @@ -411,17 +427,17 @@ """ User's email address """ - email: String + email: BinaryString """ User's name """ - name: String + name: BinaryString """ User's full name """ - fullname: String + fullname: BinaryString } """ @@ -448,7 +464,7 @@ """ Branch name """ - name: String + name: BinaryString """ Type of Branch target @@ -518,7 +534,7 @@ """ Message associated to the revision """ - message: String + message: BinaryString """ """ @@ -607,12 +623,12 @@ """ The name of the release """ - name: String + name: BinaryString """ The message associated to the release """ - message: String + message: BinaryString """ """ @@ -695,7 +711,7 @@ """ The directory entry name """ - name: String + name: BinaryString """ Directory entry object type; can be file, dir or rev diff --git a/swh/graphql/tests/unit/resolvers/test_resolvers.py b/swh/graphql/tests/unit/resolvers/test_resolvers.py --- a/swh/graphql/tests/unit/resolvers/test_resolvers.py +++ b/swh/graphql/tests/unit/resolvers/test_resolvers.py @@ -118,3 +118,11 @@ obj = mocker.Mock() obj.is_type_of.return_value = "test" assert rs.union_resolver(obj) == "test" + + def test_binary_string_text_resolver(self): + text = rs.binary_string_text_resolver(b"test", None) + assert text == "test" + + def test_binary_string_base64_resolver(self): + b64string = rs.binary_string_base64_resolver(b"test", None) + assert b64string == "dGVzdA==" diff --git a/swh/graphql/tests/unit/utils/test_utils.py b/swh/graphql/tests/unit/utils/test_utils.py --- a/swh/graphql/tests/unit/utils/test_utils.py +++ b/swh/graphql/tests/unit/utils/test_utils.py @@ -9,8 +9,11 @@ class TestUtils: - def test_b64encode(self): - assert utils.b64encode("testing") == "dGVzdGluZw==" + def test_get_b64_string(self): + assert utils.get_b64_string("testing") == "dGVzdGluZw==" + + def test_get_b64_string_binary(self): + assert utils.get_b64_string(b"testing") == "dGVzdGluZw==" def test_get_encoded_cursor_is_none(self): assert utils.get_encoded_cursor(None) is None diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -9,21 +9,25 @@ from swh.storage.interface import PagedResult +ENCODING = "utf-8" -def b64encode(text: str) -> str: - return base64.b64encode(bytes(text, "utf-8")).decode("utf-8") + +def get_b64_string(source) -> str: + if type(source) is str: + source = source.encode(ENCODING) + return base64.b64encode(source).decode(ENCODING) def get_encoded_cursor(cursor: str) -> str: if cursor is None: return None - return b64encode(cursor) + return get_b64_string(cursor) def get_decoded_cursor(cursor: str) -> str: if cursor is None: return None - return base64.b64decode(cursor).decode("utf-8") + return base64.b64decode(cursor).decode(ENCODING) def str_to_sha1(sha1: str) -> bytearray: