File size: 6,472 Bytes
01c3073
9377434
28b893e
7e17e4e
bbc1fe3
3035027
01c3073
59324e3
d3eb07d
d11795a
d9be3a6
d11795a
 
31d9ee3
d11795a
 
59324e3
d3eb07d
b1dd808
 
3f47af7
 
d5e44b4
 
 
 
b1dd808
 
d3eb07d
1311a82
09910fb
33e5ee4
 
3035027
93fab23
d3fb139
e2beaba
3393188
f01d69f
 
3035027
4c9f7ae
 
 
 
 
 
d3eb07d
d4ae7ba
b1dd808
0d90edb
 
 
 
01c3073
2d64e29
 
 
 
28b893e
d4ae7ba
28b893e
 
 
 
 
 
 
2d64e29
 
03ad973
2d64e29
 
 
347742a
 
2d64e29
2456748
11f2ade
2ab342b
2d64e29
01c3073
6ab20b3
3f47af7
5f12c87
 
 
 
1311a82
 
 
 
 
 
fb5ce4e
1311a82
 
 
6ab20b3
 
 
2ab342b
3f47af7
0adad70
7a59fcd
0adad70
2456748
bbc1fe3
4c9f7ae
d5e44b4
0d90edb
0b155f0
2ab342b
d5e44b4
4954b56
 
 
 
 
 
9faefe0
 
 
0b155f0
2ab342b
0b155f0
 
 
9377434
11fc228
 
2ab342b
2b63412
0a4c2da
2b63412
4c9f7ae
d5e44b4
2b63412
 
3035027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b63412
91e7dfd
 
bbc1fe3
4954b56
 
 
 
d868a6b
1311a82
 
 
2ab342b
1311a82
91e7dfd
4954b56
9bd433d
 
 
55167c7
 
 
 
11fc228
 
5672f53
11fc228
0d90edb
4954b56
01c3073
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
from gradio_client import Client
from gradio_client.exceptions import AppError
import frontmatter
import os
import spaces
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging as transformers_logging

# Set up comprehensive logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

import huggingface_hub

import prep_decompiled

# Model configuration constants
MAX_CONTEXT_LENGTH = 8192
MAX_NEW_TOKENS = 1024

hf_key = os.environ["HF_TOKEN"]
huggingface_hub.login(token=hf_key)

tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoderbase-3b")
vardecoder_model = AutoModelForCausalLM.from_pretrained(
    "ejschwartz/resym-vardecoder", 
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_safetensors=False
)
print("Loaded vardecoder model successfully.")

logger.info("Loading fielddecoder model...")

fielddecoder_model = None
fielddecoder_model = AutoModelForCausalLM.from_pretrained(
   "ejschwartz/resym-fielddecoder", 
   torch_dtype=torch.bfloat16,
   use_safetensors=False
)
logger.info("Successfully loaded fielddecoder model")

make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")

examples = [
    ex.encode().decode("unicode_escape") for ex in open("examples.txt", "r").readlines()
]


# Example prompt
#   "input": "```\n_BOOL8 __fastcall sub_409B9A(_QWORD *a1, _QWORD *a2)\n{\nreturn *a1 < *a2 || *a1 == *a2 && a1[1] < a2[1];\n}\n```\nWhat are the variable name and type for the following memory accesses:a1, a1[1], a2, a2[1]?\n",
#  "output": "a1: a, os_reltime* -> sec, os_time_t\na1[1]: a, os_reltime* -> usec, os_time_t\na2: b, os_reltime* -> sec, os_time_t\na2[1]: b, os_reltime* -> usec, os_time_t",
def field_prompt(code):
    try:
        field_helper_result = make_gradio_client().predict(
            decompiled_code=code,
            api_name="/predict",
        )
    except AppError as e:
        print(f"AppError: {e}")
        return None, [], None
    
    print(f"field helper result: {field_helper_result}")

    fields = sorted(list(set([e['expr'] for e in field_helper_result[0] if e['expr'] != ''])))
    print(f"fields: {fields}")

    prompt = f"```\n{code}\n```\nWhat are the variable name and type for the following memory accesses:{', '.join(fields)}?\n"
    if len(fields) > 0:
        prompt += f"{fields[0]}:"

    print(f"field prompt: {repr(prompt)}")

    return prompt, fields, field_helper_result

@spaces.GPU
def infer(code):

    splitcode = code.splitlines()
    #splitcode = [s.strip() for s in code.splitlines()]
    #code = "\n".join(splitcode)
    
    bodyvars = [
        v["name"] for v in prep_decompiled.extract_comments(splitcode) if "name" in v
    ]
    argvars = [
        v["name"] for v in prep_decompiled.parse_signature(splitcode) if "name" in v
    ]
    vars = argvars + bodyvars
    # comments = prep_decompiled.extract_comments(splitcode)
    # sig = prep_decompiled.parse_signature(splitcode)
    # print(f"vars {vars}")

    varstring = ", ".join([f"`{v}`" for v in vars])

    first_var = vars[0]

    # ejs: Yeah, this var_name thing is really bizarre. But look at https://github.com/lt-asset/resym/blob/main/training_src/fielddecoder_inf.py
    var_prompt = f"What are the original name and data types of variables {varstring}?\n```\n{code}\n```{first_var}:"

    print(f"Prompt:\n{repr(var_prompt)}")

    var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").to(vardecoder_model.device)[
        :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
    ]
    var_output = vardecoder_model.generate(
        input_ids=var_input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        num_beams=4,
        num_return_sequences=1,
        do_sample=False,
        early_stopping=False,
        pad_token_id=0,
        eos_token_id=0,
    )
    print(f"Pre Var output: {var_output}")
    var_output = var_output[0]
    var_output = tokenizer.decode(
        var_output[var_input_ids.size(1) :],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    print(f"Var output: {repr(var_output)}")

    field_prompt_result, fields, field_helper_result = field_prompt(code)
    if len(fields) == 0:
        field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
    else:
        field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").to(fielddecoder_model.device)[
            :, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
        ]

        if fielddecoder_model is None:
            field_output = "TEMPORARILY DISABLED"
        else:
            field_output = fielddecoder_model.generate(
                input_ids=field_input_ids,
                max_new_tokens=MAX_NEW_TOKENS,
                num_beams=4,
                num_return_sequences=1,
                do_sample=False,
                early_stopping=False,
                pad_token_id=0,
                eos_token_id=0,
            )[0]
            field_output = tokenizer.decode(
                field_output[field_input_ids.size(1) :],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

            field_output = fields[0] + ":" + field_output
    var_output = first_var + ":" + var_output
    fieldstring = ", ".join(fields)
    return var_output, field_output, varstring, fieldstring


demo = gr.Interface(
    fn=infer,
    inputs=[
        gr.Textbox(lines=10, value=examples[0], label="Hex-Rays Decompilation"),
    ],
    outputs=[
        gr.Text(label="Var Decoder Output"),
        gr.Text(label="Field Decoder Output"),
        gr.Text(label="Generated Variable List"),
        gr.Text(label="Generated Field Access List"),
    ],
    # description=frontmatter.load("README.md").content,
    description="""This is a test space of the models from the [ReSym
artifacts](https://github.com/lt-asset/resym).  For more information, please see
[the
README](https://huggingface.co/spaces/ejschwartz/resym/blob/main/README.md).  If
you get an error, please make sure the [ReSym field helper
space](https://huggingface.co/spaces/ejschwartz/resym-field-helper) is
running.

The field decoder model is currently **not working** due to a HuggingFace accelerate library problem. I am investigating the issue.
""",
    examples=examples,
)
demo.launch()