Spaces:
Sleeping
Sleeping
File size: 5,251 Bytes
27f5740 71a1f99 b414b9f 27f5740 b414b9f 27f5740 b414b9f 27f5740 b414b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
from huggingface_hub import InferenceClient
from PIL import Image
import time
import os
import base64
from io import BytesIO
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
else:
HF_TOKEN_ERROR = None
client = InferenceClient(token=HF_TOKEN)
PROMPT_IMPROVER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
def improve_prompt(original_prompt):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
try:
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
prompt_for_llm = f"""<|system|>
{system_prompt}</s>
<|user|>
Improve this prompt: {original_prompt}
</s>
<|assistant|>
"""
improved_prompt = client.text_generation(
prompt=prompt_for_llm,
model=PROMPT_IMPROVER_MODEL,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
stop_sequences=["</s>"],
)
return improved_prompt.strip()
except Exception as e:
print(f"Error improving prompt: {e}")
return original_prompt
def generate_image(prompt, progress=gr.Progress()):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
progress(0, desc="Improving prompt...")
improved_prompt = improve_prompt(prompt)
progress(0.2, desc="Sending request to Hugging Face...")
try:
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
if not isinstance(image, Image.Image):
raise Exception(f"Expected a PIL Image, but got: {type(image)}")
progress(0.8, desc="Processing image...")
time.sleep(0.5)
progress(1.0, desc="Done!")
return image
except Exception as e:
if "rate limit" in str(e).lower():
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
else:
error_message = f"An error occurred: {e}"
raise gr.Error(error_message)
def pil_to_base64(img):
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
css = """
body {
background-color: #f4f4f4;
font-family: 'Arial', sans-serif;
}
.container {
max-width: 900px;
margin: auto;
padding: 30px;
border-radius: 10px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1);
background-color: white;
}
.title {
text-align: center;
font-size: 3em;
margin-bottom: 0.5em;
color: #3a3a3a;
}
.input-section {
background-color: #e3f7fc;
border-radius: 8px;
padding: 15px;
}
.output-section {
background-color: #f0f0f0;
border-radius: 8px;
padding: 15px;
}
.output-section img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
.submit-button {
background-color: #007BFF;
border: none;
border-radius: 5px;
color: white;
padding: 12px 20px;
cursor: pointer;
transition: background-color 0.3s ease, transform 0.2s ease;
}
.submit-button:hover {
background-color: #0056b3;
transform: scale(1.05);
}
.error-message {
color: red;
text-align: center;
font-weight: bold;
}
.label {
font-weight: bold;
}
.download-link {
color: #007BFF;
font-weight: bold;
text-decoration: none;
}
.download-link:hover {
text-decoration: underline;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Xylaria Iris Image Generator
""",
elem_classes="title"
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="input-section"):
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3)
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
with gr.Column():
with gr.Group(elem_classes="output-section") as output_group:
image_output = gr.Image(label="Generated Image", interactive=False)
def on_generate_click(prompt):
output_group.elem_classes = ["output-section", "animate"]
image = generate_image(prompt) # Ignore the improved prompt
output_group.elem_classes = ["output-section"]
return image # Return only the generated image
generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output)
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output)
gr.Examples(
[["A dog"],
["A house on a hill"],
["A spaceship"]],
inputs=prompt_input
)
if __name__ == "__main__":
demo.queue().launch() |