kgourgou commited on
Commit
e1724e4
·
verified ·
1 Parent(s): bf63770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -21
app.py CHANGED
@@ -1,43 +1,66 @@
1
  import gradio as gr
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- # Load your model (using GPT-2 as an example)
5
  model_name = "gpt2"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def generate_completions(prompt):
10
- # Define decoding strategies with corresponding parameters
11
- strategies = {
12
- "Greedy": {"do_sample": False},
13
- "Beam Search": {"num_beams": 5, "early_stopping": True},
14
- "Top-k Sampling": {"do_sample": True, "top_k": 50},
15
- "Top-p Sampling": {"do_sample": True, "top_p": 0.95}
 
 
 
 
 
 
 
16
  }
 
 
17
 
18
- results = {}
19
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
20
 
21
- for strategy, params in strategies.items():
22
- # Generate output using the specific strategy
23
- output_ids = model.generate(input_ids, max_length=50, **params)
24
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
25
- results[strategy] = output_text
26
 
27
- return results["Greedy"], results["Beam Search"], results["Top-k Sampling"], results["Top-p Sampling"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Define the Gradio interface using the updated API
30
  interface = gr.Interface(
31
- fn=generate_completions,
32
  inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
33
  outputs=[
34
  gr.Textbox(label="Greedy"),
35
- gr.Textbox(label="Beam Search"),
36
  gr.Textbox(label="Top-k Sampling"),
37
  gr.Textbox(label="Top-p Sampling"),
 
38
  ],
39
- title="LLM Decoding Strategies Comparison",
40
- description="Enter a prompt to see how different decoding strategies affect the output of a language model."
41
  )
42
 
43
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ import concurrent.futures
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Load your model and tokenizer (using GPT-2 as an example)
7
  model_name = "gpt2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
+ def generate_completion(prompt, strategy, params):
12
+ """Generate a complete answer using the specified decoding strategy."""
13
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
14
+ # Adjust generation parameters as needed.
15
+ output_ids = model.generate(input_ids, max_length=50, **params)
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
17
+
18
+ def generate_all(prompt):
19
+ # Define decoding strategies and their corresponding parameters.
20
+ methods = {
21
+ "Greedy": {"params": {"do_sample": False}},
22
+ "Top-k Sampling": {"params": {"do_sample": True, "top_k": 50}},
23
+ "Top-p Sampling": {"params": {"do_sample": True, "top_p": 0.95}},
24
+ "Beam Search": {"params": {"num_beams": 5, "early_stopping": True}},
25
  }
26
+ # This list defines the order in which results are displayed.
27
+ method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search"]
28
 
29
+ # Dictionary to store the final answer for each method (initially None)
30
+ results = {method: None for method in methods}
31
 
32
+ # Yield an initial state so the UI shows placeholders.
33
+ yield tuple("Processing..." for _ in method_order)
 
 
 
34
 
35
+ # Use ThreadPoolExecutor to run each generation concurrently.
36
+ with concurrent.futures.ThreadPoolExecutor() as executor:
37
+ future_to_method = {
38
+ executor.submit(generate_completion, prompt, method, methods[method]["params"]): method
39
+ for method in methods
40
+ }
41
+ # As soon as a method finishes, update its result and yield the current state.
42
+ for future in concurrent.futures.as_completed(future_to_method):
43
+ method = future_to_method[future]
44
+ try:
45
+ result = future.result()
46
+ except Exception as exc:
47
+ result = f"Error: {exc}"
48
+ results[method] = result
49
+ # Yield the results in the specified order; methods still processing show "Processing..."
50
+ yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
51
 
52
+ # Create a Gradio interface that uses the generator function.
53
  interface = gr.Interface(
54
+ fn=generate_all,
55
  inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
56
  outputs=[
57
  gr.Textbox(label="Greedy"),
 
58
  gr.Textbox(label="Top-k Sampling"),
59
  gr.Textbox(label="Top-p Sampling"),
60
+ gr.Textbox(label="Beam Search"),
61
  ],
62
+ title="Decoding Methods Results",
63
+ description="Each decoding method's complete answer is printed as soon as it's done."
64
  )
65
 
66
  if __name__ == "__main__":