Spaces:
Running
Running
from flask import Flask, render_template, request, jsonify | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
from scipy.stats import percentileofscore | |
app = Flask(__name__) | |
DEFAULT_MODEL = "gpt2" | |
model_cache = {} | |
tokenizer_cache = {} | |
def get_model_and_tokenizer(model_name): | |
if model_name not in model_cache: | |
trust_code = model_name == "microsoft/phi-1_5" | |
model_cache[model_name] = AutoModelForCausalLM.from_pretrained( | |
model_name, trust_remote_code=trust_code | |
) | |
tokenizer_cache[model_name] = AutoTokenizer.from_pretrained( | |
model_name, trust_remote_code=trust_code | |
) | |
return model_cache[model_name], tokenizer_cache[model_name] | |
def index(): | |
return render_template( | |
"index.html", | |
models=[ | |
DEFAULT_MODEL, | |
# "gpt2-medium", | |
# "gpt2-large", | |
# "gpt2-xl", | |
# "EleutherAI/pythia-1.4b", | |
# "facebook/opt-1.3b", | |
# "bigscience/bloom-1b7", | |
# "microsoft/phi-1_5", | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
], | |
) | |
def analyze(): | |
data = request.get_json() | |
text = data["text"] | |
model_name = data["model"] | |
model, tokenizer = get_model_and_tokenizer(model_name) | |
model.eval() | |
with torch.no_grad(): | |
inputs = tokenizer(text, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
input_ids = inputs["input_ids"][0] | |
tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
log_probs = [] | |
all_log_probs_list = [] | |
top_k_predictions = [] | |
for i in range(len(input_ids) - 1): | |
probs_at_position = F.log_softmax(logits[0, i, :], dim=-1) | |
all_log_probs_list.extend(probs_at_position.tolist()) | |
top_k_values, top_k_indices = torch.topk(probs_at_position, 5) | |
top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices) | |
top_k_predictions.append( | |
[ | |
{"token": t, "log_prob": v.item()} | |
for t, v in zip(top_k_tokens, top_k_values) | |
] | |
) | |
log_prob = probs_at_position[input_ids[i + 1]].item() | |
log_probs.append(log_prob) | |
percentiles = [percentileofscore(all_log_probs_list, lp) for lp in log_probs] | |
joint_log_likelihood = sum(log_probs) | |
average_log_likelihood = ( | |
joint_log_likelihood / len(log_probs) if log_probs else 0 | |
) | |
return jsonify({ | |
"tokens": tokens, | |
"percentiles": percentiles, | |
"log_probs": log_probs, | |
"top_k_predictions": top_k_predictions, | |
"joint_log_likelihood": joint_log_likelihood, | |
"average_log_likelihood": average_log_likelihood, | |
}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |