|
import gradio as gr |
|
from src.models import ModelManager |
|
from src.chat_logic import ChatProcessor |
|
from src.vector_db import VectorDBHandler |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_manager = ModelManager() |
|
vector_db = VectorDBHandler() |
|
chat_processor = ChatProcessor(model_manager, vector_db) |
|
|
|
|
|
try: |
|
import spaces |
|
@spaces.GPU(duration=60) |
|
def run_respond(*args, **kwargs): |
|
for token in respond(*args, **kwargs): |
|
yield token |
|
except ImportError: |
|
def run_respond(*args, **kwargs): |
|
for token in respond(*args, **kwargs): |
|
yield token |
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
model_name: str, |
|
system_message: str = "You are a Qwen3 assistant.", |
|
max_new_tokens: int = 512, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.2, |
|
use_direct_pipeline: bool = False |
|
): |
|
""" |
|
Process chat using the ChatProcessor with streaming support. |
|
|
|
Args: |
|
message: The user message |
|
history: Chat history as list of (user, assistant) message pairs |
|
model_name: Name of the model to use |
|
system_message: System prompt to guide the model's behavior |
|
max_new_tokens: Maximum number of tokens to generate |
|
temperature: Sampling temperature |
|
top_p: Nucleus sampling parameter |
|
top_k: Top-k sampling parameter |
|
repetition_penalty: Penalty for token repetition |
|
use_direct_pipeline: Whether to use the direct pipeline method |
|
|
|
Yields: |
|
Generated response tokens for streaming UI |
|
""" |
|
|
|
print(f"Running respond with use_direct_pipeline: {use_direct_pipeline}") |
|
try: |
|
if use_direct_pipeline: |
|
|
|
generation_config = { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"repetition_penalty": repetition_penalty, |
|
"do_sample": True |
|
} |
|
|
|
response = chat_processor.generate_with_pipeline( |
|
message=message, |
|
history=history, |
|
model_name=model_name, |
|
generation_config=generation_config, |
|
system_prompt=system_message |
|
) |
|
|
|
yield response |
|
else: |
|
|
|
response_generator = chat_processor.process_chat( |
|
message=message, |
|
history=history, |
|
model_name=model_name, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
system_prompt=system_message |
|
) |
|
|
|
|
|
response = "" |
|
for history, dbg in response_generator: |
|
response = history[-1]['content'] |
|
yield response |
|
|
|
except Exception as e: |
|
logger.error(f"Chat response error: {str(e)}") |
|
yield f"Error: {str(e)}" |
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
run_respond, |
|
additional_inputs=[ |
|
gr.Dropdown( |
|
choices=["Qwen3-14B", "Qwen3-8B", "Qwen3-0.6B"], |
|
value="Qwen3-0.6B", |
|
label="Model Selection" |
|
), |
|
gr.Textbox(value="You are a Qwen3 assistant.", label="System message"), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"), |
|
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"), |
|
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"), |
|
gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)") |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|