from flask import Flask, render_template, request, jsonify from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F from scipy.stats import percentileofscore app = Flask(__name__) DEFAULT_MODEL = "gpt2" model_cache = {} tokenizer_cache = {} def get_model_and_tokenizer(model_name): if model_name not in model_cache: trust_code = model_name == "microsoft/phi-1_5" model_cache[model_name] = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=trust_code ) tokenizer_cache[model_name] = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_code ) return model_cache[model_name], tokenizer_cache[model_name] @app.route("/") def index(): return render_template( "index.html", models=[ DEFAULT_MODEL, # "gpt2-medium", # "gpt2-large", # "gpt2-xl", # "EleutherAI/pythia-1.4b", # "facebook/opt-1.3b", # "bigscience/bloom-1b7", # "microsoft/phi-1_5", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ], ) @app.route("/analyze", methods=["POST"]) def analyze(): data = request.get_json() text = data["text"] model_name = data["model"] model, tokenizer = get_model_and_tokenizer(model_name) model.eval() with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits input_ids = inputs["input_ids"][0] tokens = tokenizer.convert_ids_to_tokens(input_ids) log_probs = [] all_log_probs_list = [] top_k_predictions = [] for i in range(len(input_ids) - 1): probs_at_position = F.log_softmax(logits[0, i, :], dim=-1) all_log_probs_list.extend(probs_at_position.tolist()) top_k_values, top_k_indices = torch.topk(probs_at_position, 5) top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices) top_k_predictions.append( [ {"token": t, "log_prob": v.item()} for t, v in zip(top_k_tokens, top_k_values) ] ) log_prob = probs_at_position[input_ids[i + 1]].item() log_probs.append(log_prob) percentiles = [percentileofscore(all_log_probs_list, lp) for lp in log_probs] joint_log_likelihood = sum(log_probs) average_log_likelihood = ( joint_log_likelihood / len(log_probs) if log_probs else 0 ) return jsonify({ "tokens": tokens, "percentiles": percentiles, "log_probs": log_probs, "top_k_predictions": top_k_predictions, "joint_log_likelihood": joint_log_likelihood, "average_log_likelihood": average_log_likelihood, }) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)