File size: 4,323 Bytes
c687e78 a5ea6e6 6421d05 a5ea6e6 e83210b a5ea6e6 6421d05 a5ea6e6 33d5962 a5ea6e6 6421d05 33d5962 6421d05 33d5962 a5ea6e6 33d5962 a5ea6e6 33d5962 a5ea6e6 33d5962 a5ea6e6 33d5962 a5ea6e6 e83210b a5ea6e6 6421d05 a5ea6e6 e83210b 6421d05 a5ea6e6 6421d05 a5ea6e6 33d5962 a5ea6e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
import gradio as gr
import transformers
print(f"Gradio version: {gr.__version__}") # Print Gradio version
print(f"Transformers version: {transformers.__version__}") # Print Transformers version
# --- Configuration (Read from Environment Variables) ---
# Get the model path from an environment variable.
model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new")
deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1")
# Get the Hugging Face token from an environment variable (for gated models).
hf_token = os.environ.get("HF_TOKEN", None) # Default to None if not set
# --- Model and Tokenizer Loading ---
# Use try-except for robust error handling
try:
# Load the model. Assume a merged model.
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Use GPU if available, otherwise CPU
torch_dtype=torch.float16, # Use float16 if supported
token=hf_token # Use the token from the environment variable
)
# Load the DeepSeek tokenizer
tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Model and tokenizer loaded successfully!") # Success message
except OSError as e:
print(f"Error loading model or tokenizer: {e}")
print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.")
print("If using a gated model, ensure HF_TOKEN is set correctly.")
exit() # Terminate the script if loading fails
# --- Chat Function ---
def chat_with_llm(prompt, history):
"""Generates a response from the LLM, handling history correctly."""
formatted_prompt = ""
if history:
print("DEBUG: History variable type:", type(history))
if history:
print("DEBUG: Example history item:", history[0]) # Print first history item
else:
print("DEBUG: History is empty but should not be in chat turn > 1")
for item in history:
if not isinstance(item, dict) or "role" not in item or "content" not in item: # Check item structure
print("DEBUG: Invalid history item format:", item) # Debug invalid item
continue # Skip invalid items instead of crashing
if item["role"] == "user":
formatted_prompt += f"{tokenizer.bos_token}{item['content']}{tokenizer.eos_token}"
elif item["role"] == "assistant":
formatted_prompt += f"{item['content']}{tokenizer.eos_token}"
formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
try:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
result = pipe(
formatted_prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=50,
return_full_text=False,
pad_token_id=tokenizer.eos_token_id,
)
response = result[0]['generated_text'].strip()
return response
except Exception as e:
return f"Error during generation: {e}"
# --- Gradio Interface ---
# Use the 'messages' format for chatbot
def predict(message, history):
history_messages = history or [] # Rename to avoid shadowing
response = chat_with_llm(message, history_messages)
history_messages.append({"role": "user", "content": message})
history_messages.append({"role": "assistant", "content": response})
return "", history_messages # Return the updated history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True,
value=[{"role": "assistant", "content": "Hi! I'm Athspi. How can I help you today?"}],
type="messages") # Ensure type is "messages"
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
clear = gr.Button("Clear")
msg.submit(predict, [msg, chatbot], [msg, chatbot])
clear.click(lambda: [], [], chatbot, queue=False)
demo.launch(share=True) |