Spaces:
Running
on
Zero
Running
on
Zero
xformers when cuda available
Browse files- app.py +3 -2
- bytelatent/entropy_model.py +1 -1
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import spaces
|
|
|
|
| 2 |
import os
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
|
@@ -30,7 +31,7 @@ class Config:
|
|
| 30 |
|
| 31 |
# Bytelatent Specific
|
| 32 |
BLT_WEIGHTS_DIR: str = "hf-weights"
|
| 33 |
-
BLT_MAX_BYTES_FOR_DEMO:
|
| 34 |
|
| 35 |
# Gradio
|
| 36 |
DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
@@ -158,7 +159,7 @@ class BytelatentProcessor:
|
|
| 158 |
|
| 159 |
return highlighted_data, patch_count
|
| 160 |
|
| 161 |
-
def process(self, prompt: str, max_bytes:
|
| 162 |
"""Processes the prompt using the loaded Bytelatent model."""
|
| 163 |
status = ""
|
| 164 |
if not self.is_available or self.tokenizer is None or self.patcher is None:
|
|
|
|
| 1 |
import spaces
|
| 2 |
+
import math
|
| 3 |
import os
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
|
|
|
| 31 |
|
| 32 |
# Bytelatent Specific
|
| 33 |
BLT_WEIGHTS_DIR: str = "hf-weights"
|
| 34 |
+
BLT_MAX_BYTES_FOR_DEMO: float = math.inf # Limit for this specific demo's entropy model
|
| 35 |
|
| 36 |
# Gradio
|
| 37 |
DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
|
|
| 159 |
|
| 160 |
return highlighted_data, patch_count
|
| 161 |
|
| 162 |
+
def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
|
| 163 |
"""Processes the prompt using the loaded Bytelatent model."""
|
| 164 |
status = ""
|
| 165 |
if not self.is_available or self.tokenizer is None or self.patcher is None:
|
bytelatent/entropy_model.py
CHANGED
|
@@ -28,7 +28,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
|
|
| 28 |
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
| 29 |
vocab_size=model_params["vocab_size"],
|
| 30 |
attn_bias_type="causal",
|
| 31 |
-
attn_impl="sdpa",
|
| 32 |
sliding_window=512,
|
| 33 |
)
|
| 34 |
)
|
|
|
|
| 28 |
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
| 29 |
vocab_size=model_params["vocab_size"],
|
| 30 |
attn_bias_type="causal",
|
| 31 |
+
attn_impl="xformers" if torch.cuda.is_available() else "sdpa",
|
| 32 |
sliding_window=512,
|
| 33 |
)
|
| 34 |
)
|