Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
444ccdb
1
Parent(s):
3a7a5c6
update
Browse files- src/attribution/attntrace.py +1 -1
- src/models/Llama.py +7 -6
src/attribution/attntrace.py
CHANGED
@@ -41,7 +41,7 @@ class AttnTraceAttribution(Attribution):
|
|
41 |
if self.llm.model!=None:
|
42 |
self.model = self.llm.model
|
43 |
else:
|
44 |
-
self.model = self.llm._load_model_if_needed()
|
45 |
self.layers = range(len(self.model.model.layers))
|
46 |
model = self.model
|
47 |
tokenizer = self.tokenizer
|
|
|
41 |
if self.llm.model!=None:
|
42 |
self.model = self.llm.model
|
43 |
else:
|
44 |
+
self.model = self.llm._load_model_if_needed().to("cuda")
|
45 |
self.layers = range(len(self.model.model.layers))
|
46 |
model = self.model
|
47 |
tokenizer = self.tokenizer
|
src/models/Llama.py
CHANGED
@@ -24,17 +24,18 @@ class Llama(Model):
|
|
24 |
]
|
25 |
|
26 |
def _load_model_if_needed(self):
|
27 |
-
if self.
|
28 |
-
|
29 |
self.name,
|
30 |
torch_dtype=torch.bfloat16,
|
31 |
-
|
32 |
-
|
33 |
)
|
34 |
-
|
|
|
35 |
|
36 |
def query(self, msg, max_tokens=128000):
|
37 |
-
model = self._load_model_if_needed()
|
38 |
messages = self.messages
|
39 |
messages[1]["content"] = msg
|
40 |
|
|
|
24 |
]
|
25 |
|
26 |
def _load_model_if_needed(self):
|
27 |
+
if self._model is None:
|
28 |
+
model = AutoModelForCausalLM.from_pretrained(
|
29 |
self.name,
|
30 |
torch_dtype=torch.bfloat16,
|
31 |
+
token=self.hf_token,
|
32 |
+
device_map="auto", # or omit entirely to default to CPU
|
33 |
)
|
34 |
+
self._model = model
|
35 |
+
return self._model
|
36 |
|
37 |
def query(self, msg, max_tokens=128000):
|
38 |
+
model = self._load_model_if_needed().to("cuda")
|
39 |
messages = self.messages
|
40 |
messages[1]["content"] = msg
|
41 |
|