Created
April 17, 2026 01:59
-
-
Save eliaskanelis/7a15196f28ecc0a57d712062381a9595 to your computer and use it in GitHub Desktop.
tree-sitter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from tree_sitter import Language, Parser | |
| import tree_sitter_c | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional, Dict, Any | |
| import sys | |
| from rich.pretty import pprint | |
| # ------------------------------------------------------------------------------ | |
| @dataclass | |
| class Primitive: | |
| primitive: str = "" | |
| name: str = "" | |
| comment: Optional[str] = None | |
| @dataclass | |
| class Argument: | |
| primitive: str = "" | |
| name: str = "" | |
| direction: str = "" # "in" | "out" | "inout" | "" | |
| comment: Optional[str] = None | |
| @dataclass | |
| class Function: | |
| primitive: str = "" | |
| name: str = "" | |
| args: List[Argument] = field(default_factory=list) | |
| comment: Optional[str] = None | |
| @dataclass | |
| class Enumerator: | |
| name: str = "" | |
| value: Optional[int] = None | |
| comment: Optional[str] = None | |
| @dataclass | |
| class StructElement: | |
| primitive: str = "" | |
| name: str = "" | |
| multiplicity: str = "1" | |
| comment: Optional[str] = None | |
| @dataclass | |
| class Struct: | |
| name: str = "" | |
| element: List[StructElement] = field(default_factory=list) | |
| comment: Optional[str] = None | |
| @dataclass | |
| class Enum: | |
| name: str = "" | |
| enumerator: List[Enumerator] = field(default_factory=list) | |
| comment: Optional[str] = None | |
| # ------------------------------------------------------------------------------ | |
| def node_text(node, code: bytes) -> str: | |
| return code[node.start_byte : node.end_byte].decode("utf-8") | |
| # ------------------------------------------------------------------------------ | |
| def parse_enum(node, code: bytes) -> Enum: | |
| # enum_specifier | |
| enum = Enum() | |
| for child in node.children: | |
| match child.type: | |
| case "enum": | |
| pass | |
| case "type_identifier": | |
| enum.name = node_text(child, code) | |
| case "enumerator_list": | |
| for child2 in child.children: | |
| match child2.type: | |
| case "enumerator": | |
| enumerator = Enumerator() | |
| for child3 in child2.children: | |
| match child3.type: | |
| case "identifier": | |
| enumerator.name = node_text(child3, code) | |
| case "number_literal": | |
| enumerator.value = node_text(child3, code) | |
| case "=": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Enum 3: {child3.type:<25} {child3.text}" | |
| ) | |
| enum.enumerator.append(enumerator) | |
| case "{": | |
| pass | |
| case "}": | |
| pass | |
| case ",": | |
| pass | |
| case "comment": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Enum 2: {child2.type:<25} {child2.text}" | |
| ) | |
| case _: | |
| raise NotImplementedError(f"Enum 1: {child.type:<25} {child.text}") | |
| return enum | |
| def parse_struct(node, code: bytes) -> Struct: | |
| # struct_specifier | |
| struct = Struct() | |
| for child in node.children: | |
| match child.type: | |
| case "struct": | |
| pass | |
| case "type_identifier": | |
| struct.name = node_text(child, code) | |
| case "field_declaration_list": | |
| for child2 in child.children: | |
| match child2.type: | |
| case "field_declaration": | |
| element = StructElement() | |
| for child3 in child2.children: | |
| match child3.type: | |
| case "primitive_type": | |
| element.primitive = node_text(child3, code) | |
| case "pointer_declarator": | |
| element.name = node_text(child3, code).strip( | |
| "*" | |
| ) | |
| case "type_identifier": | |
| element.primitive = ( | |
| node_text(child3, code) + "*" | |
| ) | |
| case "array_declarator": | |
| for child4 in child3.children: | |
| match child4.type: | |
| case "field_identifier": | |
| element.name = node_text( | |
| child4, code | |
| ) | |
| # pass | |
| case "identifier": | |
| element.multiplicity = node_text( | |
| child4, code | |
| ) | |
| case "[": | |
| pass | |
| case "]": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Struct 4: {child4.type:<25} {child4.text}" | |
| ) | |
| case "field_identifier": | |
| element.name = node_text(child3, code) | |
| case ";": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Struct 3: {child3.type:<25} {child3.text}" | |
| ) | |
| struct.element.append(element) | |
| case "{": | |
| pass | |
| case "}": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Struct 2: {child2.type:<25} {child2.text}" | |
| ) | |
| case _: | |
| raise NotImplementedError(f"Struct 1: {child.type:<25} {child.text}") | |
| return struct | |
| def parse_typedef(node, code: bytes) -> Struct | Enum: | |
| # type_definition | |
| c_object = None | |
| for child in node.children: | |
| match child.type: | |
| case "struct_specifier": | |
| c_object = Struct() | |
| case "enum_specifier": | |
| c_object = Enum() | |
| case "primitive_type": | |
| c_object = Primitive() | |
| case _: | |
| pass | |
| if c_object is None: | |
| raise ValueError() | |
| for child in node.children: | |
| match child.type: | |
| case "typedef": | |
| pass | |
| case "struct_specifier": | |
| struct = parse_struct(child, code) | |
| c_object.element = struct.element | |
| c_object.comment = struct.comment | |
| case "type_identifier": | |
| c_object.name = node_text(child, code) | |
| case "enum_specifier": | |
| enum = parse_enum(child, code) | |
| c_object.enumerator = enum.enumerator | |
| c_object.comment = enum.comment | |
| pass | |
| case "primitive_type": | |
| # TODO: Implement | |
| pass | |
| case ";": | |
| pass | |
| case _: | |
| # print(f"{child.type:<25} {child.text}") | |
| raise NotImplementedError(f"Struct 1: {child.type:<25} {child.text}") | |
| pass | |
| return c_object | |
| def parse_function(node, code: bytes) -> Function: | |
| # function_definition | |
| function = Function() | |
| for child in node.children: | |
| match child.type: | |
| case "primitive_type": | |
| function.primitive = node_text(child, code) | |
| case "compound_statement": | |
| pass | |
| case "function_declarator": | |
| for child2 in child.children: | |
| match child2.type: | |
| case "identifier": | |
| function.name = node_text(child2, code) | |
| case "parameter_list": | |
| for child3 in child2.children: | |
| match child3.type: | |
| case "parameter_declaration": | |
| argument = Argument() | |
| for child4 in child3.children: | |
| match child4.type: | |
| case "primitive_type": | |
| argument.primitive = node_text( | |
| child4, code | |
| ) | |
| case "identifier": | |
| argument.name = node_text( | |
| child4, code | |
| ) | |
| case "type_identifier": | |
| argument.primitive = ( | |
| node_text(child4, code) + "*" | |
| ) | |
| case "pointer_declarator": | |
| argument.name = node_text( | |
| child4, code | |
| ).strip("*") | |
| case _: | |
| raise NotImplementedError( | |
| f"Function 4: {child4.type:<25} {child4.text}" | |
| ) | |
| function.args.append(argument) | |
| case "(": | |
| pass | |
| case ")": | |
| pass | |
| case ",": | |
| pass | |
| case _: | |
| raise NotImplementedError( | |
| f"Function 3: {child3.type:<25} {child3.text}" | |
| ) | |
| case _: | |
| raise NotImplementedError( | |
| f"Function 2: {child2.type:<25} {child2.text}" | |
| ) | |
| pass | |
| case _: | |
| raise NotImplementedError(f"Function 1: {child.type:<25} {child.text}") | |
| return function | |
| def parse_node(node, code: bytes) -> None: | |
| match node.type: | |
| case "translation_unit": | |
| pass | |
| case "preproc_include": | |
| pass | |
| case "#include": | |
| pass | |
| case "system_lib_string": | |
| pass | |
| case "preproc_def": | |
| pass | |
| case "#define": | |
| pass | |
| case "identifier": | |
| pass | |
| case "comment": | |
| # print(f"{node.type:<25} {node.text}") | |
| pass | |
| case "enum_specifier": | |
| # TODO: We handle typedef enums | |
| # enum = parse_enum(node, code) | |
| # pprint(enum) | |
| pass | |
| case "struct_specifier": | |
| # TODO: We handle typedef structs | |
| # struct = parse_struct(node, code) | |
| # pprint(struct) | |
| pass | |
| case "type_definition": | |
| struct = parse_typedef(node, code) | |
| pprint(struct) | |
| case "function_definition": | |
| function = parse_function(node, code) | |
| pprint(function) | |
| case "declaration": | |
| # print(f"{node.type:<25} {node.text}") | |
| pass | |
| case _: | |
| # print(f"{node.type:<25} {node.text}") | |
| pass | |
| def walk(node, code: bytes, depth: int = 0) -> None: | |
| parse_node(node, code) | |
| for child in node.children: | |
| walk(child, code, depth + 1) | |
| def parse_c_file(path: str) -> List[Dict[str, Any]]: | |
| with open(path, "rb") as f: | |
| code = f.read() | |
| parser = Parser() | |
| parser.language = Language(tree_sitter_c.language()) | |
| tree = parser.parse(code) | |
| walk(tree.root_node, code) | |
| if __name__ == "__main__": | |
| parse_c_file(sys.argv[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment