Spaces:
Running
Running
File size: 5,536 Bytes
fb881f7 19e5757 fb881f7 19e5757 0aca33c e1724e4 0aca33c fb881f7 0aca33c 2e20bf5 e1724e4 2e20bf5 6d4c2f8 2e20bf5 174da7d e1724e4 2e20bf5 e1724e4 2e20bf5 e1724e4 2e20bf5 0aca33c 2e20bf5 e1724e4 0aca33c 2e20bf5 e1724e4 0aca33c 2e20bf5 e1724e4 2e20bf5 e1724e4 2e20bf5 e1724e4 0aca33c 2e20bf5 0aca33c e1724e4 bf63770 0aca33c bf63770 e1724e4 2e20bf5 0aca33c 2e20bf5 fb881f7 0aca33c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
"""
Fun little experiment.
"""
import gradio as gr
import torch
import concurrent.futures
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def min_p_sampling(logits, pbase=0.1):
"""
Perform min-p sampling on the logits.
Args:
logits (torch.Tensor): 1D tensor of logits for the next token.
pbase (float): Base probability to scale pmax.
Returns:
int: The sampled token index.
"""
# Convert logits to probabilities.
probs = torch.softmax(logits, dim=-1)
# 1. Find maximum probability.
pmax = probs.max()
# 2. Compute the dynamic threshold.
pscaled = pbase * pmax
# 3. Create a mask of tokens with probability >= pscaled.
mask = probs >= pscaled
# In the unlikely event that no token meets the threshold, use the full distribution.
if mask.sum() == 0:
mask = torch.ones_like(probs, dtype=torch.bool)
# Zero out probabilities not meeting the threshold.
probs_filtered = probs * mask.float()
# 4. Normalize and sample.
probs_normalized = probs_filtered / probs_filtered.sum()
sampled_index = torch.multinomial(probs_normalized, num_samples=1)
return sampled_index.item()
def generate_completion(prompt, strategy, params):
"""
Generate a complete answer using model.generate with specified parameters.
"""
# Encode the prompt and get the attention mask.
tokenizer.pad_token = tokenizer.eos_token
encoded = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
# Generate the output.
output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, **params)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
"""
Generate a complete answer using a token-by-token loop with min-p sampling.
"""
# Encode the prompt.
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate up to max_length tokens.
for _ in range(max_length - input_ids.size(1)):
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] # Get logits for the last token.
next_token = min_p_sampling(logits, pbase=pbase)
# Append the new token.
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
# Stop if the end-of-sequence token is generated.
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
def generate_all(prompt):
"""
Run multiple decoding strategies concurrently and yield updates as each completes.
"""
# Define each decoding strategy and its parameters.
# For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
methods = {
"Greedy": {"type": "default", "params": {"do_sample": False}},
"Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
"Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
"Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
"Min-p Sampling": {"type": "min_p", "pbase": 0.1},
}
# Define the order for display.
method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
results = {method: None for method in methods}
# Yield an initial placeholder state.
yield tuple("Processing..." for _ in method_order)
# Use a thread pool to run each generation concurrently.
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_method = {}
for method, info in methods.items():
if info["type"] == "default":
future = executor.submit(generate_completion, prompt, method, info["params"])
elif info["type"] == "min_p":
future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
future_to_method[future] = method
# As each future completes, update its result and yield the current state.
for future in concurrent.futures.as_completed(future_to_method):
method = future_to_method[future]
try:
result = future.result()
except Exception as exc:
result = f"Error: {exc}"
results[method] = result
# Yield the results in the pre-defined order; pending methods show "Processing..."
yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)
# Create the Gradio interface.
interface = gr.Interface(
fn=generate_all,
inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
outputs=[
gr.Textbox(label="Greedy"),
gr.Textbox(label="Top-k Sampling"),
gr.Textbox(label="Top-p Sampling"),
gr.Textbox(label="Beam Search"),
gr.Textbox(label="Min-p Sampling"),
],
title="Decoding Methods Comparison",
description="Each decoding method's final answer is printed as soon as it is done. This uses GPT2."
)
if __name__ == "__main__":
interface.launch() |