diff --git a/setup.py b/setup.py index 23bcb26..90140da 100644 --- a/setup.py +++ b/setup.py @@ -1,48 +1,42 @@ """ Py-Tree-sitter """ -import platform -from os import path -from setuptools import Extension -from setuptools import setup +from os import path +from platform import system +from setuptools import Extension, setup with open(path.join(path.dirname(__file__), "README.md")) as f: LONG_DESCRIPTION = f.read() setup( name="tree_sitter", version="0.0.8", maintainer="Max Brunsfeld", maintainer_email="maxbrunsfeld@gmail.com", author="Max Brunsfeld", author_email="maxbrunsfeld@gmail.com", url="https://github.com/tree-sitter/py-tree-sitter", license="MIT", platforms=["any"], python_requires=">=3.3", description="Python bindings to the Tree-sitter parsing library", long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown", classifiers=[ "License :: OSI Approved :: MIT License", "Topic :: Software Development :: Compilers", "Topic :: Text Processing :: Linguistic", ], packages=["tree_sitter"], ext_modules=[ Extension( "tree_sitter.binding", ["tree_sitter/core/lib/src/lib.c", "tree_sitter/binding.c"], - include_dirs=[ - "tree_sitter/core/lib/include", - "tree_sitter/core/lib/src", - ], - extra_compile_args=( - ["-std=c99"] if platform.system() != "Windows" else None - ), + include_dirs=["tree_sitter/core/lib/include", "tree_sitter/core/lib/src"], + extra_compile_args=(["-std=c99"] if system() != "Windows" else None), ) ], project_urls={"Source": "https://github.com/tree-sitter/py-tree-sitter"}, ) diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index 0befe48..0aa1a6e 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -1,278 +1,269 @@ # pylint: disable=missing-docstring import re -import unittest +from unittest import TestCase from os import path - -from tree_sitter import Language -from tree_sitter import Parser - +from tree_sitter import Language, Parser LIB_PATH = path.join("build", "languages.so") Language.build_library( LIB_PATH, [ path.join("tests", "fixtures", "tree-sitter-python"), path.join("tests", "fixtures", "tree-sitter-javascript"), ], ) PYTHON = Language(LIB_PATH, "python") JAVASCRIPT = Language(LIB_PATH, "javascript") -def _collapse_ws(string): - return re.sub(r"\s+", " ", string).strip() - - -class TestParser(unittest.TestCase): +class TestParser(TestCase): def test_set_language(self): parser = Parser() parser.set_language(PYTHON) tree = parser.parse(b"def foo():\n bar()") self.assertEqual( tree.root_node.sexp(), - _collapse_ws( + trim( """(module (function_definition name: (identifier) parameters: (parameters) body: (block (expression_statement (call function: (identifier) arguments: (argument_list))))))""" ), ) parser.set_language(JAVASCRIPT) tree = parser.parse(b"function foo() {\n bar();\n}") self.assertEqual( tree.root_node.sexp(), - _collapse_ws( + trim( """(program (function_declaration name: (identifier) parameters: (formal_parameters) body: (statement_block (expression_statement (call_expression function: (identifier) arguments: (arguments))))))""" ), ) def test_multibyte_characters(self): parser = Parser() parser.set_language(JAVASCRIPT) source_code = bytes("'😎' && '🐍'", "utf8") tree = parser.parse(source_code) root_node = tree.root_node statement_node = root_node.children[0] binary_node = statement_node.children[0] snake_node = binary_node.children[2] self.assertEqual(binary_node.type, "binary_expression") self.assertEqual(snake_node.type, "string") self.assertEqual( - source_code[snake_node.start_byte : snake_node.end_byte].decode( - "utf8" - ), + source_code[snake_node.start_byte:snake_node.end_byte].decode("utf8"), "'🐍'", ) -class TestNode(unittest.TestCase): +class TestNode(TestCase): def test_child_by_field_id(self): parser = Parser() parser.set_language(PYTHON) tree = parser.parse(b"def foo():\n bar()") root_node = tree.root_node fn_node = tree.root_node.children[0] self.assertEqual(PYTHON.field_id_for_name("nameasdf"), None) name_field = PYTHON.field_id_for_name("name") alias_field = PYTHON.field_id_for_name("alias") self.assertIsInstance(alias_field, int) self.assertIsInstance(name_field, int) self.assertRaises(TypeError, root_node.child_by_field_id, "") self.assertEqual(root_node.child_by_field_id(alias_field), None) self.assertEqual(root_node.child_by_field_id(name_field), None) self.assertEqual(fn_node.child_by_field_id(alias_field), None) - self.assertEqual( - fn_node.child_by_field_id(name_field).type, "identifier" - ) + self.assertEqual(fn_node.child_by_field_id(name_field).type, "identifier") self.assertRaises(TypeError, root_node.child_by_field_name, True) self.assertRaises(TypeError, root_node.child_by_field_name, 1) - self.assertEqual( - fn_node.child_by_field_name("name").type, "identifier" - ) + self.assertEqual(fn_node.child_by_field_name("name").type, "identifier") self.assertEqual(fn_node.child_by_field_name("asdfasdfname"), None) def test_children(self): parser = Parser() parser.set_language(PYTHON) tree = parser.parse(b"def foo():\n bar()") root_node = tree.root_node self.assertEqual(root_node.type, "module") self.assertEqual(root_node.start_byte, 0) self.assertEqual(root_node.end_byte, 18) self.assertEqual(root_node.start_point, (0, 0)) self.assertEqual(root_node.end_point, (1, 7)) # List object is reused self.assertIs(root_node.children, root_node.children) fn_node = root_node.children[0] self.assertEqual(fn_node.type, "function_definition") self.assertEqual(fn_node.start_byte, 0) self.assertEqual(fn_node.end_byte, 18) self.assertEqual(fn_node.start_point, (0, 0)) self.assertEqual(fn_node.end_point, (1, 7)) def_node = fn_node.children[0] self.assertEqual(def_node.type, "def") self.assertEqual(def_node.is_named, False) id_node = fn_node.children[1] self.assertEqual(id_node.type, "identifier") self.assertEqual(id_node.is_named, True) self.assertEqual(len(id_node.children), 0) params_node = fn_node.children[2] self.assertEqual(params_node.type, "parameters") self.assertEqual(params_node.is_named, True) colon_node = fn_node.children[3] self.assertEqual(colon_node.type, ":") self.assertEqual(colon_node.is_named, False) statement_node = fn_node.children[4] self.assertEqual(statement_node.type, "block") self.assertEqual(statement_node.is_named, True) -class TestTree(unittest.TestCase): +class TestTree(TestCase): def test_walk(self): parser = Parser() parser.set_language(PYTHON) tree = parser.parse(b"def foo():\n bar()") cursor = tree.walk() # Node always returns the same instance self.assertIs(cursor.node, cursor.node) self.assertEqual(cursor.node.type, "module") self.assertEqual(cursor.node.start_byte, 0) self.assertEqual(cursor.node.end_byte, 18) self.assertEqual(cursor.node.start_point, (0, 0)) self.assertEqual(cursor.node.end_point, (1, 7)) self.assertTrue(cursor.goto_first_child()) self.assertEqual(cursor.node.type, "function_definition") self.assertEqual(cursor.node.start_byte, 0) self.assertEqual(cursor.node.end_byte, 18) self.assertEqual(cursor.node.start_point, (0, 0)) self.assertEqual(cursor.node.end_point, (1, 7)) self.assertTrue(cursor.goto_first_child()) self.assertEqual(cursor.node.type, "def") self.assertEqual(cursor.node.is_named, False) self.assertEqual(cursor.node.sexp(), '("def")') def_node = cursor.node # Node remains cached after a failure to move self.assertFalse(cursor.goto_first_child()) self.assertIs(cursor.node, def_node) self.assertTrue(cursor.goto_next_sibling()) self.assertEqual(cursor.node.type, "identifier") self.assertEqual(cursor.node.is_named, True) self.assertFalse(cursor.goto_first_child()) self.assertTrue(cursor.goto_next_sibling()) self.assertEqual(cursor.node.type, "parameters") self.assertEqual(cursor.node.is_named, True) def test_edit(self): parser = Parser() parser.set_language(PYTHON) tree = parser.parse(b"def foo():\n bar()") edit_offset = len(b"def foo(") tree.edit( start_byte=edit_offset, old_end_byte=edit_offset, new_end_byte=edit_offset + 2, start_point=(0, edit_offset), old_end_point=(0, edit_offset), new_end_point=(0, edit_offset + 2), ) fn_node = tree.root_node.children[0] self.assertEqual(fn_node.type, "function_definition") self.assertTrue(fn_node.has_changes) self.assertFalse(fn_node.children[0].has_changes) self.assertFalse(fn_node.children[1].has_changes) self.assertFalse(fn_node.children[3].has_changes) params_node = fn_node.children[2] self.assertEqual(params_node.type, "parameters") self.assertTrue(params_node.has_changes) self.assertEqual(params_node.start_point, (0, edit_offset - 1)) self.assertEqual(params_node.end_point, (0, edit_offset + 3)) new_tree = parser.parse(b"def foo(ab):\n bar()", tree) self.assertEqual( new_tree.root_node.sexp(), - _collapse_ws( + trim( """(module (function_definition name: (identifier) parameters: (parameters (identifier)) body: (block (expression_statement (call function: (identifier) arguments: (argument_list))))))""" ), ) -class TestQuery(unittest.TestCase): +class TestQuery(TestCase): def test_errors(self): with self.assertRaisesRegex(NameError, "Invalid node type foo"): PYTHON.query("(list (foo))") with self.assertRaisesRegex(NameError, "Invalid field name buzz"): PYTHON.query("(function_definition buzz: (identifier))") with self.assertRaisesRegex(NameError, "Invalid capture name garbage"): PYTHON.query("((function_definition) (eq? @garbage foo))") with self.assertRaisesRegex(SyntaxError, "Invalid syntax at offset 6"): PYTHON.query("(list))") PYTHON.query("(function_definition)") def test_captures(self): parser = Parser() parser.set_language(PYTHON) source = b"def foo():\n bar()\ndef baz():\n quux()\n" tree = parser.parse(source) query = PYTHON.query( """ (function_definition name: (identifier) @func-def) (call function: (identifier) @func-call) """ ) captures = query.captures(tree.root_node) captures = query.captures(tree.root_node) captures = query.captures(tree.root_node) captures = query.captures(tree.root_node) self.assertEqual(captures[0][0].start_point, (0, 4)) self.assertEqual(captures[0][0].end_point, (0, 7)) self.assertEqual(captures[0][1], "func-def") self.assertEqual(captures[1][0].start_point, (1, 2)) self.assertEqual(captures[1][0].end_point, (1, 5)) self.assertEqual(captures[1][1], "func-call") self.assertEqual(captures[2][0].start_point, (2, 4)) self.assertEqual(captures[2][0].end_point, (2, 7)) self.assertEqual(captures[2][1], "func-def") self.assertEqual(captures[3][0].start_point, (3, 2)) self.assertEqual(captures[3][0].end_point, (3, 6)) self.assertEqual(captures[3][1], "func-call") + + +def trim(string): + return re.sub(r"\s+", " ", string).strip() diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index 66fc989..cc94ac9 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -1,100 +1,92 @@ """Python bindings for tree-sitter.""" -import platform -from ctypes import c_void_p -from ctypes import cdll +from ctypes import cdll, c_void_p from ctypes.util import find_library from distutils.ccompiler import new_compiler from os import path +from platform import system from tempfile import TemporaryDirectory - -# pylint: disable=no-name-in-module,import-error from tree_sitter.binding import _language_field_id_for_name, _language_query -from tree_sitter.binding import Node -from tree_sitter.binding import Parser -from tree_sitter.binding import Tree -from tree_sitter.binding import TreeCursor +from tree_sitter.binding import Node, Parser, Tree, TreeCursor # noqa: F401 class Language: """A tree-sitter language""" @staticmethod def build_library(output_path, repo_paths): """ Build a dynamic library at the given path, based on the parser repositories at the given paths. Returns `True` if the dynamic library was compiled and `False` if the library already existed and was modified more recently than any of the source files. """ - output_mtime = ( - path.getmtime(output_path) if path.exists(output_path) else 0 - ) + output_mtime = path.getmtime(output_path) if path.exists(output_path) else 0 if not repo_paths: raise ValueError("Must provide at least one language folder") cpp = False source_paths = [] for repo_path in repo_paths: src_path = path.join(repo_path, "src") source_paths.append(path.join(src_path, "parser.c")) if path.exists(path.join(src_path, "scanner.cc")): cpp = True source_paths.append(path.join(src_path, "scanner.cc")) elif path.exists(path.join(src_path, "scanner.c")): source_paths.append(path.join(src_path, "scanner.c")) source_mtimes = [path.getmtime(__file__)] + [ path.getmtime(path_) for path_ in source_paths ] compiler = new_compiler() if cpp: if find_library("c++"): compiler.add_library("c++") elif find_library("stdc++"): compiler.add_library("stdc++") if max(source_mtimes) <= output_mtime: return False with TemporaryDirectory(suffix="tree_sitter_language") as out_dir: object_paths = [] for source_path in source_paths: - if platform.system() == "Windows": + if system() == "Windows": flags = None else: flags = ["-fPIC"] if source_path.endswith(".c"): flags.append("-std=c99") object_paths.append( compiler.compile( [source_path], output_dir=out_dir, include_dirs=[path.dirname(source_path)], extra_preargs=flags, )[0] ) compiler.link_shared_object(object_paths, output_path) return True def __init__(self, library_path, name): """ Load the language with the given name from the dynamic library at the given path. """ self.name = name self.lib = cdll.LoadLibrary(library_path) language_function = getattr(self.lib, "tree_sitter_%s" % name) language_function.restype = c_void_p self.language_id = language_function() def field_id_for_name(self, name): """Return the field id for a field name.""" return _language_field_id_for_name(self.language_id, name) def query(self, source): """Create a Query with the given source code.""" return _language_query(self.language_id, source)