import gradio as gr import os import base64 import pandas as pd from PIL import Image from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool from typing import Optional import requests from io import BytesIO import re from pathlib import Path import openai ## utilty functions def is_image_extension(filename: str) -> bool: IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'} ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext) return ext in IMAGE_EXTS def load_file(path: str) -> list | dict: """Based on the file extension, load the file into a suitable object.""" image = None excel = None csv = None text = None ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower() if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"): image = Image.open(path).convert("RGB") # pillow object elif ext.endswith(".xlsx") or ext.endswith(".xls"): excel = pd.read_excel(path) # DataFrame elif ext.endswith(".csv"): csv = pd.read_csv(path) # DataFrame elif ext.endswith(".py") or ext.endswith(".txt"): with open(path, 'r') as f: text = f.read() # plain text str if image is not None: return [image] else: return {"excel": excel, "csv": csv, "raw text": text, "audio path": path} ## tools definition @tool def download_images(image_urls: str) -> list: """ Download web images from the given comma‐separated URLs and return them in a list of PIL Images. Args: image_urls: comma‐separated list of URLs to download Returns: List of PIL.Image.Image objects """ urls = [u.strip() for u in image_urls.split(",") if u.strip()] # strip() removes whitespaces images = [] for __, url in enumerate(urls, start=1): # enumerate seems not needed... keeping it for now try: # Fetch the image bytes resp = requests.get(url, timeout=10) resp.raise_for_status() # Load into a PIL image img = Image.open(BytesIO(resp.content)).convert("RGB") images.append(img) except Exception as e: print(f"Failed to download from {url}: {e}") return images @tool def transcribe_audio(audio_path: str) -> str: # since they gave us OpenAI API credits, we can keep using it """ Transcribe audio file using OpenAI Whisper API. Args: audio_path: path to the audio file to be transcribed. Returns: str : Transcription of the audio. """ client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) with open(audio_path, "rb") as audio: # to modify path because it is arriving from gradio transcript = client.audio.transcriptions.create( file=audio, model="whisper-1", response_format="text", ) print(transcript) try: return transcript except Exception as e: print(f"Error transcribing audio: {e}") @tool def generate_image(prompt: str, neg_prompt: str) -> Image.Image: """ Generate an image based on a text prompt using Flux Dev. Args: prompt (str): The text prompt to generate the image from. neg_prompt (str): The negative prompt to avoid certain elements in the image. Returns: Image.Image: The generated image as a PIL Image object. """ client = OpenAI(base_url="https://api.studio.nebius.com/v1", api_key=os.environ.get("NEBIUS_API_KEY"), ) completion = client.images.generate( model="black-forest-labs/flux-dev", prompt=prompt, response_format="b64_json", extra_body={ "response_extension": "png", "width": 1024, "height": 1024, "num_inference_steps": 30, "seed": -1, "negative_prompt": neg_prompt, } ) image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json']) image = Image.open(BytesIO(image_data)) return image ## agent definition class Agent: def __init__(self, ): client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY")) self.agent = CodeAgent( model=client, tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), generate_image, download_images, transcribe_audio], additional_authorized_imports=["pandas", "PIL", "io"], planning_interval=1, max_steps=5, ) #self.agent.prompt_templates["system_prompt"] = self.agent.prompt_templates["system_prompt"] #print("System prompt:", self.agent.prompt_templates["system_prompt"]) def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None) -> str: answer = self.agent.run(message, images = images, additional_args={"files": files}) return answer ## gradio functions def respond(message, history): text = message.get("text", "") if not message.get("files"): # no files uploaded print("No files received.") message = agent(text) else: files = message.get("files", []) print(f"files received: {files}") if is_image_extension(files[0]): image = load_file(files[0]) # assuming only one file is uploaded at a time (gradio default behavior) message = agent(text, images=image) else: file = load_file(files[0]) message = agent(text, files=file) return message def initialize_agent(): agent = Agent() print("Agent initialized.") return agent ## gradio interface with gr.Blocks() as demo: global agent agent = initialize_agent() gr.ChatInterface( fn=respond, type='messages', multimodal=True, title='MultiAgent System for Screenplay Creation and Editing', show_progress='full' ) if __name__ == "__main__": demo.launch()