SLM-RAG-Arena / utils /models.py
oliver-aizip's picture
remove some unneeded lines, fix pipe issue
2062515
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"
import spaces
from peft import PeftModel
import traceback
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
# #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
"Qwen3-0.6b": "qwen/qwen3-0.6b",
"Qwen3-1.7b": "qwen/qwen3-1.7b",
"Qwen3-4b": "qwen/qwen3-4b",
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
"icecream-3b": "aizip-dev/icecream-3b",
}
tokenizer_cache = {}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example and example["full_contexts"]:
for i, ctx in enumerate(example["full_contexts"]):
content = ""
# Extract content from either dict or string
if isinstance(ctx, dict) and "content" in ctx:
content = ctx["content"]
elif isinstance(ctx, str):
content = ctx
# Add document number if not already present
if not content.strip().startswith("Document"):
content = f"Document {i + 1}:\n{content}"
context_parts.append(content)
context_text = "\n\n".join(context_parts)
else:
# Provide a graceful fallback instead of raising an error
print("Warning: No full context found in the example, using empty context")
context_text = ""
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text or empty string if interrupted.
"""
# Check interrupt at the beginning
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
tokenizer_kwargs = {
"add_generation_prompt": True,
} # make sure qwen3 doesn't use thinking
generation_kwargs = {
"max_new_tokens": 512,
}
if "qwen3" in model_name.lower():
print(
f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False."
)
tokenizer_kwargs["enable_thinking"] = False
try:
print("REACHED HERE BEFORE tokenizer")
if model_name in tokenizer_cache:
tokenizer = tokenizer_cache[model_name]
else:
# Common arguments for tokenizer loading
tokenizer_load_args = {"padding_side": "left", "token": True}
actual_model_name_for_tokenizer = model_name
if "icecream" in model_name.lower():
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
tokenizer_cache[model_name] = tokenizer
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template
else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
return ""
print("REACHED HERE BEFORE pipe")
print(f"Loading model {model_name}...")
if "icecream" not in model_name.lower():
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map="cuda",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
model_kwargs={
"attn_implementation": "eager",
},
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/llama-3.2-3b-instruct",
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"aizip-dev/icecream-3b",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
text_input = format_rag_prompt(question, context, accepts_sys)
if "Gemma-3".lower() in model_name.lower():
print("REACHED HERE BEFORE GEN")
result = pipe(
text_input,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True},
)[0]["generated_text"]
result = result[-1]["content"]
elif "icecream" in model_name.lower():
print("ICECREAM")
model_inputs = tokenizer.apply_chat_template(
text_input,
tokenize=True,
return_tensors="pt",
return_dict=True,
**tokenizer_kwargs,
)
model_inputs = model_inputs.to(model.device)
input_ids = model_inputs.input_ids
attention_mask = model_inputs.attention_mask
prompt_tokens_length = input_ids.shape[1]
with torch.inference_mode():
# Check interrupt before generation
if generation_interrupt.is_set():
return ""
output_sequences = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id # Addresses the warning
)
generated_token_ids = output_sequences[0][prompt_tokens_length:]
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
else: # For other models
formatted = pipe.tokenizer.apply_chat_template(
text_input,
tokenize=False,
**tokenizer_kwargs,
)
input_length = len(formatted)
# Check interrupt before generation
outputs = pipe(
formatted,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True},
)
# print(outputs[0]['generated_text'])
result = outputs[0]["generated_text"][input_length:]
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
print(traceback.format_exc())
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result