qwen2.5 / app.py
Avinash109's picture
Update app.py
3d4f049 verified
raw
history blame
5.64 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import datetime
# Set Streamlit page configuration
st.set_page_config(
page_title="Qwen2.5-Coder Chat",
page_icon="πŸ’¬",
layout="wide",
)
# Title of the app
st.title("πŸ’¬ Qwen2.5-Coder Chat Interface")
# Initialize session state for messages (store conversation history)
st.session_state.setdefault('messages', [])
# Load the model and tokenizer
@st.cache_resource
def load_model():
model_name = "Qwen/Qwen2.5-Coder-32B-Instruct" # Replace with the correct model path
# Define BitsAndBytesConfig for 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, # Enable 8-bit loading
llm_int8_enable_fp32_cpu_offload=True # Optional: Enables offloading to CPU
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
# Load tokenizer and model with error handling
try:
with st.spinner("Loading model... This may take a while..."):
tokenizer, model = load_model()
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
# Function to generate model response
def generate_response(messages, tokenizer, model, max_tokens=150, temperature=0.7, top_p=0.9):
"""
Generates a response from the model based on the conversation history.
Args:
messages (list): List of message dictionaries containing 'role' and 'content'.
tokenizer: The tokenizer instance.
model: The language model instance.
max_tokens (int): Maximum number of tokens for the response.
temperature (float): Sampling temperature.
top_p (float): Nucleus sampling probability.
Returns:
str: The generated response text.
"""
# Concatenate all previous messages
conversation = ""
for message in messages:
role = "You" if message['role'] == 'user' else "Qwen2.5-Coder"
conversation += f"**{role}:** {message['content']}\n"
# Append the latest user input
conversation += f"**You:** {messages[-1]['content']}\n**Qwen2.5-Coder:**"
# Tokenize the conversation
inputs = tokenizer.encode(conversation, return_tensors="pt").to(model.device)
# Generate a response
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=inputs.shape[1] + max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the generated response after the conversation
generated_response = response.split("Qwen2.5-Coder:")[-1].strip()
return generated_response
# Layout: Two columns for the main chat and sidebar
chat_col, sidebar_col = st.columns([4, 1])
with chat_col:
st.markdown("### Chat")
chat_container = st.container()
with chat_container:
for message in st.session_state['messages']:
time = message.get('timestamp', '')
if message['role'] == 'user':
st.markdown(f"**You:** {message['content']} _({time})_")
else:
st.markdown(f"**Qwen2.5-Coder:** {message['content']} _({time})_")
# Input area for user message
with st.form(key='chat_form', clear_on_submit=True):
user_input = st.text_area("You:", height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input.strip():
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Append the user's message to the chat history
st.session_state['messages'].append({'role': 'user', 'content': user_input, 'timestamp': timestamp})
# Generate and append the model's response
try:
with st.spinner("Qwen2.5-Coder is typing..."):
response = generate_response(
st.session_state['messages'],
tokenizer,
model,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p
)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state['messages'].append({'role': 'assistant', 'content': response, 'timestamp': timestamp})
except Exception as e:
st.error(f"Error generating response: {e}")
with sidebar_col:
st.sidebar.header("Settings")
max_tokens = st.sidebar.slider(
"Maximum Tokens",
min_value=50,
max_value=4096,
value=512,
step=256,
help="Set the maximum number of tokens for the model's response."
)
temperature = st.sidebar.slider(
"Temperature",
min_value=0.1,
max_value=1.0,
value=0.7,
step=0.1,
help="Controls the randomness of the model's output."
)
top_p = st.sidebar.slider(
"Top-p (Nucleus Sampling)",
min_value=0.1,
max_value=1.0,
value=0.9,
step=0.1,
help="Controls the diversity of the model's output."
)
if st.sidebar.button("Clear Chat"):
st.session_state['messages'] = []
st.success("Chat history cleared.")