svjack's picture
Update app.py
e5f98db verified
'''
import os
from gradio_client import Client, handle_file
from PIL import Image
# Initialize Gradio client
client = Client("http://localhost:7860")
# Path configuration
input_dir = "Aesthetics_X_Phone_720p_Images_Rec_Captioned_16_9"
output_dir = "processed_output_v2_single_thread"
os.makedirs(output_dir, exist_ok=True)
def process_single_image(png_path):
"""Process a single image and save results"""
try:
# Get base filename without extension
base_name = os.path.splitext(os.path.basename(png_path))[0]
# Corresponding text file path
txt_path = os.path.join(input_dir, f"{base_name}.txt")
print(f"Processing: {png_path}...", end=" ", flush=True)
# Process image through API (returns WEBP path)
webp_result = client.predict(
img=handle_file(png_path),
model_selection="v2",
api_name="/predict"
)
# Output paths
output_image_path = os.path.join(output_dir, f"{base_name}.png")
output_text_path = os.path.join(output_dir, f"{base_name}.txt")
# Convert WEBP to PNG
with Image.open(webp_result) as img:
img.save(output_image_path, "PNG")
# Copy corresponding text file if exists
if os.path.exists(txt_path):
with open(txt_path, 'r', encoding='utf-8') as src, \
open(output_text_path, 'w', encoding='utf-8') as dst:
dst.write(src.read())
print("Done")
return True
except Exception as e:
print(f"Failed: {str(e)}")
return False
def main():
# Get all PNG files in input directory
png_files = sorted([
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.lower().endswith('.png')
])
print(f"Found {len(png_files)} PNG files to process")
# Process files one by one
success_count = 0
for i, png_path in enumerate(png_files, 1):
print(f"\n[{i}/{len(png_files)}] ", end="")
if process_single_image(png_path):
success_count += 1
print(f"\nProcessing complete! Success: {success_count}/{len(png_files)}")
if __name__ == "__main__":
main()
'''
from aura_sr import AuraSR
import gradio as gr
import spaces
class ZeroGPUAuraSR(AuraSR):
@classmethod
def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True):
import json
import torch
from pathlib import Path
from huggingface_hub import snapshot_download
# Check if model_id is a local file
if Path(model_id).is_file():
local_file = Path(model_id)
if local_file.suffix == '.safetensors':
use_safetensors = True
elif local_file.suffix == '.ckpt':
use_safetensors = False
else:
raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")
# For local files, we need to provide the config separately
config_path = local_file.with_name('config.json')
if not config_path.exists():
raise FileNotFoundError(
f"Config file not found: {config_path}. "
f"When loading from a local file, ensure that 'config.json' "
f"is present in the same directory as '{local_file.name}'. "
f"If you're trying to load a model from Hugging Face, "
f"please provide the model ID instead of a file path."
)
config = json.loads(config_path.read_text())
hf_model_path = local_file.parent
else:
hf_model_path = Path(snapshot_download(model_id))
config = json.loads((hf_model_path / "config.json").read_text())
model = cls(config)
if use_safetensors:
try:
from safetensors.torch import load_file
checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)
except ImportError:
raise ImportError(
"The safetensors library is not installed. "
"Please install it with `pip install safetensors` "
"or use `use_safetensors=False` to load the model with PyTorch."
)
else:
checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)
model.upsampler.load_state_dict(checkpoint, strict=True)
return model
aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2")
aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR")
@spaces.GPU()
def predict(img, model_selection):
return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img)
demo = gr.Interface(
predict,
inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])],
outputs=gr.Image()
)
demo.launch(share = True)