| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from elm.model import * | 
					
					
						
						| 
							 | 
						from elm.utils import batchify | 
					
					
						
						| 
							 | 
						from transformers import AutoTokenizer | 
					
					
						
						| 
							 | 
						import json | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_elm_model_and_tokenizer(local_path,  | 
					
					
						
						| 
							 | 
						                                 model_config_dict, | 
					
					
						
						| 
							 | 
						                                 device="cuda", | 
					
					
						
						| 
							 | 
						                                 load_partial=True, | 
					
					
						
						| 
							 | 
						                                 get_num_layers_from_ckpt=True): | 
					
					
						
						| 
							 | 
						    """Load ELM model and tokenizer from local checkpoint.""" | 
					
					
						
						| 
							 | 
						    model_args = ModelArgs(**model_config_dict) | 
					
					
						
						| 
							 | 
						    model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    tokenizer = AutoTokenizer.from_pretrained(local_path) | 
					
					
						
						| 
							 | 
						    tokenizer.padding_side = "left" | 
					
					
						
						| 
							 | 
						    tokenizer.truncation_side = "left" | 
					
					
						
						| 
							 | 
						    return model, tokenizer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def generate_elm_response_given_model(prompts, model, tokenizer,  | 
					
					
						
						| 
							 | 
						                          device="cuda", | 
					
					
						
						| 
							 | 
						                          max_ctx_word_len=1024, | 
					
					
						
						| 
							 | 
						                          max_ctx_token_len=0, | 
					
					
						
						| 
							 | 
						                          max_new_tokens=500, | 
					
					
						
						| 
							 | 
						                          temperature=0.8,  | 
					
					
						
						| 
							 | 
						                          top_k=200, | 
					
					
						
						| 
							 | 
						                          return_tok_cnt=False, | 
					
					
						
						| 
							 | 
						                          return_gen_only=False, | 
					
					
						
						| 
							 | 
						                          early_stop_on_eos=False): | 
					
					
						
						| 
							 | 
						    """Generate responses from ELM model given an input list of prompts ([str]).""" | 
					
					
						
						| 
							 | 
						    if max_ctx_token_len > 0: | 
					
					
						
						| 
							 | 
						        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts] | 
					
					
						
						| 
							 | 
						        inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    results = [] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    input_tok_cnt = torch.numel(inputs.input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    out_tok_cnt = 0 | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						        temperature = temperature | 
					
					
						
						| 
							 | 
						        top_k = top_k | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k, | 
					
					
						
						| 
							 | 
						                                 return_gen_only=return_gen_only) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if return_tok_cnt: | 
					
					
						
						| 
							 | 
						            out_tok_cnt += torch.numel(outputs) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if early_stop_on_eos: | 
					
					
						
						| 
							 | 
						            mod_outputs = [] | 
					
					
						
						| 
							 | 
						            for i in range(len(outputs)): | 
					
					
						
						| 
							 | 
						                curr_out = outputs[i] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                eos_loc_id = -1 | 
					
					
						
						| 
							 | 
						                for j in range(len(outputs[i])): | 
					
					
						
						| 
							 | 
						                    tok_id = outputs[i][j] | 
					
					
						
						| 
							 | 
						                    if tok_id == tokenizer.eos_token_id: | 
					
					
						
						| 
							 | 
						                        eos_loc_id = j | 
					
					
						
						| 
							 | 
						                        break | 
					
					
						
						| 
							 | 
						                if eos_loc_id >= 0: | 
					
					
						
						| 
							 | 
						                    curr_out = outputs[i][:eos_loc_id] | 
					
					
						
						| 
							 | 
						                mod_outputs.append(curr_out) | 
					
					
						
						| 
							 | 
						            outputs = mod_outputs | 
					
					
						
						| 
							 | 
						        detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        results = detokenized_output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if return_tok_cnt: | 
					
					
						
						| 
							 | 
						        return results, (input_tok_cnt, out_tok_cnt) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return results | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def generate_elm_responses(elm_model_path,  | 
					
					
						
						| 
							 | 
						                           prompts,  | 
					
					
						
						| 
							 | 
						                           device=None,  | 
					
					
						
						| 
							 | 
						                           elm_model_config={}, | 
					
					
						
						| 
							 | 
						                           eval_batch_size=1, | 
					
					
						
						| 
							 | 
						                           verbose=True): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if not device: | 
					
					
						
						| 
							 | 
						        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 
					
					
						
						| 
							 | 
						    print(f"Setting device to {device}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model_config_dict = { | 
					
					
						
						| 
							 | 
						        "hidden_size": elm_model_config.get("hidden_size", 2048), | 
					
					
						
						| 
							 | 
						        "max_inp_len": elm_model_config.get("max_inp_len", 2048), | 
					
					
						
						| 
							 | 
						        "num_attention_heads": elm_model_config.get("num_attention_heads", 32), | 
					
					
						
						| 
							 | 
						        "num_layers": elm_model_config.get("num_layers", 48), | 
					
					
						
						| 
							 | 
						        "bits": elm_model_config.get("bits", 256), | 
					
					
						
						| 
							 | 
						        "vocab_size": elm_model_config.get("vocab_size", 50304), | 
					
					
						
						| 
							 | 
						        "dropout": elm_model_config.get("dropout", 0.1), | 
					
					
						
						| 
							 | 
						        "use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True) | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    max_new_tokens = 128 | 
					
					
						
						| 
							 | 
						    if "classification" in elm_model_path or "detection" in elm_model_path: | 
					
					
						
						| 
							 | 
						        max_new_tokens = 12 | 
					
					
						
						| 
							 | 
						    result = [] | 
					
					
						
						| 
							 | 
						    for prompt_batch in batchify(prompts, eval_batch_size): | 
					
					
						
						| 
							 | 
						        responses, _ = generate_elm_response_given_model(prompt_batch, | 
					
					
						
						| 
							 | 
						                                                            model,  | 
					
					
						
						| 
							 | 
						                                                            tokenizer,  | 
					
					
						
						| 
							 | 
						                                                            device=device, | 
					
					
						
						| 
							 | 
						                                                            max_ctx_word_len=1024, | 
					
					
						
						| 
							 | 
						                                                            max_ctx_token_len=512, | 
					
					
						
						| 
							 | 
						                                                            max_new_tokens=max_new_tokens, | 
					
					
						
						| 
							 | 
						                                                            return_tok_cnt=True,  | 
					
					
						
						| 
							 | 
						                                                            return_gen_only=False,  | 
					
					
						
						| 
							 | 
						                                                            temperature=0.0,  | 
					
					
						
						| 
							 | 
						                                                            early_stop_on_eos=True) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						        for prompt, response in zip(prompt_batch, responses): | 
					
					
						
						| 
							 | 
						            response = response.split("[/INST]")[-1].strip() | 
					
					
						
						| 
							 | 
						            result.append(response) | 
					
					
						
						| 
							 | 
						            if verbose: | 
					
					
						
						| 
							 | 
						                print(json.dumps({"prompt": prompt, "response": response}, indent=4)) | 
					
					
						
						| 
							 | 
						                print("\n***\n") | 
					
					
						
						| 
							 | 
						    return result | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 |