|
import gradio as gr |
|
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BartTokenizer, GenerationConfig, AutoTokenizer |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
import io |
|
from PIL import Image |
|
|
|
|
|
model_name = "lacos03/bart-base-finetuned-xsum" |
|
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
generation_config = GenerationConfig.from_pretrained(model_name) |
|
generation_config.early_stopping = True |
|
|
|
|
|
summarizer = pipeline( |
|
"summarization", |
|
model=model, |
|
tokenizer=tokenizer, |
|
generation_config=generation_config |
|
) |
|
|
|
|
|
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") |
|
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist") |
|
|
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
image_generator = StableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
use_safetensors=True |
|
).to(device) |
|
|
|
|
|
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0" |
|
image_generator.load_lora_weights(lora_weights) |
|
|
|
def summarize_and_generate_image(article_text): |
|
|
|
summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"] |
|
title = summary.split(".")[0] |
|
|
|
|
|
inputs = promptist_tokenizer(title, return_tensors="pt").to(device) |
|
prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1) |
|
generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True) |
|
|
|
|
|
image = image_generator( |
|
generated_prompt, |
|
num_inference_steps=50, |
|
guidance_scale=7.5 |
|
).images[0] |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format="PNG") |
|
img_byte_arr.seek(0) |
|
|
|
return title, generated_prompt, image, img_byte_arr |
|
|
|
|
|
def create_app(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Article Summarizer and Image Generator") |
|
gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.") |
|
|
|
|
|
article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...") |
|
|
|
|
|
title_output = gr.Textbox(label="Generated Title") |
|
prompt_output = gr.Textbox(label="Generated Prompt") |
|
image_output = gr.Image(label="Generated Image") |
|
|
|
|
|
submit_button = gr.Button("Generate") |
|
copy_button = gr.Button("Copy Title") |
|
download_button = gr.File(label="Download Image") |
|
|
|
|
|
submit_button.click( |
|
fn=summarize_and_generate_image, |
|
inputs=article_input, |
|
outputs=[title_output, prompt_output, image_output, download_button] |
|
) |
|
copy_button.click( |
|
fn=lambda x: x, |
|
inputs=title_output, |
|
outputs=gr.State(value=title_output), |
|
_js="function(x) { navigator.clipboard.writeText(x); return x; }" |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
app = create_app() |
|
app.launch() |