Update app/main.py
Browse files- app/main.py +12 -11
app/main.py
CHANGED
@@ -1175,8 +1175,7 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1175 |
print("Prompt structure: Unknown format")
|
1176 |
|
1177 |
|
1178 |
-
|
1179 |
-
|
1180 |
if request.stream:
|
1181 |
# Streaming call (Async)
|
1182 |
response_id = f"chatcmpl-{int(time.time())}"
|
@@ -1188,9 +1187,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1188 |
try:
|
1189 |
# No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
|
1190 |
print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1191 |
-
|
|
|
|
|
1192 |
contents=prompt,
|
1193 |
-
generation_config=current_gen_config,
|
1194 |
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1195 |
)
|
1196 |
|
@@ -1198,12 +1199,10 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1198 |
async for chunk in async_responses: # Use async for
|
1199 |
first_chunk_received = True
|
1200 |
# Determine candidate_index based on the chunk itself if possible, fallback to 0
|
1201 |
-
# Note: Adjust this if the async stream chunk structure provides candidate index differently
|
1202 |
candidate_index = 0 # Assuming default index for now
|
1203 |
-
if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute
|
1204 |
candidate_index = chunk._candidate_index
|
1205 |
elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
|
1206 |
-
# Or check standard candidate structure if available on chunk
|
1207 |
candidate_index = chunk.candidates[0].index
|
1208 |
|
1209 |
if hasattr(chunk, 'text') and chunk.text:
|
@@ -1218,8 +1217,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1218 |
yield "data: [DONE]\n\n"
|
1219 |
|
1220 |
# Return status based on content received
|
1221 |
-
if all_chunks_empty and first_chunk_received:
|
1222 |
-
raise ValueError("Streamed response contained only empty chunks")
|
1223 |
|
1224 |
except Exception as stream_error:
|
1225 |
error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
|
@@ -1236,9 +1235,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1236 |
# Non-streaming call (Async)
|
1237 |
try:
|
1238 |
print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1239 |
-
|
|
|
|
|
1240 |
contents=prompt,
|
1241 |
-
generation_config=current_gen_config,
|
1242 |
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1243 |
)
|
1244 |
if not is_response_valid(response):
|
|
|
1175 |
print("Prompt structure: Unknown format")
|
1176 |
|
1177 |
|
1178 |
+
# Use the client.models object as in the original synchronous code
|
|
|
1179 |
if request.stream:
|
1180 |
# Streaming call (Async)
|
1181 |
response_id = f"chatcmpl-{int(time.time())}"
|
|
|
1187 |
try:
|
1188 |
# No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
|
1189 |
print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1190 |
+
# Call async stream method on client.models
|
1191 |
+
async_responses = await client.models.generate_content_stream_async(
|
1192 |
+
model=model_name, # Pass model name here
|
1193 |
contents=prompt,
|
1194 |
+
generation_config=current_gen_config,
|
1195 |
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1196 |
)
|
1197 |
|
|
|
1199 |
async for chunk in async_responses: # Use async for
|
1200 |
first_chunk_received = True
|
1201 |
# Determine candidate_index based on the chunk itself if possible, fallback to 0
|
|
|
1202 |
candidate_index = 0 # Assuming default index for now
|
1203 |
+
if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute
|
1204 |
candidate_index = chunk._candidate_index
|
1205 |
elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
|
|
|
1206 |
candidate_index = chunk.candidates[0].index
|
1207 |
|
1208 |
if hasattr(chunk, 'text') and chunk.text:
|
|
|
1217 |
yield "data: [DONE]\n\n"
|
1218 |
|
1219 |
# Return status based on content received
|
1220 |
+
if all_chunks_empty and first_chunk_received:
|
1221 |
+
raise ValueError("Streamed response contained only empty chunks")
|
1222 |
|
1223 |
except Exception as stream_error:
|
1224 |
error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
|
|
|
1235 |
# Non-streaming call (Async)
|
1236 |
try:
|
1237 |
print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1238 |
+
# Call async method on client.models
|
1239 |
+
response = await client.models.generate_content_async(
|
1240 |
+
model=model_name, # Pass model name here
|
1241 |
contents=prompt,
|
1242 |
+
generation_config=current_gen_config,
|
1243 |
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1244 |
)
|
1245 |
if not is_response_valid(response):
|