import os import textwrap from PIL import Image, ImageDraw, ImageFont import gradio as gr import numpy as np import torch from lavis.models import load_model_and_preprocess import openai device = torch.device("cuda") if torch.cuda.is_available() else "cpu" model, vis_processors, _ = load_model_and_preprocess( name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device ) openai.api_key = os.environ["OPENAI_API_KEY"] def generate_caption(image): pil_image = image.copy() # Create a copy of the input PIL image image = vis_processors["eval"](image).unsqueeze(0).to(device) caption = model.generate({"image": image}) caption = "\n".join(caption) #use gpt-4 to generate a meme based on the caption response = openai.ChatCompletion.create( model="gpt-4", messages=[ {"role": "system", "content": "Escribe un meme chistoso en español a partir de la descripción de una imagen dada por el usuario. No uses emojis, ni comillas, ni saltos de línea. No es necesario que empieces con 'cuando'. El output del asistente solo debe ser el texto del meme. Debe ser corto pero chistoso."}, {"role": "user", "content": caption} ], temperature=0.6 ) meme_text = response.choices[0].message.content print(meme_text) # Put the meme text on the image draw = ImageDraw.Draw(pil_image) # Set the fixed font size to 80 font_size = 60 font = ImageFont.truetype("impact.ttf", font_size) # Calculate the average character width for the font alphabet = "ABCEMOPQRSTWXZ" total_char_width = sum(draw.textlength(char, font=font) for char in alphabet) average_char_width = total_char_width / len(alphabet) # Calculate the number of characters that fit within the image width chars_per_line = int(pil_image.width / average_char_width) # Wrap the text to fit within the image width wrapped_text = textwrap.fill(meme_text, width=chars_per_line) # Calculate the position to place the text at the top and center horizontally text_lines = wrapped_text.split('\n') y = 10 # Adjust this value to add more or less padding from the top for line in text_lines: line_width = draw.textlength(line, font=font) line_mask = font.getmask(line) _, line_height = line_mask.size x = (pil_image.width - line_width) // 2 draw.text((x, y), line, fill=(255, 255, 255), font=font) y += line_height + int(line_height * 0.1) pil_image = pil_image.convert('RGB') if torch.cuda.is_available(): torch.cuda.empty_cache() return pil_image with gr.Blocks() as demo: gr.Markdown( "### Memero - Generador de Memes" ) gr.Markdown( "Genera un meme en español a partir de una imagen." ) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Imagen", type="pil") btn_caption = gr.Button("Generar meme") output_text = gr.Image(label="Meme", lines=5) btn_caption.click( generate_caption, inputs=[input_image], outputs=[output_text] ) demo.launch()