Athspi commited on
Commit
a5ea6e6
·
verified ·
1 Parent(s): ff3b5db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch
4
+ import os
5
+
6
+ # --- Configuration (Read from Environment Variables) ---
7
+
8
+ # Get the model path from an environment variable. Default to a placeholder
9
+ # if the environment variable is not set. This is important for deployment.
10
+ model_path = os.environ.get("MODEL_PATH", "Athspi/Athspiv2new")
11
+ deepseek_tokenizer_path = os.environ.get("TOKENIZER_PATH", "deepseek-ai/DeepSeek-R1")
12
+ # Get the Hugging Face token from an environment variable (for gated models).
13
+ hf_token = os.environ.get("HF_TOKEN", None) # Default to None if not set
14
+
15
+
16
+ # --- Model and Tokenizer Loading ---
17
+ # Use try-except for robust error handling
18
+ try:
19
+ # Load the model. Assume a merged model.
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_path,
22
+ device_map="auto", # Use GPU if available, otherwise CPU
23
+ torch_dtype=torch.float16, # Use float16 if supported
24
+ token=hf_token # Use the token from the environment variable
25
+ )
26
+
27
+ # Load the DeepSeek tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(deepseek_tokenizer_path, token=hf_token)
29
+
30
+ if tokenizer.pad_token is None:
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ tokenizer.padding_side = "right"
33
+
34
+ except OSError as e:
35
+ print(f"Error loading model or tokenizer: {e}")
36
+ print("Ensure MODEL_PATH and TOKENIZER_PATH environment variables are set correctly.")
37
+ print("If using a gated model, ensure HF_TOKEN is set correctly.")
38
+ exit() # Terminate the script if loading fails
39
+
40
+
41
+ # --- Chat Function ---
42
+
43
+ def chat_with_llm(prompt, history):
44
+ """Generates a response from the LLM."""
45
+
46
+ formatted_prompt = ""
47
+ if history:
48
+ for user_msg, ai_msg in history:
49
+ formatted_prompt += f"{tokenizer.bos_token}{user_msg}{tokenizer.eos_token}"
50
+ formatted_prompt += f"{ai_msg}{tokenizer.eos_token}"
51
+
52
+ formatted_prompt += f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
53
+ try:
54
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
55
+ result = pipe(
56
+ formatted_prompt,
57
+ max_new_tokens=256,
58
+ do_sample=True,
59
+ temperature=0.7,
60
+ top_p=0.95,
61
+ top_k=50,
62
+ return_full_text=False,
63
+ pad_token_id=tokenizer.eos_token_id,
64
+ )
65
+ response = result[0]['generated_text'].strip()
66
+ return response
67
+ except Exception as e:
68
+ return f"Error during generation: {e}"
69
+
70
+
71
+ # --- Gradio Interface ---
72
+
73
+ def predict(message, history):
74
+ history = history or []
75
+ response = chat_with_llm(message, history)
76
+ history.append((message, response))
77
+ return "", history
78
+
79
+ with gr.Blocks() as demo:
80
+ chatbot = gr.Chatbot(label="Athspi Chat", height=500, show_label=True, value=[[None, "Hi! I'm Athspi. How can I help you today?"]])
81
+ msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
82
+ clear = gr.Button("Clear")
83
+
84
+ msg.submit(predict, [msg, chatbot], [msg, chatbot])
85
+ clear.click(lambda: None, None, chatbot, queue=False)
86
+
87
+ demo.launch(share=True)