Commit
·
5d7dc12
1
Parent(s):
cdf27f4
added reasoning support
Browse files- app/routes/chat_api.py +85 -31
app/routes/chat_api.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import asyncio
|
|
|
2 |
import json # Needed for error streaming
|
3 |
import random
|
4 |
from fastapi import APIRouter, Depends, Request
|
@@ -7,6 +8,7 @@ from typing import List, Dict, Any
|
|
7 |
|
8 |
# Google and OpenAI specific imports
|
9 |
from google.genai import types
|
|
|
10 |
from google import genai
|
11 |
import openai
|
12 |
from credentials_manager import _refresh_auth
|
@@ -229,7 +231,6 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
229 |
async for chunk in stream_response:
|
230 |
try:
|
231 |
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
232 |
-
print(chunk_as_dict)
|
233 |
|
234 |
# Safely navigate and check for thought flag
|
235 |
choices = chunk_as_dict.get('choices')
|
@@ -253,6 +254,7 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
253 |
del delta['extra_content']
|
254 |
|
255 |
# Yield the (potentially modified) dictionary as JSON
|
|
|
256 |
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
257 |
|
258 |
except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
|
@@ -290,39 +292,91 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
290 |
extra_body=openai_extra_body
|
291 |
)
|
292 |
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
293 |
-
|
294 |
-
# Process reasoning_tokens for non-streaming response
|
295 |
try:
|
296 |
usage = response_dict.get('usage')
|
|
|
|
|
297 |
if usage and isinstance(usage, dict):
|
298 |
-
|
299 |
-
if completion_details and isinstance(completion_details, dict):
|
300 |
-
num_reasoning_tokens = completion_details.get('reasoning_tokens')
|
301 |
-
|
302 |
-
if isinstance(num_reasoning_tokens, int) and num_reasoning_tokens > 0:
|
303 |
-
choices = response_dict.get('choices')
|
304 |
-
if choices and isinstance(choices, list) and len(choices) > 0:
|
305 |
-
# Ensure choices[0] and message are dicts, model_dump makes them so
|
306 |
-
message_dict = choices[0].get('message')
|
307 |
-
if message_dict and isinstance(message_dict, dict):
|
308 |
-
full_content = message_dict.get('content')
|
309 |
-
if isinstance(full_content, str): # Ensure content is a string
|
310 |
-
reasoning_text = full_content[:num_reasoning_tokens]
|
311 |
-
actual_content = full_content[num_reasoning_tokens:]
|
312 |
-
|
313 |
-
message_dict['reasoning_content'] = reasoning_text
|
314 |
-
message_dict['content'] = actual_content
|
315 |
-
|
316 |
-
# Clean up Vertex-specific field
|
317 |
-
del completion_details['reasoning_tokens']
|
318 |
-
if not completion_details: # If dict is now empty
|
319 |
-
del usage['completion_tokens_details']
|
320 |
-
if not usage: # If dict is now empty
|
321 |
-
del response_dict['usage']
|
322 |
-
except Exception as e_non_stream_reasoning:
|
323 |
-
print(f"WARNING: Could not process non-streaming reasoning tokens for model {request.model}: {e_non_stream_reasoning}. Response will be returned as is from Vertex.")
|
324 |
-
# Fallthrough to return response_dict as is if processing fails
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
return JSONResponse(content=response_dict)
|
327 |
except Exception as generate_error:
|
328 |
error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
@@ -396,4 +450,4 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
396 |
except Exception as e:
|
397 |
error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
|
398 |
print(error_msg)
|
399 |
-
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|
|
|
1 |
import asyncio
|
2 |
+
import base64 # Ensure base64 is imported
|
3 |
import json # Needed for error streaming
|
4 |
import random
|
5 |
from fastapi import APIRouter, Depends, Request
|
|
|
8 |
|
9 |
# Google and OpenAI specific imports
|
10 |
from google.genai import types
|
11 |
+
from google.genai.types import HttpOptions # Added for compute_tokens
|
12 |
from google import genai
|
13 |
import openai
|
14 |
from credentials_manager import _refresh_auth
|
|
|
231 |
async for chunk in stream_response:
|
232 |
try:
|
233 |
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
|
|
234 |
|
235 |
# Safely navigate and check for thought flag
|
236 |
choices = chunk_as_dict.get('choices')
|
|
|
254 |
del delta['extra_content']
|
255 |
|
256 |
# Yield the (potentially modified) dictionary as JSON
|
257 |
+
print(chunk_as_dict)
|
258 |
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
259 |
|
260 |
except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
|
|
|
292 |
extra_body=openai_extra_body
|
293 |
)
|
294 |
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
295 |
+
|
|
|
296 |
try:
|
297 |
usage = response_dict.get('usage')
|
298 |
+
vertex_completion_tokens = 0
|
299 |
+
|
300 |
if usage and isinstance(usage, dict):
|
301 |
+
vertex_completion_tokens = usage.get('completion_tokens')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
choices = response_dict.get('choices')
|
304 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
305 |
+
message_dict = choices[0].get('message')
|
306 |
+
if message_dict and isinstance(message_dict, dict):
|
307 |
+
# Always remove extra_content from the message if it exists, before any splitting
|
308 |
+
if 'extra_content' in message_dict:
|
309 |
+
del message_dict['extra_content']
|
310 |
+
print("DEBUG: Removed 'extra_content' from response message.")
|
311 |
+
|
312 |
+
if isinstance(vertex_completion_tokens, int) and vertex_completion_tokens > 0:
|
313 |
+
full_content = message_dict.get('content')
|
314 |
+
if isinstance(full_content, str) and full_content:
|
315 |
+
|
316 |
+
def _get_token_strings_and_split_texts_sync(creds, proj_id, loc, model_id_for_tokenizer, text_to_tokenize, num_completion_tokens_from_usage):
|
317 |
+
sync_tokenizer_client = genai.Client(
|
318 |
+
vertexai=True, credentials=creds, project=proj_id, location=loc,
|
319 |
+
http_options=HttpOptions(api_version="v1")
|
320 |
+
)
|
321 |
+
if not text_to_tokenize: return "", text_to_tokenize, [] # No reasoning, original content, empty token list
|
322 |
+
|
323 |
+
token_compute_response = sync_tokenizer_client.models.compute_tokens(
|
324 |
+
model=model_id_for_tokenizer, contents=text_to_tokenize
|
325 |
+
)
|
326 |
+
|
327 |
+
all_final_token_strings = []
|
328 |
+
if token_compute_response.tokens_info:
|
329 |
+
for token_info_item in token_compute_response.tokens_info:
|
330 |
+
for api_token_bytes in token_info_item.tokens:
|
331 |
+
intermediate_str = api_token_bytes.decode('utf-8', errors='replace')
|
332 |
+
final_token_text = ""
|
333 |
+
try:
|
334 |
+
b64_decoded_bytes = base64.b64decode(intermediate_str)
|
335 |
+
final_token_text = b64_decoded_bytes.decode('utf-8', errors='replace')
|
336 |
+
except Exception:
|
337 |
+
final_token_text = intermediate_str
|
338 |
+
all_final_token_strings.append(final_token_text)
|
339 |
+
|
340 |
+
if not all_final_token_strings: # Should not happen if text_to_tokenize is not empty
|
341 |
+
return "", text_to_tokenize, []
|
342 |
+
|
343 |
+
if not (0 < num_completion_tokens_from_usage <= len(all_final_token_strings)):
|
344 |
+
print(f"WARNING_TOKEN_SPLIT: num_completion_tokens_from_usage ({num_completion_tokens_from_usage}) is invalid for total client-tokenized tokens ({len(all_final_token_strings)}). Returning full content as 'content'.")
|
345 |
+
return "", "".join(all_final_token_strings), all_final_token_strings
|
346 |
+
|
347 |
+
completion_part_tokens = all_final_token_strings[-num_completion_tokens_from_usage:]
|
348 |
+
reasoning_part_tokens = all_final_token_strings[:-num_completion_tokens_from_usage]
|
349 |
+
|
350 |
+
reasoning_output_str = "".join(reasoning_part_tokens)
|
351 |
+
completion_output_str = "".join(completion_part_tokens)
|
352 |
+
|
353 |
+
return reasoning_output_str, completion_output_str, all_final_token_strings
|
354 |
+
|
355 |
+
model_id_for_tokenizer = base_model_name
|
356 |
+
|
357 |
+
reasoning_text, actual_content, dbg_all_tokens = await asyncio.to_thread(
|
358 |
+
_get_token_strings_and_split_texts_sync,
|
359 |
+
rotated_credentials, PROJECT_ID, LOCATION,
|
360 |
+
model_id_for_tokenizer, full_content, vertex_completion_tokens
|
361 |
+
)
|
362 |
+
|
363 |
+
message_dict['content'] = actual_content # Set the new content (potentially from joined tokens)
|
364 |
+
if reasoning_text: # Only add reasoning_content if it's not empty
|
365 |
+
message_dict['reasoning_content'] = reasoning_text
|
366 |
+
print(f"DEBUG_REASONING_SPLIT_DIRECT_JOIN: Successful. Reasoning len: {len(reasoning_text)}. Content len: {len(actual_content)}")
|
367 |
+
print(f" Vertex completion_tokens: {vertex_completion_tokens}. Our tokenizer total tokens: {len(dbg_all_tokens)}")
|
368 |
+
elif "".join(dbg_all_tokens) != full_content : # Content was re-joined from tokens but no reasoning
|
369 |
+
print(f"INFO: Content reconstructed from tokens. Original len: {len(full_content)}, Reconstructed len: {len(actual_content)}")
|
370 |
+
# else: No reasoning, and content is original full_content because num_completion_tokens was invalid or zero.
|
371 |
+
|
372 |
+
else:
|
373 |
+
print(f"WARNING: Full content is not a string or is empty. Cannot perform split. Content: {full_content}")
|
374 |
+
else:
|
375 |
+
print(f"INFO: No positive vertex_completion_tokens ({vertex_completion_tokens}) found in usage, or no message content. No split performed.")
|
376 |
+
|
377 |
+
except Exception as e_reasoning_processing:
|
378 |
+
print(f"WARNING: Error during non-streaming reasoning token processing for model {request.model} due to: {e_reasoning_processing}.")
|
379 |
+
|
380 |
return JSONResponse(content=response_dict)
|
381 |
except Exception as generate_error:
|
382 |
error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
|
|
450 |
except Exception as e:
|
451 |
error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
|
452 |
print(error_msg)
|
453 |
+
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|