Update app.py
Browse files
app.py
CHANGED
|
@@ -65,6 +65,11 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
| 65 |
return [progress_dummy_output, hidden_states, *token_btns]
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 69 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
| 70 |
num_beams=1):
|
|
@@ -89,7 +94,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
| 89 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 90 |
|
| 91 |
# generate the interpretations
|
| 92 |
-
generate =
|
| 93 |
generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 94 |
generation_texts = tokenizer.batch_decode(generated)
|
| 95 |
progress_dummy_output = ''
|
|
|
|
| 65 |
return [progress_dummy_output, hidden_states, *token_btns]
|
| 66 |
|
| 67 |
|
| 68 |
+
@spaces.GPU
|
| 69 |
+
def generate_interpretation_gpu(interpret_prompt, **kwargs):
|
| 70 |
+
return interpret_prompt.generate(**kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
| 74 |
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
| 75 |
num_beams=1):
|
|
|
|
| 94 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
| 95 |
|
| 96 |
# generate the interpretations
|
| 97 |
+
generate = generate_interpretation_gpu if use_gpu else lambda lambda interpretation_prompt, **kwargs: interpretation_prompt.generate(**kwargs)
|
| 98 |
generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
| 99 |
generation_texts = tokenizer.batch_decode(generated)
|
| 100 |
progress_dummy_output = ''
|