llm-decoders / app.py
kgourgou's picture
Update app.py
19e5757 verified
raw
history blame
5.54 kB
"""
Fun little experiment.
"""
import gradio as gr
import torch
import concurrent.futures
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def min_p_sampling(logits, pbase=0.1):
"""
Perform min-p sampling on the logits.
Args:
logits (torch.Tensor): 1D tensor of logits for the next token.
pbase (float): Base probability to scale pmax.
Returns:
int: The sampled token index.
"""
# Convert logits to probabilities.
probs = torch.softmax(logits, dim=-1)
# 1. Find maximum probability.
pmax = probs.max()
# 2. Compute the dynamic threshold.
pscaled = pbase * pmax
# 3. Create a mask of tokens with probability >= pscaled.
mask = probs >= pscaled
# In the unlikely event that no token meets the threshold, use the full distribution.
if mask.sum() == 0:
mask = torch.ones_like(probs, dtype=torch.bool)
# Zero out probabilities not meeting the threshold.
probs_filtered = probs * mask.float()
# 4. Normalize and sample.
probs_normalized = probs_filtered / probs_filtered.sum()
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
return sampled_index.item()
def generate_completion(prompt, strategy, params):
"""
Generate a complete answer using model.generate with specified parameters.
"""
# Encode the prompt and get the attention mask.
tokenizer.pad_token = tokenizer.eos_token
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Generate the output.
output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, **params)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
"""
Generate a complete answer using a token-by-token loop with min-p sampling.
"""
# Encode the prompt.
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate up to max_length tokens.
for _ in range(max_length - input_ids.size(1)):
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] # Get logits for the last token.
next_token = min_p_sampling(logits, pbase=pbase)
# Append the new token.
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
# Stop if the end-of-sequence token is generated.
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def generate_all(prompt):
"""
Run multiple decoding strategies concurrently and yield updates as each completes.
"""
# Define each decoding strategy and its parameters.
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
methods = {
"Greedy": {"type": "default", "params": {"do_sample": False}},
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
"Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
"Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
}
# Define the order for display.
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
results = {method: None for method in methods}
# Yield an initial placeholder state.
yield tuple("Processing..." for _ in method_order)
# Use a thread pool to run each generation concurrently.
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_method = {}
for method, info in methods.items():
if info["type"] == "default":
future = executor.submit(generate_completion, prompt, method, info["params"])
elif info["type"] == "min_p":
future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
future_to_method[future] = method
# As each future completes, update its result and yield the current state.
for future in concurrent.futures.as_completed(future_to_method):
method = future_to_method[future]
try:
result = future.result()
except Exception as exc:
result = f"Error: {exc}"
results[method] = result
# Yield the results in the pre-defined order; pending methods show "Processing..."
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
# Create the Gradio interface.
interface = gr.Interface(
fn=generate_all,
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
outputs=[
gr.Textbox(label="Greedy"),
gr.Textbox(label="Top-k Sampling"),
gr.Textbox(label="Top-p Sampling"),
gr.Textbox(label="Beam Search"),
gr.Textbox(label="Min-p Sampling"),
],
title="Decoding Methods Comparison",
description="Each decoding method's final answer is printed as soon as it is done. This uses GPT2."
)
if __name__ == "__main__":
interface.launch()