|
""" |
|
This generates .pyi stubs for the cffi Python bindings generated by regenerate.py |
|
""" |
|
import sys, re, itertools |
|
sys.path.extend(['.', '..']) |
|
|
|
from pycparser import c_ast, parse_file, CParser |
|
import pycparser.plyparser |
|
from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef |
|
from typing import Tuple |
|
|
|
__c_type_to_python_type = { |
|
'void': 'None', '_Bool': 'bool', |
|
'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int', |
|
'ptrdiff_t': 'int', 'size_t': 'int', |
|
'int8_t': 'int', 'uint8_t': 'int', |
|
'int16_t': 'int', 'uint16_t': 'int', |
|
'int32_t': 'int', 'uint32_t': 'int', |
|
'int64_t': 'int', 'uint64_t': 'int', |
|
'float': 'float', 'double': 'float', |
|
'ggml_fp16_t': 'np.float16', |
|
} |
|
|
|
def format_type(t: TypeDecl): |
|
if isinstance(t, PtrDecl) or isinstance(t, Struct): |
|
return 'ffi.CData' |
|
if isinstance(t, Enum): |
|
return 'int' |
|
if isinstance(t, TypeDecl): |
|
return format_type(t.type) |
|
if isinstance(t, IdentifierType): |
|
assert len(t.names) == 1, f'Expected a single name, got {t.names}' |
|
return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData' |
|
return t.name |
|
|
|
class PythonStubFuncDeclVisitor(c_ast.NodeVisitor): |
|
def __init__(self): |
|
self.sigs = {} |
|
self.sources = {} |
|
|
|
def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]: |
|
if coord.file not in self.sources: |
|
with open(coord.file, 'rt') as f: |
|
self.sources[coord.file] = f.readlines() |
|
source_lines = self.sources[coord.file] |
|
ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1)))) |
|
comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]] |
|
decl_lines = [] |
|
for line in source_lines[coord.line - 1:]: |
|
decl_lines.append(line.rstrip()) |
|
if (';' in line) or ('{' in line): break |
|
return (comment_lines, decl_lines) |
|
|
|
def visit_Enum(self, node: Enum): |
|
if node.values is not None: |
|
for e in node.values.enumerators: |
|
self.sigs[e.name] = f' @property\n def {e.name}(self) -> int: ...' |
|
|
|
def visit_Typedef(self, node: Typedef): |
|
pass |
|
|
|
def visit_FuncDecl(self, node: FuncDecl): |
|
ret_type = node.type |
|
is_ptr = False |
|
while isinstance(ret_type, PtrDecl): |
|
ret_type = ret_type.type |
|
is_ptr = True |
|
|
|
fun_name = ret_type.declname |
|
if fun_name.startswith('__'): |
|
return |
|
|
|
args = [] |
|
argnames = [] |
|
def gen_name(stem): |
|
i = 1 |
|
while True: |
|
new_name = stem if i == 1 else f'{stem}{i}' |
|
if new_name not in argnames: return new_name |
|
i += 1 |
|
|
|
for a in node.args.params: |
|
if isinstance(a, EllipsisParam): |
|
arg_name = gen_name('args') |
|
argnames.append(arg_name) |
|
args.append('*' + gen_name('args')) |
|
elif format_type(a.type) == 'None': |
|
continue |
|
else: |
|
arg_name = a.name or gen_name('arg') |
|
argnames.append(arg_name) |
|
args.append(f'{arg_name}: {format_type(a.type)}') |
|
|
|
ret = format_type(ret_type if not is_ptr else node.type) |
|
|
|
comment_lines, decl_lines = self.get_source_snippet_lines(node.coord) |
|
|
|
lines = [f' def {fun_name}({", ".join(args)}) -> {ret}:'] |
|
if len(comment_lines) == 0 and len(decl_lines) == 1: |
|
lines += [f' """{decl_lines[0]}"""'] |
|
else: |
|
lines += [' """'] |
|
lines += [f' {c.lstrip("/* ")}' for c in comment_lines] |
|
if len(comment_lines) > 0: |
|
lines += [''] |
|
lines += [f' {d}' for d in decl_lines] |
|
lines += [' """'] |
|
lines += [' ...'] |
|
self.sigs[fun_name] = '\n'.join(lines) |
|
|
|
def generate_stubs(header: str): |
|
""" |
|
Generates a .pyi Python stub file for the GGML API using C header files. |
|
""" |
|
|
|
v = PythonStubFuncDeclVisitor() |
|
v.visit(CParser().parse(header, "<input>")) |
|
|
|
keys = list(v.sigs.keys()) |
|
keys.sort() |
|
|
|
return '\n'.join([ |
|
'# auto-generated file', |
|
'import ggml.ffi as ffi', |
|
'import numpy as np', |
|
'class lib:', |
|
*[v.sigs[k] for k in keys] |
|
]) |
|
|