File size: 4,323 Bytes
c687e78
a5ea6e6
 
 
6421d05
 
 
 
 
 
a5ea6e6
 
 
e83210b
a5ea6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6421d05
 
a5ea6e6
 
 
 
 
 
 
 
 
33d5962
a5ea6e6
 
6421d05
 
 
 
 
 
 
33d5962
6421d05
 
 
 
33d5962
 
 
 
a5ea6e6
 
33d5962
a5ea6e6
33d5962
 
a5ea6e6
 
 
 
 
 
 
 
 
33d5962
 
a5ea6e6
 
 
33d5962
 
a5ea6e6
e83210b
a5ea6e6
6421d05
 
 
 
 
 
a5ea6e6
 
e83210b
 
6421d05
a5ea6e6
 
 
6421d05
a5ea6e6
33d5962
a5ea6e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
import gradio as gr
import transformers

print(f"Gradio version: {gr.__version__}") # Print Gradio version
print(f"Transformers version: {transformers.__version__}") # Print Transformers version


# --- Configuration (Read from Environment Variables) ---

# Get the model path from an environment variable.
model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new")
deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1")
# Get the Hugging Face token from an environment variable (for gated models).
hf_token = os.environ.get("HF_TOKEN", None)  # Default to None if not set


# --- Model and Tokenizer Loading ---
# Use try-except for robust error handling
try:
    # Load the model.  Assume a merged model.
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",  # Use GPU if available, otherwise CPU
        torch_dtype=torch.float16,  # Use float16 if supported
        token=hf_token  # Use the token from the environment variable
    )

     # Load the DeepSeek tokenizer
    tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    print("Model and tokenizer loaded successfully!") # Success message

except OSError as e:
    print(f"Error loading model or tokenizer: {e}")
    print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.")
    print("If using a gated model, ensure HF_TOKEN is set correctly.")
    exit()  # Terminate the script if loading fails


# --- Chat Function ---
def chat_with_llm(prompt, history):
    """Generates a response from the LLM, handling history correctly."""
    formatted_prompt = ""
    if history:
        print("DEBUG: History variable type:", type(history))
        if history:
            print("DEBUG: Example history item:", history[0]) # Print first history item
        else:
            print("DEBUG: History is empty but should not be in chat turn > 1")


        for item in history:
            if not isinstance(item, dict) or "role" not in item or "content" not in item: # Check item structure
                print("DEBUG: Invalid history item format:", item) # Debug invalid item
                continue # Skip invalid items instead of crashing

            if item["role"] == "user":
                formatted_prompt += f"{tokenizer.bos_token}{item['content']}{tokenizer.eos_token}"
            elif item["role"] == "assistant":
                formatted_prompt += f"{item['content']}{tokenizer.eos_token}"

    formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"

    try:
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        result = pipe(
            formatted_prompt,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            top_k=50,
            return_full_text=False,
            pad_token_id=tokenizer.eos_token_id,
        )
        response = result[0]['generated_text'].strip()
        return response
    except Exception as e:
        return f"Error during generation: {e}"



# --- Gradio Interface ---
# Use the 'messages' format for chatbot
def predict(message, history):
    history_messages = history or [] # Rename to avoid shadowing
    response = chat_with_llm(message, history_messages)
    history_messages.append({"role": "user", "content": message})
    history_messages.append({"role": "assistant", "content": response})
    return "", history_messages  # Return the updated history


with gr.Blocks() as demo:
    chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True,
                         value=[{"role": "assistant", "content": "Hi! I'm Athspi. How can I help you today?"}],
                         type="messages")  # Ensure type is "messages"
    msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
    clear = gr.Button("Clear")


    msg.submit(predict, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: [], [], chatbot, queue=False)

demo.launch(share=True)