File size: 6,644 Bytes
c105678
 
3a58e1b
 
 
cd1cc5d
3a58e1b
c105678
 
 
 
e5edf92
 
8dac441
cd1cc5d
bf928c6
 
 
cb8f4d2
 
 
c105678
 
cb8f4d2
c105678
cb8f4d2
e5edf92
 
 
c105678
cb8f4d2
c105678
 
 
e5edf92
 
cb8f4d2
e5edf92
 
c105678
 
cb8f4d2
c105678
cb8f4d2
 
bf928c6
cd1cc5d
 
e5edf92
cb8f4d2
cd1cc5d
 
 
 
c105678
 
 
cb8f4d2
 
 
 
cd1cc5d
cb8f4d2
 
8dac441
cb8f4d2
 
cd1cc5d
cb8f4d2
 
 
 
 
 
 
 
bf928c6
cb8f4d2
 
bf928c6
 
cb8f4d2
8dac441
cb8f4d2
f515ccd
cb8f4d2
8dac441
c1fca59
cb8f4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f515ccd
afec0dd
c105678
 
 
 
 
 
 
cb8f4d2
 
3a58e1b
 
 
cb8f4d2
3a58e1b
cd1cc5d
3a58e1b
cb8f4d2
3a58e1b
 
 
 
cb8f4d2
3a58e1b
cb8f4d2
3a58e1b
 
cb8f4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
8dac441
cb8f4d2
 
 
 
1c91a49
cb8f4d2
 
 
c105678
cb8f4d2
 
8dac441
cb8f4d2
 
 
8dac441
 
cb8f4d2
8dac441
cb8f4d2
c1fca59
8dac441
cb8f4d2
 
8dac441
 
cb8f4d2
 
f515ccd
cb8f4d2
 
 
 
 
 
8dac441
cb8f4d2
8dac441
cb8f4d2
 
 
 
 
 
 
 
 
 
 
8dac441
 
cb8f4d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import torch
import time
import threading
import json
import gc
from flask import Flask, request, jsonify, send_file, Response, stream_with_context
from werkzeug.utils import secure_filename
from PIL import Image
import io
import zipfile
import uuid
import traceback
from huggingface_hub import snapshot_download
from flask_cors import CORS
import numpy as np
import trimesh
from transformers import pipeline
from diffusers import StableDiffusionZero123Pipeline
import imageio
from scipy.spatial.transform import Rotation

app = Flask(__name__)
CORS(app)

# Configuration
UPLOAD_FOLDER = '/tmp/uploads'
RESULTS_FOLDER = '/tmp/results'
CACHE_DIR = '/tmp/huggingface'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
VIEW_ANGLES = [(30, 0), (30, 90), (30, 180), (30, 270)]  # (elevation, azimuth)

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

# Environment variables for caching
os.environ['HF_HOME'] = CACHE_DIR
os.environ['TRANSFORMERS_CACHE'] = os.path.join(CACHE_DIR, 'transformers')

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024

# Global models
view_generator = None
depth_estimator = None
model_loaded = False
model_loading = False

processing_jobs = {}

class TimeoutError(Exception):
    pass

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def preprocess_image(image_path, size=256):
    img = Image.open(image_path).convert("RGB")
    img = img.resize((size, size), Image.LANCZOS)
    return img

def load_models():
    global view_generator, depth_estimator, model_loaded
    if model_loaded:
        return

    try:
        # Load view generator
        view_generator = StableDiffusionZero123Pipeline.from_pretrained(
            "stabilityai/stable-zero123-6dof",
            torch_dtype=torch.float16,
            cache_dir=CACHE_DIR
        ).to("cuda" if torch.cuda.is_available() else "cpu")

        # Load depth estimator
        depth_estimator = pipeline(
            "depth-estimation",
            model="Intel/dpt-hybrid-midas",
            cache_dir=CACHE_DIR
        )

        model_loaded = True
        print("Models loaded successfully")
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        raise

def generate_novel_views(image, num_views=4):
    views = []
    for elevation, azimuth in VIEW_ANGLES:
        result = view_generator(
            image,
            num_inference_steps=50,
            elevation=elevation,
            azimuth=azimuth,
            guidance_scale=3.0
        ).images[0]
        views.append((result, (elevation, azimuth)))
    return views

def depth_to_pointcloud(depth_map, pose, fov=60):
    h, w = depth_map.shape
    f = w / (2 * np.tan(np.radians(fov/2)))
    
    xx, yy = np.meshgrid(np.arange(w), np.arange(h))
    x = (xx - w/2) * depth_map / f
    y = (yy - h/2) * depth_map / f
    z = depth_map
    
    points = np.vstack((x.flatten(), y.flatten(), z.flatten())).T
    
    # Apply pose transformation
    rot = Rotation.from_euler('zyx', [pose[1], pose[0], 0], degrees=True)
    points = rot.apply(points)
    
    return points

def create_mesh_from_pointcloud(points, image):
    pcd = trimesh.PointCloud(points)
    scene = pcd.scene()
    mesh = scene.delaunay_3d.triangulate_pcd(pcd)
    mesh.visual.vertex_colors = image.resize((mesh.vertices.shape[0], 3))
    return mesh

@app.route('/convert', methods=['POST'])
def convert_image_to_3d():
    if 'image' not in request.files:
        return jsonify({"error": "No image provided"}), 400
    
    file = request.files['image']
    if not allowed_file(file.filename):
        return jsonify({"error": "Invalid file type"}), 400

    job_id = str(uuid.uuid4())
    output_dir = os.path.join(RESULTS_FOLDER, job_id)
    os.makedirs(output_dir, exist_ok=True)

    filename = secure_filename(file.filename)
    filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
    file.save(filepath)

    processing_jobs[job_id] = {
        'status': 'processing',
        'progress': 0,
        'result_url': None,
        'error': None
    }

    def process_image():
        try:
            # Preprocess input image
            img = preprocess_image(filepath)
            processing_jobs[job_id]['progress'] = 20

            # Generate novel views
            views = generate_novel_views(img)
            processing_jobs[job_id]['progress'] = 40

            # Process each view
            all_points = []
            for view_img, pose in views:
                # Estimate depth
                depth_result = depth_estimator(view_img)
                depth_map = np.array(depth_result["depth"])
                
                # Convert to point cloud
                points = depth_to_pointcloud(depth_map, pose)
                all_points.append(points)
                processing_jobs[job_id]['progress'] += 10

            # Combine point clouds
            combined_points = np.vstack(all_points)
            processing_jobs[job_id]['progress'] = 80

            # Create mesh
            mesh = create_mesh_from_pointcloud(combined_points, img)
            
            # Export
            obj_path = os.path.join(output_dir, "model.obj")
            mesh.export(obj_path)
            
            processing_jobs[job_id]['status'] = 'completed'
            processing_jobs[job_id]['result_url'] = f"/download/{job_id}"
            processing_jobs[job_id]['progress'] = 100

        except Exception as e:
            processing_jobs[job_id]['status'] = 'error'
            processing_jobs[job_id]['error'] = str(e)
        finally:
            if os.path.exists(filepath):
                os.remove(filepath)
            gc.collect()
            torch.cuda.empty_cache()

    thread = threading.Thread(target=process_image)
    thread.start()
    return jsonify({"job_id": job_id}), 202

@app.route('/download/<job_id>')
def download_model(job_id):
    if job_id not in processing_jobs or processing_jobs[job_id]['status'] != 'completed':
        return jsonify({"error": "Job not complete"}), 404
    
    obj_path = os.path.join(RESULTS_FOLDER, job_id, "model.obj")
    return send_file(obj_path, as_attachment=True)

@app.route('/progress/<job_id>')
def get_progress(job_id):
    job = processing_jobs.get(job_id, {})
    return jsonify({
        'status': job.get('status'),
        'progress': job.get('progress'),
        'error': job.get('error')
    })

if __name__ == '__main__':
    load_models()
    app.run(host='0.0.0.0', port=7860)