bibibi12345 commited on
Commit
5d7dc12
·
1 Parent(s): cdf27f4

added reasoning support

Browse files
Files changed (1) hide show
  1. app/routes/chat_api.py +85 -31
app/routes/chat_api.py CHANGED
@@ -1,4 +1,5 @@
1
  import asyncio
 
2
  import json # Needed for error streaming
3
  import random
4
  from fastapi import APIRouter, Depends, Request
@@ -7,6 +8,7 @@ from typing import List, Dict, Any
7
 
8
  # Google and OpenAI specific imports
9
  from google.genai import types
 
10
  from google import genai
11
  import openai
12
  from credentials_manager import _refresh_auth
@@ -229,7 +231,6 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
229
  async for chunk in stream_response:
230
  try:
231
  chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
232
- print(chunk_as_dict)
233
 
234
  # Safely navigate and check for thought flag
235
  choices = chunk_as_dict.get('choices')
@@ -253,6 +254,7 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
253
  del delta['extra_content']
254
 
255
  # Yield the (potentially modified) dictionary as JSON
 
256
  yield f"data: {json.dumps(chunk_as_dict)}\n\n"
257
 
258
  except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
@@ -290,39 +292,91 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
290
  extra_body=openai_extra_body
291
  )
292
  response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
293
-
294
- # Process reasoning_tokens for non-streaming response
295
  try:
296
  usage = response_dict.get('usage')
 
 
297
  if usage and isinstance(usage, dict):
298
- completion_details = usage.get('completion_tokens_details')
299
- if completion_details and isinstance(completion_details, dict):
300
- num_reasoning_tokens = completion_details.get('reasoning_tokens')
301
-
302
- if isinstance(num_reasoning_tokens, int) and num_reasoning_tokens > 0:
303
- choices = response_dict.get('choices')
304
- if choices and isinstance(choices, list) and len(choices) > 0:
305
- # Ensure choices[0] and message are dicts, model_dump makes them so
306
- message_dict = choices[0].get('message')
307
- if message_dict and isinstance(message_dict, dict):
308
- full_content = message_dict.get('content')
309
- if isinstance(full_content, str): # Ensure content is a string
310
- reasoning_text = full_content[:num_reasoning_tokens]
311
- actual_content = full_content[num_reasoning_tokens:]
312
-
313
- message_dict['reasoning_content'] = reasoning_text
314
- message_dict['content'] = actual_content
315
-
316
- # Clean up Vertex-specific field
317
- del completion_details['reasoning_tokens']
318
- if not completion_details: # If dict is now empty
319
- del usage['completion_tokens_details']
320
- if not usage: # If dict is now empty
321
- del response_dict['usage']
322
- except Exception as e_non_stream_reasoning:
323
- print(f"WARNING: Could not process non-streaming reasoning tokens for model {request.model}: {e_non_stream_reasoning}. Response will be returned as is from Vertex.")
324
- # Fallthrough to return response_dict as is if processing fails
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  return JSONResponse(content=response_dict)
327
  except Exception as generate_error:
328
  error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
@@ -396,4 +450,4 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
396
  except Exception as e:
397
  error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
398
  print(error_msg)
399
- return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
 
1
  import asyncio
2
+ import base64 # Ensure base64 is imported
3
  import json # Needed for error streaming
4
  import random
5
  from fastapi import APIRouter, Depends, Request
 
8
 
9
  # Google and OpenAI specific imports
10
  from google.genai import types
11
+ from google.genai.types import HttpOptions # Added for compute_tokens
12
  from google import genai
13
  import openai
14
  from credentials_manager import _refresh_auth
 
231
  async for chunk in stream_response:
232
  try:
233
  chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
 
234
 
235
  # Safely navigate and check for thought flag
236
  choices = chunk_as_dict.get('choices')
 
254
  del delta['extra_content']
255
 
256
  # Yield the (potentially modified) dictionary as JSON
257
+ print(chunk_as_dict)
258
  yield f"data: {json.dumps(chunk_as_dict)}\n\n"
259
 
260
  except Exception as chunk_processing_error: # Catch errors from dict manipulation or json.dumps
 
292
  extra_body=openai_extra_body
293
  )
294
  response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
295
+
 
296
  try:
297
  usage = response_dict.get('usage')
298
+ vertex_completion_tokens = 0
299
+
300
  if usage and isinstance(usage, dict):
301
+ vertex_completion_tokens = usage.get('completion_tokens')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ choices = response_dict.get('choices')
304
+ if choices and isinstance(choices, list) and len(choices) > 0:
305
+ message_dict = choices[0].get('message')
306
+ if message_dict and isinstance(message_dict, dict):
307
+ # Always remove extra_content from the message if it exists, before any splitting
308
+ if 'extra_content' in message_dict:
309
+ del message_dict['extra_content']
310
+ print("DEBUG: Removed 'extra_content' from response message.")
311
+
312
+ if isinstance(vertex_completion_tokens, int) and vertex_completion_tokens > 0:
313
+ full_content = message_dict.get('content')
314
+ if isinstance(full_content, str) and full_content:
315
+
316
+ def _get_token_strings_and_split_texts_sync(creds, proj_id, loc, model_id_for_tokenizer, text_to_tokenize, num_completion_tokens_from_usage):
317
+ sync_tokenizer_client = genai.Client(
318
+ vertexai=True, credentials=creds, project=proj_id, location=loc,
319
+ http_options=HttpOptions(api_version="v1")
320
+ )
321
+ if not text_to_tokenize: return "", text_to_tokenize, [] # No reasoning, original content, empty token list
322
+
323
+ token_compute_response = sync_tokenizer_client.models.compute_tokens(
324
+ model=model_id_for_tokenizer, contents=text_to_tokenize
325
+ )
326
+
327
+ all_final_token_strings = []
328
+ if token_compute_response.tokens_info:
329
+ for token_info_item in token_compute_response.tokens_info:
330
+ for api_token_bytes in token_info_item.tokens:
331
+ intermediate_str = api_token_bytes.decode('utf-8', errors='replace')
332
+ final_token_text = ""
333
+ try:
334
+ b64_decoded_bytes = base64.b64decode(intermediate_str)
335
+ final_token_text = b64_decoded_bytes.decode('utf-8', errors='replace')
336
+ except Exception:
337
+ final_token_text = intermediate_str
338
+ all_final_token_strings.append(final_token_text)
339
+
340
+ if not all_final_token_strings: # Should not happen if text_to_tokenize is not empty
341
+ return "", text_to_tokenize, []
342
+
343
+ if not (0 < num_completion_tokens_from_usage <= len(all_final_token_strings)):
344
+ print(f"WARNING_TOKEN_SPLIT: num_completion_tokens_from_usage ({num_completion_tokens_from_usage}) is invalid for total client-tokenized tokens ({len(all_final_token_strings)}). Returning full content as 'content'.")
345
+ return "", "".join(all_final_token_strings), all_final_token_strings
346
+
347
+ completion_part_tokens = all_final_token_strings[-num_completion_tokens_from_usage:]
348
+ reasoning_part_tokens = all_final_token_strings[:-num_completion_tokens_from_usage]
349
+
350
+ reasoning_output_str = "".join(reasoning_part_tokens)
351
+ completion_output_str = "".join(completion_part_tokens)
352
+
353
+ return reasoning_output_str, completion_output_str, all_final_token_strings
354
+
355
+ model_id_for_tokenizer = base_model_name
356
+
357
+ reasoning_text, actual_content, dbg_all_tokens = await asyncio.to_thread(
358
+ _get_token_strings_and_split_texts_sync,
359
+ rotated_credentials, PROJECT_ID, LOCATION,
360
+ model_id_for_tokenizer, full_content, vertex_completion_tokens
361
+ )
362
+
363
+ message_dict['content'] = actual_content # Set the new content (potentially from joined tokens)
364
+ if reasoning_text: # Only add reasoning_content if it's not empty
365
+ message_dict['reasoning_content'] = reasoning_text
366
+ print(f"DEBUG_REASONING_SPLIT_DIRECT_JOIN: Successful. Reasoning len: {len(reasoning_text)}. Content len: {len(actual_content)}")
367
+ print(f" Vertex completion_tokens: {vertex_completion_tokens}. Our tokenizer total tokens: {len(dbg_all_tokens)}")
368
+ elif "".join(dbg_all_tokens) != full_content : # Content was re-joined from tokens but no reasoning
369
+ print(f"INFO: Content reconstructed from tokens. Original len: {len(full_content)}, Reconstructed len: {len(actual_content)}")
370
+ # else: No reasoning, and content is original full_content because num_completion_tokens was invalid or zero.
371
+
372
+ else:
373
+ print(f"WARNING: Full content is not a string or is empty. Cannot perform split. Content: {full_content}")
374
+ else:
375
+ print(f"INFO: No positive vertex_completion_tokens ({vertex_completion_tokens}) found in usage, or no message content. No split performed.")
376
+
377
+ except Exception as e_reasoning_processing:
378
+ print(f"WARNING: Error during non-streaming reasoning token processing for model {request.model} due to: {e_reasoning_processing}.")
379
+
380
  return JSONResponse(content=response_dict)
381
  except Exception as generate_error:
382
  error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
 
450
  except Exception as e:
451
  error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
452
  print(error_msg)
453
+ return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))