Update app.py
Browse files
app.py
CHANGED
@@ -7,11 +7,11 @@ import psutil
|
|
7 |
import multiprocessing
|
8 |
import time
|
9 |
import uuid # For generating unique session IDs
|
|
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
13 |
# === Model Config ===
|
14 |
-
# Corrected REPO_ID to use TheBloke's GGUF version of TinyLlama
|
15 |
REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
16 |
FILENAME = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" # Q4_K_M is a good balance of size and quality
|
17 |
MODEL_DIR = "models"
|
@@ -38,14 +38,8 @@ else:
|
|
38 |
model_path = MODEL_PATH
|
39 |
|
40 |
# === Optimal thread usage ===
|
41 |
-
# psutil.cpu_count(logical=True) gives the number of logical cores (threads)
|
42 |
-
# psutil.cpu_count(logical=False) gives the number of physical cores
|
43 |
-
# For llama.cpp, n_threads often performs best when set to the number of physical cores,
|
44 |
-
# or slightly more, but not exceeding logical cores. Experimentation is key.
|
45 |
logical_cores = psutil.cpu_count(logical=True)
|
46 |
physical_cores = psutil.cpu_count(logical=False)
|
47 |
-
# A common recommendation is to use physical cores or physical_cores * 2
|
48 |
-
# Let's try physical_cores for a start, or a fixed value if physical_cores is too low.
|
49 |
recommended_threads = max(1, physical_cores) # Ensure at least 1 thread
|
50 |
|
51 |
print(f"Detected physical cores: {physical_cores}, logical cores: {logical_cores}")
|
@@ -55,10 +49,10 @@ print(f"Using n_threads: {recommended_threads}")
|
|
55 |
try:
|
56 |
llm = Llama(
|
57 |
model_path=model_path,
|
58 |
-
n_ctx=1024, #
|
59 |
n_threads=recommended_threads,
|
60 |
-
use_mlock=True, # Lock model in RAM for faster access
|
61 |
-
n_gpu_layers=0, # CPU only
|
62 |
chat_format="chatml", # TinyLlama Chat uses ChatML format
|
63 |
verbose=False
|
64 |
)
|
@@ -67,11 +61,65 @@ except Exception as e:
|
|
67 |
print(f"❌ Error loading Llama model: {e}")
|
68 |
exit(1)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# === Global dictionary to store chat histories per session ===
|
71 |
-
# In a production environment, this should be replaced with a persistent storage
|
72 |
-
# like Redis, a database, or a dedicated session management system.
|
73 |
chat_histories = {}
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
@app.get("/")
|
76 |
def root():
|
77 |
return {"message": "✅ Data Analysis AI API is live and optimized!"}
|
@@ -159,14 +207,31 @@ async def generate(request: Request):
|
|
159 |
|
160 |
print(f"🧾 Prompt received for session {session_id}: {prompt}")
|
161 |
|
162 |
-
# Add the user's new message to
|
163 |
-
chat_histories[session_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
try:
|
166 |
-
# Pass the
|
167 |
response = llm.create_chat_completion(
|
168 |
messages=chat_histories[session_id],
|
169 |
-
max_tokens=
|
170 |
temperature=0.7, # Adjust temperature for creativity vs. coherence (0.0-1.0)
|
171 |
stop=["</s>"] # Stop sequence for TinyLlama Chat
|
172 |
)
|
@@ -178,7 +243,8 @@ async def generate(request: Request):
|
|
178 |
|
179 |
return {
|
180 |
"response": ai_response_content,
|
181 |
-
"session_id": session_id # Return the session_id so the client can use it for subsequent requests
|
|
|
182 |
}
|
183 |
except Exception as e:
|
184 |
print(f"❌ Error during generation for session {session_id}: {e}")
|
|
|
7 |
import multiprocessing
|
8 |
import time
|
9 |
import uuid # For generating unique session IDs
|
10 |
+
import tiktoken # For estimating token count
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
# === Model Config ===
|
|
|
15 |
REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
|
16 |
FILENAME = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" # Q4_K_M is a good balance of size and quality
|
17 |
MODEL_DIR = "models"
|
|
|
38 |
model_path = MODEL_PATH
|
39 |
|
40 |
# === Optimal thread usage ===
|
|
|
|
|
|
|
|
|
41 |
logical_cores = psutil.cpu_count(logical=True)
|
42 |
physical_cores = psutil.cpu_count(logical=False)
|
|
|
|
|
43 |
recommended_threads = max(1, physical_cores) # Ensure at least 1 thread
|
44 |
|
45 |
print(f"Detected physical cores: {physical_cores}, logical cores: {logical_cores}")
|
|
|
49 |
try:
|
50 |
llm = Llama(
|
51 |
model_path=model_path,
|
52 |
+
n_ctx=1024, # Context window size for the model
|
53 |
n_threads=recommended_threads,
|
54 |
+
use_mlock=True, # Lock model in RAM for faster access
|
55 |
+
n_gpu_layers=0, # CPU only
|
56 |
chat_format="chatml", # TinyLlama Chat uses ChatML format
|
57 |
verbose=False
|
58 |
)
|
|
|
61 |
print(f"❌ Error loading Llama model: {e}")
|
62 |
exit(1)
|
63 |
|
64 |
+
# Initialize tiktoken encoder for token counting (approximate for GGUF models, but good enough)
|
65 |
+
# For TinyLlama, we'll use a generic encoder or one that's close enough.
|
66 |
+
# 'cl100k_base' is common for OpenAI models, but a good approximation for many others.
|
67 |
+
# For more precise counts for GGUF, you might need to use the model's tokenizer if available
|
68 |
+
# or rely on llama.cpp's internal tokenization (which is harder to access directly).
|
69 |
+
# For simplicity and general estimation, cl100k_base is often used.
|
70 |
+
try:
|
71 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
72 |
+
except Exception:
|
73 |
+
print("⚠️ Could not load tiktoken 'cl100k_base' encoding. Using basic len() for token estimation.")
|
74 |
+
encoding = None
|
75 |
+
|
76 |
# === Global dictionary to store chat histories per session ===
|
|
|
|
|
77 |
chat_histories = {}
|
78 |
|
79 |
+
# === Context Truncation Settings ===
|
80 |
+
# Max tokens for the entire conversation history (input to the model)
|
81 |
+
# This should be less than n_ctx to leave room for the new prompt and generated response.
|
82 |
+
MAX_CONTEXT_TOKENS = 800 # Keep total input context below this, leaving 224 tokens for new prompt + response
|
83 |
+
|
84 |
+
def count_tokens_in_message(message):
|
85 |
+
"""Estimates tokens in a single message using tiktoken or simple char count."""
|
86 |
+
if encoding:
|
87 |
+
return len(encoding.encode(message.get("content", "")))
|
88 |
+
else:
|
89 |
+
# Fallback for when tiktoken isn't available or for simple estimation
|
90 |
+
return len(message.get("content", "")) // 4 # Rough estimate: 1 token ~ 4 characters
|
91 |
+
|
92 |
+
def get_message_token_length(messages):
|
93 |
+
"""Calculates total tokens for a list of messages."""
|
94 |
+
total_tokens = 0
|
95 |
+
for message in messages:
|
96 |
+
total_tokens += count_tokens_in_message(message)
|
97 |
+
return total_tokens
|
98 |
+
|
99 |
+
def truncate_history(history, max_tokens):
|
100 |
+
"""
|
101 |
+
Truncates the chat history to fit within max_tokens.
|
102 |
+
Keeps the system message and recent messages.
|
103 |
+
"""
|
104 |
+
if not history:
|
105 |
+
return []
|
106 |
+
|
107 |
+
# Always keep the system message
|
108 |
+
system_message = history[0]
|
109 |
+
truncated_history = [system_message]
|
110 |
+
current_tokens = count_tokens_in_message(system_message)
|
111 |
+
|
112 |
+
# Add messages from most recent, until max_tokens is reached
|
113 |
+
for message in reversed(history[1:]): # Iterate from second-to-last to first user/assistant message
|
114 |
+
message_tokens = count_tokens_in_message(message)
|
115 |
+
if current_tokens + message_tokens <= max_tokens:
|
116 |
+
truncated_history.insert(1, message) # Insert after system message
|
117 |
+
current_tokens += message_tokens
|
118 |
+
else:
|
119 |
+
break # Stop adding if next message exceeds limit
|
120 |
+
|
121 |
+
return truncated_history
|
122 |
+
|
123 |
@app.get("/")
|
124 |
def root():
|
125 |
return {"message": "✅ Data Analysis AI API is live and optimized!"}
|
|
|
207 |
|
208 |
print(f"🧾 Prompt received for session {session_id}: {prompt}")
|
209 |
|
210 |
+
# Add the user's new message to a temporary list to check total length
|
211 |
+
current_messages = list(chat_histories[session_id]) # Create a copy
|
212 |
+
current_messages.append({"role": "user", "content": prompt})
|
213 |
+
|
214 |
+
# Truncate history if it exceeds the max context tokens
|
215 |
+
# We subtract a buffer for the new prompt itself and the expected response
|
216 |
+
# A rough estimate for prompt + response: 100 tokens (prompt) + 200 tokens (response) = 300 tokens
|
217 |
+
# So, MAX_CONTEXT_TOKENS - 300 for the actual history
|
218 |
+
effective_max_history_tokens = MAX_CONTEXT_TOKENS - count_tokens_in_message({"role": "user", "content": prompt}) - 200 # Buffer for response
|
219 |
+
|
220 |
+
if get_message_token_length(current_messages) > MAX_CONTEXT_TOKENS:
|
221 |
+
print(f"✂️ Truncating history for session {session_id}. Current tokens: {get_message_token_length(current_messages)}")
|
222 |
+
chat_histories[session_id] = truncate_history(current_messages, effective_max_history_tokens)
|
223 |
+
# Re-add the current user prompt after truncation
|
224 |
+
if chat_histories[session_id][-1]["role"] != "user" or chat_histories[session_id][-1]["content"] != prompt:
|
225 |
+
chat_histories[session_id].append({"role": "user", "content": prompt})
|
226 |
+
print(f"✅ History truncated. New tokens: {get_message_token_length(chat_histories[session_id])}")
|
227 |
+
else:
|
228 |
+
chat_histories[session_id] = current_messages # If not truncated, just update with the new message
|
229 |
|
230 |
try:
|
231 |
+
# Pass the (potentially truncated) chat history for context
|
232 |
response = llm.create_chat_completion(
|
233 |
messages=chat_histories[session_id],
|
234 |
+
max_tokens=256, # Further limit response length for faster generation
|
235 |
temperature=0.7, # Adjust temperature for creativity vs. coherence (0.0-1.0)
|
236 |
stop=["</s>"] # Stop sequence for TinyLlama Chat
|
237 |
)
|
|
|
243 |
|
244 |
return {
|
245 |
"response": ai_response_content,
|
246 |
+
"session_id": session_id, # Return the session_id so the client can use it for subsequent requests
|
247 |
+
"current_context_tokens": get_message_token_length(chat_histories[session_id])
|
248 |
}
|
249 |
except Exception as e:
|
250 |
print(f"❌ Error during generation for session {session_id}: {e}")
|