Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,689 Bytes
6c63a2d eb1a863 b8ee0a2 eb1a863 6c63a2d ddaff53 eb1a863 ddaff53 1db9e92 97629be ddaff53 eb1a863 1898f4f eb1a863 1898f4f c5183c8 eb1a863 8a142a6 fd247b7 8a142a6 ddaff53 eb1a863 1db9e92 eb1a863 1db9e92 8a142a6 eb1a863 97629be 8a142a6 97629be 8a142a6 1db9e92 eb1a863 8a142a6 ddaff53 eb1a863 798ebc4 eb1a863 798ebc4 8a142a6 798ebc4 eb1a863 798ebc4 eb1a863 798ebc4 eb1a863 798ebc4 8a142a6 798ebc4 eb1a863 8a142a6 eb1a863 1db9e92 eb1a863 97629be eb1a863 97629be 1db9e92 eb1a863 97629be eb1a863 8a142a6 ddaff53 eb1a863 b8ee0a2 97629be ddaff53 97629be ddaff53 97629be 1db9e92 97629be 6b26b26 ddaff53 6b26b26 fd247b7 eb1a863 fd247b7 eb1a863 fd247b7 ddaff53 1db9e92 eb1a863 fd247b7 eb1a863 fd247b7 eb1a863 1db9e92 eb1a863 1db9e92 6b26b26 b8ee0a2 6b26b26 b8ee0a2 97629be eb1a863 6b26b26 1db9e92 eb1a863 2062515 eb1a863 2062515 eb1a863 2062515 eb1a863 2062515 eb1a863 1898f4f 2062515 1898f4f eb1a863 1898f4f eb1a863 97629be eb1a863 6b26b26 1db9e92 97629be eb1a863 97629be 6b26b26 1db9e92 97629be 1db9e92 97629be eb1a863 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"
import spaces
from peft import PeftModel
import traceback
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
StoppingCriteriaList,
)
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
# #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
"Qwen3-0.6b": "qwen/qwen3-0.6b",
"Qwen3-1.7b": "qwen/qwen3-1.7b",
"Qwen3-4b": "qwen/qwen3-4b",
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
"icecream-3b": "aizip-dev/icecream-3b",
}
tokenizer_cache = {}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example and example["full_contexts"]:
for i, ctx in enumerate(example["full_contexts"]):
content = ""
# Extract content from either dict or string
if isinstance(ctx, dict) and "content" in ctx:
content = ctx["content"]
elif isinstance(ctx, str):
content = ctx
# Add document number if not already present
if not content.strip().startswith("Document"):
content = f"Document {i + 1}:\n{content}"
context_parts.append(content)
context_text = "\n\n".join(context_parts)
else:
# Provide a graceful fallback instead of raising an error
print("Warning: No full context found in the example, using empty context")
context_text = ""
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text or empty string if interrupted.
"""
# Check interrupt at the beginning
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
tokenizer_kwargs = {
"add_generation_prompt": True,
} # make sure qwen3 doesn't use thinking
generation_kwargs = {
"max_new_tokens": 512,
}
if "qwen3" in model_name.lower():
print(
f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False."
)
tokenizer_kwargs["enable_thinking"] = False
try:
print("REACHED HERE BEFORE tokenizer")
if model_name in tokenizer_cache:
tokenizer = tokenizer_cache[model_name]
else:
# Common arguments for tokenizer loading
tokenizer_load_args = {"padding_side": "left", "token": True}
actual_model_name_for_tokenizer = model_name
if "icecream" in model_name.lower():
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
tokenizer_cache[model_name] = tokenizer
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template
else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
return ""
print("REACHED HERE BEFORE pipe")
print(f"Loading model {model_name}...")
if "icecream" not in model_name.lower():
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map="cuda",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
model_kwargs={
"attn_implementation": "eager",
},
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/llama-3.2-3b-instruct",
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"aizip-dev/icecream-3b",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
text_input = format_rag_prompt(question, context, accepts_sys)
if "Gemma-3".lower() in model_name.lower():
print("REACHED HERE BEFORE GEN")
result = pipe(
text_input,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True},
)[0]["generated_text"]
result = result[-1]["content"]
elif "icecream" in model_name.lower():
print("ICECREAM")
model_inputs = tokenizer.apply_chat_template(
text_input,
tokenize=True,
return_tensors="pt",
return_dict=True,
**tokenizer_kwargs,
)
model_inputs = model_inputs.to(model.device)
input_ids = model_inputs.input_ids
attention_mask = model_inputs.attention_mask
prompt_tokens_length = input_ids.shape[1]
with torch.inference_mode():
# Check interrupt before generation
if generation_interrupt.is_set():
return ""
output_sequences = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id # Addresses the warning
)
generated_token_ids = output_sequences[0][prompt_tokens_length:]
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
else: # For other models
formatted = pipe.tokenizer.apply_chat_template(
text_input,
tokenize=False,
**tokenizer_kwargs,
)
input_length = len(formatted)
# Check interrupt before generation
outputs = pipe(
formatted,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True},
)
# print(outputs[0]['generated_text'])
result = outputs[0]["generated_text"][input_length:]
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
print(traceback.format_exc())
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
|