Luigi's picture
reduce rep penality to 1.2
079e166
import os
import time
import gc
import threading
from itertools import islice
from datetime import datetime
import re # for parsing <think> blocks
import gradio as gr
import torch
from transformers import pipeline, TextIteratorStreamer
from transformers import AutoTokenizer
from duckduckgo_search import DDGS
import spaces # Import spaces early to enable ZeroGPU support
# Optional: Disable GPU visibility if you wish to force CPU usage
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# ------------------------------
# Global Cancellation Event
# ------------------------------
cancel_event = threading.Event()
# ------------------------------
# Torch-Compatible Model Definitions with Adjusted Descriptions
# ------------------------------
MODELS = {
"Taiwan-ELM-1_1B-Instruct": {"repo_id": "liswei/Taiwan-ELM-1_1B-Instruct", "description": "Taiwan-ELM-1_1B-Instruct"},
"Taiwan-ELM-270M-Instruct": {"repo_id": "liswei/Taiwan-ELM-270M-Instruct", "description": "Taiwan-ELM-270M-Instruct"},
# "Granite-4.0-Tiny-Preview": {"repo_id": "ibm-granite/granite-4.0-tiny-preview", "description": "Granite-4.0-Tiny-Preview"},
"Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."},
"Qwen3-1.7B": {"repo_id":"Qwen/Qwen3-1.7B","description":"Dense causal language model with 1.7 B total parameters (1.4 B non-embedding), 28 layers, 16 query heads & 8 KV heads, 32 768-token context, stronger reasoning vs. 0.6 B variant, dual-mode inference, instruction following across 100+ languages."},
"Qwen3-4B": {"repo_id":"Qwen/Qwen3-4B","description":"Dense causal language model with 4.0 B total parameters (3.6 B non-embedding), 36 layers, 32 query heads & 8 KV heads, native 32 768-token context (extendable to 131 072 via YaRN), balanced mid-range capacity & long-context reasoning."},
"Qwen3-8B": {"repo_id":"Qwen/Qwen3-8B","description":"Dense causal language model with 8.2 B total parameters (6.95 B non-embedding), 36 layers, 32 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), excels at multilingual instruction following & zero-shot tasks."},
"Qwen3-14B": {"repo_id":"Qwen/Qwen3-14B","description":"Dense causal language model with 14.8 B total parameters (13.2 B non-embedding), 40 layers, 40 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), enhanced human preference alignment & advanced agent integration."},
# "Qwen3-32B": {"repo_id":"Qwen/Qwen3-32B","description":"Dense causal language model with 32.8 B total parameters (31.2 B non-embedding), 64 layers, 64 query heads & 8 KV heads, 32 768-token context (131 072 via YaRN), flagship variant delivering state-of-the-art reasoning & instruction following."},
# "Qwen3-30B-A3B": {"repo_id":"Qwen/Qwen3-30B-A3B","description":"Mixture-of-Experts model with 30.5 B total parameters (29.9 B non-embedding, 3.3 B activated per token), 48 layers, 128 experts (8 activated per token), 32 query heads & 4 KV heads, 32 768-token context (131 072 via YaRN), MoE routing for scalable specialized reasoning."},
# "Qwen3-235B-A22B":{"repo_id":"Qwen/Qwen3-235B-A22B","description":"Mixture-of-Experts model with 235 B total parameters (234 B non-embedding, 22 B activated per token), 94 layers, 128 experts (8 activated per token), 64 query heads & 4 KV heads, 32 768-token context (131 072 via YaRN), ultra-scale reasoning & agentic workflows."},
"Gemma-3-4B-IT": {"repo_id": "unsloth/gemma-3-4b-it", "description": "Gemma-3-4B-IT"},
"SmolLM2_135M_Grpo_Gsm8k":{"repo_id":"prithivMLmods/SmolLM2_135M_Grpo_Gsm8k", "desscription":"SmolLM2_135M_Grpo_Gsm8k"},
"SmolLM2-135M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-135M-Instruct-TaiwanChat", "description": "SmolLM2‑135M Instruct fine-tuned on TaiwanChat"},
"SmolLM2-135M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-135M-Instruct", "description": "Original SmolLM2‑135M Instruct"},
"SmolLM2-360M-Instruct-TaiwanChat": {"repo_id": "Luigi/SmolLM2-360M-Instruct-TaiwanChat", "description": "SmolLM2‑360M Instruct fine-tuned on TaiwanChat"},
"SmolLM2-360M-Instruct": {"repo_id": "HuggingFaceTB/SmolLM2-360M-Instruct", "description": "Original SmolLM2‑360M Instruct"},
"Llama-3.2-Taiwan-3B-Instruct": {"repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", "description": "Llama-3.2-Taiwan-3B-Instruct"},
"MiniCPM3-4B": {"repo_id": "openbmb/MiniCPM3-4B", "description": "MiniCPM3-4B"},
"Qwen2.5-3B-Instruct": {"repo_id": "Qwen/Qwen2.5-3B-Instruct", "description": "Qwen2.5-3B-Instruct"},
"Qwen2.5-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-7B-Instruct", "description": "Qwen2.5-7B-Instruct"},
"Phi-4-mini-Reasoning": {"repo_id": "microsoft/Phi-4-mini-reasoning", "description": "Phi-4-mini-Reasoning"},
# "Phi-4-Reasoning": {"repo_id": "microsoft/Phi-4-reasoning", "description": "Phi-4-Reasoning"},
"Phi-4-mini-Instruct": {"repo_id": "microsoft/Phi-4-mini-instruct", "description": "Phi-4-mini-Instruct"},
"Meta-Llama-3.1-8B-Instruct": {"repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", "description": "Meta-Llama-3.1-8B-Instruct"},
"DeepSeek-R1-Distill-Llama-8B": {"repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", "description": "DeepSeek-R1-Distill-Llama-8B"},
"Mistral-7B-Instruct-v0.3": {"repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", "description": "Mistral-7B-Instruct-v0.3"},
"Qwen2.5-Coder-7B-Instruct": {"repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", "description": "Qwen2.5-Coder-7B-Instruct"},
"Qwen2.5-Omni-3B": {"repo_id": "Qwen/Qwen2.5-Omni-3B", "description": "Qwen2.5-Omni-3B"},
"MiMo-7B-RL": {"repo_id": "XiaomiMiMo/MiMo-7B-RL", "description": "MiMo-7B-RL"},
}
# Global cache for pipelines to avoid re-loading.
PIPELINES = {}
def load_pipeline(model_name):
"""
Load and cache a transformers pipeline for text generation.
Tries bfloat16, falls back to float16 or float32 if unsupported.
"""
global PIPELINES
if model_name in PIPELINES:
return PIPELINES[model_name]
repo = MODELS[model_name]["repo_id"]
tokenizer = AutoTokenizer.from_pretrained(repo)
for dtype in (torch.bfloat16, torch.float16, torch.float32):
try:
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
except Exception:
continue
# Final fallback
pipe = pipeline(
task="text-generation",
model=repo,
tokenizer=tokenizer,
trust_remote_code=True,
device_map="auto"
)
PIPELINES[model_name] = pipe
return pipe
def retrieve_context(query, max_results=6, max_chars=600):
"""
Retrieve search snippets from DuckDuckGo (runs in background).
Returns a list of result strings.
"""
try:
with DDGS() as ddgs:
return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}"
for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))]
except Exception:
return []
def format_conversation(history, system_prompt, tokenizer):
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
messages = [{"role": "system", "content": system_prompt.strip()}] + history
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
else:
# Fallback for base LMs without chat template
prompt = system_prompt.strip() + "\n"
for msg in history:
if msg['role'] == 'user':
prompt += "User: " + msg['content'].strip() + "\n"
elif msg['role'] == 'assistant':
prompt += "Assistant: " + msg['content'].strip() + "\n"
if not prompt.strip().endswith("Assistant:"):
prompt += "Assistant: "
return prompt
@spaces.GPU(duration=60)
def chat_response(user_msg, chat_history, system_prompt,
enable_search, max_results, max_chars,
model_name, max_tokens, temperature,
top_k, top_p, repeat_penalty, search_timeout):
"""
Generates streaming chat responses, optionally with background web search.
"""
cancel_event.clear()
history = list(chat_history or [])
history.append({'role': 'user', 'content': user_msg})
# Launch web search if enabled
debug = ''
search_results = []
if enable_search:
debug = 'Search task started.'
thread_search = threading.Thread(
target=lambda: search_results.extend(
retrieve_context(user_msg, int(max_results), int(max_chars))
)
)
thread_search.daemon = True
thread_search.start()
else:
debug = 'Web search disabled.'
try:
# merge any fetched search results into the system prompt
if search_results:
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
else:
enriched = system_prompt
# wait up to 1s for snippets, then replace debug with them
if enable_search:
thread_search.join(timeout=float(search_timeout))
if search_results:
debug = "### Search results merged into prompt\n\n" + "\n".join(
f"- {r}" for r in search_results
)
else:
debug = "*No web search results found.*"
# merge fetched snippets into the system prompt
if search_results:
enriched = system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results)
else:
enriched = system_prompt
pipe = load_pipeline(model_name)
prompt = format_conversation(history, enriched, pipe.tokenizer)
prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
streamer = TextIteratorStreamer(pipe.tokenizer,
skip_prompt=True,
skip_special_tokens=True)
gen_thread = threading.Thread(
target=pipe,
args=(prompt,),
kwargs={
'max_new_tokens': max_tokens,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'repetition_penalty': repeat_penalty,
'streamer': streamer,
'return_full_text': False,
}
)
gen_thread.start()
# Buffers for thought vs answer
thought_buf = ''
answer_buf = ''
in_thought = False
# Stream tokens
for chunk in streamer:
if cancel_event.is_set():
break
text = chunk
# Detect start of thinking
if not in_thought and '<think>' in text:
in_thought = True
# Insert thought placeholder
history.append({
'role': 'assistant',
'content': '',
'metadata': {'title': '💭 Thought'}
})
# Capture after opening tag
after = text.split('<think>', 1)[1]
thought_buf += after
# If closing tag in same chunk
if '</think>' in thought_buf:
before, after2 = thought_buf.split('</think>', 1)
history[-1]['content'] = before.strip()
in_thought = False
# Start answer buffer
answer_buf = after2
history.append({'role': 'assistant', 'content': answer_buf})
else:
history[-1]['content'] = thought_buf
yield history, debug
continue
# Continue thought streaming
if in_thought:
thought_buf += text
if '</think>' in thought_buf:
before, after2 = thought_buf.split('</think>', 1)
history[-1]['content'] = before.strip()
in_thought = False
# Start answer buffer
answer_buf = after2
history.append({'role': 'assistant', 'content': answer_buf})
else:
history[-1]['content'] = thought_buf
yield history, debug
continue
# Stream answer
if not answer_buf:
history.append({'role': 'assistant', 'content': ''})
answer_buf += text
history[-1]['content'] = answer_buf
yield history, debug
gen_thread.join()
yield history, debug + prompt_debug
except Exception as e:
history.append({'role': 'assistant', 'content': f"Error: {e}"})
yield history, debug
finally:
gc.collect()
def cancel_generation():
cancel_event.set()
return 'Generation cancelled.'
def update_default_prompt(enable_search):
today = datetime.now().strftime('%Y-%m-%d')
return f"You are a helpful assistant. Today is {today}."
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search")
gr.Markdown("Interact with the model. Select parameters and chat below.")
with gr.Row():
with gr.Column(scale=3):
model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
search_chk = gr.Checkbox(label="Enable Web Search", value=True)
sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
gr.Markdown("### Generation Parameters")
max_tok = gr.Slider(64, 16384, value=2048, step=32, label="Max Tokens")
temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
gr.Markdown("### Web Search Settings")
mr = gr.Number(value=6, precision=0, label="Max Results")
mc = gr.Number(value=600, precision=0, label="Max Chars/Result")
st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)")
clr = gr.Button("Clear Chat")
cnl = gr.Button("Cancel Generation")
with gr.Column(scale=7):
chat = gr.Chatbot(type="messages")
txt = gr.Textbox(placeholder="Type your message and press Enter...")
dbg = gr.Markdown()
search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt)
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
cnl.click(fn=cancel_generation, outputs=dbg)
txt.submit(fn=chat_response,
inputs=[txt, chat, sys_prompt, search_chk, mr, mc,
model_dd, max_tok, temp, k, p, rp, st],
outputs=[chat, dbg])
demo.launch()