mac9087 commited on
Commit
e2068e3
·
verified ·
1 Parent(s): 7cd5166

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -9,6 +9,7 @@ import uuid
9
  import traceback
10
  from diffusers import ShapEImg2ImgPipeline
11
  from diffusers.utils import export_to_obj
 
12
 
13
  app = Flask(__name__)
14
 
@@ -38,12 +39,29 @@ pipe = None
38
  def load_model():
39
  global pipe
40
  if pipe is None:
41
- pipe = ShapEImg2ImgPipeline.from_pretrained(
42
- "openai/shap-e-img2img",
43
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
- cache_dir=CACHE_DIR # Explicitly set cache directory
45
- )
46
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return pipe
48
 
49
  def allowed_file(filename):
@@ -138,6 +156,7 @@ def convert_image_to_3d():
138
 
139
  @app.route('/', methods=['GET'])
140
  def index():
 
141
  return """
142
  <html>
143
  <head>
 
9
  import traceback
10
  from diffusers import ShapEImg2ImgPipeline
11
  from diffusers.utils import export_to_obj
12
+ from huggingface_hub import snapshot_download # Add this import
13
 
14
  app = Flask(__name__)
15
 
 
39
  def load_model():
40
  global pipe
41
  if pipe is None:
42
+ try:
43
+ model_name = "openai/shap-e-img2img"
44
+
45
+ # Explicitly download the model first to ensure all components are available
46
+ snapshot_download(
47
+ repo_id=model_name,
48
+ cache_dir=CACHE_DIR,
49
+ resume_download=True, # Resume partial downloads
50
+ )
51
+
52
+ # Now initialize the pipeline with the downloaded model
53
+ pipe = ShapEImg2ImgPipeline.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
56
+ cache_dir=CACHE_DIR,
57
+ # Avoid passing renderer parameter which causes the warning
58
+ )
59
+ pipe = pipe.to(device)
60
+ print(f"Model loaded successfully on {device}")
61
+ except Exception as e:
62
+ print(f"Error loading model: {str(e)}")
63
+ print(traceback.format_exc())
64
+ raise
65
  return pipe
66
 
67
  def allowed_file(filename):
 
156
 
157
  @app.route('/', methods=['GET'])
158
  def index():
159
+ # HTML content remains unchanged
160
  return """
161
  <html>
162
  <head>