Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,43 +1,66 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
|
4 |
-
# Load your model (using GPT-2 as an example)
|
5 |
model_name = "gpt2"
|
6 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
8 |
|
9 |
-
def
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
}
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
output_ids = model.generate(input_ids, max_length=50, **params)
|
24 |
-
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
25 |
-
results[strategy] = output_text
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
#
|
30 |
interface = gr.Interface(
|
31 |
-
fn=
|
32 |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
|
33 |
outputs=[
|
34 |
gr.Textbox(label="Greedy"),
|
35 |
-
gr.Textbox(label="Beam Search"),
|
36 |
gr.Textbox(label="Top-k Sampling"),
|
37 |
gr.Textbox(label="Top-p Sampling"),
|
|
|
38 |
],
|
39 |
-
title="
|
40 |
-
description="
|
41 |
)
|
42 |
|
43 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import concurrent.futures
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
|
6 |
+
# Load your model and tokenizer (using GPT-2 as an example)
|
7 |
model_name = "gpt2"
|
8 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
10 |
|
11 |
+
def generate_completion(prompt, strategy, params):
|
12 |
+
"""Generate a complete answer using the specified decoding strategy."""
|
13 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
14 |
+
# Adjust generation parameters as needed.
|
15 |
+
output_ids = model.generate(input_ids, max_length=50, **params)
|
16 |
+
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
17 |
+
|
18 |
+
def generate_all(prompt):
|
19 |
+
# Define decoding strategies and their corresponding parameters.
|
20 |
+
methods = {
|
21 |
+
"Greedy": {"params": {"do_sample": False}},
|
22 |
+
"Top-k Sampling": {"params": {"do_sample": True, "top_k": 50}},
|
23 |
+
"Top-p Sampling": {"params": {"do_sample": True, "top_p": 0.95}},
|
24 |
+
"Beam Search": {"params": {"num_beams": 5, "early_stopping": True}},
|
25 |
}
|
26 |
+
# This list defines the order in which results are displayed.
|
27 |
+
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search"]
|
28 |
|
29 |
+
# Dictionary to store the final answer for each method (initially None)
|
30 |
+
results = {method: None for method in methods}
|
31 |
|
32 |
+
# Yield an initial state so the UI shows placeholders.
|
33 |
+
yield tuple("Processing..." for _ in method_order)
|
|
|
|
|
|
|
34 |
|
35 |
+
# Use ThreadPoolExecutor to run each generation concurrently.
|
36 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
37 |
+
future_to_method = {
|
38 |
+
executor.submit(generate_completion, prompt, method, methods[method]["params"]): method
|
39 |
+
for method in methods
|
40 |
+
}
|
41 |
+
# As soon as a method finishes, update its result and yield the current state.
|
42 |
+
for future in concurrent.futures.as_completed(future_to_method):
|
43 |
+
method = future_to_method[future]
|
44 |
+
try:
|
45 |
+
result = future.result()
|
46 |
+
except Exception as exc:
|
47 |
+
result = f"Error: {exc}"
|
48 |
+
results[method] = result
|
49 |
+
# Yield the results in the specified order; methods still processing show "Processing..."
|
50 |
+
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
|
51 |
|
52 |
+
# Create a Gradio interface that uses the generator function.
|
53 |
interface = gr.Interface(
|
54 |
+
fn=generate_all,
|
55 |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
|
56 |
outputs=[
|
57 |
gr.Textbox(label="Greedy"),
|
|
|
58 |
gr.Textbox(label="Top-k Sampling"),
|
59 |
gr.Textbox(label="Top-p Sampling"),
|
60 |
+
gr.Textbox(label="Beam Search"),
|
61 |
],
|
62 |
+
title="Decoding Methods Results",
|
63 |
+
description="Each decoding method's complete answer is printed as soon as it's done."
|
64 |
)
|
65 |
|
66 |
if __name__ == "__main__":
|