File size: 4,876 Bytes
3e6cc30
 
b9ef8fe
3e6cc30
508fe78
16c0efc
3e6cc30
 
16c0efc
 
3e6cc30
 
f8caef0
3e6cc30
508fe78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8caef0
 
ef502da
16c0efc
ef502da
 
16c0efc
 
3e6cc30
16c0efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8caef0
3e6cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ef8fe
3e6cc30
 
 
 
 
 
 
 
16c0efc
c908ca0
 
 
41211aa
16c0efc
f8caef0
3e6cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16c0efc
3e6cc30
16c0efc
3e6cc30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16c0efc
3e6cc30
 
 
 
16c0efc
3e6cc30
 
 
 
 
 
 
16c0efc
 
3e6cc30
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os

import gradio as gr
import numpy as np
import requests
from dotenv import load_dotenv
from huggingface_hub import InferenceClient, login

load_dotenv()

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TOKEN = None


def download_image_locally(image_url: str, local_path: str):
    """
    Download an image from a URL to a local path.
    
    Args:
        image_url (str):
            The URL of the image to download.
        local_path (str):
            The path to save the downloaded image.
    """
    response = requests.get(image_url)
    with open(local_path, "wb") as f:
        f.write(response.content)
    return local_path

def get_token(oauth_token: gr.OAuthToken | None):
    global TOKEN
    if oauth_token and oauth_token.token:
        print("Received OAuth token, logging in...")
        TOKEN = oauth_token.token
    else:
        print("No OAuth token provided, using environment variable HF_TOKEN.")
        TOKEN = os.environ.get("HF_TOKEN")

def generate(prompt: str, seed: int =42, width: int =1024, height: int =1024, num_inference_steps: int = 25):
    """
    Generate an image from a prompt.
    Args:
        prompt (str):
                The prompt to generate an image from.
        seed (int, default=42):
            Seed for the random number generator.
        height (int,  default=1024):
            The height in pixels of the output image
        width (int, default=1024):
            The width in pixels of the output image
        num_inference_steps (int, default=25):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        
    """
    client = InferenceClient(provider="fal-ai", token=TOKEN)
    image = client.text_to_image(
        prompt=prompt,
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        seed=seed,
        model="black-forest-labs/FLUX.1-dev"
    )
    return image, seed
 
examples = [
    "a tiny astronaut hatching from an egg on the moon",
    "a cat holding a sign that says hello world",
    "an anime illustration of a wiener schnitzel",
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    demo.load(get_token, inputs=None, outputs=None)
    with gr.Sidebar():
        gr.Markdown("# Inference Provider")
        gr.Markdown("This Space showcases the black-forest-labs/FLUX.1-dev model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
        button = gr.LoginButton("Sign in")
        button.click(fn=get_token, inputs=[], outputs=[])
        
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# FLUX.1 [schnell] with fal-ai through HF Inference Providers ⚡
learn more about HF Inference Providers [here](https://huggingface.co/docs/inference-providers/index)""")
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False, format="png")
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=42,
            )
                        
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
            with gr.Row():
                
  
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=25,
                )
        
        gr.Examples(
            examples = examples,
            fn = generate,
            inputs = [prompt],
            outputs = [result, seed],
            cache_examples="lazy"
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = generate,
        inputs = [prompt, seed, width, height, num_inference_steps],
        outputs = [result, seed]
    )

demo.launch(mcp_server=True)