File size: 4,876 Bytes
3e6cc30 b9ef8fe 3e6cc30 508fe78 16c0efc 3e6cc30 16c0efc 3e6cc30 f8caef0 3e6cc30 508fe78 f8caef0 ef502da 16c0efc ef502da 16c0efc 3e6cc30 16c0efc f8caef0 3e6cc30 b9ef8fe 3e6cc30 16c0efc c908ca0 41211aa 16c0efc f8caef0 3e6cc30 16c0efc 3e6cc30 16c0efc 3e6cc30 16c0efc 3e6cc30 16c0efc 3e6cc30 16c0efc 3e6cc30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import gradio as gr
import numpy as np
import requests
from dotenv import load_dotenv
from huggingface_hub import InferenceClient, login
load_dotenv()
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TOKEN = None
def download_image_locally(image_url: str, local_path: str):
"""
Download an image from a URL to a local path.
Args:
image_url (str):
The URL of the image to download.
local_path (str):
The path to save the downloaded image.
"""
response = requests.get(image_url)
with open(local_path, "wb") as f:
f.write(response.content)
return local_path
def get_token(oauth_token: gr.OAuthToken | None):
global TOKEN
if oauth_token and oauth_token.token:
print("Received OAuth token, logging in...")
TOKEN = oauth_token.token
else:
print("No OAuth token provided, using environment variable HF_TOKEN.")
TOKEN = os.environ.get("HF_TOKEN")
def generate(prompt: str, seed: int =42, width: int =1024, height: int =1024, num_inference_steps: int = 25):
"""
Generate an image from a prompt.
Args:
prompt (str):
The prompt to generate an image from.
seed (int, default=42):
Seed for the random number generator.
height (int, default=1024):
The height in pixels of the output image
width (int, default=1024):
The width in pixels of the output image
num_inference_steps (int, default=25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
"""
client = InferenceClient(provider="fal-ai", token=TOKEN)
image = client.text_to_image(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
seed=seed,
model="black-forest-labs/FLUX.1-dev"
)
return image, seed
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
with gr.Blocks(css=css) as demo:
demo.load(get_token, inputs=None, outputs=None)
with gr.Sidebar():
gr.Markdown("# Inference Provider")
gr.Markdown("This Space showcases the black-forest-labs/FLUX.1-dev model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
button = gr.LoginButton("Sign in")
button.click(fn=get_token, inputs=[], outputs=[])
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# FLUX.1 [schnell] with fal-ai through HF Inference Providers ⚡
learn more about HF Inference Providers [here](https://huggingface.co/docs/inference-providers/index)""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False, format="png")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
gr.Examples(
examples = examples,
fn = generate,
inputs = [prompt],
outputs = [result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn = generate,
inputs = [prompt, seed, width, height, num_inference_steps],
outputs = [result, seed]
)
demo.launch(mcp_server=True) |