14likhit commited on
Commit
09616bc
·
verified ·
1 Parent(s): 00247d6

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -49
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import streamlit as st
2
  import torch
3
- from PIL import Image
4
  import io
5
  import os
6
- import subprocess
7
- import sys
8
 
9
  # Set page config
10
  st.set_page_config(
@@ -17,36 +15,6 @@ st.set_page_config(
17
  st.title("AI Portrait Generator")
18
  st.markdown("Generate beautiful portraits using the AWPortraitCN2 model")
19
 
20
- # Check and install compatible versions if needed
21
- @st.cache_resource
22
- def install_dependencies():
23
- try:
24
- # Try to import diffusers to see if it works
25
- import diffusers
26
- return True
27
- except ImportError:
28
- st.warning("Installing required packages. This may take a few minutes...")
29
- # Install specific versions known to work together
30
- subprocess.check_call([
31
- sys.executable, "-m", "pip", "install",
32
- "huggingface-hub==0.16.4",
33
- "diffusers==0.20.0",
34
- "transformers==4.32.0",
35
- "accelerate==0.21.0"
36
- ])
37
- return True
38
- except Exception as e:
39
- st.error(f"Failed to install dependencies: {e}")
40
- return False
41
-
42
- # Try to install compatible dependencies
43
- dependencies_installed = install_dependencies()
44
-
45
- # If dependencies installation failed, show message and exit
46
- if not dependencies_installed:
47
- st.error("Could not set up the required environment. Please check the logs.")
48
- st.stop()
49
-
50
  # Model parameters
51
  with st.sidebar:
52
  st.header("Generation Settings")
@@ -64,16 +32,19 @@ prompt = st.text_area(
64
  value="Masterpiece portrait of a beautiful young woman with flowing hair, detailed face, photorealistic, 8k, professional photography"
65
  )
66
 
67
- # Function to load model with proper dependencies
68
  @st.cache_resource
69
  def load_model():
70
  try:
71
- from diffusers import StableDiffusionPipeline
 
72
 
73
- pipeline = StableDiffusionPipeline.from_pretrained(
 
74
  "Shakker-Labs/AWPortraitCN2",
75
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
76
- use_safetensors=True
 
77
  )
78
 
79
  # Move to GPU if available
@@ -82,8 +53,28 @@ def load_model():
82
 
83
  return pipeline
84
  except Exception as e:
85
- st.error(f"Error loading model: {e}")
86
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Generate button
89
  if st.button("Generate Portrait", type="primary"):
@@ -93,7 +84,7 @@ if st.button("Generate Portrait", type="primary"):
93
  pipeline = load_model()
94
 
95
  if pipeline is None:
96
- st.error("Failed to load the model. Please check the logs.")
97
  st.stop()
98
 
99
  # Set seed if specified
@@ -127,13 +118,17 @@ if st.button("Generate Portrait", type="primary"):
127
  )
128
 
129
  except Exception as e:
130
- st.error(f"An error occurred: {e}")
131
- st.info("Make sure you have enough GPU memory and the required dependencies installed.")
132
 
133
- # Add requirements info at the bottom
134
- st.markdown("---")
135
- st.markdown("""
136
- ### About This App
137
- This app uses the AWPortraitCN2 model to generate AI portraits based on your text prompts.
138
- Adjust the settings in the sidebar to customize your generation.
139
- """)
 
 
 
 
 
1
  import streamlit as st
2
  import torch
 
3
  import io
4
  import os
5
+ from PIL import Image
 
6
 
7
  # Set page config
8
  st.set_page_config(
 
15
  st.title("AI Portrait Generator")
16
  st.markdown("Generate beautiful portraits using the AWPortraitCN2 model")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Model parameters
19
  with st.sidebar:
20
  st.header("Generation Settings")
 
32
  value="Masterpiece portrait of a beautiful young woman with flowing hair, detailed face, photorealistic, 8k, professional photography"
33
  )
34
 
35
+ # Function to load model using modern API
36
  @st.cache_resource
37
  def load_model():
38
  try:
39
+ # Import these inside the function to handle errors gracefully
40
+ from diffusers import AutoPipelineForText2Image
41
 
42
+ # Use AutoPipeline which is more compatible with newer versions
43
+ pipeline = AutoPipelineForText2Image.from_pretrained(
44
  "Shakker-Labs/AWPortraitCN2",
45
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
46
+ use_safetensors=True,
47
+ variant="fp16" if torch.cuda.is_available() else None
48
  )
49
 
50
  # Move to GPU if available
 
53
 
54
  return pipeline
55
  except Exception as e:
56
+ st.error(f"Error loading model: {str(e)}")
57
+ st.info("Debug info: Using modern API with AutoPipelineForText2Image")
58
+
59
+ # Fallback to traditional StableDiffusionPipeline if needed
60
+ try:
61
+ st.info("Trying alternative method...")
62
+ from diffusers import StableDiffusionPipeline
63
+
64
+ pipeline = StableDiffusionPipeline.from_pretrained(
65
+ "Shakker-Labs/AWPortraitCN2",
66
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
67
+ use_safetensors=True
68
+ )
69
+
70
+ # Move to GPU if available
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+ pipeline = pipeline.to(device)
73
+
74
+ return pipeline
75
+ except Exception as e2:
76
+ st.error(f"Alternative method also failed: {str(e2)}")
77
+ return None
78
 
79
  # Generate button
80
  if st.button("Generate Portrait", type="primary"):
 
84
  pipeline = load_model()
85
 
86
  if pipeline is None:
87
+ st.error("Failed to load the model. Check the logs for details.")
88
  st.stop()
89
 
90
  # Set seed if specified
 
118
  )
119
 
120
  except Exception as e:
121
+ st.error(f"An error occurred during generation: {str(e)}")
122
+ st.info("Make sure you have enough GPU memory (T4 or better recommended).")
123
 
124
+ # Add hardware info at the bottom
125
+ if torch.cuda.is_available():
126
+ st.markdown("---")
127
+ st.markdown(f"""
128
+ ### Hardware Info
129
+ - Running on: {torch.cuda.get_device_name(0)}
130
+ - Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB
131
+ """)
132
+ else:
133
+ st.markdown("---")
134
+ st.markdown("⚠️ Running on CPU. For better performance, use a GPU runtime.")