qwen2.5 / app.py
Avinash109's picture
Update app.py
ed64278 verified
raw
history blame
5.39 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datetime
import gc
import os
# Enable memory efficient options
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# Set page configuration
st.set_page_config(
page_title="Qwen2.5-Coder Chat",
page_icon="πŸ’¬",
layout="wide",
)
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
@st.cache_resource(show_spinner=False)
def load_model_and_tokenizer():
try:
model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"
with st.spinner("πŸ”„ Loading tokenizer..."):
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
with st.spinner("πŸ”„ Loading model... (this may take a few minutes on CPU)"):
# Load model with 8-bit quantization for CPU
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map={"": "cpu"},
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
load_in_8bit=True # Enable 8-bit quantization
)
# Force CPU mode and eval mode
model = model.to("cpu").eval()
# Clear memory after loading
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
st.session_state.model_loaded = True
return tokenizer, model
except Exception as e:
st.error(f"❌ Error loading model: {str(e)}")
return None, None
def generate_response(prompt, model, tokenizer, max_length=256):
try:
# Clear memory before generation
gc.collect()
# Tokenize with shorter maximum length
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
).to("cpu")
# Generate with minimal parameters for CPU
with torch.no_grad(), st.spinner("πŸ€” Thinking... (please be patient)"):
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
num_beams=1, # Disable beam search
early_stopping=True
)
# Clear memory after generation
gc.collect()
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):].strip()
except torch.cuda.OutOfMemoryError:
st.error("πŸ’Ύ Memory exceeded. Try reducing the maximum length.")
return None
except Exception as e:
st.error(f"❌ Error: {str(e)}")
return None
# Main UI
st.title("πŸ’¬ Qwen2.5-Coder Chat")
# Sidebar with minimal settings
with st.sidebar:
st.header("βš™οΈ Settings")
max_length = st.slider(
"Response Length πŸ“",
min_value=64,
max_value=512,
value=256,
step=64,
help="Shorter lengths are recommended for CPU"
)
if st.button("πŸ—‘οΈ Clear Conversation"):
st.session_state.messages = []
st.rerun()
# Load model
if not st.session_state.model_loaded:
tokenizer, model = load_model_and_tokenizer()
if model is None:
st.stop()
else:
tokenizer, model = load_model_and_tokenizer()
# Display conversation history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(f"{message['content']}\n\n_{message['timestamp']}_")
# Chat input
if prompt := st.chat_input("πŸ’­ Ask me anything about coding..."):
# Add user message
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.messages.append({
"role": "user",
"content": prompt,
"timestamp": timestamp
})
# Display user message
with st.chat_message("user"):
st.markdown(f"{prompt}\n\n_{timestamp}_")
# Generate and display response
with st.chat_message("assistant"):
# Keep only last message for context to reduce memory usage
conversation = f"Human: {prompt}\nAssistant:"
response = generate_response(
conversation,
model,
tokenizer,
max_length=max_length
)
if response:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.markdown(f"{response}\n\n_{timestamp}_")
# Add response to chat history
st.session_state.messages.append({
"role": "assistant",
"content": response,
"timestamp": timestamp
})
else:
st.error("❌ Failed to generate response. Please try again with a shorter length.")
# Clear memory after response
gc.collect()