Final_Assignment_Template / tools /multimodal_tools.py
Giustino Esposito
Refactored code
d5ccf60
import base64
import os
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools import Tool
from langchain_core.tools import tool
api_key = os.getenv("GEMINI_API_KEY")
# Create LLM class
vision_llm = ChatGoogleGenerativeAI(
model= "gemini-2.5-flash-preview-05-20",
temperature=0,
max_retries=2,
google_api_key=api_key
)
@tool("extract_text_tool", parse_docstring=True)
def extract_text(img_path: str) -> str:
"""Extract text from an image file using a multimodal model.
Args:
img_path (str): The path to the image file from which to extract text.
Returns:
str: The extracted text from the image, or an empty string if an error occurs.
"""
all_text = ""
try:
# Read image and encode as base64
with open(img_path, "rb") as image_file:
image_bytes = image_file.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
# Prepare the prompt including the base64 image data
message = [
HumanMessage(
content=[
{
"type": "text",
"text": (
"Extract all the text from this image. "
"Return only the extracted text, no explanations."
),
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
]
)
]
# Call the vision-capable model
response = vision_llm.invoke(message)
# Append extracted text
all_text += response.content + "\n\n"
return all_text.strip()
except Exception as e:
# A butler should handle errors gracefully
error_msg = f"Error extracting text: {str(e)}"
print(error_msg)
return ""
@tool("analyze_image_tool", parse_docstring=True)
def analyze_image_tool(user_query: str, img_path: str) -> str:
"""Answer the question reasoning on the image.
Args:
user_query (str): The question to be answered based on the image.
img_path (str): Path to the image file to be analyzed.
Returns:
str: The answer to the query based on image content, or an empty string if an error occurs.
"""
all_text = ""
try:
# Read image and encode as base64
with open(img_path, "rb") as image_file:
image_bytes = image_file.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
# Prepare the prompt including the base64 image data
message = [
HumanMessage(
content=[
{
"type": "text",
"text": (
f"User query: {user_query}"
),
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
]
)
]
# Call the vision-capable model
response = vision_llm.invoke(message)
# Append extracted text
all_text += response.content + "\n\n"
return all_text.strip()
except Exception as e:
# A butler should handle errors gracefully
error_msg = f"Error analyzing image: {str(e)}"
print(error_msg)
return ""
@tool("analyze_audio_tool", parse_docstring=True)
def analyze_audio_tool(user_query: str, audio_path: str) -> str:
"""Answer the question by reasoning on the provided audio file.
Args:
user_query (str): The question to be answered based on the audio content.
audio_path (str): Path to the audio file (e.g., .mp3, .wav, .flac, .aac, .ogg).
Returns:
str: The answer to the query based on audio content, or an error message/empty string if an error occurs.
"""
try:
# Determine MIME type from file extension
_filename, file_extension = os.path.splitext(audio_path)
file_extension = file_extension.lower()
supported_formats = {
".mp3": "audio/mp3", ".wav": "audio/wav", ".flac": "audio/flac",
".aac": "audio/aac", ".ogg": "audio/ogg"
}
if file_extension not in supported_formats:
return (f"Error: Unsupported audio file format '{file_extension}'. "
f"Supported extensions: {', '.join(supported_formats.keys())}.")
mime_type = supported_formats[file_extension]
# Read audio file and encode as base64
with open(audio_path, "rb") as audio_file:
audio_bytes = audio_file.read()
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
# Prepare the prompt including the base64 audio data
message = [
HumanMessage(
content=[
{
"type": "text",
"text": f"User query: {user_query}",
},
{
"type": "audio",
"source_type": "base64",
"mime_type": mime_type,
"data": audio_base64
},
]
)
]
# Call the vision-capable model
response = vision_llm.invoke(message)
return response.content.strip()
except Exception as e:
error_msg = f"Error analyzing audio: {str(e)}"
print(error_msg)
return ""