Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,047 Bytes
ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 665e5a3 1db9e92 8a142a6 ddaff53 1db9e92 8a142a6 1db9e92 ddaff53 8a142a6 ddaff53 8151596 8a142a6 8151596 ddaff53 8a142a6 1db9e92 ddaff53 1db9e92 ddaff53 8a142a6 ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 1db9e92 ddaff53 1db9e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
"""
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if generation_interrupt.is_set():
return ""
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
).to(device)
text_input = format_rag_prompt(question, context, accepts_sys)
if generation_interrupt.is_set():
return ""
actual_input = tokenizer.apply_chat_template(
text_input,
return_tensors="pt",
tokenize=True,
max_length=2048,
add_generation_prompt=True,
).to(device)
input_length = actual_input.shape[1]
attention_mask = torch.ones_like(actual_input).to(device)
if generation_interrupt.is_set():
return ""
stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
with torch.inference_mode():
outputs = model.generate(
actual_input,
attention_mask=attention_mask,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
stopping_criteria=stopping_criteria
)
if generation_interrupt.is_set():
return ""
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
return result
except Exception as e:
print(f"Error in inference: {e}")
return f"Error generating response: {str(e)[:100]}..."
finally:
if 'model' in locals():
del model
if 'tokenizer' in locals():
del tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache() |