|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import torch |
|
import os |
|
import gradio as gr |
|
import transformers |
|
|
|
print(f"Gradio version: {gr.__version__}") |
|
print(f"Transformers version: {transformers.__version__}") |
|
|
|
|
|
|
|
|
|
|
|
model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new") |
|
deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1") |
|
|
|
hf_token = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
token=hf_token |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
print("Model and tokenizer loaded successfully!") |
|
|
|
except OSError as e: |
|
print(f"Error loading model or tokenizer: {e}") |
|
print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.") |
|
print("If using a gated model, ensure HF_TOKEN is set correctly.") |
|
exit() |
|
|
|
|
|
|
|
def chat_with_llm(prompt, history): |
|
"""Generates a response from the LLM, handling history correctly.""" |
|
formatted_prompt = "" |
|
if history: |
|
print("DEBUG: History variable type:", type(history)) |
|
if history: |
|
print("DEBUG: Example history item:", history[0]) |
|
else: |
|
print("DEBUG: History is empty but should not be in chat turn > 1") |
|
|
|
|
|
for item in history: |
|
if not isinstance(item, dict) or "role" not in item or "content" not in item: |
|
print("DEBUG: Invalid history item format:", item) |
|
continue |
|
|
|
if item["role"] == "user": |
|
formatted_prompt += f"{tokenizer.bos_token}{item['content']}{tokenizer.eos_token}" |
|
elif item["role"] == "assistant": |
|
formatted_prompt += f"{item['content']}{tokenizer.eos_token}" |
|
|
|
formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}" |
|
|
|
try: |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") |
|
result = pipe( |
|
formatted_prompt, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
top_k=50, |
|
return_full_text=False, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
response = result[0]['generated_text'].strip() |
|
return response |
|
except Exception as e: |
|
return f"Error during generation: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def predict(message, history): |
|
history_messages = history or [] |
|
response = chat_with_llm(message, history_messages) |
|
history_messages.append({"role": "user", "content": message}) |
|
history_messages.append({"role": "assistant", "content": response}) |
|
return "", history_messages |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True, |
|
value=[{"role": "assistant", "content": "Hi! I'm Athspi. How can I help you today?"}], |
|
type="messages") |
|
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") |
|
clear = gr.Button("Clear") |
|
|
|
|
|
msg.submit(predict, [msg, chatbot], [msg, chatbot]) |
|
clear.click(lambda: [], [], chatbot, queue=False) |
|
|
|
demo.launch(share=True) |