import gradio as gr
import os
import json
import uuid
import torch
import datetime
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import HfApi, create_repo, hf_hub_download
from torchcrf import CRF
# Constants
HF_DATASET_REPO = "M2ai/mgtd-logs"
HF_TOKEN = os.getenv("Mgtd")
DATASET_CREATED = False
# Model identifiers
code = "ENG"
pntr = 2
model_name_or_path = "microsoft/mdeberta-v3-base"
hf_token = os.environ.get("Mgtd") # Set this before running
# Download model checkpoint
file_path = hf_hub_download(
repo_id="1024m/MGTD-Long-New",
filename=f"{code}/mdeberta-epoch-{pntr}.pt",
token=hf_token,
local_dir="./checkpoints"
)
# Define CRF model
class AutoModelCRF(nn.Module):
def __init__(self, model_name_or_path, dropout=0.075):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.num_labels = 2
self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
self.crf = CRF(self.num_labels, batch_first=True)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
seq_output = self.dropout(outputs[0])
emissions = self.linear(seq_output)
tags = self.crf.decode(emissions, attention_mask.byte())
return tags, emissions
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelCRF(model_name_or_path)
checkpoint = torch.load(file_path, map_location="cpu")
model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False)
model = model.to(device)
model.eval()
# Inference function
def get_word_probabilities(text):
text = " ".join(text.split(" ")[:2048])
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
with torch.no_grad():
tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
word_probs = []
word_colors = []
current_word = ""
current_probs = []
for token, prob in zip(tokens, probs):
if token in ["", ""]:
continue
if token.startswith("▁"):
if current_word and current_probs:
current_prob = sum(current_probs) / len(current_probs)
word_probs.append(current_prob)
# Determine color based on probability
color = (
"green" if current_prob < 0.25 else
"yellow" if current_prob < 0.5 else
"orange" if current_prob < 0.75 else
"red"
)
word_colors.append(color)
current_word = token[1:] if token != "▁" else ""
current_probs = [prob]
else:
current_word += token
current_probs.append(prob)
if current_word and current_probs:
current_prob = sum(current_probs) / len(current_probs)
word_probs.append(current_prob)
# Determine color for the last word
color = (
"green" if current_prob < 0.25 else
"yellow" if current_prob < 0.5 else
"orange" if current_prob < 0.75 else
"red"
)
word_colors.append(color)
word_probs = [float(p) for p in word_probs]
return word_probs, word_colors
# HF logging setup
def setup_hf_dataset():
global DATASET_CREATED
if not DATASET_CREATED and HF_TOKEN:
try:
create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
DATASET_CREATED = True
print(f"Dataset {HF_DATASET_REPO} is ready.")
except Exception as e:
print(f"Error setting up dataset: {e}")
# Main inference + logging function
def infer_and_log(text_input):
word_probs, word_colors = get_word_probabilities(text_input)
timestamp = datetime.datetime.now().isoformat()
submission_id = str(uuid.uuid4())
log_data = {
"id": submission_id,
"timestamp": timestamp,
"input": text_input,
"output_probs": word_probs
}
os.makedirs("logs", exist_ok=True)
log_file = f"logs/{timestamp.replace(':', '_')}.json"
with open(log_file, "w") as f:
json.dump(log_data, f, indent=2)
if HF_TOKEN and DATASET_CREATED:
try:
HfApi().upload_file(
path_or_fileobj=log_file,
path_in_repo=f"logs/{os.path.basename(log_file)}",
repo_id=HF_DATASET_REPO,
repo_type="dataset",
token=HF_TOKEN
)
print(f"Uploaded log {submission_id}")
except Exception as e:
print(f"Error uploading log: {e}")
tokens = text_input.split()
formatted_output = " ".join(
f'{token}' for token, color in zip(tokens, word_colors)
)
return formatted_output
def clear_fields():
return "", ""
# Prepare dataset once
setup_hf_dataset()
# Gradio UI
with gr.Blocks() as app:
gr.Markdown("Machine Generated Text Detector")
with gr.Row():
input_box = gr.Textbox(label="Input Text", lines=10)
output_box = gr.HTML(label="Output Text")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
if __name__ == "__main__":
app.launch()