minemaster01 commited on
Commit
e01a601
·
verified ·
1 Parent(s): 43e8b90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -79
app.py CHANGED
@@ -8,27 +8,28 @@ import torch.nn as nn
8
  from transformers import AutoTokenizer, AutoModel, AutoConfig
9
  from huggingface_hub import HfApi, create_repo, hf_hub_download
10
  from torchcrf import CRF
11
-
12
  # Constants
13
  HF_DATASET_REPO = "M2ai/mgtd-logs"
14
  HF_TOKEN = os.getenv("Mgtd")
15
  DATASET_CREATED = False
16
-
17
  # Model identifiers
18
  code = "ENG"
19
  pntr = 2
20
  model_name_or_path = "microsoft/mdeberta-v3-base"
21
- hf_token = os.environ.get("Mgtd") # Set this before running
22
-
23
  # Download model checkpoint
24
- file_path = hf_hub_download(
25
- repo_id="1024m/MGTD-Long-New",
26
- filename=f"{code}/mdeberta-epoch-{pntr}.pt",
27
- token=hf_token,
28
- local_dir="./checkpoints"
29
- )
 
 
 
 
 
30
 
31
- # Define CRF model
32
  class AutoModelCRF(nn.Module):
33
  def __init__(self, model_name_or_path, dropout=0.075):
34
  super().__init__()
@@ -38,7 +39,6 @@ class AutoModelCRF(nn.Module):
38
  self.dropout = nn.Dropout(dropout)
39
  self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
40
  self.crf = CRF(self.num_labels, batch_first=True)
41
-
42
  def forward(self, input_ids, attention_mask):
43
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
44
  seq_output = self.dropout(outputs[0])
@@ -46,7 +46,6 @@ class AutoModelCRF(nn.Module):
46
  tags = self.crf.decode(emissions, attention_mask.byte())
47
  return tags, emissions
48
 
49
- # Load model
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
52
  model = AutoModelCRF(model_name_or_path)
@@ -55,8 +54,6 @@ model.load_state_dict(checkpoint.get("model_state_dict", checkpoint), strict=Fal
55
  model = model.to(device)
56
  model.eval()
57
 
58
- # Inference function
59
-
60
  def get_word_probabilities(text):
61
  text = " ".join(text.split(" ")[:2048])
62
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -65,12 +62,10 @@ def get_word_probabilities(text):
65
  with torch.no_grad():
66
  tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
67
  probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
68
-
69
  word_probs = []
70
  word_colors = []
71
  current_word = ""
72
  current_probs = []
73
-
74
  for token, prob in zip(tokens, probs):
75
  if token in ["<s>", "</s>"]:
76
  continue
@@ -78,109 +73,54 @@ def get_word_probabilities(text):
78
  if current_word and current_probs:
79
  current_prob = sum(current_probs) / len(current_probs)
80
  word_probs.append(current_prob)
81
-
82
- # Determine color based on probability
83
- color = (
84
- "green" if current_prob < 0.25 else
85
- "yellow" if current_prob < 0.5 else
86
- "orange" if current_prob < 0.75 else
87
- "red"
88
- )
89
  word_colors.append(color)
90
-
91
  current_word = token[1:] if token != "▁" else ""
92
  current_probs = [prob]
93
  else:
94
  current_word += token
95
  current_probs.append(prob)
96
-
97
  if current_word and current_probs:
98
  current_prob = sum(current_probs) / len(current_probs)
99
  word_probs.append(current_prob)
100
-
101
- # Determine color for the last word
102
- color = (
103
- "green" if current_prob < 0.25 else
104
- "yellow" if current_prob < 0.5 else
105
- "orange" if current_prob < 0.75 else
106
- "red"
107
- )
108
  word_colors.append(color)
109
-
110
  word_probs = [float(p) for p in word_probs]
111
  return word_probs, word_colors
112
-
113
- # HF logging setup
114
- def setup_hf_dataset():
115
- global DATASET_CREATED
116
- if not DATASET_CREATED and HF_TOKEN:
117
- try:
118
- create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
119
- DATASET_CREATED = True
120
- print(f"Dataset {HF_DATASET_REPO} is ready.")
121
- except Exception as e:
122
- print(f"Error setting up dataset: {e}")
123
-
124
- # Main inference + logging function
125
  def infer_and_log(text_input):
126
  word_probs, word_colors = get_word_probabilities(text_input)
127
  timestamp = datetime.datetime.now().isoformat()
128
  submission_id = str(uuid.uuid4())
129
-
130
- log_data = {
131
- "id": submission_id,
132
- "timestamp": timestamp,
133
- "input": text_input,
134
- "output_probs": word_probs
135
- }
136
-
137
  os.makedirs("logs", exist_ok=True)
138
  log_file = f"logs/{timestamp.replace(':', '_')}.json"
139
  with open(log_file, "w") as f:
140
  json.dump(log_data, f, indent=2)
141
-
142
  if HF_TOKEN and DATASET_CREATED:
143
  try:
144
- HfApi().upload_file(
145
- path_or_fileobj=log_file,
146
- path_in_repo=f"logs/{os.path.basename(log_file)}",
147
- repo_id=HF_DATASET_REPO,
148
- repo_type="dataset",
149
- token=HF_TOKEN
150
- )
151
  print(f"Uploaded log {submission_id}")
152
  except Exception as e:
153
  print(f"Error uploading log: {e}")
154
-
155
  tokens = text_input.split()
156
- formatted_output = " ".join(
157
- f'<span style="color:{color}">{token}</span>' for token, color in zip(tokens, word_colors)
158
- )
159
-
160
  return formatted_output
161
 
162
-
163
  def clear_fields():
164
  return "", ""
165
-
166
- # Prepare dataset once
167
  setup_hf_dataset()
168
 
169
- # Gradio UI
170
  with gr.Blocks() as app:
171
  gr.Markdown("Machine Generated Text Detector")
172
-
173
  with gr.Row():
174
  input_box = gr.Textbox(label="Input Text", lines=10)
175
  output_box = gr.HTML(label="Output Text")
176
-
177
  with gr.Row():
178
  submit_btn = gr.Button("Submit")
179
  clear_btn = gr.Button("Clear")
180
-
181
-
182
  submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
183
  clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
184
-
185
  if __name__ == "__main__":
186
  app.launch()
 
8
  from transformers import AutoTokenizer, AutoModel, AutoConfig
9
  from huggingface_hub import HfApi, create_repo, hf_hub_download
10
  from torchcrf import CRF
 
11
  # Constants
12
  HF_DATASET_REPO = "M2ai/mgtd-logs"
13
  HF_TOKEN = os.getenv("Mgtd")
14
  DATASET_CREATED = False
 
15
  # Model identifiers
16
  code = "ENG"
17
  pntr = 2
18
  model_name_or_path = "microsoft/mdeberta-v3-base"
19
+ hf_token = os.environ.get("Mgtd")
 
20
  # Download model checkpoint
21
+ file_path = hf_hub_download(repo_id="1024m/MGTD-Long-New",filename=f"{code}/mdeberta-epoch-{pntr}.pt",token=hf_token,local_dir="./checkpoints")
22
+
23
+ def setup_hf_dataset():
24
+ global DATASET_CREATED
25
+ if not DATASET_CREATED and HF_TOKEN:
26
+ try:
27
+ create_repo(HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, exist_ok=True)
28
+ DATASET_CREATED = True
29
+ print(f"Dataset {HF_DATASET_REPO} is ready.")
30
+ except Exception as e:
31
+ print(f"Error setting up dataset: {e}")
32
 
 
33
  class AutoModelCRF(nn.Module):
34
  def __init__(self, model_name_or_path, dropout=0.075):
35
  super().__init__()
 
39
  self.dropout = nn.Dropout(dropout)
40
  self.linear = nn.Linear(self.config.hidden_size, self.num_labels)
41
  self.crf = CRF(self.num_labels, batch_first=True)
 
42
  def forward(self, input_ids, attention_mask):
43
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
44
  seq_output = self.dropout(outputs[0])
 
46
  tags = self.crf.decode(emissions, attention_mask.byte())
47
  return tags, emissions
48
 
 
49
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
51
  model = AutoModelCRF(model_name_or_path)
 
54
  model = model.to(device)
55
  model.eval()
56
 
 
 
57
  def get_word_probabilities(text):
58
  text = " ".join(text.split(" ")[:2048])
59
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
62
  with torch.no_grad():
63
  tags, emission = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
64
  probs = torch.softmax(emission, dim=-1)[0, :, 1].cpu().numpy()
 
65
  word_probs = []
66
  word_colors = []
67
  current_word = ""
68
  current_probs = []
 
69
  for token, prob in zip(tokens, probs):
70
  if token in ["<s>", "</s>"]:
71
  continue
 
73
  if current_word and current_probs:
74
  current_prob = sum(current_probs) / len(current_probs)
75
  word_probs.append(current_prob)
76
+ color = ("green" if current_prob < 0.25 else "yellow" if current_prob < 0.5 else "orange" if current_prob < 0.75 else "red")
 
 
 
 
 
 
 
77
  word_colors.append(color)
 
78
  current_word = token[1:] if token != "▁" else ""
79
  current_probs = [prob]
80
  else:
81
  current_word += token
82
  current_probs.append(prob)
 
83
  if current_word and current_probs:
84
  current_prob = sum(current_probs) / len(current_probs)
85
  word_probs.append(current_prob)
86
+ color = ("green" if current_prob < 0.25 else "yellow" if current_prob < 0.5 else "orange" if current_prob < 0.75 else "red")
 
 
 
 
 
 
 
87
  word_colors.append(color)
 
88
  word_probs = [float(p) for p in word_probs]
89
  return word_probs, word_colors
90
+
 
 
 
 
 
 
 
 
 
 
 
 
91
  def infer_and_log(text_input):
92
  word_probs, word_colors = get_word_probabilities(text_input)
93
  timestamp = datetime.datetime.now().isoformat()
94
  submission_id = str(uuid.uuid4())
95
+ log_data = {"id": submission_id,"timestamp": timestamp,"input": text_input,"output_probs": word_probs}
 
 
 
 
 
 
 
96
  os.makedirs("logs", exist_ok=True)
97
  log_file = f"logs/{timestamp.replace(':', '_')}.json"
98
  with open(log_file, "w") as f:
99
  json.dump(log_data, f, indent=2)
 
100
  if HF_TOKEN and DATASET_CREATED:
101
  try:
102
+ HfApi().upload_file(path_or_fileobj=log_file,path_in_repo=f"logs/{os.path.basename(log_file)}",repo_id=HF_DATASET_REPO,repo_type="dataset",token=HF_TOKEN)
 
 
 
 
 
 
103
  print(f"Uploaded log {submission_id}")
104
  except Exception as e:
105
  print(f"Error uploading log: {e}")
 
106
  tokens = text_input.split()
107
+ formatted_output = " ".join(f'<span style="color:{color}">{token}</span>' for token, color in zip(tokens, word_colors))
 
 
 
108
  return formatted_output
109
 
 
110
  def clear_fields():
111
  return "", ""
 
 
112
  setup_hf_dataset()
113
 
 
114
  with gr.Blocks() as app:
115
  gr.Markdown("Machine Generated Text Detector")
 
116
  with gr.Row():
117
  input_box = gr.Textbox(label="Input Text", lines=10)
118
  output_box = gr.HTML(label="Output Text")
 
119
  with gr.Row():
120
  submit_btn = gr.Button("Submit")
121
  clear_btn = gr.Button("Clear")
 
 
122
  submit_btn.click(fn=infer_and_log, inputs=input_box, outputs=output_box)
123
  clear_btn.click(fn=clear_fields, outputs=[input_box, output_box])
124
+
125
  if __name__ == "__main__":
126
  app.launch()