|
|
|
import json |
|
import os |
|
import re |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from nltk.tokenize import sent_tokenize |
|
|
|
import nltk |
|
nltk.download('punkt_tab') |
|
nltk.download('punkt') |
|
|
|
|
|
|
|
|
|
def segment_batchalign(text: str) -> list[int]: |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = "talkbank/CHATUtterance-en" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForTokenClassification.from_pretrained(model_path) |
|
model.to(DEVICE) |
|
model.eval() |
|
|
|
text = text.lower().replace(".", "").replace(",", "") |
|
words = text.split() |
|
|
|
|
|
tokd = tokenizer([words], return_tensors="pt", is_split_into_words=True).to(DEVICE) |
|
with torch.no_grad(): |
|
logits = model(**tokd).logits |
|
predictions = torch.argmax(logits, dim=2).squeeze(0).cpu().tolist() |
|
|
|
|
|
word_ids = tokd.word_ids(0) |
|
result_words = [] |
|
seen = set() |
|
|
|
for i, word_idx in enumerate(word_ids): |
|
if word_idx is None or word_idx in seen: |
|
continue |
|
seen.add(word_idx) |
|
|
|
pred = predictions[i] |
|
word = words[word_idx] |
|
|
|
if pred == 1: |
|
word = word[0].upper() + word[1:] |
|
elif pred == 2: |
|
word += "." |
|
elif pred == 3: |
|
word += "?" |
|
elif pred == 4: |
|
word += "!" |
|
elif pred == 5: |
|
word += "," |
|
|
|
result_words.append(word) |
|
|
|
|
|
sentence = tokenizer.convert_tokens_to_string(result_words) |
|
try: |
|
sentences = sent_tokenize(sentence) |
|
except LookupError: |
|
import nltk |
|
nltk.download('punkt') |
|
sentences = sent_tokenize(sentence) |
|
|
|
|
|
boundaries = [] |
|
for sent in sentences: |
|
sent_word_count = len(sent.split()) |
|
boundaries += [0] * (sent_word_count - 1) + [1] |
|
|
|
for i in range(1, len(boundaries)): |
|
if boundaries[i - 1] == 1 and boundaries[i] == 1: |
|
boundaries[i - 1] = 0 |
|
|
|
return boundaries |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
test_text = "sir can I have balloon and the sir say yes you can and he said five dollars that xxx and and he is like where is that they his tether is right there and and he said and the bunny said oopsies I do not have money and the doc and the and the and the bunny runned for the doctor an and he says doctor doctor I want a balloon here is the money and you can have the balloons both of them now they are happy the end" |
|
print(f"Input text: {test_text}") |
|
print(f"Words: {test_text.split()}") |
|
|
|
labels = segment_batchalign(test_text) |
|
print(f"Segment labels: {labels}") |
|
|
|
|
|
words = test_text.split() |
|
segments = [] |
|
current_segment = [] |
|
|
|
for word, label in zip(words, labels): |
|
current_segment.append(word) |
|
if label == 1: |
|
segments.append(" ".join(current_segment)) |
|
current_segment = [] |
|
|
|
|
|
if current_segment: |
|
segments.append(" ".join(current_segment)) |
|
|
|
print("\nSegmented text:") |
|
for i, segment in enumerate(segments, 1): |
|
print(f"Segment {i}: {segment}") |