import gradio as gr from src.models import ModelManager from src.chat_logic import ChatProcessor from src.vector_db import VectorDBHandler import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize components model_manager = ModelManager() vector_db = VectorDBHandler() chat_processor = ChatProcessor(model_manager, vector_db) try: import spaces @spaces.GPU(duration=60) def run_respond(*args, **kwargs): for token in respond(*args, **kwargs): yield token except ImportError: def run_respond(*args, **kwargs): for token in respond(*args, **kwargs): yield token def respond( message, history: list[tuple[str, str]], model_name: str, system_message: str = "You are a Qwen3 assistant.", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, use_direct_pipeline: bool = False ): """ Process chat using the ChatProcessor with streaming support. Args: message: The user message history: Chat history as list of (user, assistant) message pairs model_name: Name of the model to use system_message: System prompt to guide the model's behavior max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter top_k: Top-k sampling parameter repetition_penalty: Penalty for token repetition use_direct_pipeline: Whether to use the direct pipeline method Yields: Generated response tokens for streaming UI """ print(f"Running respond with use_direct_pipeline: {use_direct_pipeline}") try: if use_direct_pipeline: # Use the direct pipeline method (non-streaming) generation_config = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True } response = chat_processor.generate_with_pipeline( message=message, history=history, model_name=model_name, generation_config=generation_config, system_prompt=system_message ) yield response else: # Use the streaming method response_generator = chat_processor.process_chat( message=message, history=history, model_name=model_name, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, system_prompt=system_message ) # Stream response tokens response = "" for history, dbg in response_generator: response = history[-1]['content'] yield response # Yield the accumulated response for streaming UI except Exception as e: logger.error(f"Chat response error: {str(e)}") yield f"Error: {str(e)}" # Create Gradio interface demo = gr.ChatInterface( run_respond, additional_inputs=[ gr.Dropdown( choices=["Qwen3-14B", "Qwen3-8B", "Qwen3-0.6B"], value="Qwen3-0.6B", label="Model Selection" ), gr.Textbox(value="You are a Qwen3 assistant.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"), gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"), gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"), gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)") ] ) if __name__ == "__main__": demo.launch()