Olubakka / src /app.py
Sachi Wagaarachchi
bugfixes, update response streaming, thought
b264ac6
import gradio as gr
from src.models import ModelManager
from src.chat_logic import ChatProcessor
from src.vector_db import VectorDBHandler
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize components
model_manager = ModelManager()
vector_db = VectorDBHandler()
chat_processor = ChatProcessor(model_manager, vector_db)
try:
import spaces
@spaces.GPU(duration=60)
def run_respond(*args, **kwargs):
for token in respond(*args, **kwargs):
yield token
except ImportError:
def run_respond(*args, **kwargs):
for token in respond(*args, **kwargs):
yield token
def respond(
message,
history: list[tuple[str, str]],
model_name: str,
system_message: str = "You are a Qwen3 assistant.",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
use_direct_pipeline: bool = False
):
"""
Process chat using the ChatProcessor with streaming support.
Args:
message: The user message
history: Chat history as list of (user, assistant) message pairs
model_name: Name of the model to use
system_message: System prompt to guide the model's behavior
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Penalty for token repetition
use_direct_pipeline: Whether to use the direct pipeline method
Yields:
Generated response tokens for streaming UI
"""
print(f"Running respond with use_direct_pipeline: {use_direct_pipeline}")
try:
if use_direct_pipeline:
# Use the direct pipeline method (non-streaming)
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": True
}
response = chat_processor.generate_with_pipeline(
message=message,
history=history,
model_name=model_name,
generation_config=generation_config,
system_prompt=system_message
)
yield response
else:
# Use the streaming method
response_generator = chat_processor.process_chat(
message=message,
history=history,
model_name=model_name,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
system_prompt=system_message
)
# Stream response tokens
response = ""
for history, dbg in response_generator:
response = history[-1]['content']
yield response # Yield the accumulated response for streaming UI
except Exception as e:
logger.error(f"Chat response error: {str(e)}")
yield f"Error: {str(e)}"
# Create Gradio interface
demo = gr.ChatInterface(
run_respond,
additional_inputs=[
gr.Dropdown(
choices=["Qwen3-14B", "Qwen3-8B", "Qwen3-0.6B"],
value="Qwen3-0.6B",
label="Model Selection"
),
gr.Textbox(value="You are a Qwen3 assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
]
)
if __name__ == "__main__":
demo.launch()