"""Handles loading and running of models.""" from calendar import c import json import math import os import re import warnings from time import sleep, time import spaces from dotenv import load_dotenv from logger import logger warnings.filterwarnings("ignore") os.environ["VLLM_LOGGING_LEVEL"] = "ERROR" load_dotenv() safe_token = "No" risky_token = "Yes" nlogprobs = 20 inference_engine = os.getenv("INFERENCE_ENGINE", "TORCH") logger.debug(f"Inference engine is: {inference_engine}") if inference_engine == "TORCH": import torch from transformers import AutoTokenizer from vllm import LLM, SamplingParams from torch.nn.functional import softmax from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel # backend_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" backend_device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug(f"Backend device is: {backend_device}") model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.2-3b-a800m") logger.debug(f"model_path is {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) device = torch.device("cpu") model = AutoModelForCausalLM.from_pretrained(model_path) model = model.to(device).eval() def get_probablities(logprobs): safe_token_prob = 1e-50 unsafe_token_prob = 1e-50 for gen_token_i in logprobs: for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]): decoded_token = tokenizer.convert_ids_to_tokens(index) if decoded_token.strip().lower() == safe_token.lower(): safe_token_prob += math.exp(logprob) if decoded_token.strip().lower() == risky_token.lower(): unsafe_token_prob += math.exp(logprob) probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0) return probabilities def parse_output(output_ids, input_len): label, prob_of_risk = None, None if nlogprobs > 0: list_index_logprobs_i = [ torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output_ids.scores)[:-1] ] if list_index_logprobs_i is not None: prob = get_probablities(list_index_logprobs_i) prob_of_risk = round(prob[1].item(), 3) generated_text = tokenizer.decode(output_ids.sequences[:, input_len:][0], skip_special_tokens=True).strip() res = re.search(r"^\w+", generated_text, re.MULTILINE).group(0).strip() if risky_token.lower() == res.lower(): label = risky_token elif safe_token.lower() == res.lower(): label = safe_token else: label = "Failed" confidence_level = re.search(r" (.*?) ", generated_text).group(1).strip() certainty = prob_of_risk if prob_of_risk > 0.5 else 1 - prob_of_risk return label, confidence_level, prob_of_risk, certainty @spaces.GPU def get_prompt(messages, criteria_name, criteria_description = None): """Todo""" logger.debug("Creating prompt for the model.") logger.debug(f"Messages are: {json.dumps(messages, indent=2)}") if criteria_name == "general_harm": criteria_name = "harm" elif criteria_name == "function_calling_hallucination": criteria_name = "function_call" logger.debug("Criteria name was changed too: " + criteria_name) guardian_config = {"risk_name": criteria_name} if criteria_description is not None: guardian_config['risk_definition'] = criteria_description logger.debug(f"guardian_config is: {guardian_config}") prompt = tokenizer.apply_chat_template( messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True, ) logger.debug(f"Prompt is:\n{prompt}") return prompt @spaces.GPU def get_guardian_response(messages, criteria_name, criteria_description=None): start = time() if criteria_name == "general_harm": criteria_name = "harm" elif criteria_name == "function_calling_hallucination": criteria_name = "function_call" logger.debug(f"Messages are: {json.dumps(messages, indent=2)}") if inference_engine == "MOCK": logger.debug("Returning mocked model result.") sleep(1) label, confidence_level, prob_of_risk, certainty = "Yes", 'High', 0.97, 0.97 elif inference_engine == "TORCH": guardian_config = {"risk_name": criteria_name} if criteria_description is not None: guardian_config['risk_definition'] = criteria_description logger.debug(f"guardian_config is: {guardian_config}") input_ids = tokenizer.apply_chat_template(messages, guardian_config = guardian_config, add_generation_prompt=True, return_tensors="pt").to(model.device) input_len = input_ids.shape[1] with torch.no_grad(): output_ids = model.generate( input_ids, do_sample=False, max_new_tokens=nlogprobs, return_dict_in_generate=True, output_scores=True, ) label, confidence_level, prob_of_risk, certainty = parse_output(output_ids, input_len) else: raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [MOCK, TORCH]") logger.debug(f"label={label}, confidence_level={confidence_level}, prob_of_risk={prob_of_risk}, certainty={certainty}") end = time() total = end - start logger.debug(f"The evaluation took {total} secs") return {"label": label, "certainty": certainty}