File size: 6,227 Bytes
d12a6b6
 
 
 
 
 
8061397
d12a6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8061397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d12a6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8061397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d12a6b6
 
 
8061397
 
d12a6b6
 
8061397
d12a6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Gemini API Routes - Handles native Gemini API endpoints.
This module provides native Gemini API endpoints that proxy directly to Google's API
without any format transformations.
"""
import json
import logging
from fastapi import APIRouter, Request, Response, Depends

from .auth import authenticate_user
from .google_api_client import send_gemini_request, build_gemini_payload_from_native
from .config import SUPPORTED_MODELS

router = APIRouter()


@router.get("/v1beta/models")
async def gemini_list_models(request: Request, username: str = Depends(authenticate_user)):
    """
    Native Gemini models endpoint.
    Returns available models in Gemini format, matching the official Gemini API.
    """
    
    try:
        logging.info("Gemini models list requested")
        
        models_response = {
            "models": SUPPORTED_MODELS
        }
        
        logging.info(f"Returning {len(SUPPORTED_MODELS)} Gemini models")
        return Response(
            content=json.dumps(models_response),
            status_code=200,
            media_type="application/json; charset=utf-8"
        )
    except Exception as e:
        logging.error(f"Failed to list Gemini models: {str(e)}")
        return Response(
            content=json.dumps({
                "error": {
                    "message": f"Failed to list models: {str(e)}",
                    "code": 500
                }
            }),
            status_code=500,
            media_type="application/json"
        )


@router.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def gemini_proxy(request: Request, full_path: str, username: str = Depends(authenticate_user)):
    """
    Native Gemini API proxy endpoint.
    Handles all native Gemini API calls by proxying them directly to Google's API.
    
    This endpoint handles paths like:
    - /v1beta/models/{model}/generateContent
    - /v1beta/models/{model}/streamGenerateContent
    - /v1/models/{model}/generateContent
    - etc.
    """
    
    try:
        # Get the request body
        post_data = await request.body()
        
        # Determine if this is a streaming request
        is_streaming = "stream" in full_path.lower()
        
        # Extract model name from the path
        # Paths typically look like: v1beta/models/gemini-1.5-pro/generateContent
        model_name = _extract_model_from_path(full_path)
        
        logging.info(f"Gemini proxy request: path={full_path}, model={model_name}, stream={is_streaming}")
        
        if not model_name:
            logging.error(f"Could not extract model name from path: {full_path}")
            return Response(
                content=json.dumps({
                    "error": {
                        "message": f"Could not extract model name from path: {full_path}",
                        "code": 400
                    }
                }),
                status_code=400,
                media_type="application/json"
            )
        
        # Parse the incoming request
        try:
            if post_data:
                incoming_request = json.loads(post_data)
            else:
                incoming_request = {}
        except json.JSONDecodeError as e:
            logging.error(f"Invalid JSON in request body: {str(e)}")
            return Response(
                content=json.dumps({
                    "error": {
                        "message": "Invalid JSON in request body",
                        "code": 400
                    }
                }),
                status_code=400,
                media_type="application/json"
            )
        
        # Build the payload for Google API
        gemini_payload = build_gemini_payload_from_native(incoming_request, model_name)
        
        # Send the request to Google API
        response = send_gemini_request(gemini_payload, is_streaming=is_streaming)
        
        # Log the response status
        if hasattr(response, 'status_code'):
            if response.status_code != 200:
                logging.error(f"Gemini API returned error: status={response.status_code}")
            else:
                logging.info(f"Successfully processed Gemini request for model: {model_name}")
        
        return response
        
    except Exception as e:
        logging.error(f"Gemini proxy error: {str(e)}")
        return Response(
            content=json.dumps({
                "error": {
                    "message": f"Proxy error: {str(e)}",
                    "code": 500
                }
            }),
            status_code=500,
            media_type="application/json"
        )


def _extract_model_from_path(path: str) -> str:
    """
    Extract the model name from a Gemini API path.
    
    Examples:
    - "v1beta/models/gemini-1.5-pro/generateContent" -> "gemini-1.5-pro"
    - "v1/models/gemini-2.0-flash/streamGenerateContent" -> "gemini-2.0-flash"
    
    Args:
        path: The API path
        
    Returns:
        Model name (just the model name, not prefixed with "models/") or None if not found
    """
    parts = path.split('/')
    
    # Look for the pattern: .../models/{model_name}/...
    try:
        models_index = parts.index('models')
        if models_index + 1 < len(parts):
            model_name = parts[models_index + 1]
            # Remove any action suffix like ":streamGenerateContent" or ":generateContent"
            if ':' in model_name:
                model_name = model_name.split(':')[0]
            # Return just the model name without "models/" prefix
            return model_name
    except ValueError:
        pass
    
    # If we can't find the pattern, return None
    return None


@router.get("/v1/models")
async def gemini_list_models_v1(request: Request, username: str = Depends(authenticate_user)):
    """
    Alternative models endpoint for v1 API version.
    Some clients might use /v1/models instead of /v1beta/models.
    """
    return await gemini_list_models(request, username)


# Health check endpoint
@router.get("/health")
async def health_check():
    """
    Simple health check endpoint.
    """
    return {"status": "healthy", "service": "geminicli2api"}