File size: 4,280 Bytes
8483978 0b7ba67 8483978 60f0153 b264ac6 60f0153 b264ac6 60f0153 8483978 0b7ba67 8483978 0b7ba67 60f0153 8483978 0b7ba67 8483978 0b7ba67 8483978 0b7ba67 c285061 60f0153 8483978 b264ac6 8483978 60f0153 8483978 0b7ba67 60f0153 8483978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
from src.models import ModelManager
from src.chat_logic import ChatProcessor
from src.vector_db import VectorDBHandler
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize components
model_manager = ModelManager()
vector_db = VectorDBHandler()
chat_processor = ChatProcessor(model_manager, vector_db)
try:
import spaces
@spaces.GPU(duration=60)
def run_respond(*args, **kwargs):
for token in respond(*args, **kwargs):
yield token
except ImportError:
def run_respond(*args, **kwargs):
for token in respond(*args, **kwargs):
yield token
def respond(
message,
history: list[tuple[str, str]],
model_name: str,
system_message: str = "You are a Qwen3 assistant.",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
use_direct_pipeline: bool = False
):
"""
Process chat using the ChatProcessor with streaming support.
Args:
message: The user message
history: Chat history as list of (user, assistant) message pairs
model_name: Name of the model to use
system_message: System prompt to guide the model's behavior
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Penalty for token repetition
use_direct_pipeline: Whether to use the direct pipeline method
Yields:
Generated response tokens for streaming UI
"""
print(f"Running respond with use_direct_pipeline: {use_direct_pipeline}")
try:
if use_direct_pipeline:
# Use the direct pipeline method (non-streaming)
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": True
}
response = chat_processor.generate_with_pipeline(
message=message,
history=history,
model_name=model_name,
generation_config=generation_config,
system_prompt=system_message
)
yield response
else:
# Use the streaming method
response_generator = chat_processor.process_chat(
message=message,
history=history,
model_name=model_name,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
system_prompt=system_message
)
# Stream response tokens
response = ""
for history, dbg in response_generator:
response = history[-1]['content']
yield response # Yield the accumulated response for streaming UI
except Exception as e:
logger.error(f"Chat response error: {str(e)}")
yield f"Error: {str(e)}"
# Create Gradio interface
demo = gr.ChatInterface(
run_respond,
additional_inputs=[
gr.Dropdown(
choices=["Qwen3-14B", "Qwen3-8B", "Qwen3-0.6B"],
value="Qwen3-0.6B",
label="Model Selection"
),
gr.Textbox(value="You are a Qwen3 assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
]
)
if __name__ == "__main__":
demo.launch()
|