svjack's picture
Upload folder using huggingface_hub
13d3ba0
"""
This generates .pyi stubs for the cffi Python bindings generated by regenerate.py
"""
import sys, re, itertools
sys.path.extend(['.', '..']) # for pycparser
from pycparser import c_ast, parse_file, CParser
import pycparser.plyparser
from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef
from typing import Tuple
__c_type_to_python_type = {
'void': 'None', '_Bool': 'bool',
'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',
'ptrdiff_t': 'int', 'size_t': 'int',
'int8_t': 'int', 'uint8_t': 'int',
'int16_t': 'int', 'uint16_t': 'int',
'int32_t': 'int', 'uint32_t': 'int',
'int64_t': 'int', 'uint64_t': 'int',
'float': 'float', 'double': 'float',
'ggml_fp16_t': 'np.float16',
}
def format_type(t: TypeDecl):
if isinstance(t, PtrDecl) or isinstance(t, Struct):
return 'ffi.CData'
if isinstance(t, Enum):
return 'int'
if isinstance(t, TypeDecl):
return format_type(t.type)
if isinstance(t, IdentifierType):
assert len(t.names) == 1, f'Expected a single name, got {t.names}'
return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'
return t.name
class PythonStubFuncDeclVisitor(c_ast.NodeVisitor):
def __init__(self):
self.sigs = {}
self.sources = {}
def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:
if coord.file not in self.sources:
with open(coord.file, 'rt') as f:
self.sources[coord.file] = f.readlines()
source_lines = self.sources[coord.file]
ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1))))
comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]
decl_lines = []
for line in source_lines[coord.line - 1:]:
decl_lines.append(line.rstrip())
if (';' in line) or ('{' in line): break
return (comment_lines, decl_lines)
def visit_Enum(self, node: Enum):
if node.values is not None:
for e in node.values.enumerators:
self.sigs[e.name] = f' @property\n def {e.name}(self) -> int: ...'
def visit_Typedef(self, node: Typedef):
pass
def visit_FuncDecl(self, node: FuncDecl):
ret_type = node.type
is_ptr = False
while isinstance(ret_type, PtrDecl):
ret_type = ret_type.type
is_ptr = True
fun_name = ret_type.declname
if fun_name.startswith('__'):
return
args = []
argnames = []
def gen_name(stem):
i = 1
while True:
new_name = stem if i == 1 else f'{stem}{i}'
if new_name not in argnames: return new_name
i += 1
for a in node.args.params:
if isinstance(a, EllipsisParam):
arg_name = gen_name('args')
argnames.append(arg_name)
args.append('*' + gen_name('args'))
elif format_type(a.type) == 'None':
continue
else:
arg_name = a.name or gen_name('arg')
argnames.append(arg_name)
args.append(f'{arg_name}: {format_type(a.type)}')
ret = format_type(ret_type if not is_ptr else node.type)
comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)
lines = [f' def {fun_name}({", ".join(args)}) -> {ret}:']
if len(comment_lines) == 0 and len(decl_lines) == 1:
lines += [f' """{decl_lines[0]}"""']
else:
lines += [' """']
lines += [f' {c.lstrip("/* ")}' for c in comment_lines]
if len(comment_lines) > 0:
lines += ['']
lines += [f' {d}' for d in decl_lines]
lines += [' """']
lines += [' ...']
self.sigs[fun_name] = '\n'.join(lines)
def generate_stubs(header: str):
"""
Generates a .pyi Python stub file for the GGML API using C header files.
"""
v = PythonStubFuncDeclVisitor()
v.visit(CParser().parse(header, "<input>"))
keys = list(v.sigs.keys())
keys.sort()
return '\n'.join([
'# auto-generated file',
'import ggml.ffi as ffi',
'import numpy as np',
'class lib:',
*[v.sigs[k] for k in keys]
])