import gradio as gr import os import base64 import pandas as pd from PIL import Image from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool from typing import Optional import requests from io import BytesIO import re from pathlib import Path import openai from openai import OpenAI import pdfplumber ## utilty functions def is_image_extension(filename: str) -> bool: IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'} ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext) return ext in IMAGE_EXTS def load_file(path: str) -> list | dict: """Based on the file extension, load the file into a suitable object.""" image = None text = None ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower() if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"): image = Image.open(path).convert("RGB") # pillow object elif ext.endswith(".xlsx") or ext.endswith(".xls"): text = pd.read_excel(path) # DataFrame elif ext.endswith(".csv"): text = pd.read_csv(path) # DataFrame elif ext.endswith(".pdf"): with pdfplumber.open(path) as pdf: text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text()) elif ext.endswith(".py") or ext.endswith(".txt"): with open(path, 'r') as f: text = f.read() # plain text str if image is not None: return [image] elif ext.endswith(".mp3") or ext.endswith(".wav"): return {"raw document text": text, "audio path": path} else: return {"raw document text": text, "file path": path} def check_format(answer: str | list, *args, **kwargs ) -> list: """ Check if the answer is a list. """ print("Checking format of the answer:", answer) def flatten(lst): for item in lst: if isinstance(item, list): yield from flatten(item) else: yield item try: if isinstance(answer, list): flat = list(flatten(answer)) return flat except Exception as e: if isinstance(answer, dict): raise TypeError(f"Final answer must be a list, not a dict. Please check the answer format. Error: {e}") ## tools definition @tool def download_images(image_urls: str) -> list: """ Download web images from the given comma‐separated URLs and return them in a list of PIL Images. Args: image_urls: comma‐separated list of URLs to download Returns: List of PIL.Image.Image objects """ urls = [u.strip() for u in image_urls.split(",") if u.strip()] # strip() removes whitespaces images = [] for __, url in enumerate(urls, start=1): # enumerate seems not needed... keeping it for now try: # Fetch the image bytes resp = requests.get(url, timeout=10) resp.raise_for_status() # Load into a PIL image img = Image.open(BytesIO(resp.content)).convert("RGB") images.append(img) except Exception as e: print(f"Failed to download from {url}: {e}") return images @tool # since they gave us OpenAI API credits, we can keep using it def transcribe_audio(audio_path: str) -> str: """ Transcribe audio file using OpenAI Whisper API. Args: audio_path: path to the audio file to be transcribed. Returns: str : Transcription of the audio. """ client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) with open(audio_path, "rb") as audio: # to modify path because it is arriving from gradio transcript = client.audio.transcriptions.create( file=audio, model="whisper-1", response_format="text", ) print(transcript) try: return transcript except Exception as e: print(f"Error transcribing audio: {e}") @tool def generate_image(prompt: str, neg_prompt: str) -> Image.Image: """ Generate an image based on a text prompt using Flux Dev. Args: prompt: The text prompt to generate the image from. neg_prompt: The negative prompt to avoid certain elements in the image. Returns: Image.Image: The generated image as a PIL Image object. """ client = OpenAI(base_url="https://api.studio.nebius.com/v1", api_key=os.environ.get("NEBIUS_API_KEY"), ) completion = client.images.generate( model="black-forest-labs/flux-dev", prompt=prompt, response_format="b64_json", extra_body={ "response_extension": "png", "width": 1024, "height": 1024, "num_inference_steps": 30, "seed": -1, "negative_prompt": neg_prompt, } ) image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json']) image = BytesIO(image_data) image = Image.open(image).convert("RGB") return gr.Image(value=image, label="Generated Image") ## agent definition class Agent: def __init__(self, ): client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY")) self.agent = CodeAgent( model=client, tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), generate_image, download_images, transcribe_audio], additional_authorized_imports=["pandas", "PIL", "io"], planning_interval=1, max_steps=5, stream_outputs=False, final_answer_checks=[check_format] ) with open("system_prompt.txt", "r") as f: system_prompt = f.read() self.agent.prompt_templates["system_prompt"] = system_prompt #print("System prompt:", self.agent.prompt_templates["system_prompt"]) def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None, conversation_history: Optional[dict] = None) -> str: answer = self.agent.run(message, images = images, additional_args={"files": files, "conversation_history": conversation_history}) return answer ## gradio functions def respond(message: str, history : dict, web_search: bool = False): # input print("history:", history) text = message.get("text", "") if not message.get("files") and not web_search: # no files uploaded print("No files received.") message = agent(text + "\nADDITIONAL CONTRAINT: Don't use web search", conversation_history=history) # conversation_history is a dict with the history of the conversation elif not message.get("files") and web_search==True: # no files uploaded print("No files received + web search enabled.") message = agent(text, conversation_history=history) else: files = message.get("files", []) print(f"files received: {files}") if is_image_extension(files[0]): image = load_file(files[0]) # assuming only one file is uploaded at a time (gradio default behavior) message = agent(text, images=image, conversation_history=history) else: file = load_file(files[0]) message = agent(text, files=file, conversation_history=history) # output print("Agent response:", message) return message def initialize_agent(): agent = Agent() print("Agent initialized.") return agent ## gradio interface with gr.Blocks() as demo: global agent agent = initialize_agent() gr.ChatInterface( fn=respond, type='messages', multimodal=True, title='MultiAgent System for Screenplay Creation and Editing', show_progress='full', fill_height=True, fill_width=False, save_history=True, additional_inputs=[ gr.Checkbox(value=False, label="Web Search", info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.", render=False), ]) if __name__ == "__main__": demo.launch()