Athspiv2 / app.py
Athspi's picture
Create app.py
a5ea6e6 verified
raw
history blame
3.15 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
# --- Configuration (Read from Environment Variables) ---
# Get the model path from an environment variable. Default to a placeholder
# if the environment variable is not set. This is important for deployment.
model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new")
deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1")
# Get the Hugging Face token from an environment variable (for gated models).
hf_token = os.environ.get("HF_TOKEN", None) # Default to None if not set
# --- Model and Tokenizer Loading ---
# Use try-except for robust error handling
try:
# Load the model. Assume a merged model.
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Use GPU if available, otherwise CPU
torch_dtype=torch.float16, # Use float16 if supported
token=hf_token # Use the token from the environment variable
)
# Load the DeepSeek tokenizer
tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
except OSError as e:
print(f"Error loading model or tokenizer: {e}")
print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.")
print("If using a gated model, ensure HF_TOKEN is set correctly.")
exit() # Terminate the script if loading fails
# --- Chat Function ---
def chat_with_llm(prompt, history):
"""Generates a response from the LLM."""
formatted_prompt = ""
if history:
for user_msg, ai_msg in history:
formatted_prompt += f"{tokenizer.bos_token}{user_msg}{tokenizer.eos_token}"
formatted_prompt += f"{ai_msg}{tokenizer.eos_token}"
formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
try:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
result = pipe(
formatted_prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=50,
return_full_text=False,
pad_token_id=tokenizer.eos_token_id,
)
response = result[0]['generated_text'].strip()
return response
except Exception as e:
return f"Error during generation: {e}"
# --- Gradio Interface ---
def predict(message, history):
history = history or []
response = chat_with_llm(message, history)
history.append((message, response))
return "", history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True, value=[[None, "Hi! I'm Athspi. How can I help you today?"]])
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
clear = gr.Button("Clear")
msg.submit(predict, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=True)