oliver-aizip commited on
Commit
8151596
·
1 Parent(s): 2d7d23d

switched models to using full context

Browse files
Files changed (1) hide show
  1. utils/models.py +4 -12
utils/models.py CHANGED
@@ -30,18 +30,13 @@ def generate_summaries(example, model_a_name, model_b_name):
30
  # Create a plain text version of the contexts for the models
31
  context_text = ""
32
  context_parts = []
33
- if "contexts" in example and example["contexts"]:
34
- for ctx in example["contexts"]:
35
  if isinstance(ctx, dict) and "content" in ctx:
36
  context_parts.append(ctx["content"])
37
  context_text = "\n---\n".join(context_parts)
38
  else:
39
- # Fallback to full contexts if highlighted contexts are not available
40
- if "full_contexts" in example:
41
- for ctx in example["full_contexts"]:
42
- if isinstance(ctx, dict) and "content" in ctx:
43
- context_parts.append(ctx["content"])
44
- context_text = "\n---\n".join(context_parts)
45
 
46
  # Pass 'Answerable' status to models (they might use it)
47
  answerable = example.get("Answerable", True)
@@ -85,17 +80,14 @@ def run_inference(model_name, context, question):
85
  ).to(device)
86
 
87
  input_length = actual_input.shape[1]
88
-
89
- # Create attention mask (1 for all tokens since we're not padding)
90
  attention_mask = torch.ones_like(actual_input).to(device)
91
 
92
  # Generate output
93
  with torch.inference_mode():
94
- # Disable gradient calculation for inference
95
  outputs = model.generate(
96
  actual_input,
97
  attention_mask=attention_mask,
98
- max_new_tokens=512, # Use max_new_tokens instead of max_length
99
  pad_token_id=tokenizer.pad_token_id,
100
  )
101
 
 
30
  # Create a plain text version of the contexts for the models
31
  context_text = ""
32
  context_parts = []
33
+ if "full_contexts" in example:
34
+ for ctx in example["full_contexts"]:
35
  if isinstance(ctx, dict) and "content" in ctx:
36
  context_parts.append(ctx["content"])
37
  context_text = "\n---\n".join(context_parts)
38
  else:
39
+ raise ValueError("No context found in the example.")
 
 
 
 
 
40
 
41
  # Pass 'Answerable' status to models (they might use it)
42
  answerable = example.get("Answerable", True)
 
80
  ).to(device)
81
 
82
  input_length = actual_input.shape[1]
 
 
83
  attention_mask = torch.ones_like(actual_input).to(device)
84
 
85
  # Generate output
86
  with torch.inference_mode():
 
87
  outputs = model.generate(
88
  actual_input,
89
  attention_mask=attention_mask,
90
+ max_new_tokens=512,
91
  pad_token_id=tokenizer.pad_token_id,
92
  )
93