Athspiv2 / app.py
Athspi's picture
Update app.py
c687e78 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
import gradio as gr
import transformers
print(f"Gradio version: {gr.__version__}") # Print Gradio version
print(f"Transformers version: {transformers.__version__}") # Print Transformers version
# --- Configuration (Read from Environment Variables) ---
# Get the model path from an environment variable.
model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new")
deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1")
# Get the Hugging Face token from an environment variable (for gated models).
hf_token = os.environ.get("HF_TOKEN", None) # Default to None if not set
# --- Model and Tokenizer Loading ---
# Use try-except for robust error handling
try:
# Load the model. Assume a merged model.
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Use GPU if available, otherwise CPU
torch_dtype=torch.float16, # Use float16 if supported
token=hf_token # Use the token from the environment variable
)
# Load the DeepSeek tokenizer
tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Model and tokenizer loaded successfully!") # Success message
except OSError as e:
print(f"Error loading model or tokenizer: {e}")
print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.")
print("If using a gated model, ensure HF_TOKEN is set correctly.")
exit() # Terminate the script if loading fails
# --- Chat Function ---
def chat_with_llm(prompt, history):
"""Generates a response from the LLM, handling history correctly."""
formatted_prompt = ""
if history:
print("DEBUG: History variable type:", type(history))
if history:
print("DEBUG: Example history item:", history[0]) # Print first history item
else:
print("DEBUG: History is empty but should not be in chat turn > 1")
for item in history:
if not isinstance(item, dict) or "role" not in item or "content" not in item: # Check item structure
print("DEBUG: Invalid history item format:", item) # Debug invalid item
continue # Skip invalid items instead of crashing
if item["role"] == "user":
formatted_prompt += f"{tokenizer.bos_token}{item['content']}{tokenizer.eos_token}"
elif item["role"] == "assistant":
formatted_prompt += f"{item['content']}{tokenizer.eos_token}"
formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
try:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
result = pipe(
formatted_prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=50,
return_full_text=False,
pad_token_id=tokenizer.eos_token_id,
)
response = result[0]['generated_text'].strip()
return response
except Exception as e:
return f"Error during generation: {e}"
# --- Gradio Interface ---
# Use the 'messages' format for chatbot
def predict(message, history):
history_messages = history or [] # Rename to avoid shadowing
response = chat_with_llm(message, history_messages)
history_messages.append({"role": "user", "content": message})
history_messages.append({"role": "assistant", "content": response})
return "", history_messages # Return the updated history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True,
value=[{"role": "assistant", "content": "Hi! I'm Athspi. How can I help you today?"}],
type="messages") # Ensure type is "messages"
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
clear = gr.Button("Clear")
msg.submit(predict, [msg, chatbot], [msg, chatbot])
clear.click(lambda: [], [], chatbot, queue=False)
demo.launch(share=True)