bibibi12345 commited on
Commit
cfdf66d
·
1 Parent(s): d9e170e

added search models

Browse files
src/config.py CHANGED
@@ -35,8 +35,8 @@ DEFAULT_SAFETY_SETTINGS = [
35
  {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
36
  ]
37
 
38
- # Supported Models (for /v1beta/models endpoint)
39
- SUPPORTED_MODELS = [
40
  {
41
  "name": "models/gemini-2.5-pro-preview-05-06",
42
  "version": "001",
@@ -93,7 +93,7 @@ SUPPORTED_MODELS = [
93
  "name": "models/gemini-2.5-flash-preview-04-17",
94
  "version": "001",
95
  "displayName": "Gemini 2.5 Flash Preview 04-17",
96
- "description": "Preview version of Gemini 2.5 Flash from May 20th",
97
  "inputTokenLimit": 1048576,
98
  "outputTokenLimit": 65535,
99
  "supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
@@ -115,4 +115,33 @@ SUPPORTED_MODELS = [
115
  "topP": 0.95,
116
  "topK": 64
117
  }
118
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}
36
  ]
37
 
38
+ # Base Models (without search variants)
39
+ BASE_MODELS = [
40
  {
41
  "name": "models/gemini-2.5-pro-preview-05-06",
42
  "version": "001",
 
93
  "name": "models/gemini-2.5-flash-preview-04-17",
94
  "version": "001",
95
  "displayName": "Gemini 2.5 Flash Preview 04-17",
96
+ "description": "Preview version of Gemini 2.5 Flash from April 17th",
97
  "inputTokenLimit": 1048576,
98
  "outputTokenLimit": 65535,
99
  "supportedGenerationMethods": ["generateContent", "streamGenerateContent"],
 
115
  "topP": 0.95,
116
  "topK": 64
117
  }
118
+ ]
119
+
120
+ # Generate search variants for applicable models
121
+ def _generate_search_variants():
122
+ """Generate search variants for models that support content generation."""
123
+ search_models = []
124
+ for model in BASE_MODELS:
125
+ # Only add search variants for models that support content generation
126
+ if "generateContent" in model["supportedGenerationMethods"]:
127
+ search_variant = model.copy()
128
+ search_variant["name"] = model["name"] + "-search"
129
+ search_variant["displayName"] = model["displayName"] + " with Google Search"
130
+ search_variant["description"] = model["description"] + " (includes Google Search grounding)"
131
+ search_models.append(search_variant)
132
+ return search_models
133
+
134
+ # Supported Models (includes both base models and search variants)
135
+ SUPPORTED_MODELS = BASE_MODELS + _generate_search_variants()
136
+
137
+ # Helper function to get base model name from search variant
138
+ def get_base_model_name(model_name):
139
+ """Convert search variant model name to base model name."""
140
+ if model_name.endswith("-search"):
141
+ return model_name[:-7] # Remove "-search" suffix
142
+ return model_name
143
+
144
+ # Helper function to check if model uses search grounding
145
+ def is_search_model(model_name):
146
+ """Check if model name indicates search grounding should be enabled."""
147
+ return model_name.endswith("-search")
src/google_api_client.py CHANGED
@@ -11,7 +11,7 @@ from google.auth.transport.requests import Request as GoogleAuthRequest
11
 
12
  from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
13
  from .utils import get_user_agent
14
- from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS
15
  import asyncio
16
 
17
 
@@ -310,7 +310,15 @@ def build_gemini_payload_from_native(native_request: dict, model_from_path: str)
310
  native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
311
  native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
312
 
 
 
 
 
 
 
 
 
313
  return {
314
- "model": model_from_path,
315
  "request": native_request
316
  }
 
11
 
12
  from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
13
  from .utils import get_user_agent
14
+ from .config import CODE_ASSIST_ENDPOINT, DEFAULT_SAFETY_SETTINGS, get_base_model_name, is_search_model
15
  import asyncio
16
 
17
 
 
310
  native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = True
311
  native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = -1
312
 
313
+ # Add Google Search grounding for search models
314
+ if is_search_model(model_from_path):
315
+ if "tools" not in native_request:
316
+ native_request["tools"] = []
317
+ # Add googleSearch tool if not already present
318
+ if not any(tool.get("googleSearch") for tool in native_request["tools"]):
319
+ native_request["tools"].append({"googleSearch": {}})
320
+
321
  return {
322
+ "model": get_base_model_name(model_from_path), # Use base model name for API call
323
  "request": native_request
324
  }
src/openai_transformers.py CHANGED
@@ -8,7 +8,7 @@ import uuid
8
  from typing import Dict, Any
9
 
10
  from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
11
- from .config import DEFAULT_SAFETY_SETTINGS
12
 
13
 
14
  def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
@@ -91,12 +91,19 @@ def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dic
91
  if openai_request.response_format.get("type") == "json_object":
92
  generation_config["responseMimeType"] = "application/json"
93
 
94
- return {
 
95
  "contents": contents,
96
  "generationConfig": generation_config,
97
  "safetySettings": DEFAULT_SAFETY_SETTINGS,
98
- "model": openai_request.model
99
  }
 
 
 
 
 
 
100
 
101
 
102
  def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]:
 
8
  from typing import Dict, Any
9
 
10
  from .models import OpenAIChatCompletionRequest, OpenAIChatCompletionResponse
11
+ from .config import DEFAULT_SAFETY_SETTINGS, is_search_model, get_base_model_name
12
 
13
 
14
  def openai_request_to_gemini(openai_request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
 
91
  if openai_request.response_format.get("type") == "json_object":
92
  generation_config["responseMimeType"] = "application/json"
93
 
94
+ # Build the request payload
95
+ request_payload = {
96
  "contents": contents,
97
  "generationConfig": generation_config,
98
  "safetySettings": DEFAULT_SAFETY_SETTINGS,
99
+ "model": get_base_model_name(openai_request.model) # Use base model name for API call
100
  }
101
+
102
+ # Add Google Search grounding for search models
103
+ if is_search_model(openai_request.model):
104
+ request_payload["tools"] = [{"googleSearch": {}}]
105
+
106
+ return request_payload
107
 
108
 
109
  def gemini_response_to_openai(gemini_response: Dict[str, Any], model: str) -> Dict[str, Any]: