File size: 11,267 Bytes
abbaf87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import time
from io import BytesIO
import modal
from huggingface_hub import login
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import sys
import requests
import os
from safetensors.torch import load_file

# Modal setup (same as your original)
cuda_version = "12.4.0"
flavor = "devel"
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
cuda_dev_image = modal.Image.from_registry(
    f"nvidia/cuda:{tag}", add_python="3.11"
).entrypoint([])

diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b"

flux_image = (
    cuda_dev_image.apt_install(
        "git",
        "libglib2.0-0",
        "libsm6",
        "libxrender1",
        "libxext6",
        "ffmpeg",
        "libgl1",
    )
    .pip_install(
        "invisible_watermark==0.2.0",
        "peft==0.10.0",
        "transformers==4.44.0",
        "huggingface_hub[hf_transfer]==0.26.2",
        "accelerate==0.33.0",
        "safetensors==0.4.4",
        "sentencepiece==0.2.0",
        "torch==2.5.0",
        f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}",
        "numpy<2",
        "fastapi==0.104.1",
        "uvicorn==0.24.0",
    )
    .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache"})
)

flux_image = flux_image.env(
    {
        "TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache",
        "TORCHINDUCTOR_FX_GRAPH_CACHE": "1",
    }
)

app = modal.App("flux-api-server", image=flux_image, secrets=[modal.Secret.from_name("huggingface-token")])

with flux_image.imports():
    import torch
    from diffusers import FluxPipeline

MINUTES = 60  # seconds
VARIANT = "dev"
NUM_INFERENCE_STEPS = 50

class ImageRequest(BaseModel):
    prompt: str
    num_inference_steps: int = 50
    width: int = 1024  # Add width parameter
    height: int = 1024  # Add height parameter

class ImageResponse(BaseModel):
    image_base64: str
    generation_time: float

@app.cls(

    gpu="H200",

    scaledown_window=20 * MINUTES,

    timeout=60 * MINUTES,

    volumes={

        "/cache": modal.Volume.from_name("hf-hub-cache", create_if_missing=True),

        "/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True),

        "/root/.triton": modal.Volume.from_name("triton-cache", create_if_missing=True),

        "/root/.inductor-cache": modal.Volume.from_name(

            "inductor-cache", create_if_missing=True

        ),

    },

)
class Model:
    compile: bool = modal.parameter(default=False)

    lora_loaded = False
    lora_path = "/cache/flux.1_lora_flyway_doodle-poster.safetensors"
    lora_url = "https://huggingface.co/RajputVansh/SG161222-DISTILLED-IITI-VANSH-RUHELA/resolve/main/flux.1_lora_flyway_doodle-poster.safetensors?download=true"

    def download_lora_from_url(self, url, save_path):
        """Download LoRA with proper error handling"""
        try:
            print(f"πŸ“₯ Downloading LoRA from {url}")
            response = requests.get(url, timeout=300)  # 5 minute timeout
            response.raise_for_status()  # Raise exception for bad status codes
            
            with open(save_path, "wb") as f:
                f.write(response.content)
            
            print(f"βœ… LoRA downloaded successfully to {save_path}")
            print(f"πŸ“Š File size: {len(response.content)} bytes")
            return True
        except Exception as e:
            print(f"❌ LoRA download failed: {str(e)}")
            return False

    def verify_lora_file(self, lora_path):
        """Verify that the LoRA file is valid"""
        try:
            if not os.path.exists(lora_path):
                return False, "File does not exist"
            
            file_size = os.path.getsize(lora_path)
            if file_size == 0:
                return False, "File is empty"
            
            # Try to load the file to verify it's valid
            try:
                load_file(lora_path)
                return True, f"Valid LoRA file ({file_size} bytes)"
            except Exception as e:
                return False, f"Invalid LoRA file: {str(e)}"
                
        except Exception as e:
            return False, f"Error verifying file: {str(e)}"

    @modal.enter()
    def enter(self):
        from huggingface_hub import login
        import os

        # Login to HuggingFace
        token = os.environ["huggingface_token"]
        login(token)

        # Download and verify LoRA
        if not os.path.exists(self.lora_path):
            print("πŸ“₯ LoRA not found, downloading...")
            download_success = self.download_lora_from_url(self.lora_url, self.lora_path)
            if not download_success:
                print("❌ Failed to download LoRA, continuing without it")
                self.lora_loaded = False
        else:
            print("πŸ“ LoRA file found in cache")

        # Verify LoRA file
        is_valid, message = self.verify_lora_file(self.lora_path)
        print(f"πŸ” LoRA verification: {message}")

        # Load the base model
        from diffusers import FluxPipeline
        import torch

        print("πŸš€ Loading Flux model...")
        pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16
        ).to("cuda")

        # Load LoRA if available and valid
        if is_valid:
            try:
                print(f"πŸ”„ Loading LoRA from {self.lora_path}")
                pipe.load_lora_weights(self.lora_path)
                print("βœ… LoRA successfully loaded!")
                self.lora_loaded = True
                
                # Test LoRA by checking if it affects the model
                print("πŸ§ͺ Testing LoRA integration...")
                # You could add a simple test generation here if needed
                
            except Exception as e:
                print(f"❌ LoRA loading failed: {str(e)}")
                self.lora_loaded = False
        else:
            print("⚠️ LoRA not loaded due to verification failure")
            self.lora_loaded = False

        # Optimize the pipeline
        self.pipe = optimize(pipe, compile=self.compile)
        
        print(f"🎯 Model ready! LoRA status: {'βœ… Loaded' if self.lora_loaded else '❌ Not loaded'}")


    @modal.method()
    def get_model_status(self) -> dict:
        """Get detailed model and LoRA status"""
        lora_file_info = {}
        if os.path.exists(self.lora_path):
            try:
                file_size = os.path.getsize(self.lora_path)
                lora_file_info = {
                    "exists": True,
                    "size_bytes": file_size,
                    "size_mb": round(file_size / (1024 * 1024), 2)
                }
            except:
                lora_file_info = {"exists": False}
        else:
            lora_file_info = {"exists": False}

        return {
            "status": "ready",
            "lora_loaded": self.lora_loaded,
            "lora_path": self.lora_path,
            "model_info": {
                "base_model": "black-forest-labs/FLUX.1-dev",
                "lora_file": lora_file_info,
                "lora_url": self.lora_url
            }
        }

    @modal.method()
    def inference(self, prompt: str, num_inference_steps: int = 50, width: int = 1024, height: int = 1024) -> dict:
        # Clean and prepare the prompt
        final_prompt = prompt
        
        print(f"🎨 Generating image:")
        print(f"   Original prompt: {prompt}")
        print(f"   Final prompt: {final_prompt}")
        print(f"   Dimensions: {width}x{height}")
        print(f"   LoRA status: {'βœ… Active' if self.lora_loaded else '❌ Inactive'}")
        
        start_time = time.time()
        
        out = self.pipe(
            final_prompt,
            output_type="pil",
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            max_sequence_length=512
        ).images[0]

        # Convert to base64
        byte_stream = BytesIO()
        out.save(byte_stream, format="PNG")
        image_bytes = byte_stream.getvalue()
        image_base64 = base64.b64encode(image_bytes).decode('utf-8')
        
        generation_time = time.time() - start_time
        print(f"βœ… Generated image in {generation_time:.2f} seconds")
        
        return {
            "image_base64": image_base64,
            "generation_time": generation_time,
            "final_prompt": final_prompt,
            "lora_used": self.lora_loaded
        }
# FastAPI server
fastapi_app = FastAPI(title="Flux Image Generation API")

# Initialize model instance
model_instance = Model(compile=False)

@fastapi_app.post("/generate", response_model=ImageResponse)
async def generate_image(request: ImageRequest):
    try:
        print(f"Received request: {request.prompt} at {request.width}x{request.height}")
        result = model_instance.inference.remote(
            request.prompt, 
            request.num_inference_steps,
            request.width,
            request.height
        )
        return ImageResponse(**result)
    except Exception as e:
        print(f"Error generating image: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@fastapi_app.get("/health")
async def health_check():
    return {"status": "healthy", "message": "Flux API server is running"}

@app.function(

    image=flux_image.pip_install("fastapi", "uvicorn"),

    keep_warm=1,

    timeout=60 * MINUTES,

)
@modal.asgi_app()
def fastapi_server():
    return fastapi_app

def optimize(pipe, compile=True):
    # fuse QKV projections in Transformer and VAE
    pipe.transformer.fuse_qkv_projections()
    pipe.vae.fuse_qkv_projections()

    # switch memory layout to Torch's preferred, channels_last
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.vae.to(memory_format=torch.channels_last)

    if not compile:
        return pipe

    # set torch compile flags
    config = torch._inductor.config
    config.disable_progress = False
    config.conv_1x1_as_mm = True
    config.coordinate_descent_tuning = True
    config.coordinate_descent_check_all_directions = True
    config.epilogue_fusion = False

    # compile the compute-intensive modules
    pipe.transformer = torch.compile(
        pipe.transformer, mode="max-autotune", fullgraph=True
    )
    pipe.vae.decode = torch.compile(
        pipe.vae.decode, mode="max-autotune", fullgraph=True
    )

    # trigger torch compilation
    print("πŸ”¦ Running torch compilation (may take up to 20 minutes)...")
    pipe(
        "dummy prompt to trigger torch compilation",
        output_type="pil",
        num_inference_steps=NUM_INFERENCE_STEPS,
    ).images[0]
    print("πŸ”¦ Finished torch compilation")

    return pipe

if __name__ == "__main__":
    print("Starting Modal Flux API server...")
    # This will be handled by Modal's deployment