nguyenlam0306
tokenizer
7ee0475
raw
history blame
3.96 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BartTokenizer, GenerationConfig, AutoTokenizer
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
# Load the summarization model (lacos03/bart-base-finetuned-xsum) manually
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False) # Use slow tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Create a custom GenerationConfig to fix early_stopping issue
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.early_stopping = True # Fix early_stopping error
# Initialize the summarization pipeline
summarizer = pipeline(
"summarization",
model=model,
tokenizer=tokenizer,
generation_config=generation_config
)
# Load the prompt generation model (microsoft/Promptist)
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")
# Load the base Stable Diffusion model
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the Stable Diffusion pipeline
image_generator = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
use_safetensors=True
).to(device)
# Load LoRA weights
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)
def summarize_and_generate_image(article_text):
# Step 1: Summarize the article and extract a title
summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
title = summary.split(".")[0] # Use the first sentence as the title
# Step 2: Generate a prompt from the title using Promptist
inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True)
# Step 3: Generate an image from the prompt
image = image_generator(
generated_prompt,
num_inference_steps=50,
guidance_scale=7.5
).images[0]
# Step 4: Save the image to a BytesIO object for download
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
return title, generated_prompt, image, img_byte_arr
# Define the Gradio interface
def create_app():
with gr.Blocks() as demo:
gr.Markdown("# Article Summarizer and Image Generator")
gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.")
# Input for article text
article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...")
# Output components
title_output = gr.Textbox(label="Generated Title")
prompt_output = gr.Textbox(label="Generated Prompt")
image_output = gr.Image(label="Generated Image")
# Buttons
submit_button = gr.Button("Generate")
copy_button = gr.Button("Copy Title")
download_button = gr.File(label="Download Image")
# Define actions
submit_button.click(
fn=summarize_and_generate_image,
inputs=article_input,
outputs=[title_output, prompt_output, image_output, download_button]
)
copy_button.click(
fn=lambda x: x,
inputs=title_output,
outputs=gr.State(value=title_output), # For copying to clipboard
_js="function(x) { navigator.clipboard.writeText(x); return x; }"
)
return demo
# Launch the app
if __name__ == "__main__":
app = create_app()
app.launch()