File size: 3,958 Bytes
7742b37
7ee0475
7742b37
 
 
 
 
016478f
 
1c20d3f
016478f
 
 
 
1c20d3f
016478f
 
 
 
 
 
 
 
7742b37
 
 
 
 
 
 
 
 
 
 
 
 
016478f
7742b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSeq2SeqLM, BartTokenizer, GenerationConfig, AutoTokenizer
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image

# Load the summarization model (lacos03/bart-base-finetuned-xsum) manually
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = BartTokenizer.from_pretrained(model_name, use_fast=False)  # Use slow tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Create a custom GenerationConfig to fix early_stopping issue
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.early_stopping = True  # Fix early_stopping error

# Initialize the summarization pipeline
summarizer = pipeline(
    "summarization",
    model=model,
    tokenizer=tokenizer,
    generation_config=generation_config
)

# Load the prompt generation model (microsoft/Promptist)
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")

# Load the base Stable Diffusion model
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the Stable Diffusion pipeline
image_generator = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    use_safetensors=True
).to(device)

# Load LoRA weights
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)

def summarize_and_generate_image(article_text):
    # Step 1: Summarize the article and extract a title
    summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
    title = summary.split(".")[0]  # Use the first sentence as the title

    # Step 2: Generate a prompt from the title using Promptist
    inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
    prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
    generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True)

    # Step 3: Generate an image from the prompt
    image = image_generator(
        generated_prompt,
        num_inference_steps=50,
        guidance_scale=7.5
    ).images[0]

    # Step 4: Save the image to a BytesIO object for download
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format="PNG")
    img_byte_arr.seek(0)

    return title, generated_prompt, image, img_byte_arr

# Define the Gradio interface
def create_app():
    with gr.Blocks() as demo:
        gr.Markdown("# Article Summarizer and Image Generator")
        gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.")
        
        # Input for article text
        article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...")
        
        # Output components
        title_output = gr.Textbox(label="Generated Title")
        prompt_output = gr.Textbox(label="Generated Prompt")
        image_output = gr.Image(label="Generated Image")
        
        # Buttons
        submit_button = gr.Button("Generate")
        copy_button = gr.Button("Copy Title")
        download_button = gr.File(label="Download Image")
        
        # Define actions
        submit_button.click(
            fn=summarize_and_generate_image,
            inputs=article_input,
            outputs=[title_output, prompt_output, image_output, download_button]
        )
        copy_button.click(
            fn=lambda x: x,
            inputs=title_output,
            outputs=gr.State(value=title_output),  # For copying to clipboard
            _js="function(x) { navigator.clipboard.writeText(x); return x; }"
        )

    return demo

# Launch the app
if __name__ == "__main__":
    app = create_app()
    app.launch()