Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
import time | |
import torch | |
import gradio as gr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
import spaces | |
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1") | |
# -------- Load model & tokenizer -------- | |
print(f"Loading model: {MODEL_ID}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto", | |
torch_dtype="auto", | |
low_cpu_mem_usage=True, | |
trust_remote_code=True, | |
) | |
model.eval() | |
# Ensure a pad token to avoid warnings on some bases | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
TITLE = "SoftwareArchitecture-Instruct v1 — Chat" | |
DESCRIPTION = ( | |
"An instruction-tuned LLM for **software architecture**. " | |
"Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. " | |
"Designed for technical professionals: accurate, detailed, and on-topic answers." | |
) | |
SAMPLES = [ | |
"Explain the API Gateway pattern and when to use it.", | |
"CQRS vs Event Sourcing — how do they relate, and when would you combine them?", | |
"Design a resilient payment workflow with retries, idempotency keys, and DLQ.", | |
"Rate limiting strategies for a public REST API: token bucket vs sliding window.", | |
"Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.", | |
"Blue/green vs canary deployments — trade-offs and where each fits best.", | |
] | |
def format_history_as_messages(history): | |
""" | |
Convert Gradio chat history into OpenAI-style messages for apply_chat_template. | |
history: list of tuples (user, assistant) | |
""" | |
messages = [] | |
for (u, a) in history: | |
if u: | |
messages.append({"role": "user", "content": u}) | |
if a: | |
messages.append({"role": "assistant", "content": a}) | |
return messages | |
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None): | |
if seed is not None and seed >= 0: | |
torch.manual_seed(seed) | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
tokenize=True, | |
return_dict=True, | |
) | |
# Keep only what the model expects | |
allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs | |
inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed} | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = dict( | |
**inputs, | |
max_new_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
repetition_penalty=float(repetition_penalty), | |
do_sample=temperature > 0, | |
use_cache=True, | |
streamer=streamer, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
partial = "" | |
for chunk in streamer: | |
partial += chunk | |
yield partial | |
# -------- Gradio callbacks -------- | |
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed): | |
if not user_msg or not user_msg.strip(): | |
return gr.update(), chat_history | |
# Add user turn | |
chat_history = chat_history + [(user_msg, None)] | |
# Build messages from full history | |
messages = format_history_as_messages(chat_history) | |
# Stream assistant output | |
stream = stream_generate( | |
messages=messages, | |
max_new_tokens=int(max_new_tokens), | |
temperature=float(temperature), | |
top_p=float(top_p), | |
repetition_penalty=float(repetition_penalty), | |
seed=int(seed) if seed is not None else None, | |
) | |
# Yield progressive updates for the last assistant turn | |
final_assistant_text = "" | |
for chunk in stream: | |
final_assistant_text = chunk | |
yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), "" | |
# Ensure final state returned | |
chat_history[-1] = (user_msg, final_assistant_text) | |
yield gr.update(value=chat_history), "" | |
def use_sample(sample, chat_history): | |
return sample, chat_history | |
def clear_chat(): | |
return [] | |
# -------- UI -------- | |
CUSTOM_CSS = """ | |
:root { | |
--brand: #0ea5e9; /* cyan-500 */ | |
--ink: #0b1220; | |
} | |
.gradio-container { | |
font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji"; | |
} | |
#title h1 { | |
font-weight: 700; | |
letter-spacing: -0.02em; | |
} | |
#desc { | |
opacity: 0.9; | |
} | |
footer {visibility: hidden} | |
""" | |
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML(f"<div id='title'><h1>{TITLE}</h1></div>") | |
gr.Markdown(f"<div id='desc'>{DESCRIPTION}</div>", elem_id="desc") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chat = gr.Chatbot( | |
label="SoftwareArchitecture-Instruct v1", | |
avatar_images=(None, None), | |
height=480, | |
bubble_full_width=False, | |
sanitize_html=False, | |
) | |
with gr.Row(): | |
user_box = gr.Textbox( | |
placeholder="Ask about software architecture…", | |
show_label=False, | |
lines=3, | |
autofocus=True, | |
scale=4, | |
) | |
send_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Accordion("Generation Settings", open=False): | |
max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="Max new tokens") | |
temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
repetition_penalty = gr.Slider(1.0, 1.5, value=1.05, step=0.01, label="Repetition penalty") | |
seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear", variant="secondary") | |
# sample buttons | |
sample_dropdown = gr.Dropdown(choices=SAMPLES, label="Samples", value=None) | |
use_sample_btn = gr.Button("Use Sample") | |
with gr.Column(scale=2): | |
gr.Markdown("### Samples") | |
gr.Markdown("\n".join([f"• {s}" for s in SAMPLES])) | |
gr.Markdown("—\n**Tip:** Increase *Max new tokens* for longer, more complete answers.") | |
# Events | |
send_btn.click( | |
chat_respond, | |
inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed], | |
outputs=[chat, user_box], | |
queue=True, | |
show_progress=True, | |
) | |
user_box.submit( | |
chat_respond, | |
inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed], | |
outputs=[chat, user_box], | |
queue=True, | |
show_progress=True, | |
) | |
clear_btn.click(fn=clear_chat, outputs=chat) | |
use_sample_btn.click(use_sample, inputs=[sample_dropdown, chat], outputs=[user_box, chat]) | |
gr.Markdown( | |
"—\nBuilt for engineers and architects. Base model: **LiquidAI/LFM2-1.2B** · Fine-tuned: **Software-Architecture** dataset." | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() | |