Spaces:
Running
on
Zero
Running
on
Zero
| import fal_client | |
| from PIL import Image | |
| import requests | |
| import io | |
| import os | |
| import base64 | |
| FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl", | |
| "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"} | |
| class FalModel(): | |
| def __init__(self, model_name, model_type): | |
| self.model_name = model_name | |
| self.model_type = model_type | |
| os.environ['FAL_KEY'] = os.environ['FalAPI'] | |
| def __call__(self, *args, **kwargs): | |
| def decode_data_url(data_url): | |
| # Find the start of the Base64 encoded data | |
| base64_start = data_url.find(",") + 1 | |
| if base64_start == 0: | |
| raise ValueError("Invalid data URL provided") | |
| # Extract the Base64 encoded data | |
| base64_string = data_url[base64_start:] | |
| # Decode the Base64 string | |
| decoded_bytes = base64.b64decode(base64_string) | |
| return decoded_bytes | |
| if self.model_type == "text2image": | |
| assert "prompt" in kwargs, "prompt is required for text2image model" | |
| handler = fal_client.submit( | |
| f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}", | |
| arguments={ | |
| "prompt": kwargs["prompt"] | |
| }, | |
| ) | |
| for event in handler.iter_events(with_logs=True): | |
| if isinstance(event, fal_client.InProgress): | |
| print('Request in progress') | |
| print(event.logs) | |
| result = handler.get() | |
| print(result) | |
| result_url = result['images'][0]['url'] | |
| if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]: | |
| result_url = io.BytesIO(decode_data_url(result_url)) | |
| result = Image.open(result_url) | |
| else: | |
| response = requests.get(result_url) | |
| result = Image.open(io.BytesIO(response.content)) | |
| return result | |
| elif self.model_type == "image2image": | |
| raise NotImplementedError("image2image model is not implemented yet") | |
| # assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" | |
| # if "image" in kwargs: | |
| # image_url = None | |
| # pass | |
| # handler = fal_client.submit( | |
| # f"fal-ai/{self.model_name}", | |
| # arguments={ | |
| # "image_url": image_url | |
| # }, | |
| # ) | |
| # | |
| # for event in handler.iter_events(): | |
| # if isinstance(event, fal_client.InProgress): | |
| # print('Request in progress') | |
| # print(event.logs) | |
| # | |
| # result = handler.get() | |
| # return result | |
| elif self.model_type == "text2video": | |
| assert "prompt" in kwargs, "prompt is required for text2video model" | |
| if self.model_name == 'AnimateDiff': | |
| fal_model_name = 'fast-animatediff/text-to-video' | |
| elif self.model_name == 'AnimateDiffTurbo': | |
| fal_model_name = 'fast-animatediff/turbo/text-to-video' | |
| elif self.model_name == 'StableVideoDiffusion': | |
| fal_model_name = 'fast-svd/text-to-video' | |
| else: | |
| raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") | |
| handler = fal_client.submit( | |
| f"fal-ai/{fal_model_name}", | |
| arguments={ | |
| "prompt": kwargs["prompt"] | |
| }, | |
| ) | |
| for event in handler.iter_events(with_logs=True): | |
| if isinstance(event, fal_client.InProgress): | |
| print('Request in progress') | |
| print(event.logs) | |
| result = handler.get() | |
| print("result video: ====") | |
| print(result) | |
| result_url = result['video']['url'] | |
| return result_url | |
| else: | |
| raise ValueError("model_type must be text2image or image2image") | |
| def load_fal_model(model_name, model_type): | |
| return FalModel(model_name, model_type) |