yasserrmd's picture
Update app.py
efb082b verified
raw
history blame
7.53 kB
import os
import threading
import time
import torch
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import spaces
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
# -------- Load model & tokenizer --------
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model.eval()
# Ensure a pad token to avoid warnings on some bases
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
TITLE = "SoftwareArchitecture-Instruct v1 — Chat"
DESCRIPTION = (
"An instruction-tuned LLM for **software architecture**. "
"Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. "
"Designed for technical professionals: accurate, detailed, and on-topic answers."
)
SAMPLES = [
"Explain the API Gateway pattern and when to use it.",
"CQRS vs Event Sourcing — how do they relate, and when would you combine them?",
"Design a resilient payment workflow with retries, idempotency keys, and DLQ.",
"Rate limiting strategies for a public REST API: token bucket vs sliding window.",
"Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.",
"Blue/green vs canary deployments — trade-offs and where each fits best.",
]
def format_history_as_messages(history):
"""
Convert Gradio chat history into OpenAI-style messages for apply_chat_template.
history: list of tuples (user, assistant)
"""
messages = []
for (u, a) in history:
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
return messages
@spaces.GPU
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
if seed is not None and seed >= 0:
torch.manual_seed(seed)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
return_dict=True,
)
# Keep only what the model expects
allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs
inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
do_sample=temperature > 0,
use_cache=True,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for chunk in streamer:
partial += chunk
yield partial
# -------- Gradio callbacks --------
@spaces.GPU
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
if not user_msg or not user_msg.strip():
return gr.update(), chat_history
# Add user turn
chat_history = chat_history + [(user_msg, None)]
# Build messages from full history
messages = format_history_as_messages(chat_history)
# Stream assistant output
stream = stream_generate(
messages=messages,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
repetition_penalty=float(repetition_penalty),
seed=int(seed) if seed is not None else None,
)
# Yield progressive updates for the last assistant turn
final_assistant_text = ""
for chunk in stream:
final_assistant_text = chunk
yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), ""
# Ensure final state returned
chat_history[-1] = (user_msg, final_assistant_text)
yield gr.update(value=chat_history), ""
def use_sample(sample, chat_history):
return sample, chat_history
def clear_chat():
return []
# -------- UI --------
CUSTOM_CSS = """
:root {
--brand: #0ea5e9; /* cyan-500 */
--ink: #0b1220;
}
.gradio-container {
font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji";
}
#title h1 {
font-weight: 700;
letter-spacing: -0.02em;
}
#desc {
opacity: 0.9;
}
footer {visibility: hidden}
"""
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo:
with gr.Row():
with gr.Column():
gr.HTML(f"<div id='title'><h1>{TITLE}</h1></div>")
gr.Markdown(f"<div id='desc'>{DESCRIPTION}</div>", elem_id="desc")
with gr.Row():
with gr.Column(scale=4):
chat = gr.Chatbot(
label="SoftwareArchitecture-Instruct v1",
avatar_images=(None, None),
height=480,
bubble_full_width=False,
sanitize_html=False,
)
with gr.Row():
user_box = gr.Textbox(
placeholder="Ask about software architecture…",
show_label=False,
lines=3,
autofocus=True,
scale=4,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Accordion("Generation Settings", open=False):
max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="Max new tokens")
temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
repetition_penalty = gr.Slider(1.0, 1.5, value=1.05, step=0.01, label="Repetition penalty")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)")
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
# sample buttons
sample_dropdown = gr.Dropdown(choices=SAMPLES, label="Samples", value=None)
use_sample_btn = gr.Button("Use Sample")
with gr.Column(scale=2):
gr.Markdown("### Samples")
gr.Markdown("\n".join([f"• {s}" for s in SAMPLES]))
gr.Markdown("—\n**Tip:** Increase *Max new tokens* for longer, more complete answers.")
# Events
send_btn.click(
chat_respond,
inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
outputs=[chat, user_box],
queue=True,
show_progress=True,
)
user_box.submit(
chat_respond,
inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
outputs=[chat, user_box],
queue=True,
show_progress=True,
)
clear_btn.click(fn=clear_chat, outputs=chat)
use_sample_btn.click(use_sample, inputs=[sample_dropdown, chat], outputs=[user_box, chat])
gr.Markdown(
"—\nBuilt for engineers and architects. Base model: **LiquidAI/LFM2-1.2B** · Fine-tuned: **Software-Architecture** dataset."
)
if __name__ == "__main__":
demo.queue().launch()