File size: 3,887 Bytes
7742b37
016478f
7742b37
 
 
 
 
016478f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7742b37
 
 
 
 
 
 
 
 
 
 
 
 
016478f
7742b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image

# Load the summarization model (lacos03/bart-base-finetuned-xsum) manually
model_name = "lacos03/bart-base-finetuned-xsum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Create a custom GenerationConfig to fix early_stopping issue
generation_config = GenerationConfig.from_pretrained(model_name)
generation_config.early_stopping = True  # Set to True to fix the error

# Initialize the summarization pipeline
summarizer = pipeline(
    "summarization",
    model=model,
    tokenizer=tokenizer,
    generation_config=generation_config
)

# Load the prompt generation model (microsoft/Promptist)
promptist_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
promptist_tokenizer = AutoTokenizer.from_pretrained("microsoft/Promptist")

# Load the base Stable Diffusion model
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the Stable Diffusion pipeline
image_generator = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    use_safetensors=True
).to(device)

# Load LoRA weights
lora_weights = "lacos03/std-1.5-lora-midjourney-1.0"
image_generator.load_lora_weights(lora_weights)

def summarize_and_generate_image(article_text):
    # Step 1: Summarize the article and extract a title
    summary = summarizer(article_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
    title = summary.split(".")[0]  # Use the first sentence as the title

    # Step 2: Generate a prompt from the title using Promptist
    inputs = promptist_tokenizer(title, return_tensors="pt").to(device)
    prompt_output = promptist_model.generate(**inputs, max_length=50, num_return_sequences=1)
    generated_prompt = promptist_tokenizer.decode(prompt_output[0], skip_special_tokens=True)

    # Step 3: Generate an image from the prompt
    image = image_generator(
        generated_prompt,
        num_inference_steps=50,
        guidance_scale=7.5
    ).images[0]

    # Step 4: Save the image to a BytesIO object for download
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format="PNG")
    img_byte_arr.seek(0)

    return title, generated_prompt, image, img_byte_arr

# Define the Gradio interface
def create_app():
    with gr.Blocks() as demo:
        gr.Markdown("# Article Summarizer and Image Generator")
        gr.Markdown("Enter an article, get a summarized title, and generate an illustrative image.")
        
        # Input for article text
        article_input = gr.Textbox(label="Input Article", lines=10, placeholder="Paste your article here...")
        
        # Output components
        title_output = gr.Textbox(label="Generated Title")
        prompt_output = gr.Textbox(label="Generated Prompt")
        image_output = gr.Image(label="Generated Image")
        
        # Buttons
        submit_button = gr.Button("Generate")
        copy_button = gr.Button("Copy Title")
        download_button = gr.File(label="Download Image")
        
        # Define actions
        submit_button.click(
            fn=summarize_and_generate_image,
            inputs=article_input,
            outputs=[title_output, prompt_output, image_output, download_button]
        )
        copy_button.click(
            fn=lambda x: x,
            inputs=title_output,
            outputs=gr.State(value=title_output),  # For copying to clipboard
            _js="function(x) { navigator.clipboard.writeText(x); return x; }"
        )

    return demo

# Launch the app
if __name__ == "__main__":
    app = create_app()
    app.launch()