import gradio as gr import os import json import uuid import torch import datetime import torch.nn as nn from transformers import AutoTokenizer, AutoModel, AutoConfig from huggingface_hub import HfApi, create_repo, hf_hub_download from torchcrf import CRF # Constants HF_DATASET_REPO = "M2ai/mgtd-logs" HF_TOKEN = os.getenv("Mgtd") DATASET_CREATED = False # Model identifiers code = "ENG" pntr = 2 model_name_or_path = "microsoft/mdeberta-v3-base" hf_token = os.environ.get("Mgtd") # Set this before running # Download model checkpoint file_path = hf_hub_download( repo_id="1024m/MGTD-Long-New", filename=f"{code}/mdeberta-epoch-{pntr}.pt", token=hf_token, local_dir="./checkpoints" ) # Define CRF model class AutoModelCRF(nn.Module): def __init__(self, model_name_or_path, dropout=0.075): super().__init__() self.config = AutoConfig.from_pretrained(model_name_or_path) self.num_labels = 2 self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config) self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(self.config.hidden_size, self.num_labels) self.crf = CRF(self.num_labels, batch_first=True) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) seq_output = self.dropout(outputs[0]) emissions = self.linear(seq_output) tags = self.crf.decode(emissions, attention_mask.byte()) return tags, emissions # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelCRF(model_name_or_path) checkpoint = torch.load(file_path, map_location="cpu") model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False) model = model.to(device) model.eval() # Inference function def get_word_probabilities(text): try: text = " ".join(text.split(" ")[:2048]) except Exception as e: print("Error during text preprocessing:", e) return [] try: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} except Exception as e: print("Error during tokenization or moving inputs to device:", e) return [] try: tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) except Exception as e: print("Error during token conversion:", e) return [] try: with torch.no_grad(): tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) except Exception as e: print("Error during model inference:", e) return [] try: probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy() except Exception as e: print("Error during softmax or extracting class probabilities:", e) return [] word_probs = [] current_word = "" current_probs = [] try: for token, prob in zip(tokens, probs): if token in ["", ""]: continue if token.startswith("▁"): if current_word and current_probs: word_probs.append(sum(current_probs) / len(current_probs)) current_word = token[1:] if token != "▁" else "" current_probs = [prob] else: current_word += token current_probs.append(prob) if current_word and current_probs: word_probs.append(sum(current_probs) / len(current_probs)) except Exception as e: print("Error during word aggregation:", e) return [] word_probs = [float(p) for p in word_probs] return word_probs # def get_word_classifications(text): # text = " ".join(text.split(" ")[:2048]) # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) # inputs = {k: v.to(device) for k, v in inputs.items()} # tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # with torch.no_grad(): # tags, emissions = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) # word_tags = [] # color_output = [] # current_word = "" # current_prob = 0.0 # for token, prob in zip(tokens, tags[0]): # if token in ["", ""]: # continue # if token.startswith("▁"): # if current_word: # word_tags.append(round(current_prob, 3)) # color = ( # "green" if current_prob < 0.25 else # "yellow" if current_prob < 0.5 else # "orange" if current_prob < 0.75 else # "red" # ) # color_output.append(f'{current_word}') # current_word = token[1:] if token != "▁" else "" # current_prob = prob # else: # current_word += token # current_prob = max(current_prob, prob) # if current_word: # word_tags.append(round(current_prob, 3)) # color = ( # "green" if current_prob < 0.25 else # "yellow" if current_prob < 0.5 else # "orange" if current_prob < 0.75 else # "red" # ) # color_output.append(f'{current_word}') # output = " ".join(color_output) # return output, word_tags # HF logging setup def setup_hf_dataset(): global DATASET_CREATED if not DATASET_CREATED and HF_TOKEN: try: create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True) DATASET_CREATED = True print(f"Dataset {HF_DATASET_REPO} is ready.") except Exception as e: print(f"Error setting up dataset: {e}") # Main inference + logging function def infer_and_log(text_input): word_tags = get_word_probabilities(text_input) timestamp = datetime.datetime.now().isoformat() submission_id = str(uuid.uuid4()) log_data = { "id": submission_id, "timestamp": timestamp, "input": text_input, "output_tags": word_tags } os.makedirs("logs", exist_ok=True) log_file = f"logs/{timestamp.replace(':', '_')}.json" with open(log_file, "w") as f: json.dump(log_data, f, indent=2) if HF_TOKEN and DATASET_CREATED: try: HfApi().upload_file( path_or_fileobj=log_file, path_in_repo=f"logs/{os.path.basename(log_file)}", repo_id=HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN ) print(f"Uploaded log {submission_id}") except Exception as e: print(f"Error uploading log: {e}") return json.dumps(word_tags, indent=2) def clear_fields(): return "", "" # Prepare dataset once setup_hf_dataset() # Gradio UI with gr.Blocks() as app: gr.Markdown("Machine Generated Text Detector") with gr.Row(): input_box = gr.Textbox(label="Input Text", lines=10) output_box = gr.Textbox(label="Output Text", lines=10, interactive=False) with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box) clear_btn.click(fn=clear_fields, outputs=[input_box, output_box]) if __name__ == "__main__": app.launch()