Priyanshukr-1 commited on
Commit
01e79df
·
verified ·
1 Parent(s): d9ba98f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -17
app.py CHANGED
@@ -7,11 +7,11 @@ import psutil
7
  import multiprocessing
8
  import time
9
  import uuid # For generating unique session IDs
 
10
 
11
  app = FastAPI()
12
 
13
  # === Model Config ===
14
- # Corrected REPO_ID to use TheBloke's GGUF version of TinyLlama
15
  REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
16
  FILENAME = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" # Q4_K_M is a good balance of size and quality
17
  MODEL_DIR = "models"
@@ -38,14 +38,8 @@ else:
38
  model_path = MODEL_PATH
39
 
40
  # === Optimal thread usage ===
41
- # psutil.cpu_count(logical=True) gives the number of logical cores (threads)
42
- # psutil.cpu_count(logical=False) gives the number of physical cores
43
- # For llama.cpp, n_threads often performs best when set to the number of physical cores,
44
- # or slightly more, but not exceeding logical cores. Experimentation is key.
45
  logical_cores = psutil.cpu_count(logical=True)
46
  physical_cores = psutil.cpu_count(logical=False)
47
- # A common recommendation is to use physical cores or physical_cores * 2
48
- # Let's try physical_cores for a start, or a fixed value if physical_cores is too low.
49
  recommended_threads = max(1, physical_cores) # Ensure at least 1 thread
50
 
51
  print(f"Detected physical cores: {physical_cores}, logical cores: {logical_cores}")
@@ -55,10 +49,10 @@ print(f"Using n_threads: {recommended_threads}")
55
  try:
56
  llm = Llama(
57
  model_path=model_path,
58
- n_ctx=1024, # Reduced context for TinyLlama, can increase if memory allows and context is critical
59
  n_threads=recommended_threads,
60
- use_mlock=True, # Lock model in RAM for faster access (good for stability on CPU)
61
- n_gpu_layers=0, # CPU only, keep at 0 for Hugging Face free tier
62
  chat_format="chatml", # TinyLlama Chat uses ChatML format
63
  verbose=False
64
  )
@@ -67,11 +61,65 @@ except Exception as e:
67
  print(f"❌ Error loading Llama model: {e}")
68
  exit(1)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # === Global dictionary to store chat histories per session ===
71
- # In a production environment, this should be replaced with a persistent storage
72
- # like Redis, a database, or a dedicated session management system.
73
  chat_histories = {}
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @app.get("/")
76
  def root():
77
  return {"message": "✅ Data Analysis AI API is live and optimized!"}
@@ -159,14 +207,31 @@ async def generate(request: Request):
159
 
160
  print(f"🧾 Prompt received for session {session_id}: {prompt}")
161
 
162
- # Add the user's new message to the history for this session
163
- chat_histories[session_id].append({"role": "user", "content": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  try:
166
- # Pass the entire chat history for context
167
  response = llm.create_chat_completion(
168
  messages=chat_histories[session_id],
169
- max_tokens=512, # Limit response length for faster generation
170
  temperature=0.7, # Adjust temperature for creativity vs. coherence (0.0-1.0)
171
  stop=["</s>"] # Stop sequence for TinyLlama Chat
172
  )
@@ -178,7 +243,8 @@ async def generate(request: Request):
178
 
179
  return {
180
  "response": ai_response_content,
181
- "session_id": session_id # Return the session_id so the client can use it for subsequent requests
 
182
  }
183
  except Exception as e:
184
  print(f"❌ Error during generation for session {session_id}: {e}")
 
7
  import multiprocessing
8
  import time
9
  import uuid # For generating unique session IDs
10
+ import tiktoken # For estimating token count
11
 
12
  app = FastAPI()
13
 
14
  # === Model Config ===
 
15
  REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
16
  FILENAME = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" # Q4_K_M is a good balance of size and quality
17
  MODEL_DIR = "models"
 
38
  model_path = MODEL_PATH
39
 
40
  # === Optimal thread usage ===
 
 
 
 
41
  logical_cores = psutil.cpu_count(logical=True)
42
  physical_cores = psutil.cpu_count(logical=False)
 
 
43
  recommended_threads = max(1, physical_cores) # Ensure at least 1 thread
44
 
45
  print(f"Detected physical cores: {physical_cores}, logical cores: {logical_cores}")
 
49
  try:
50
  llm = Llama(
51
  model_path=model_path,
52
+ n_ctx=1024, # Context window size for the model
53
  n_threads=recommended_threads,
54
+ use_mlock=True, # Lock model in RAM for faster access
55
+ n_gpu_layers=0, # CPU only
56
  chat_format="chatml", # TinyLlama Chat uses ChatML format
57
  verbose=False
58
  )
 
61
  print(f"❌ Error loading Llama model: {e}")
62
  exit(1)
63
 
64
+ # Initialize tiktoken encoder for token counting (approximate for GGUF models, but good enough)
65
+ # For TinyLlama, we'll use a generic encoder or one that's close enough.
66
+ # 'cl100k_base' is common for OpenAI models, but a good approximation for many others.
67
+ # For more precise counts for GGUF, you might need to use the model's tokenizer if available
68
+ # or rely on llama.cpp's internal tokenization (which is harder to access directly).
69
+ # For simplicity and general estimation, cl100k_base is often used.
70
+ try:
71
+ encoding = tiktoken.get_encoding("cl100k_base")
72
+ except Exception:
73
+ print("⚠️ Could not load tiktoken 'cl100k_base' encoding. Using basic len() for token estimation.")
74
+ encoding = None
75
+
76
  # === Global dictionary to store chat histories per session ===
 
 
77
  chat_histories = {}
78
 
79
+ # === Context Truncation Settings ===
80
+ # Max tokens for the entire conversation history (input to the model)
81
+ # This should be less than n_ctx to leave room for the new prompt and generated response.
82
+ MAX_CONTEXT_TOKENS = 800 # Keep total input context below this, leaving 224 tokens for new prompt + response
83
+
84
+ def count_tokens_in_message(message):
85
+ """Estimates tokens in a single message using tiktoken or simple char count."""
86
+ if encoding:
87
+ return len(encoding.encode(message.get("content", "")))
88
+ else:
89
+ # Fallback for when tiktoken isn't available or for simple estimation
90
+ return len(message.get("content", "")) // 4 # Rough estimate: 1 token ~ 4 characters
91
+
92
+ def get_message_token_length(messages):
93
+ """Calculates total tokens for a list of messages."""
94
+ total_tokens = 0
95
+ for message in messages:
96
+ total_tokens += count_tokens_in_message(message)
97
+ return total_tokens
98
+
99
+ def truncate_history(history, max_tokens):
100
+ """
101
+ Truncates the chat history to fit within max_tokens.
102
+ Keeps the system message and recent messages.
103
+ """
104
+ if not history:
105
+ return []
106
+
107
+ # Always keep the system message
108
+ system_message = history[0]
109
+ truncated_history = [system_message]
110
+ current_tokens = count_tokens_in_message(system_message)
111
+
112
+ # Add messages from most recent, until max_tokens is reached
113
+ for message in reversed(history[1:]): # Iterate from second-to-last to first user/assistant message
114
+ message_tokens = count_tokens_in_message(message)
115
+ if current_tokens + message_tokens <= max_tokens:
116
+ truncated_history.insert(1, message) # Insert after system message
117
+ current_tokens += message_tokens
118
+ else:
119
+ break # Stop adding if next message exceeds limit
120
+
121
+ return truncated_history
122
+
123
  @app.get("/")
124
  def root():
125
  return {"message": "✅ Data Analysis AI API is live and optimized!"}
 
207
 
208
  print(f"🧾 Prompt received for session {session_id}: {prompt}")
209
 
210
+ # Add the user's new message to a temporary list to check total length
211
+ current_messages = list(chat_histories[session_id]) # Create a copy
212
+ current_messages.append({"role": "user", "content": prompt})
213
+
214
+ # Truncate history if it exceeds the max context tokens
215
+ # We subtract a buffer for the new prompt itself and the expected response
216
+ # A rough estimate for prompt + response: 100 tokens (prompt) + 200 tokens (response) = 300 tokens
217
+ # So, MAX_CONTEXT_TOKENS - 300 for the actual history
218
+ effective_max_history_tokens = MAX_CONTEXT_TOKENS - count_tokens_in_message({"role": "user", "content": prompt}) - 200 # Buffer for response
219
+
220
+ if get_message_token_length(current_messages) > MAX_CONTEXT_TOKENS:
221
+ print(f"✂️ Truncating history for session {session_id}. Current tokens: {get_message_token_length(current_messages)}")
222
+ chat_histories[session_id] = truncate_history(current_messages, effective_max_history_tokens)
223
+ # Re-add the current user prompt after truncation
224
+ if chat_histories[session_id][-1]["role"] != "user" or chat_histories[session_id][-1]["content"] != prompt:
225
+ chat_histories[session_id].append({"role": "user", "content": prompt})
226
+ print(f"✅ History truncated. New tokens: {get_message_token_length(chat_histories[session_id])}")
227
+ else:
228
+ chat_histories[session_id] = current_messages # If not truncated, just update with the new message
229
 
230
  try:
231
+ # Pass the (potentially truncated) chat history for context
232
  response = llm.create_chat_completion(
233
  messages=chat_histories[session_id],
234
+ max_tokens=256, # Further limit response length for faster generation
235
  temperature=0.7, # Adjust temperature for creativity vs. coherence (0.0-1.0)
236
  stop=["</s>"] # Stop sequence for TinyLlama Chat
237
  )
 
243
 
244
  return {
245
  "response": ai_response_content,
246
+ "session_id": session_id, # Return the session_id so the client can use it for subsequent requests
247
+ "current_context_tokens": get_message_token_length(chat_histories[session_id])
248
  }
249
  except Exception as e:
250
  print(f"❌ Error during generation for session {session_id}: {e}")