File size: 3,958 Bytes
7742b37 7ee0475 7742b37 016478f 1c20d3f 016478f 1c20d3f 016478f 7742b37 016478f 7742b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BartTokenizer, GenerationConfig, AutoTokenizer
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
# Load the summarization model (lacos03/bart-base-finetuned-xsum) manually
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False) # Use slow tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Create a custom GenerationConfig to fix early_stopping issue
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.early_stopping = True # Fix early_stopping error
# Initialize the summarization pipeline
summarizer = pipeline(
"summarization",
model=model,
tokenizer=tokenizer,
generation_config=generation_config
)
# Load the prompt generation model (microsoft/Promptist)
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")
# Load the base Stable Diffusion model
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the Stable Diffusion pipeline
image_generator = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True
).to(device)
# Load LoRA weights
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)
def summarize_and_generate_image(article_text):
# Step 1: Summarize the article and extract a title
summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
title = summary.split(".")[0] # Use the first sentence as the title
# Step 2: Generate a prompt from the title using Promptist
inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True)
# Step 3: Generate an image from the prompt
image = image_generator(
generated_prompt,
num_inference_steps=50,
guidance_scale=7.5
).images[0]
# Step 4: Save the image to a BytesIO object for download
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return title, generated_prompt, image, img_byte_arr
# Define the Gradio interface
def create_app():
with gr.Blocks() as demo:
gr.Markdown("# Article Summarizer and Image Generator")
gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.")
# Input for article text
article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...")
# Output components
title_output = gr.Textbox(label="Generated Title")
prompt_output = gr.Textbox(label="Generated Prompt")
image_output = gr.Image(label="Generated Image")
# Buttons
submit_button = gr.Button("Generate")
copy_button = gr.Button("Copy Title")
download_button = gr.File(label="Download Image")
# Define actions
submit_button.click(
fn=summarize_and_generate_image,
inputs=article_input,
outputs=[title_output, prompt_output, image_output, download_button]
)
copy_button.click(
fn=lambda x: x,
inputs=title_output,
outputs=gr.State(value=title_output), # For copying to clipboard
_js="function(x) { navigator.clipboard.writeText(x); return x; }"
)
return demo
# Launch the app
if __name__ == "__main__":
app = create_app()
app.launch() |