MedQA / tools /gemini_tool.py
mgbam's picture
Update tools/gemini_tool.py
8bfd9b5 verified
raw
history blame
2.36 kB
from langchain.tools import BaseTool
from typing import Type, Optional, Any
from pydantic import BaseModel, Field
import google.generativeai as genai
from config.settings import settings
from services.logger import app_logger
class GeminiInput(BaseModel):
query: str = Field(description="The query or prompt to send to Google Gemini.")
class GeminiTool(BaseTool):
name: str = "google_gemini_chat"
description: str = (
"Useful for when you need to answer questions or generate text using Google Gemini. "
"Use this for general knowledge, creative text generation, or complex reasoning tasks "
"that might benefit from a powerful large language model."
)
args_schema: Type[BaseModel] = GeminiInput
# return_direct: bool = True # If you want the agent to return Gemini's output directly
def _run(self, query: str) -> str:
if not settings.GEMINI_API_KEY:
app_logger.error("GEMINI_API_KEY not configured.")
return "Error: Gemini API key not configured."
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content(query)
return response.text
except Exception as e:
app_logger.error(f"Error calling Gemini API: {e}")
return f"Error interacting with Gemini: {str(e)}"
async def _arun(self, query: str) -> str:
# Asynchronous version (optional, implement if needed)
# For simplicity, using the synchronous version for now.
# You might need to use an async client for genai if available or run sync in thread.
if not settings.GEMINI_API_KEY:
app_logger.error("GEMINI_API_KEY not configured.")
return "Error: Gemini API key not configured."
try:
genai.configure(api_key=settings.GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-pro')
# For async, genai might have an async client or you'd use `loop.run_in_executor`
response = await model.generate_content_async(query) # Assuming an async method
return response.text
except Exception as e:
app_logger.error(f"Error calling Gemini API asynchronously: {e}")
return f"Error interacting with Gemini: {str(e)}"