minemaster01 commited on
Commit
b0d6a30
·
verified ·
1 Parent(s): 132a2dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -34
app.py CHANGED
@@ -4,86 +4,139 @@ import json
4
  import uuid
5
  import torch
6
  import datetime
7
- import pandas as pd
8
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
- from huggingface_hub import HfApi, create_repo, upload_file
10
- from datasets import Dataset
11
 
12
- # Configuration
13
- MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
14
  HF_DATASET_REPO = "M2ai/mgtd-logs"
15
  HF_TOKEN = os.getenv("Mgtd")
16
  DATASET_CREATED = False
17
 
18
- # Load model and tokenizer
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
21
-
22
- # Make directories
23
- os.makedirs("logs", exist_ok=True)
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def setup_hf_dataset():
26
  global DATASET_CREATED
27
  if not DATASET_CREATED and HF_TOKEN:
28
  try:
29
- api = HfApi()
30
  create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
31
  DATASET_CREATED = True
32
- print(f"Dataset {HF_DATASET_REPO} is ready")
33
  except Exception as e:
34
  print(f"Error setting up dataset: {e}")
35
- elif not HF_TOKEN:
36
- print("Warning: HF_TOKEN not set. Logs will be saved locally only.")
37
 
 
38
  def infer_and_log(text_input):
39
- inputs = tokenizer(text_input, return_tensors="pt", truncation=True)
40
- with torch.no_grad():
41
- outputs = model(**inputs)
42
- logits = outputs.logits.tolist()
43
- predicted = torch.argmax(outputs.logits, dim=-1).item()
44
- label = model.config.id2label[predicted]
45
-
46
  timestamp = datetime.datetime.now().isoformat()
47
  submission_id = str(uuid.uuid4())
 
48
  log_data = {
49
  "id": submission_id,
50
  "timestamp": timestamp,
51
  "input": text_input,
52
- "logits": logits
53
  }
54
 
 
55
  log_file = f"logs/{timestamp.replace(':', '_')}.json"
56
  with open(log_file, "w") as f:
57
  json.dump(log_data, f, indent=2)
58
 
59
  if HF_TOKEN and DATASET_CREATED:
60
  try:
61
- api = HfApi()
62
- api.upload_file(
63
  path_or_fileobj=log_file,
64
  path_in_repo=f"logs/{os.path.basename(log_file)}",
65
  repo_id=HF_DATASET_REPO,
66
  repo_type="dataset",
67
  token=HF_TOKEN
68
  )
69
- print(f"Uploaded log {submission_id} to {HF_DATASET_REPO}")
70
  except Exception as e:
71
- print(f"Error uploading to HF dataset: {e}")
72
 
73
- return label
74
 
75
  def clear_fields():
76
  return "", ""
77
 
78
- # Setup the dataset on startup
79
  setup_hf_dataset()
80
 
 
81
  with gr.Blocks() as app:
82
- gr.Markdown("## AI Text Detector")
83
 
84
  with gr.Row():
85
- input_box = gr.Textbox(label="Input Text", lines=10, interactive=True)
86
- output_box = gr.Textbox(label="Output", lines=2, interactive=False)
87
 
88
  with gr.Row():
89
  submit_btn = gr.Button("Submit")
 
4
  import uuid
5
  import torch
6
  import datetime
7
+ import torch.nn as nn
8
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
9
+ from huggingface_hub import HfApi, create_repo, hf_hub_download
10
+ from torchcrf import CRF
11
 
12
+ # Constants
 
13
  HF_DATASET_REPO = "M2ai/mgtd-logs"
14
  HF_TOKEN = os.getenv("Mgtd")
15
  DATASET_CREATED = False
16
 
17
+ # Model identifiers
18
+ code = "ENG"
19
+ pntr = 2
20
+ model_name_or_path = "microsoft/mdeberta-v3-base"
21
+ hf_token = os.environ.get("HF_WRITE") # Set this before running
22
+
23
+ # Download model checkpoint
24
+ file_path = hf_hub_download(
25
+ repo_id="1024m/MGTD-Long-New",
26
+ filename=f"{code}/mdeberta-epoch-{pntr}.pt",
27
+ token=hf_token,
28
+ local_dir="./checkpoints"
29
+ )
30
+
31
+ # Define CRF model
32
+ class AutoModelCRF(nn.Module):
33
+ def __init__(self, model_name_or_path, dropout=0.075):
34
+ super().__init__()
35
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
36
+ self.num_labels = 2
37
+ self.encoder = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True, config=self.config)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
40
+ self.crf = CRF(self.num_labels, batch_first=True)
41
+
42
+ def forward(self, input_ids, attention_mask):
43
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
44
+ seq_output = self.dropout(outputs[0])
45
+ emissions = self.linear(seq_output)
46
+ tags = self.crf.decode(emissions, attention_mask.byte())
47
+ return tags, emissions
48
+
49
+ # Load model
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
52
+ model = AutoModelCRF(model_name_or_path)
53
+ checkpoint = torch.load(file_path, map_location="cpu")
54
+ model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=False)
55
+ model = model.to(device)
56
+ model.eval()
57
+
58
+ # Inference function
59
+ def get_word_classifications(text):
60
+ text = " ".join(text.split(" ")[:2048])
61
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
64
+ with torch.no_grad():
65
+ tags, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
66
+ word_tags = []
67
+ current_word = ""
68
+ current_tag = None
69
+ for token, tag in zip(tokens, tags[0]):
70
+ if token in ["<s>", "</s>"]:
71
+ continue
72
+ if token.startswith("▁"):
73
+ if current_word:
74
+ word_tags.append(str(current_tag))
75
+ current_word = token[1:] if token != "▁" else ""
76
+ current_tag = tag
77
+ else:
78
+ current_word += token
79
+ if current_word:
80
+ word_tags.append(str(current_tag))
81
+ return word_tags
82
+
83
+ # HF logging setup
84
  def setup_hf_dataset():
85
  global DATASET_CREATED
86
  if not DATASET_CREATED and HF_TOKEN:
87
  try:
 
88
  create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
89
  DATASET_CREATED = True
90
+ print(f"Dataset {HF_DATASET_REPO} is ready.")
91
  except Exception as e:
92
  print(f"Error setting up dataset: {e}")
 
 
93
 
94
+ # Main inference + logging function
95
  def infer_and_log(text_input):
96
+ word_tags = get_word_classifications(text_input)
 
 
 
 
 
 
97
  timestamp = datetime.datetime.now().isoformat()
98
  submission_id = str(uuid.uuid4())
99
+
100
  log_data = {
101
  "id": submission_id,
102
  "timestamp": timestamp,
103
  "input": text_input,
104
+ "output_tags": word_tags
105
  }
106
 
107
+ os.makedirs("logs", exist_ok=True)
108
  log_file = f"logs/{timestamp.replace(':', '_')}.json"
109
  with open(log_file, "w") as f:
110
  json.dump(log_data, f, indent=2)
111
 
112
  if HF_TOKEN and DATASET_CREATED:
113
  try:
114
+ HfApi().upload_file(
 
115
  path_or_fileobj=log_file,
116
  path_in_repo=f"logs/{os.path.basename(log_file)}",
117
  repo_id=HF_DATASET_REPO,
118
  repo_type="dataset",
119
  token=HF_TOKEN
120
  )
121
+ print(f"Uploaded log {submission_id}")
122
  except Exception as e:
123
+ print(f"Error uploading log: {e}")
124
 
125
+ return " ".join(word_tags)
126
 
127
  def clear_fields():
128
  return "", ""
129
 
130
+ # Prepare dataset once
131
  setup_hf_dataset()
132
 
133
+ # Gradio UI
134
  with gr.Blocks() as app:
135
+ gr.Markdown("## MDeBERTa+CRF Word Tagger")
136
 
137
  with gr.Row():
138
+ input_box = gr.Textbox(label="Input Text", lines=10)
139
+ output_box = gr.Textbox(label="Output Tags", lines=2, interactive=False)
140
 
141
  with gr.Row():
142
  submit_btn = gr.Button("Submit")