Spaces:
Running
Running
Commit
·
cfdf66d
1
Parent(s):
d9e170e
added search models
Browse files- src/config.py +33 -4
- src/google_api_client.py +10 -2
- src/openai_transformers.py +10 -3
src/config.py
CHANGED
@@ -35,8 +35,8 @@ DEFAULT_SAFETY_SETTINGS = [
|
|
35 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
36 |
]
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
{
|
41 |
"name": "models/gemini-2.5-pro-preview-05-06",
|
42 |
"version": "001",
|
@@ -93,7 +93,7 @@ SUPPORTED_MODELS = [
|
|
93 |
"name": "models/gemini-2.5-flash-preview-04-17",
|
94 |
"version": "001",
|
95 |
"displayName": "Gemini 2.5 Flash Preview 04-17",
|
96 |
-
"description": "Preview version of Gemini 2.5 Flash from
|
97 |
"inputTokenLimit": 1048576,
|
98 |
"outputTokenLimit": 65535,
|
99 |
"supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
|
@@ -115,4 +115,33 @@ SUPPORTED_MODELS = [
|
|
115 |
"topP": 0.95,
|
116 |
"topK": 64
|
117 |
}
|
118 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
|
36 |
]
|
37 |
|
38 |
+
# Base Models (without search variants)
|
39 |
+
BASE_MODELS = [
|
40 |
{
|
41 |
"name": "models/gemini-2.5-pro-preview-05-06",
|
42 |
"version": "001",
|
|
|
93 |
"name": "models/gemini-2.5-flash-preview-04-17",
|
94 |
"version": "001",
|
95 |
"displayName": "Gemini 2.5 Flash Preview 04-17",
|
96 |
+
"description": "Preview version of Gemini 2.5 Flash from April 17th",
|
97 |
"inputTokenLimit": 1048576,
|
98 |
"outputTokenLimit": 65535,
|
99 |
"supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
|
|
|
115 |
"topP": 0.95,
|
116 |
"topK": 64
|
117 |
}
|
118 |
+
]
|
119 |
+
|
120 |
+
# Generate search variants for applicable models
|
121 |
+
def _generate_search_variants():
|
122 |
+
"""Generate search variants for models that support content generation."""
|
123 |
+
search_models = []
|
124 |
+
for model in BASE_MODELS:
|
125 |
+
# Only add search variants for models that support content generation
|
126 |
+
if "generateContent" in model["supportedGenerationMethods"]:
|
127 |
+
search_variant = model.copy()
|
128 |
+
search_variant["name"] = model["name"] + "-search"
|
129 |
+
search_variant["displayName"] = model["displayName"] + " with Google Search"
|
130 |
+
search_variant["description"] = model["description"] + " (includes Google Search grounding)"
|
131 |
+
search_models.append(search_variant)
|
132 |
+
return search_models
|
133 |
+
|
134 |
+
# Supported Models (includes both base models and search variants)
|
135 |
+
SUPPORTED_MODELS = BASE_MODELS + _generate_search_variants()
|
136 |
+
|
137 |
+
# Helper function to get base model name from search variant
|
138 |
+
def get_base_model_name(model_name):
|
139 |
+
"""Convert search variant model name to base model name."""
|
140 |
+
if model_name.endswith("-search"):
|
141 |
+
return model_name[:-7] # Remove "-search" suffix
|
142 |
+
return model_name
|
143 |
+
|
144 |
+
# Helper function to check if model uses search grounding
|
145 |
+
def is_search_model(model_name):
|
146 |
+
"""Check if model name indicates search grounding should be enabled."""
|
147 |
+
return model_name.endswith("-search")
|
src/google_api_client.py
CHANGED
@@ -11,7 +11,7 @@ from google.auth.transport.requests import Request as GoogleAuthRequest
|
|
11 |
|
12 |
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
|
13 |
from .utils import get_user_agent
|
14 |
-
from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS
|
15 |
import asyncio
|
16 |
|
17 |
|
@@ -310,7 +310,15 @@ def build_gemini_payload_from_native(native_request: dict, model_from_path: str)
|
|
310 |
native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
|
311 |
native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
return {
|
314 |
-
"model": model_from_path,
|
315 |
"request": native_request
|
316 |
}
|
|
|
11 |
|
12 |
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
|
13 |
from .utils import get_user_agent
|
14 |
+
from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS, get_base_model_name, is_search_model
|
15 |
import asyncio
|
16 |
|
17 |
|
|
|
310 |
native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
|
311 |
native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
|
312 |
|
313 |
+
# Add Google Search grounding for search models
|
314 |
+
if is_search_model(model_from_path):
|
315 |
+
if "tools" not in native_request:
|
316 |
+
native_request["tools"] = []
|
317 |
+
# Add googleSearch tool if not already present
|
318 |
+
if not any(tool.get("googleSearch") for tool in native_request["tools"]):
|
319 |
+
native_request["tools"].append({"googleSearch": {}})
|
320 |
+
|
321 |
return {
|
322 |
+
"model": get_base_model_name(model_from_path), # Use base model name for API call
|
323 |
"request": native_request
|
324 |
}
|
src/openai_transformers.py
CHANGED
@@ -8,7 +8,7 @@ import uuid
|
|
8 |
from typing import Dict, Any
|
9 |
|
10 |
from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
|
11 |
-
from .config import DEFAULT_SAFETY_SETTINGS
|
12 |
|
13 |
|
14 |
def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
@@ -91,12 +91,19 @@ def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dic
|
|
91 |
if openai_request.response_format.get("type") == "json_object":
|
92 |
generation_config["responseMimeType"] = "application/json"
|
93 |
|
94 |
-
|
|
|
95 |
"contents": contents,
|
96 |
"generationConfig": generation_config,
|
97 |
"safetySettings": DEFAULT_SAFETY_SETTINGS,
|
98 |
-
"model": openai_request.model
|
99 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
|
|
|
8 |
from typing import Dict, Any
|
9 |
|
10 |
from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
|
11 |
+
from .config import DEFAULT_SAFETY_SETTINGS, is_search_model, get_base_model_name
|
12 |
|
13 |
|
14 |
def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
|
|
|
91 |
if openai_request.response_format.get("type") == "json_object":
|
92 |
generation_config["responseMimeType"] = "application/json"
|
93 |
|
94 |
+
# Build the request payload
|
95 |
+
request_payload = {
|
96 |
"contents": contents,
|
97 |
"generationConfig": generation_config,
|
98 |
"safetySettings": DEFAULT_SAFETY_SETTINGS,
|
99 |
+
"model": get_base_model_name(openai_request.model) # Use base model name for API call
|
100 |
}
|
101 |
+
|
102 |
+
# Add Google Search grounding for search models
|
103 |
+
if is_search_model(openai_request.model):
|
104 |
+
request_payload["tools"] = [{"googleSearch": {}}]
|
105 |
+
|
106 |
+
return request_payload
|
107 |
|
108 |
|
109 |
def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
|