File size: 5,536 Bytes
fb881f7
19e5757
fb881f7
 
19e5757
0aca33c
e1724e4
 
0aca33c
 
fb881f7
0aca33c
 
 
2e20bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1724e4
2e20bf5
 
 
 
6d4c2f8
2e20bf5
 
 
 
 
174da7d
e1724e4
 
2e20bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1724e4
2e20bf5
 
 
 
 
e1724e4
2e20bf5
 
 
 
 
0aca33c
 
2e20bf5
 
e1724e4
0aca33c
2e20bf5
e1724e4
0aca33c
2e20bf5
e1724e4
2e20bf5
 
 
 
 
 
 
 
 
e1724e4
 
 
 
 
 
 
2e20bf5
e1724e4
0aca33c
2e20bf5
0aca33c
e1724e4
bf63770
0aca33c
bf63770
 
 
e1724e4
2e20bf5
0aca33c
2e20bf5
fb881f7
0aca33c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Fun little experiment.
"""


import gradio as gr
import torch
import concurrent.futures
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def min_p_sampling(logits, pbase=0.1):
    """
    Perform min-p sampling on the logits.
    
    Args:
        logits (torch.Tensor): 1D tensor of logits for the next token.
        pbase (float): Base probability to scale pmax.
    
    Returns:
        int: The sampled token index.
    """
    # Convert logits to probabilities.
    probs = torch.softmax(logits, dim=-1)
    
    # 1. Find maximum probability.
    pmax = probs.max()
    
    # 2. Compute the dynamic threshold.
    pscaled = pbase * pmax
    
    # 3. Create a mask of tokens with probability >= pscaled.
    mask = probs >= pscaled
    # In the unlikely event that no token meets the threshold, use the full distribution.
    if mask.sum() == 0:
        mask = torch.ones_like(probs, dtype=torch.bool)
    
    # Zero out probabilities not meeting the threshold.
    probs_filtered = probs * mask.float()
    
    # 4. Normalize and sample.
    probs_normalized = probs_filtered / probs_filtered.sum()
    sampled_index = torch.multinomial(probs_normalized, num_samples=1)
    
    return sampled_index.item()

def generate_completion(prompt, strategy, params):
    """
    Generate a complete answer using model.generate with specified parameters.
    """
    # Encode the prompt and get the attention mask.
    tokenizer.pad_token = tokenizer.eos_token
    encoded = tokenizer(prompt, return_tensors="pt", padding=True)
    input_ids = encoded["input_ids"]
    attention_mask = encoded["attention_mask"]
    
    # Generate the output.
    output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, **params)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

def generate_min_p_completion(prompt, pbase=0.1, max_length=50):
    """
    Generate a complete answer using a token-by-token loop with min-p sampling.
    """
    # Encode the prompt.
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # Generate up to max_length tokens.
    for _ in range(max_length - input_ids.size(1)):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]  # Get logits for the last token.
        next_token = min_p_sampling(logits, pbase=pbase)
        
        # Append the new token.
        input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1)
        
        # Stop if the end-of-sequence token is generated.
        if next_token == tokenizer.eos_token_id:
            break
            
    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

def generate_all(prompt):
    """
    Run multiple decoding strategies concurrently and yield updates as each completes.
    """
    # Define each decoding strategy and its parameters.
    # For the default strategies, we use model.generate; for "Min‑p Sampling" we use our custom function.
    methods = {
        "Greedy": {"type": "default", "params": {"do_sample": False}},
        "Top-k Sampling": {"type": "default", "params": {"do_sample": True, "top_k": 50}},
        "Top-p Sampling": {"type": "default", "params": {"do_sample": True, "top_p": 0.95}},
        "Beam Search": {"type": "default", "params": {"num_beams": 5, "early_stopping": True}},
        "Min-p Sampling": {"type": "min_p", "pbase": 0.1},
    }
    
    # Define the order for display.
    method_order = ["Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling"]
    results = {method: None for method in methods}
    
    # Yield an initial placeholder state.
    yield tuple("Processing..." for _ in method_order)
    
    # Use a thread pool to run each generation concurrently.
    with concurrent.futures.ThreadPoolExecutor() as executor:
        future_to_method = {}
        for method, info in methods.items():
            if info["type"] == "default":
                future = executor.submit(generate_completion, prompt, method, info["params"])
            elif info["type"] == "min_p":
                future = executor.submit(generate_min_p_completion, prompt, info["pbase"])
            future_to_method[future] = method
        
        # As each future completes, update its result and yield the current state.
        for future in concurrent.futures.as_completed(future_to_method):
            method = future_to_method[future]
            try:
                result = future.result()
            except Exception as exc:
                result = f"Error: {exc}"
            results[method] = result
            # Yield the results in the pre-defined order; pending methods show "Processing..."
            yield tuple(results[m] if results[m] is not None else "Processing..." for m in method_order)

# Create the Gradio interface.
interface = gr.Interface(
    fn=generate_all,
    inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"),
    outputs=[
        gr.Textbox(label="Greedy"),
        gr.Textbox(label="Top-k Sampling"),
        gr.Textbox(label="Top-p Sampling"),
        gr.Textbox(label="Beam Search"),
        gr.Textbox(label="Min-p Sampling"),
    ],
    title="Decoding Methods Comparison",
    description="Each decoding method's final answer is printed as soon as it is done. This uses GPT2."
)

if __name__ == "__main__":
    interface.launch()