Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save eliaskanelis/7a15196f28ecc0a57d712062381a9595 to your computer and use it in GitHub Desktop.

Select an option

Save eliaskanelis/7a15196f28ecc0a57d712062381a9595 to your computer and use it in GitHub Desktop.
tree-sitter
from __future__ import annotations
from tree_sitter import Language, Parser
import tree_sitter_c
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
import sys
from rich.pretty import pprint
# ------------------------------------------------------------------------------
@dataclass
class Primitive:
primitive: str = ""
name: str = ""
comment: Optional[str] = None
@dataclass
class Argument:
primitive: str = ""
name: str = ""
direction: str = "" # "in" | "out" | "inout" | ""
comment: Optional[str] = None
@dataclass
class Function:
primitive: str = ""
name: str = ""
args: List[Argument] = field(default_factory=list)
comment: Optional[str] = None
@dataclass
class Enumerator:
name: str = ""
value: Optional[int] = None
comment: Optional[str] = None
@dataclass
class StructElement:
primitive: str = ""
name: str = ""
multiplicity: str = "1"
comment: Optional[str] = None
@dataclass
class Struct:
name: str = ""
element: List[StructElement] = field(default_factory=list)
comment: Optional[str] = None
@dataclass
class Enum:
name: str = ""
enumerator: List[Enumerator] = field(default_factory=list)
comment: Optional[str] = None
# ------------------------------------------------------------------------------
def node_text(node, code: bytes) -> str:
return code[node.start_byte : node.end_byte].decode("utf-8")
# ------------------------------------------------------------------------------
def parse_enum(node, code: bytes) -> Enum:
# enum_specifier
enum = Enum()
for child in node.children:
match child.type:
case "enum":
pass
case "type_identifier":
enum.name = node_text(child, code)
case "enumerator_list":
for child2 in child.children:
match child2.type:
case "enumerator":
enumerator = Enumerator()
for child3 in child2.children:
match child3.type:
case "identifier":
enumerator.name = node_text(child3, code)
case "number_literal":
enumerator.value = node_text(child3, code)
case "=":
pass
case _:
raise NotImplementedError(
f"Enum 3: {child3.type:<25} {child3.text}"
)
enum.enumerator.append(enumerator)
case "{":
pass
case "}":
pass
case ",":
pass
case "comment":
pass
case _:
raise NotImplementedError(
f"Enum 2: {child2.type:<25} {child2.text}"
)
case _:
raise NotImplementedError(f"Enum 1: {child.type:<25} {child.text}")
return enum
def parse_struct(node, code: bytes) -> Struct:
# struct_specifier
struct = Struct()
for child in node.children:
match child.type:
case "struct":
pass
case "type_identifier":
struct.name = node_text(child, code)
case "field_declaration_list":
for child2 in child.children:
match child2.type:
case "field_declaration":
element = StructElement()
for child3 in child2.children:
match child3.type:
case "primitive_type":
element.primitive = node_text(child3, code)
case "pointer_declarator":
element.name = node_text(child3, code).strip(
"*"
)
case "type_identifier":
element.primitive = (
node_text(child3, code) + "*"
)
case "array_declarator":
for child4 in child3.children:
match child4.type:
case "field_identifier":
element.name = node_text(
child4, code
)
# pass
case "identifier":
element.multiplicity = node_text(
child4, code
)
case "[":
pass
case "]":
pass
case _:
raise NotImplementedError(
f"Struct 4: {child4.type:<25} {child4.text}"
)
case "field_identifier":
element.name = node_text(child3, code)
case ";":
pass
case _:
raise NotImplementedError(
f"Struct 3: {child3.type:<25} {child3.text}"
)
struct.element.append(element)
case "{":
pass
case "}":
pass
case _:
raise NotImplementedError(
f"Struct 2: {child2.type:<25} {child2.text}"
)
case _:
raise NotImplementedError(f"Struct 1: {child.type:<25} {child.text}")
return struct
def parse_typedef(node, code: bytes) -> Struct | Enum:
# type_definition
c_object = None
for child in node.children:
match child.type:
case "struct_specifier":
c_object = Struct()
case "enum_specifier":
c_object = Enum()
case "primitive_type":
c_object = Primitive()
case _:
pass
if c_object is None:
raise ValueError()
for child in node.children:
match child.type:
case "typedef":
pass
case "struct_specifier":
struct = parse_struct(child, code)
c_object.element = struct.element
c_object.comment = struct.comment
case "type_identifier":
c_object.name = node_text(child, code)
case "enum_specifier":
enum = parse_enum(child, code)
c_object.enumerator = enum.enumerator
c_object.comment = enum.comment
pass
case "primitive_type":
# TODO: Implement
pass
case ";":
pass
case _:
# print(f"{child.type:<25} {child.text}")
raise NotImplementedError(f"Struct 1: {child.type:<25} {child.text}")
pass
return c_object
def parse_function(node, code: bytes) -> Function:
# function_definition
function = Function()
for child in node.children:
match child.type:
case "primitive_type":
function.primitive = node_text(child, code)
case "compound_statement":
pass
case "function_declarator":
for child2 in child.children:
match child2.type:
case "identifier":
function.name = node_text(child2, code)
case "parameter_list":
for child3 in child2.children:
match child3.type:
case "parameter_declaration":
argument = Argument()
for child4 in child3.children:
match child4.type:
case "primitive_type":
argument.primitive = node_text(
child4, code
)
case "identifier":
argument.name = node_text(
child4, code
)
case "type_identifier":
argument.primitive = (
node_text(child4, code) + "*"
)
case "pointer_declarator":
argument.name = node_text(
child4, code
).strip("*")
case _:
raise NotImplementedError(
f"Function 4: {child4.type:<25} {child4.text}"
)
function.args.append(argument)
case "(":
pass
case ")":
pass
case ",":
pass
case _:
raise NotImplementedError(
f"Function 3: {child3.type:<25} {child3.text}"
)
case _:
raise NotImplementedError(
f"Function 2: {child2.type:<25} {child2.text}"
)
pass
case _:
raise NotImplementedError(f"Function 1: {child.type:<25} {child.text}")
return function
def parse_node(node, code: bytes) -> None:
match node.type:
case "translation_unit":
pass
case "preproc_include":
pass
case "#include":
pass
case "system_lib_string":
pass
case "preproc_def":
pass
case "#define":
pass
case "identifier":
pass
case "comment":
# print(f"{node.type:<25} {node.text}")
pass
case "enum_specifier":
# TODO: We handle typedef enums
# enum = parse_enum(node, code)
# pprint(enum)
pass
case "struct_specifier":
# TODO: We handle typedef structs
# struct = parse_struct(node, code)
# pprint(struct)
pass
case "type_definition":
struct = parse_typedef(node, code)
pprint(struct)
case "function_definition":
function = parse_function(node, code)
pprint(function)
case "declaration":
# print(f"{node.type:<25} {node.text}")
pass
case _:
# print(f"{node.type:<25} {node.text}")
pass
def walk(node, code: bytes, depth: int = 0) -> None:
parse_node(node, code)
for child in node.children:
walk(child, code, depth + 1)
def parse_c_file(path: str) -> List[Dict[str, Any]]:
with open(path, "rb") as f:
code = f.read()
parser = Parser()
parser.language = Language(tree_sitter_c.language())
tree = parser.parse(code)
walk(tree.root_node, code)
if __name__ == "__main__":
parse_c_file(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment