SLM-RAG-Arena / utils /models.py
oliver-aizip's picture
Handled interruption (#10)
1db9e92 verified
raw
history blame
4.05 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
"""
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if generation_interrupt.is_set():
return ""
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
).to(device)
text_input = format_rag_prompt(question, context, accepts_sys)
if generation_interrupt.is_set():
return ""
actual_input = tokenizer.apply_chat_template(
text_input,
return_tensors="pt",
tokenize=True,
max_length=2048,
add_generation_prompt=True,
).to(device)
input_length = actual_input.shape[1]
attention_mask = torch.ones_like(actual_input).to(device)
if generation_interrupt.is_set():
return ""
stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
with torch.inference_mode():
outputs = model.generate(
actual_input,
attention_mask=attention_mask,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
stopping_criteria=stopping_criteria
)
if generation_interrupt.is_set():
return ""
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
return result
except Exception as e:
print(f"Error in inference: {e}")
return f"Error generating response: {str(e)[:100]}..."
finally:
if 'model' in locals():
del model
if 'tokenizer' in locals():
del tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()