File size: 4,280 Bytes
8483978
 
 
 
 
0b7ba67
 
 
 
8483978
 
 
 
 
60f0153
 
 
 
 
b264ac6
 
60f0153
 
b264ac6
 
60f0153
8483978
 
 
 
 
 
 
 
 
0b7ba67
 
8483978
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60f0153
 
8483978
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8483978
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
8483978
0b7ba67
 
c285061
 
60f0153
 
8483978
 
 
 
 
 
 
b264ac6
8483978
 
60f0153
 
8483978
 
 
 
 
 
 
0b7ba67
 
60f0153
8483978
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
from src.models import ModelManager
from src.chat_logic import ChatProcessor
from src.vector_db import VectorDBHandler
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize components
model_manager = ModelManager()
vector_db = VectorDBHandler()
chat_processor = ChatProcessor(model_manager, vector_db)


try:
    import spaces
    @spaces.GPU(duration=60)
    def run_respond(*args, **kwargs):
        for token in respond(*args, **kwargs):
            yield token
except ImportError:
    def run_respond(*args, **kwargs):
        for token in respond(*args, **kwargs):
            yield token

def respond(
    message,
    history: list[tuple[str, str]],
    model_name: str,
    system_message: str = "You are a Qwen3 assistant.",
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    use_direct_pipeline: bool = False
):
    """
    Process chat using the ChatProcessor with streaming support.
    
    Args:
        message: The user message
        history: Chat history as list of (user, assistant) message pairs
        model_name: Name of the model to use
        system_message: System prompt to guide the model's behavior
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        top_k: Top-k sampling parameter
        repetition_penalty: Penalty for token repetition
        use_direct_pipeline: Whether to use the direct pipeline method
    
    Yields:
        Generated response tokens for streaming UI
    """

    print(f"Running respond with use_direct_pipeline: {use_direct_pipeline}")
    try:
        if use_direct_pipeline:
            # Use the direct pipeline method (non-streaming)
            generation_config = {
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
                "repetition_penalty": repetition_penalty,
                "do_sample": True
            }
            
            response = chat_processor.generate_with_pipeline(
                message=message,
                history=history,
                model_name=model_name,
                generation_config=generation_config,
                system_prompt=system_message
            )
            
            yield response
        else:
            # Use the streaming method
            response_generator = chat_processor.process_chat(
                message=message,
                history=history,
                model_name=model_name,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                top_k=top_k,
                repetition_penalty=repetition_penalty,
                system_prompt=system_message
            )
            
            # Stream response tokens
            response = ""
            for history, dbg in response_generator:
                response = history[-1]['content']
                yield response  # Yield the accumulated response for streaming UI
            
    except Exception as e:
        logger.error(f"Chat response error: {str(e)}")
        yield f"Error: {str(e)}"


# Create Gradio interface
demo = gr.ChatInterface(
    run_respond,
    additional_inputs=[
        gr.Dropdown(
            choices=["Qwen3-14B", "Qwen3-8B", "Qwen3-0.6B"],
            value="Qwen3-0.6B",
            label="Model Selection"
        ),
        gr.Textbox(value="You are a Qwen3 assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
        gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
        gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
    ]
)


if __name__ == "__main__":
    demo.launch()