Update app.py
Browse files
app.py
CHANGED
|
@@ -154,18 +154,21 @@ def install_flash_attn():
|
|
| 154 |
|
| 155 |
logging.info(f"Detected CUDA version: {cuda_version}")
|
| 156 |
|
| 157 |
-
# CUDA
|
| 158 |
-
if cuda_version.startswith("
|
|
|
|
|
|
|
| 159 |
flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
|
| 160 |
-
subprocess.run(
|
| 161 |
-
["pip", "install", flash_attn_url],
|
| 162 |
-
check=True,
|
| 163 |
-
capture_output=True
|
| 164 |
-
)
|
| 165 |
else:
|
| 166 |
logging.warning(f"Unsupported CUDA version: {cuda_version}, skipping flash-attn installation")
|
| 167 |
return False
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
logging.info("flash-attn installed successfully!")
|
| 170 |
return True
|
| 171 |
except Exception as e:
|
|
@@ -437,11 +440,12 @@ Stay with me forever, let our love just flow
|
|
| 437 |
)
|
| 438 |
|
| 439 |
# ์๋ฒ ์ค์ ์ผ๋ก ์คํ
|
| 440 |
-
|
|
|
|
| 441 |
server_name="0.0.0.0",
|
| 442 |
server_port=7860,
|
| 443 |
share=True,
|
| 444 |
-
enable_queue=True,
|
| 445 |
show_api=True,
|
| 446 |
-
show_error=True
|
| 447 |
-
)
|
|
|
|
|
|
| 154 |
|
| 155 |
logging.info(f"Detected CUDA version: {cuda_version}")
|
| 156 |
|
| 157 |
+
# CUDA ๋ฒ์ ๋ณ wheel ํ์ผ ์ ํ
|
| 158 |
+
if cuda_version.startswith("12.1"):
|
| 159 |
+
flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.5/flash_attn-2.7.5+cu121torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
|
| 160 |
+
elif cuda_version.startswith("11.8"):
|
| 161 |
flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
else:
|
| 163 |
logging.warning(f"Unsupported CUDA version: {cuda_version}, skipping flash-attn installation")
|
| 164 |
return False
|
| 165 |
|
| 166 |
+
subprocess.run(
|
| 167 |
+
["pip", "install", flash_attn_url],
|
| 168 |
+
check=True,
|
| 169 |
+
capture_output=True
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
logging.info("flash-attn installed successfully!")
|
| 173 |
return True
|
| 174 |
except Exception as e:
|
|
|
|
| 440 |
)
|
| 441 |
|
| 442 |
# ์๋ฒ ์ค์ ์ผ๋ก ์คํ
|
| 443 |
+
# ์๋ฒ ์ค์ ์ผ๋ก ์คํ
|
| 444 |
+
demo.queue(max_size=20).launch(
|
| 445 |
server_name="0.0.0.0",
|
| 446 |
server_port=7860,
|
| 447 |
share=True,
|
|
|
|
| 448 |
show_api=True,
|
| 449 |
+
show_error=True,
|
| 450 |
+
concurrency_count=2 # queue()๊ฐ ์๋ launch()์ ํ๋ผ๋ฏธํฐ๋ก ์ด๋
|
| 451 |
+
)
|