Spaces:
Running
Running
""" | |
Fun little experiment. | |
""" | |
import gradio as gr | |
import torch | |
import concurrent.futures | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_name = "gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
def min_p_sampling(logits, pbase=0.1): | |
""" | |
Perform min-p sampling on the logits. | |
Args: | |
logits (torch.Tensor): 1D tensor of logits for the next token. | |
pbase (float): Base probability to scale pmax. | |
Returns: | |
int: The sampled token index. | |
""" | |
# Convert logits to probabilities. | |
probs = torch.softmax(logits, dim=-1) | |
# 1. Find maximum probability. | |
pmax = probs.max() | |
# 2. Compute the dynamic threshold. | |
pscaled = pbase * pmax | |
# 3. Create a mask of tokens with probability >= pscaled. | |
mask = probs >= pscaled | |
# In the unlikely event that no token meets the threshold, use the full distribution. | |
if mask.sum() == 0: | |
mask = torch.ones_like(probs, dtype=torch.bool) | |
# Zero out probabilities not meeting the threshold. | |
probs_filtered = probs * mask.float() | |
# 4. Normalize and sample. | |
probs_normalized = probs_filtered / probs_filtered.sum() | |
sampled_index = torch.multinomial(probs_normalized, num_samples=1) | |
return sampled_index.item() | |
def generate_completion(prompt, strategy, params): | |
""" | |
Generate a complete answer using model.generate with specified parameters. | |
""" | |
# Encode the prompt and get the attention mask. | |
tokenizer.pad_token = tokenizer.eos_token | |
encoded = tokenizer(prompt, return_tensors="pt", padding=True) | |
input_ids = encoded["input_ids"] | |
attention_mask = encoded["attention_mask"] | |
# Generate the output. | |
output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, **params) | |
return tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
def generate_min_p_completion(prompt, pbase=0.1, max_length=50): | |
""" | |
Generate a complete answer using a token-by-token loop with min-p sampling. | |
""" | |
# Encode the prompt. | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate up to max_length tokens. | |
for _ in range(max_length - input_ids.size(1)): | |
outputs = model(input_ids) | |
logits = outputs.logits[:, -1, :] # Get logits for the last token. | |
next_token = min_p_sampling(logits, pbase=pbase) | |
# Append the new token. | |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1) | |
# Stop if the end-of-sequence token is generated. | |
if next_token == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
def generate_all(prompt): | |
""" | |
Run multiple decoding strategies concurrently and yield updates as each completes. | |
""" | |
# Define each decoding strategy and its parameters. | |
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function. | |
methods = { | |
"Greedy": {"type": "default", "params": {"do_sample": False}}, | |
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}}, | |
"Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}}, | |
"Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}}, | |
"Min-p Sampling": {"type": "min_p", "pbase": 0.1}, | |
} | |
# Define the order for display. | |
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"] | |
results = {method: None for method in methods} | |
# Yield an initial placeholder state. | |
yield tuple("Processing..." for _ in method_order) | |
# Use a thread pool to run each generation concurrently. | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_to_method = {} | |
for method, info in methods.items(): | |
if info["type"] == "default": | |
future = executor.submit(generate_completion, prompt, method, info["params"]) | |
elif info["type"] == "min_p": | |
future = executor.submit(generate_min_p_completion, prompt, info["pbase"]) | |
future_to_method[future] = method | |
# As each future completes, update its result and yield the current state. | |
for future in concurrent.futures.as_completed(future_to_method): | |
method = future_to_method[future] | |
try: | |
result = future.result() | |
except Exception as exc: | |
result = f"Error: {exc}" | |
results[method] = result | |
# Yield the results in the pre-defined order; pending methods show "Processing..." | |
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order) | |
# Create the Gradio interface. | |
interface = gr.Interface( | |
fn=generate_all, | |
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"), | |
outputs=[ | |
gr.Textbox(label="Greedy"), | |
gr.Textbox(label="Top-k Sampling"), | |
gr.Textbox(label="Top-p Sampling"), | |
gr.Textbox(label="Beam Search"), | |
gr.Textbox(label="Min-p Sampling"), | |
], | |
title="Decoding Methods Comparison", | |
description="Each decoding method's final answer is printed as soon as it is done. This uses GPT2." | |
) | |
if __name__ == "__main__": | |
interface.launch() |