Olubakka / src /app.py
Sachi Wagaarachchi
debug: updated the pipeline
0b7ba67
raw
history blame
3.85 kB
import gradio as gr
from src.models import ModelManager
from src.chat_logic import ChatProcessor
from src.vector_db import VectorDBHandler
import logging
import spaces
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize components
model_manager = ModelManager()
vector_db = VectorDBHandler()
chat_processor = ChatProcessor(model_manager, vector_db)
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
model_name: str,
system_message: str = "You are a Qwen3 assistant.",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
use_direct_pipeline: bool = False
):
"""
Process chat using the ChatProcessor with streaming support.
Args:
message: The user message
history: Chat history as list of (user, assistant) message pairs
model_name: Name of the model to use
system_message: System prompt to guide the model's behavior
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Penalty for token repetition
use_direct_pipeline: Whether to use the direct pipeline method
Yields:
Generated response tokens for streaming UI
"""
try:
if use_direct_pipeline:
# Use the direct pipeline method (non-streaming)
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": True
}
response = chat_processor.generate_with_pipeline(
message=message,
history=history,
model_name=model_name,
generation_config=generation_config,
system_prompt=system_message
)
yield response
else:
# Use the streaming method
response_generator = chat_processor.process_chat(
message=message,
history=history,
model_name=model_name,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
system_prompt=system_message
)
# Stream response tokens
response = ""
for token in response_generator:
response += token
yield response
except Exception as e:
logger.error(f"Chat response error: {str(e)}")
yield f"Error: {str(e)}"
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Dropdown(
choices=["Qwen3-14B", "Qwen3-8B"],
value="Qwen3-8B",
label="Model Selection"
),
gr.Textbox(value="You are a Qwen3 assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
],
)
if __name__ == "__main__":
demo.launch()