Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8151596
1
Parent(s):
2d7d23d
switched models to using full context
Browse files- utils/models.py +4 -12
utils/models.py
CHANGED
@@ -30,18 +30,13 @@ def generate_summaries(example, model_a_name, model_b_name):
|
|
30 |
# Create a plain text version of the contexts for the models
|
31 |
context_text = ""
|
32 |
context_parts = []
|
33 |
-
if "
|
34 |
-
for ctx in example["
|
35 |
if isinstance(ctx, dict) and "content" in ctx:
|
36 |
context_parts.append(ctx["content"])
|
37 |
context_text = "\n---\n".join(context_parts)
|
38 |
else:
|
39 |
-
|
40 |
-
if "full_contexts" in example:
|
41 |
-
for ctx in example["full_contexts"]:
|
42 |
-
if isinstance(ctx, dict) and "content" in ctx:
|
43 |
-
context_parts.append(ctx["content"])
|
44 |
-
context_text = "\n---\n".join(context_parts)
|
45 |
|
46 |
# Pass 'Answerable' status to models (they might use it)
|
47 |
answerable = example.get("Answerable", True)
|
@@ -85,17 +80,14 @@ def run_inference(model_name, context, question):
|
|
85 |
).to(device)
|
86 |
|
87 |
input_length = actual_input.shape[1]
|
88 |
-
|
89 |
-
# Create attention mask (1 for all tokens since we're not padding)
|
90 |
attention_mask = torch.ones_like(actual_input).to(device)
|
91 |
|
92 |
# Generate output
|
93 |
with torch.inference_mode():
|
94 |
-
# Disable gradient calculation for inference
|
95 |
outputs = model.generate(
|
96 |
actual_input,
|
97 |
attention_mask=attention_mask,
|
98 |
-
max_new_tokens=512,
|
99 |
pad_token_id=tokenizer.pad_token_id,
|
100 |
)
|
101 |
|
|
|
30 |
# Create a plain text version of the contexts for the models
|
31 |
context_text = ""
|
32 |
context_parts = []
|
33 |
+
if "full_contexts" in example:
|
34 |
+
for ctx in example["full_contexts"]:
|
35 |
if isinstance(ctx, dict) and "content" in ctx:
|
36 |
context_parts.append(ctx["content"])
|
37 |
context_text = "\n---\n".join(context_parts)
|
38 |
else:
|
39 |
+
raise ValueError("No context found in the example.")
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Pass 'Answerable' status to models (they might use it)
|
42 |
answerable = example.get("Answerable", True)
|
|
|
80 |
).to(device)
|
81 |
|
82 |
input_length = actual_input.shape[1]
|
|
|
|
|
83 |
attention_mask = torch.ones_like(actual_input).to(device)
|
84 |
|
85 |
# Generate output
|
86 |
with torch.inference_mode():
|
|
|
87 |
outputs = model.generate(
|
88 |
actual_input,
|
89 |
attention_mask=attention_mask,
|
90 |
+
max_new_tokens=512,
|
91 |
pad_token_id=tokenizer.pad_token_id,
|
92 |
)
|
93 |
|