File size: 4,222 Bytes
bba8253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Set Streamlit page configuration
st.set_page_config(
    page_title="Qwen2.5-Coder Chat",
    page_icon="πŸ’¬",
    layout="wide",
)

# Title of the app
st.title("πŸ’¬ Qwen2.5-Coder Chat Interface")

# Initialize session state for messages
if 'messages' not in st.session_state:
    st.session_state['messages'] = []

# Function to load the model
@st.cache_resource
def load_model():
    model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"  # Replace with your model path or name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,  # Use appropriate dtype
        device_map='auto'           # Automatically choose device (GPU/CPU)
    )
    return tokenizer, model

# Load tokenizer and model
with st.spinner("Loading model... This may take a while..."):
    tokenizer, model = load_model()

# Function to generate model response
def generate_response(prompt, max_tokens=2048):
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=max_tokens,
            temperature=0.7,       # Adjust for creativity
            top_p=0.9,             # Nucleus sampling
            do_sample=True,        # Enable sampling
            num_return_sequences=1
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the prompt from the response
    response = response[len(prompt):].strip()
    return response

# Layout: Two columns, main chat and sidebar
chat_col, sidebar_col = st.columns([4, 1])

with chat_col:
    # Display chat messages
    for message in st.session_state['messages']:
        if message['role'] == 'user':
            st.markdown(f"**You:** {message['content']}")
        else:
            st.markdown(f"**Qwen2.5-Coder:** {message['content']}")

    # Input area for user
    with st.form(key='chat_form', clear_on_submit=True):
        user_input = st.text_area("You:", height=100)
        submit_button = st.form_submit_button(label='Send')

    if submit_button and user_input:
        # Append user message
        st.session_state['messages'].append({'role': 'user', 'content': user_input})
        
        # Generate and append model response
        with st.spinner("Qwen2.5-Coder is typing..."):
            response = generate_response(user_input, max_tokens=2048)
            st.session_state['messages'].append({'role': 'assistant', 'content': response})

        # Rerun to display new messages
        st.experimental_rerun()

with sidebar_col:
    st.sidebar.header("Settings")
    max_tokens = st.sidebar.slider(
        "Maximum Tokens",
        min_value=512,
        max_value=4096,
        value=2048,
        step=256,
        help="Set the maximum number of tokens for the model's response."
    )
    
    temperature = st.sidebar.slider(
        "Temperature",
        min_value=0.1,
        max_value=1.0,
        value=0.7,
        step=0.1,
        help="Controls the randomness of the model's output."
    )
    
    top_p = st.sidebar.slider(
        "Top-p (Nucleus Sampling)",
        min_value=0.1,
        max_value=1.0,
        value=0.9,
        step=0.1,
        help="Controls the diversity of the model's output."
    )

    if st.sidebar.button("Clear Chat"):
        st.session_state['messages'] = []
        st.experimental_rerun()

# Update the generate_response function to use sidebar settings
def generate_response(prompt):
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=1
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the prompt from the response
    response = response[len(prompt):].strip()
    return response