Spaces:
Runtime error
Runtime error
''' | |
import os | |
from gradio_client import Client, handle_file | |
from PIL import Image | |
# Initialize Gradio client | |
client = Client("http://localhost:7860") | |
# Path configuration | |
input_dir = "Aesthetics_X_Phone_720p_Images_Rec_Captioned_16_9" | |
output_dir = "processed_output_v2_single_thread" | |
os.makedirs(output_dir, exist_ok=True) | |
def process_single_image(png_path): | |
"""Process a single image and save results""" | |
try: | |
# Get base filename without extension | |
base_name = os.path.splitext(os.path.basename(png_path))[0] | |
# Corresponding text file path | |
txt_path = os.path.join(input_dir, f"{base_name}.txt") | |
print(f"Processing: {png_path}...", end=" ", flush=True) | |
# Process image through API (returns WEBP path) | |
webp_result = client.predict( | |
img=handle_file(png_path), | |
model_selection="v2", | |
api_name="/predict" | |
) | |
# Output paths | |
output_image_path = os.path.join(output_dir, f"{base_name}.png") | |
output_text_path = os.path.join(output_dir, f"{base_name}.txt") | |
# Convert WEBP to PNG | |
with Image.open(webp_result) as img: | |
img.save(output_image_path, "PNG") | |
# Copy corresponding text file if exists | |
if os.path.exists(txt_path): | |
with open(txt_path, 'r', encoding='utf-8') as src, \ | |
open(output_text_path, 'w', encoding='utf-8') as dst: | |
dst.write(src.read()) | |
print("Done") | |
return True | |
except Exception as e: | |
print(f"Failed: {str(e)}") | |
return False | |
def main(): | |
# Get all PNG files in input directory | |
png_files = sorted([ | |
os.path.join(input_dir, f) | |
for f in os.listdir(input_dir) | |
if f.lower().endswith('.png') | |
]) | |
print(f"Found {len(png_files)} PNG files to process") | |
# Process files one by one | |
success_count = 0 | |
for i, png_path in enumerate(png_files, 1): | |
print(f"\n[{i}/{len(png_files)}] ", end="") | |
if process_single_image(png_path): | |
success_count += 1 | |
print(f"\nProcessing complete! Success: {success_count}/{len(png_files)}") | |
if __name__ == "__main__": | |
main() | |
''' | |
from aura_sr import AuraSR | |
import gradio as gr | |
import spaces | |
class ZeroGPUAuraSR(AuraSR): | |
def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True): | |
import json | |
import torch | |
from pathlib import Path | |
from huggingface_hub import snapshot_download | |
# Check if model_id is a local file | |
if Path(model_id).is_file(): | |
local_file = Path(model_id) | |
if local_file.suffix == '.safetensors': | |
use_safetensors = True | |
elif local_file.suffix == '.ckpt': | |
use_safetensors = False | |
else: | |
raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") | |
# For local files, we need to provide the config separately | |
config_path = local_file.with_name('config.json') | |
if not config_path.exists(): | |
raise FileNotFoundError( | |
f"Config file not found: {config_path}. " | |
f"When loading from a local file, ensure that 'config.json' " | |
f"is present in the same directory as '{local_file.name}'. " | |
f"If you're trying to load a model from Hugging Face, " | |
f"please provide the model ID instead of a file path." | |
) | |
config = json.loads(config_path.read_text()) | |
hf_model_path = local_file.parent | |
else: | |
hf_model_path = Path(snapshot_download(model_id)) | |
config = json.loads((hf_model_path / "config.json").read_text()) | |
model = cls(config) | |
if use_safetensors: | |
try: | |
from safetensors.torch import load_file | |
checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) | |
except ImportError: | |
raise ImportError( | |
"The safetensors library is not installed. " | |
"Please install it with `pip install safetensors` " | |
"or use `use_safetensors=False` to load the model with PyTorch." | |
) | |
else: | |
checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) | |
model.upsampler.load_state_dict(checkpoint, strict=True) | |
return model | |
aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2") | |
aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR") | |
def predict(img, model_selection): | |
return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img) | |
demo = gr.Interface( | |
predict, | |
inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])], | |
outputs=gr.Image() | |
) | |
demo.launch(share = True) |