Tim Luka Horstmann commited on
Commit
54039cd
·
1 Parent(s): b77d28c

Bigger model

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -20,17 +20,18 @@ logger = logging.getLogger(__name__)
20
 
21
  app = FastAPI()
22
 
23
- # Authenticate with Hugging Fac
24
  hf_token = os.getenv("HF_TOKEN")
25
  if not hf_token:
26
  logger.error("HF_TOKEN environment variable not set.")
27
  raise ValueError("HF_TOKEN not set")
28
  login(token=hf_token)
29
 
30
- # Models
31
  sentence_transformer_model = "all-MiniLM-L6-v2"
32
- repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-3B-GGUF"
33
- filename = "deepcogito_cogito-v1-preview-llama-3B-Q4_K_M.gguf" # Updated to Cogito Q4_K_M
 
34
 
35
  # Define FAQs (unchanged)
36
  faqs = [
@@ -55,17 +56,17 @@ try:
55
  faiss_index.add(cv_embeddings)
56
  logger.info("FAISS index built successfully")
57
 
58
- # Load embedding model (unchanged)
59
  logger.info("Loading SentenceTransformer model")
60
  embedder = SentenceTransformer(sentence_transformer_model, device="cpu")
61
  logger.info("SentenceTransformer model loaded")
62
 
63
- # Compute FAQ embeddings (unchanged)
64
  faq_questions = [faq["question"] for faq in faqs]
65
  faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
66
  faiss.normalize_L2(faq_embeddings)
67
 
68
- # Load Cogito model
69
  logger.info(f"Loading {filename} model")
70
  model_path = hf_hub_download(
71
  repo_id=repo_id,
@@ -73,11 +74,12 @@ try:
73
  local_dir="/app/cache" if os.getenv("HF_HOME") else None,
74
  token=hf_token,
75
  )
 
76
  generator = Llama(
77
  model_path=model_path,
78
  n_ctx=2048,
79
  n_threads=2,
80
- n_batch=512,
81
  n_gpu_layers=0,
82
  verbose=True,
83
  )
@@ -97,6 +99,7 @@ def retrieve_context(query, top_k=2):
97
  except Exception as e:
98
  logger.error(f"Error in retrieve_context: {str(e)}")
99
  raise
 
100
  # Load the full CV at startup
101
  with open("cv_text.txt", "r", encoding="utf-8") as f:
102
  full_cv_text = f.read()
@@ -136,7 +139,7 @@ def stream_response(query):
136
  {"role": "user", "content": query}
137
  ]
138
 
139
- buffer = ""
140
  for chunk in generator.create_chat_completion(
141
  messages=messages,
142
  max_tokens=512,
@@ -145,22 +148,14 @@ def stream_response(query):
145
  top_p=0.7,
146
  repeat_penalty=1.2
147
  ):
148
- text = chunk['choices'][0]['delta'].get('content', '')
149
- if text:
150
- buffer += text
151
- if not first_token_logged and time.time() - start_time > 0:
152
  logger.info(f"First token time: {time.time() - start_time:.2f}s")
153
  first_token_logged = True
154
-
155
- # More natural chunking - yield complete sentences when possible
156
- if any(buffer.endswith(char) for char in [".", "!", "?"]) or len(buffer) > 30:
157
- yield f"data: {buffer}\n\n"
158
- buffer = ""
159
- if buffer: # Flush remaining buffer
160
- yield f"data: {buffer}\n\n"
161
  yield "data: [DONE]\n\n"
162
 
163
-
164
  class QueryRequest(BaseModel):
165
  data: list
166
 
@@ -178,21 +173,18 @@ async def health_check():
178
  @app.get("/model_info")
179
  async def model_info():
180
  return {
181
- "model_name": "deepcogito_cogito-v1-preview-llama-3B-GGUF",
182
- "model_size": "3B",
183
- "quantization": "Q4_K_M",
184
  "embedding_model": sentence_transformer_model,
185
  "faiss_index_size": len(cv_chunks),
186
  "faiss_index_dim": cv_embeddings.shape[1],
187
  }
188
 
189
- # Optimize the model loading process
190
-
191
- # Use a smaller warm-up query
192
  @app.on_event("startup")
193
  async def warm_up_model():
194
  logger.info("Warming up the model...")
195
- dummy_query = "Hello" # Shorter query
196
- # Just execute once to prime the model without waiting for completion
197
  next(stream_response(dummy_query))
198
- logger.info("Model warm-up initiated.")
 
20
 
21
  app = FastAPI()
22
 
23
+ # Authenticate with Hugging Face
24
  hf_token = os.getenv("HF_TOKEN")
25
  if not hf_token:
26
  logger.error("HF_TOKEN environment variable not set.")
27
  raise ValueError("HF_TOKEN not set")
28
  login(token=hf_token)
29
 
30
+ # Models Configuration
31
  sentence_transformer_model = "all-MiniLM-L6-v2"
32
+ # Upgrade to the 8B model and choose Q4_0 quantization for a good balance of performance and resource usage.
33
+ repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
34
+ filename = "deepcogito_cogito-v1-preview-llama-8B-Q4_KM.gguf" # New 8B model with Q4_0 quantization
35
 
36
  # Define FAQs (unchanged)
37
  faqs = [
 
56
  faiss_index.add(cv_embeddings)
57
  logger.info("FAISS index built successfully")
58
 
59
+ # Load embedding model
60
  logger.info("Loading SentenceTransformer model")
61
  embedder = SentenceTransformer(sentence_transformer_model, device="cpu")
62
  logger.info("SentenceTransformer model loaded")
63
 
64
+ # Compute FAQ embeddings
65
  faq_questions = [faq["question"] for faq in faqs]
66
  faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
67
  faiss.normalize_L2(faq_embeddings)
68
 
69
+ # Load the 8B Cogito model
70
  logger.info(f"Loading {filename} model")
71
  model_path = hf_hub_download(
72
  repo_id=repo_id,
 
74
  local_dir="/app/cache" if os.getenv("HF_HOME") else None,
75
  token=hf_token,
76
  )
77
+ # Lower n_batch for more frequent token streaming.
78
  generator = Llama(
79
  model_path=model_path,
80
  n_ctx=2048,
81
  n_threads=2,
82
+ n_batch=128, # Adjusted for lower latency on streaming responses
83
  n_gpu_layers=0,
84
  verbose=True,
85
  )
 
99
  except Exception as e:
100
  logger.error(f"Error in retrieve_context: {str(e)}")
101
  raise
102
+
103
  # Load the full CV at startup
104
  with open("cv_text.txt", "r", encoding="utf-8") as f:
105
  full_cv_text = f.read()
 
139
  {"role": "user", "content": query}
140
  ]
141
 
142
+ # Stream tokens immediately as they are generated, avoiding additional buffering.
143
  for chunk in generator.create_chat_completion(
144
  messages=messages,
145
  max_tokens=512,
 
148
  top_p=0.7,
149
  repeat_penalty=1.2
150
  ):
151
+ token = chunk['choices'][0]['delta'].get('content', '')
152
+ if token:
153
+ if not first_token_logged:
 
154
  logger.info(f"First token time: {time.time() - start_time:.2f}s")
155
  first_token_logged = True
156
+ yield f"data: {token}\n\n"
 
 
 
 
 
 
157
  yield "data: [DONE]\n\n"
158
 
 
159
  class QueryRequest(BaseModel):
160
  data: list
161
 
 
173
  @app.get("/model_info")
174
  async def model_info():
175
  return {
176
+ "model_name": "deepcogito_cogito-v1-preview-llama-8B-GGUF",
177
+ "model_size": "8B",
178
+ "quantization": "Q4_KM",
179
  "embedding_model": sentence_transformer_model,
180
  "faiss_index_size": len(cv_chunks),
181
  "faiss_index_dim": cv_embeddings.shape[1],
182
  }
183
 
184
+ # Use a smaller warm-up query to prime the model without extensive delay.
 
 
185
  @app.on_event("startup")
186
  async def warm_up_model():
187
  logger.info("Warming up the model...")
188
+ dummy_query = "Hello"
 
189
  next(stream_response(dummy_query))
190
+ logger.info("Model warm-up initiated.")