FagerholmEmil commited on
Commit
4047c81
·
1 Parent(s): 2ea0900

Improve device handling and add fallback to CPU

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -6,15 +6,18 @@ from transformer_lens.utils import to_numpy
6
  import torch
7
 
8
  model_name = "gpt2-small"
 
 
9
  model = HookedTransformer.from_pretrained(
10
  model_name,
11
- device="cuda"
12
  )
13
 
14
- if not torch.cuda.is_available():
15
- raise RuntimeError("This application requires a GPU to run. No GPU was detected.")
16
-
17
- print(f"Using GPU: {torch.cuda.get_device_name(0)}")
 
18
 
19
  def get_neuron_acts(text, layer, neuron_index):
20
  cache = {}
 
6
  import torch
7
 
8
  model_name = "gpt2-small"
9
+ # Determine device based on CUDA availability
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model = HookedTransformer.from_pretrained(
12
  model_name,
13
+ device=device
14
  )
15
 
16
+ # Only print GPU info if using CUDA
17
+ if device == "cuda":
18
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
19
+ else:
20
+ print("Using CPU")
21
 
22
  def get_neuron_acts(text, layer, neuron_index):
23
  cache = {}