Spaces:
Paused
Paused
Upload 4 files
Browse files- .gitattributes +1 -0
- README.md +6 -5
- app.py +182 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TripoSR
|
| 3 |
+
emoji: 🐳
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.20.1
|
| 8 |
+
python_version: 3.10.13
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
license: mit
|
app.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import boto3
|
| 4 |
+
import json
|
| 5 |
+
import shlex
|
| 6 |
+
import subprocess
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
import base64
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import numpy as np
|
| 12 |
+
import rembg
|
| 13 |
+
import spaces
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from functools import partial
|
| 17 |
+
import io
|
| 18 |
+
|
| 19 |
+
# s3 = boto3.client(
|
| 20 |
+
# 's3',
|
| 21 |
+
# aws_access_key_id="AKIAZW3QSPMIH4RF42UA",
|
| 22 |
+
# aws_secret_access_key="iH8UDkDS2tMuB0GUiyq+QpM0jTxm+00mhDz0PgZz",
|
| 23 |
+
# region_name='us-east-1'
|
| 24 |
+
# )
|
| 25 |
+
|
| 26 |
+
subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
|
| 27 |
+
|
| 28 |
+
from tsr.system import TSR
|
| 29 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
HEADER = """FRAME AI"""
|
| 33 |
+
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
device = "cuda:0"
|
| 36 |
+
else:
|
| 37 |
+
device = "cpu"
|
| 38 |
+
|
| 39 |
+
model = TSR.from_pretrained(
|
| 40 |
+
"stabilityai/TripoSR",
|
| 41 |
+
config_name="config.yaml",
|
| 42 |
+
weight_name="model.ckpt",
|
| 43 |
+
)
|
| 44 |
+
model.renderer.set_chunk_size(131072)
|
| 45 |
+
model.to(device)
|
| 46 |
+
|
| 47 |
+
rembg_session = rembg.new_session()
|
| 48 |
+
|
| 49 |
+
def generate_image_from_text(pos_prompt):
|
| 50 |
+
# bedrock_runtime = boto3.client(region_name = 'us-east-1', service_name='bedrock-runtime')
|
| 51 |
+
bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = "AKIAZW3QSPMIH4RF42UA", aws_secret_access_key = "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", region_name='us-east-1')
|
| 52 |
+
parameters = {'text_prompts': [{'text':pos_prompt, 'weight':1},
|
| 53 |
+
{'text': """Blurry, unnatural, ugly, pixelated obscure, dull, artifacts, duplicate, bad quality, low resolution, cropped, out of frame, out of focus""", 'weight': -1}],
|
| 54 |
+
'cfg_scale': 7, 'seed': 0, 'samples': 1}
|
| 55 |
+
request_body = json.dumps(parameters)
|
| 56 |
+
response = bedrock_runtime.invoke_model(body=request_body,modelId = 'stability.stable-diffusion-xl-v1')
|
| 57 |
+
response_body = json.loads(response.get('body').read())
|
| 58 |
+
base64_image_data = base64.b64decode(response_body['artifacts'][0]['base64'])
|
| 59 |
+
|
| 60 |
+
return Image.open(io.BytesIO(base64_image_data))
|
| 61 |
+
|
| 62 |
+
def check_input_image(input_image):
|
| 63 |
+
if input_image is None:
|
| 64 |
+
raise gr.Error("No image uploaded!")
|
| 65 |
+
|
| 66 |
+
def preprocess(input_image, do_remove_background, foreground_ratio):
|
| 67 |
+
def fill_background(image):
|
| 68 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 69 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
| 70 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
| 71 |
+
return image
|
| 72 |
+
|
| 73 |
+
if do_remove_background:
|
| 74 |
+
image = input_image.convert("RGB")
|
| 75 |
+
image = remove_background(image, rembg_session)
|
| 76 |
+
image = resize_foreground(image, foreground_ratio)
|
| 77 |
+
image = fill_background(image)
|
| 78 |
+
else:
|
| 79 |
+
image = input_image
|
| 80 |
+
if image.mode == "RGBA":
|
| 81 |
+
image = fill_background(image)
|
| 82 |
+
return image
|
| 83 |
+
|
| 84 |
+
@spaces.GPU
|
| 85 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
| 86 |
+
scene_codes = model(image, device=device)
|
| 87 |
+
mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
|
| 88 |
+
mesh = to_gradio_3d_orientation(mesh)
|
| 89 |
+
|
| 90 |
+
mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
|
| 91 |
+
mesh.export(mesh_path_glb.name)
|
| 92 |
+
|
| 93 |
+
mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
|
| 94 |
+
mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
|
| 95 |
+
mesh.export(mesh_path_obj.name)
|
| 96 |
+
|
| 97 |
+
return mesh_path_obj.name, mesh_path_glb.name
|
| 98 |
+
|
| 99 |
+
def run_example(text_prompt, do_remove_background, foreground_ratio, mc_resolution):
|
| 100 |
+
# Step 1: Generate the image from text prompt
|
| 101 |
+
image_pil = generate_image_from_text(text_prompt)
|
| 102 |
+
|
| 103 |
+
# Step 2: Preprocess the image
|
| 104 |
+
preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
|
| 105 |
+
|
| 106 |
+
# Step 3: Generate the 3D model
|
| 107 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution, ["obj", "glb"])
|
| 108 |
+
|
| 109 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
| 110 |
+
|
| 111 |
+
with gr.Blocks() as demo:
|
| 112 |
+
gr.Markdown(HEADER)
|
| 113 |
+
with gr.Row(variant="panel"):
|
| 114 |
+
with gr.Column():
|
| 115 |
+
with gr.Row():
|
| 116 |
+
text_prompt = gr.Textbox(
|
| 117 |
+
label="Text Prompt",
|
| 118 |
+
placeholder="Enter a text prompt for image generation"
|
| 119 |
+
)
|
| 120 |
+
input_image = gr.Image(
|
| 121 |
+
label="Generated Image",
|
| 122 |
+
image_mode="RGBA",
|
| 123 |
+
sources="upload",
|
| 124 |
+
type="pil",
|
| 125 |
+
elem_id="content_image",
|
| 126 |
+
visible=False # Hidden since we generate the image from text
|
| 127 |
+
)
|
| 128 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
| 129 |
+
with gr.Row():
|
| 130 |
+
with gr.Group():
|
| 131 |
+
do_remove_background = gr.Checkbox(
|
| 132 |
+
label="Remove Background", value=True
|
| 133 |
+
)
|
| 134 |
+
foreground_ratio = gr.Slider(
|
| 135 |
+
label="Foreground Ratio",
|
| 136 |
+
minimum=0.5,
|
| 137 |
+
maximum=1.0,
|
| 138 |
+
value=0.85,
|
| 139 |
+
step=0.05,
|
| 140 |
+
)
|
| 141 |
+
mc_resolution = gr.Slider(
|
| 142 |
+
label="Marching Cubes Resolution",
|
| 143 |
+
minimum=32,
|
| 144 |
+
maximum=320,
|
| 145 |
+
value=256,
|
| 146 |
+
step=32
|
| 147 |
+
)
|
| 148 |
+
with gr.Row():
|
| 149 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
| 150 |
+
with gr.Column():
|
| 151 |
+
with gr.Tab("OBJ"):
|
| 152 |
+
output_model_obj = gr.Model3D(
|
| 153 |
+
label="Output Model (OBJ Format)",
|
| 154 |
+
interactive=False,
|
| 155 |
+
)
|
| 156 |
+
gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
|
| 157 |
+
with gr.Tab("GLB"):
|
| 158 |
+
output_model_glb = gr.Model3D(
|
| 159 |
+
label="Output Model (GLB Format)",
|
| 160 |
+
interactive=False,
|
| 161 |
+
)
|
| 162 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
| 163 |
+
with gr.Row(variant="panel"):
|
| 164 |
+
gr.Examples(
|
| 165 |
+
examples=[
|
| 166 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
| 167 |
+
],
|
| 168 |
+
inputs=[text_prompt],
|
| 169 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
| 170 |
+
cache_examples=True,
|
| 171 |
+
fn=partial(run_example, do_remove_background=True, foreground_ratio=0.85, mc_resolution=256),
|
| 172 |
+
label="Examples",
|
| 173 |
+
examples_per_page=20
|
| 174 |
+
)
|
| 175 |
+
submit.click(fn=check_input_image, inputs=[text_prompt]).success(
|
| 176 |
+
fn=run_example,
|
| 177 |
+
inputs=[text_prompt, do_remove_background, foreground_ratio, mc_resolution],
|
| 178 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
demo.queue(max_size=10)
|
| 182 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
omegaconf==2.3.0
|
| 2 |
+
Pillow==10.1.0
|
| 3 |
+
einops==0.7.0
|
| 4 |
+
torch==2.0.1
|
| 5 |
+
transformers==4.35.0
|
| 6 |
+
trimesh==4.0.5
|
| 7 |
+
rembg
|
| 8 |
+
huggingface-hub
|
| 9 |
+
gradio
|
| 10 |
+
boto3
|