File size: 4,516 Bytes
13d3ba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
  This generates .pyi stubs for the cffi Python bindings generated by regenerate.py
"""
import sys, re, itertools
sys.path.extend(['.', '..']) # for pycparser

from pycparser import c_ast, parse_file, CParser
import pycparser.plyparser
from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef
from typing import Tuple

__c_type_to_python_type = {
    'void': 'None', '_Bool': 'bool',
    'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',
    'ptrdiff_t': 'int', 'size_t': 'int',
    'int8_t': 'int', 'uint8_t': 'int',
    'int16_t': 'int', 'uint16_t': 'int',
    'int32_t': 'int', 'uint32_t': 'int',
    'int64_t': 'int', 'uint64_t': 'int',
    'float': 'float', 'double': 'float',
    'ggml_fp16_t': 'np.float16',
}

def format_type(t: TypeDecl):
    if isinstance(t, PtrDecl) or isinstance(t, Struct):
        return 'ffi.CData'
    if isinstance(t, Enum):
        return 'int'
    if isinstance(t, TypeDecl):
        return format_type(t.type)
    if isinstance(t, IdentifierType):
        assert len(t.names) == 1, f'Expected a single name, got {t.names}'
        return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'
    return t.name

class PythonStubFuncDeclVisitor(c_ast.NodeVisitor):
    def __init__(self):
        self.sigs = {}
        self.sources = {}

    def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:
        if coord.file not in self.sources:
            with open(coord.file, 'rt') as f:
                self.sources[coord.file] = f.readlines()
        source_lines = self.sources[coord.file]
        ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1))))
        comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]
        decl_lines = []
        for line in source_lines[coord.line - 1:]:
            decl_lines.append(line.rstrip())
            if (';' in line) or ('{' in line): break
        return (comment_lines, decl_lines)

    def visit_Enum(self, node: Enum):
        if node.values is not None:
          for e in node.values.enumerators:
              self.sigs[e.name] = f'  @property\n  def {e.name}(self) -> int: ...'

    def visit_Typedef(self, node: Typedef):
        pass

    def visit_FuncDecl(self, node: FuncDecl):
        ret_type = node.type
        is_ptr = False
        while isinstance(ret_type, PtrDecl):
            ret_type = ret_type.type
            is_ptr = True

        fun_name = ret_type.declname
        if fun_name.startswith('__'):
            return

        args = []
        argnames = []
        def gen_name(stem):
            i = 1
            while True:
                new_name = stem if i == 1 else f'{stem}{i}'
                if new_name not in argnames: return new_name
                i += 1

        for a in node.args.params:
            if isinstance(a, EllipsisParam):
                arg_name = gen_name('args')
                argnames.append(arg_name)
                args.append('*' + gen_name('args'))
            elif format_type(a.type) == 'None':
                continue
            else:
                arg_name = a.name or gen_name('arg')
                argnames.append(arg_name)
                args.append(f'{arg_name}: {format_type(a.type)}')

        ret = format_type(ret_type if not is_ptr else node.type)

        comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)

        lines = [f'  def {fun_name}({", ".join(args)}) -> {ret}:']
        if len(comment_lines) == 0 and len(decl_lines) == 1:
            lines += [f'    """{decl_lines[0]}"""']
        else:
            lines += ['    """']
            lines += [f'    {c.lstrip("/* ")}' for c in comment_lines]
            if len(comment_lines) > 0:
                lines += ['']
            lines += [f'    {d}' for d in decl_lines]
            lines += ['    """']
        lines += ['    ...']
        self.sigs[fun_name] = '\n'.join(lines)

def generate_stubs(header: str):
    """
      Generates a .pyi Python stub file for the GGML API using C header files.
    """

    v = PythonStubFuncDeclVisitor()
    v.visit(CParser().parse(header, "<input>"))

    keys = list(v.sigs.keys())
    keys.sort()

    return '\n'.join([
        '# auto-generated file',
        'import ggml.ffi as ffi',
        'import numpy as np',
        'class lib:',
        *[v.sigs[k] for k in keys]
    ])