|
import streamlit as st |
|
import requests |
|
import io |
|
import random |
|
from PIL import Image |
|
import os |
|
import json |
|
import time |
|
|
|
|
|
start_time_file = "start_time.txt" |
|
if not os.path.exists(start_time_file): |
|
|
|
with open(start_time_file, "w") as f: |
|
f.write(str(time.time())) |
|
else: |
|
try: |
|
with open(start_time_file, "r") as f: |
|
start_time = float(f.read()) |
|
|
|
if time.time() - start_time > 48 * 3600: |
|
st.write("Restarting the space after 48 hours of uptime...") |
|
os._exit(0) |
|
except Exception as e: |
|
|
|
with open(start_time_file, "w") as f: |
|
f.write(str(time.time())) |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" |
|
headers = {"Authorization": f"Bearer {os.getenv('HF')}"} |
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
|
|
|
|
if response.status_code != 200: |
|
try: |
|
error = response.json() |
|
return {"error": error.get("error", f"API Error: {response.status_code}")} |
|
except: |
|
return {"error": f"API Error: {response.status_code} - {response.text}"} |
|
|
|
|
|
if 'image' not in response.headers.get('Content-Type', ''): |
|
return {"error": "Unexpected non-image response from API"} |
|
|
|
return {"image": response.content} |
|
|
|
def generate_image(prompt): |
|
random_seed = random.randint(0, 4294967295) |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"seed": random_seed |
|
} |
|
} |
|
|
|
result = query(payload) |
|
|
|
if "error" in result: |
|
st.error(f"API Error: {result['error']}") |
|
return None |
|
|
|
try: |
|
return Image.open(io.BytesIO(result["image"])) |
|
except Exception as e: |
|
st.error(f"Failed to process image: {str(e)}") |
|
return None |
|
|
|
|
|
query_params = st.query_params |
|
prompt_from_url = query_params.get('text') |
|
|
|
if prompt_from_url: |
|
image = generate_image(prompt_from_url) |
|
if image: |
|
st.image(image, caption="Generated Image", use_container_width=True) |
|
|
|
|
|
img_buffer = io.BytesIO() |
|
image.save(img_buffer, format="PNG") |
|
img_buffer.seek(0) |
|
st.download_button( |
|
label="Download Image 📥", |
|
data=img_buffer, |
|
file_name="generated_image.png", |
|
mime="image/png" |
|
) |
|
else: |
|
st.info("Add a 'text' parameter to the URL to generate an image. Example: ?text=astronaut riding a horse") |