SecureLLMSys commited on
Commit
be5ad12
·
1 Parent(s): 5612036
Files changed (1) hide show
  1. src/attribution/attntrace.py +5 -1
src/attribution/attntrace.py CHANGED
@@ -37,13 +37,17 @@ class AttnTraceAttribution(Attribution):
37
  #print(last_group,group, last_group_label,group_label)
38
  importances[feature_index[0]]+=(last_group_loss-group_loss)
39
  return importances
 
40
  @spaces.GPU
 
 
 
41
  def attribute(self, question: str, contexts: list, answer: str,explained_answer: str, customized_template: str = None):
42
  start_time = time.time()
43
  if self.llm.model!=None:
44
  self.model = self.llm.model
45
  else:
46
- self.model = self.llm._load_model_if_needed().to("cuda")
47
  self.layers = range(len(self.model.model.layers))
48
  model = self.model
49
  tokenizer = self.tokenizer
 
37
  #print(last_group,group, last_group_label,group_label)
38
  importances[feature_index[0]]+=(last_group_loss-group_loss)
39
  return importances
40
+
41
  @spaces.GPU
42
+ def load_model(self):
43
+ self.model = self.llm._load_model_if_needed().to("cuda")
44
+
45
  def attribute(self, question: str, contexts: list, answer: str,explained_answer: str, customized_template: str = None):
46
  start_time = time.time()
47
  if self.llm.model!=None:
48
  self.model = self.llm.model
49
  else:
50
+ self.load_model()
51
  self.layers = range(len(self.model.model.layers))
52
  model = self.model
53
  tokenizer = self.tokenizer