FagerholmEmil commited on
Commit
2ea0900
·
1 Parent(s): 53b4a79

Add GPU support and runtime check for CUDA availability

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -3,9 +3,18 @@ import os
3
  import gradio as gr
4
  from transformer_lens import HookedTransformer
5
  from transformer_lens.utils import to_numpy
 
6
 
7
  model_name = "gpt2-small"
8
- model = HookedTransformer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
9
 
10
  def get_neuron_acts(text, layer, neuron_index):
11
  cache = {}
@@ -80,7 +89,7 @@ Nested loops:
80
 
81
  The moon glows silver, wanes to shadow.
82
  Patterns persist: 11, 22, 33—harmonic echoes.
83
- Reshape,” calls the river, reflect, refract, renew.”
84
  Yellow hexagons tessellate, shifting into orange octagons.
85
  1/3 -> 1/9 -> 1/27: recursive reduction spirals infinitely.
86
 
@@ -117,7 +126,7 @@ Symmetry hums:
117
  Palindromes—"radar", "level", "madam"—appear and fade.
118
  Blue fades to white, white dissolves to black.
119
  Sequences echo: 1, 10, 100, 1000…
120
- Cycle,” whispers the clock, count forward, reverse.""" # Shortened for example
121
  default_layer = 1
122
  default_neuron_index = 1
123
  default_max_val = 4.0
 
3
  import gradio as gr
4
  from transformer_lens import HookedTransformer
5
  from transformer_lens.utils import to_numpy
6
+ import torch
7
 
8
  model_name = "gpt2-small"
9
+ model = HookedTransformer.from_pretrained(
10
+ model_name,
11
+ device="cuda"
12
+ )
13
+
14
+ if not torch.cuda.is_available():
15
+ raise RuntimeError("This application requires a GPU to run. No GPU was detected.")
16
+
17
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
18
 
19
  def get_neuron_acts(text, layer, neuron_index):
20
  cache = {}
 
89
 
90
  The moon glows silver, wanes to shadow.
91
  Patterns persist: 11, 22, 33—harmonic echoes.
92
+ "Reshape," calls the river, "reflect, refract, renew."
93
  Yellow hexagons tessellate, shifting into orange octagons.
94
  1/3 -> 1/9 -> 1/27: recursive reduction spirals infinitely.
95
 
 
126
  Palindromes—"radar", "level", "madam"—appear and fade.
127
  Blue fades to white, white dissolves to black.
128
  Sequences echo: 1, 10, 100, 1000…
129
+ "Cycle," whispers the clock, "count forward, reverse.""" # Shortened for example
130
  default_layer = 1
131
  default_neuron_index = 1
132
  default_max_val = 4.0