v1
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# A100 Zero GPU
|
2 |
-
|
3 |
|
4 |
# TroL Package
|
5 |
import torch
|
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
|
|
18 |
from torchvision.transforms.functional import pil_to_tensor
|
19 |
|
20 |
# flash attention
|
21 |
-
|
22 |
-
|
23 |
|
24 |
# accel
|
25 |
accel = Accelerator()
|
@@ -55,7 +55,7 @@ def threading_function(inputs, image_token_number, streamer, device, model, toke
|
|
55 |
generation_kwargs.update({'use_cache': True})
|
56 |
return model.generate(**generation_kwargs)
|
57 |
|
58 |
-
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
61 |
# model selection
|
|
|
1 |
# A100 Zero GPU
|
2 |
+
import spaces
|
3 |
|
4 |
# TroL Package
|
5 |
import torch
|
|
|
18 |
from torchvision.transforms.functional import pil_to_tensor
|
19 |
|
20 |
# flash attention
|
21 |
+
import subprocess
|
22 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
23 |
|
24 |
# accel
|
25 |
accel = Accelerator()
|
|
|
55 |
generation_kwargs.update({'use_cache': True})
|
56 |
return model.generate(**generation_kwargs)
|
57 |
|
58 |
+
@spaces.GPU
|
59 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
60 |
|
61 |
# model selection
|