bibibi12345 commited on
Commit
806cf01
·
verified ·
1 Parent(s): 311bcf7

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +12 -11
app/main.py CHANGED
@@ -1175,8 +1175,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1175
  print("Prompt structure: Unknown format")
1176
 
1177
 
1178
- model_instance = genai.GenerativeModel(model_name=model_name) # Use genai.GenerativeModel
1179
-
1180
  if request.stream:
1181
  # Streaming call (Async)
1182
  response_id = f"chatcmpl-{int(time.time())}"
@@ -1188,9 +1187,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1188
  try:
1189
  # No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
1190
  print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1191
- async_responses = await model_instance.generate_content_stream_async( # Use await and async method
 
 
1192
  contents=prompt,
1193
- generation_config=current_gen_config, # Use generation_config parameter
1194
  # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1195
  )
1196
 
@@ -1198,12 +1199,10 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1198
  async for chunk in async_responses: # Use async for
1199
  first_chunk_received = True
1200
  # Determine candidate_index based on the chunk itself if possible, fallback to 0
1201
- # Note: Adjust this if the async stream chunk structure provides candidate index differently
1202
  candidate_index = 0 # Assuming default index for now
1203
- if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute (may change)
1204
  candidate_index = chunk._candidate_index
1205
  elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
1206
- # Or check standard candidate structure if available on chunk
1207
  candidate_index = chunk.candidates[0].index
1208
 
1209
  if hasattr(chunk, 'text') and chunk.text:
@@ -1218,8 +1217,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1218
  yield "data: [DONE]\n\n"
1219
 
1220
  # Return status based on content received
1221
- if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
1222
- raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
1223
 
1224
  except Exception as stream_error:
1225
  error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
@@ -1236,9 +1235,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1236
  # Non-streaming call (Async)
1237
  try:
1238
  print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1239
- response = await model_instance.generate_content_async( # Use await and async method
 
 
1240
  contents=prompt,
1241
- generation_config=current_gen_config, # Use generation_config parameter
1242
  # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1243
  )
1244
  if not is_response_valid(response):
 
1175
  print("Prompt structure: Unknown format")
1176
 
1177
 
1178
+ # Use the client.models object as in the original synchronous code
 
1179
  if request.stream:
1180
  # Streaming call (Async)
1181
  response_id = f"chatcmpl-{int(time.time())}"
 
1187
  try:
1188
  # No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
1189
  print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1190
+ # Call async stream method on client.models
1191
+ async_responses = await client.models.generate_content_stream_async(
1192
+ model=model_name, # Pass model name here
1193
  contents=prompt,
1194
+ generation_config=current_gen_config,
1195
  # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1196
  )
1197
 
 
1199
  async for chunk in async_responses: # Use async for
1200
  first_chunk_received = True
1201
  # Determine candidate_index based on the chunk itself if possible, fallback to 0
 
1202
  candidate_index = 0 # Assuming default index for now
1203
+ if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute
1204
  candidate_index = chunk._candidate_index
1205
  elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
 
1206
  candidate_index = chunk.candidates[0].index
1207
 
1208
  if hasattr(chunk, 'text') and chunk.text:
 
1217
  yield "data: [DONE]\n\n"
1218
 
1219
  # Return status based on content received
1220
+ if all_chunks_empty and first_chunk_received:
1221
+ raise ValueError("Streamed response contained only empty chunks")
1222
 
1223
  except Exception as stream_error:
1224
  error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
 
1235
  # Non-streaming call (Async)
1236
  try:
1237
  print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1238
+ # Call async method on client.models
1239
+ response = await client.models.generate_content_async(
1240
+ model=model_name, # Pass model name here
1241
  contents=prompt,
1242
+ generation_config=current_gen_config,
1243
  # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1244
  )
1245
  if not is_response_valid(response):