|
import json |
|
import os |
|
import re |
|
from typing import List, Dict, Any, Callable, Tuple |
|
|
|
|
|
def map_special_tokens_to_word_positions(text: str, word_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
special_token_map: List[Dict[str, Any]] = [] |
|
for m in re.finditer(r'<[^>]*?>', text): |
|
special_token_map.append({ |
|
"token": m.group(), |
|
"char_start": m.start(), |
|
}) |
|
|
|
if not special_token_map: |
|
return [] |
|
|
|
visible_offset_map = {} |
|
visible_idx = 0 |
|
i = 0 |
|
while i < len(text): |
|
if text[i] == '<': |
|
j = text.find('>', i) + 1 |
|
i = j |
|
continue |
|
visible_offset_map[i] = visible_idx |
|
visible_idx += 1 |
|
i += 1 |
|
|
|
clean_text = re.sub(r'<[^>]*?>', '', text) |
|
|
|
|
|
word_positions = [] |
|
cur = 0 |
|
for w in word_list: |
|
pos = clean_text.find(w["word"], cur) |
|
if pos != -1: |
|
word_positions.append({ |
|
"word": w["word"], |
|
"start": pos, |
|
"end": pos + len(w["word"]) |
|
}) |
|
cur = pos + len(w["word"]) |
|
|
|
|
|
for sp in special_token_map: |
|
|
|
raw_idx = sp["char_start"] |
|
visible_before = 0 |
|
|
|
keys = [k for k in visible_offset_map.keys() if k < raw_idx] |
|
if keys: |
|
visible_before = visible_offset_map[max(keys)] + 1 |
|
|
|
insert_after = -1 |
|
for i, wp in enumerate(word_positions): |
|
if visible_before >= wp["end"]: |
|
insert_after = i |
|
else: |
|
break |
|
sp["insert_after_word_idx"] = insert_after |
|
|
|
return special_token_map |
|
|
|
|
|
def reorganize_transcription_c_unit( |
|
session_id: str, |
|
segment_func: Callable[[str], List[int]], |
|
base_dir: str = "session_data", |
|
device: str = "cuda" |
|
) -> Tuple[int, int]: |
|
"""Segment utterances into C-units with rules: |
|
1. Boundaries inside <REPSTART>…<REPEND> or <REVSTART>…<REVEND> are ignored. |
|
2. Trailing <PAUSE> <REVSTART> <REPSTART> moves to next C-unit prefix. |
|
|
|
Returns (total_cunit_count, ignored_boundary_count). |
|
""" |
|
|
|
session_dir = os.path.join(base_dir, session_id) |
|
input_file = os.path.join(session_dir, "transcription.json") |
|
output_file = os.path.join(session_dir, "transcription_cunit.json") |
|
|
|
if not os.path.exists(input_file): |
|
raise FileNotFoundError(input_file) |
|
|
|
with open(input_file, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
|
|
if "segments" in data: |
|
transcription_data = data["segments"] |
|
else: |
|
transcription_data = data |
|
|
|
cunit_data: List[Dict[str, Any]] = [] |
|
ignored_boundary_count = 0 |
|
|
|
|
|
for utt in transcription_data: |
|
original_text = utt["text"] |
|
words_meta = utt.get("words", []) |
|
|
|
clean_text = re.sub(r'<[^>]*?>', '', original_text).strip() |
|
if not clean_text: |
|
continue |
|
|
|
|
|
if words_meta: |
|
word_data = [w for w in words_meta if w["word"] not in {"?", ",", ".", "!"}] |
|
word_texts = [w["word"] for w in word_data] |
|
else: |
|
word_texts = re.sub(r'[\?\.,!]', '', clean_text).split() |
|
word_data = [{"word": w, "start": utt["start"], "end": utt["end"]} for w in word_texts] |
|
|
|
if not word_texts: |
|
continue |
|
|
|
|
|
special_token_map = map_special_tokens_to_word_positions(original_text, word_data) |
|
|
|
rep_ranges, rev_ranges = _build_special_ranges(special_token_map) |
|
def inside_special(idx: int) -> bool: |
|
return any(s <= idx <= e for s, e in rep_ranges) or any(s <= idx <= e for s, e in rev_ranges) |
|
|
|
|
|
labels = segment_func(' '.join(word_texts)) |
|
if len(labels) != len(word_texts): |
|
raise ValueError( |
|
f"Segmentation length mismatch: {len(word_texts)} words vs {len(labels)} labels" |
|
) |
|
|
|
current_words: List[str] = [] |
|
current_meta: List[Dict[str, Any]] = [] |
|
cunit_start_idx = 0 |
|
cunit_start_time = word_data[0]["start"] |
|
carry_over_tokens: List[str] = [] |
|
|
|
for i, (word, label) in enumerate(zip(word_texts, labels)): |
|
current_words.append(word) |
|
current_meta.append(word_data[i]) |
|
|
|
is_last_word = i == len(word_texts) - 1 |
|
boundary_from_model = label == 1 and not inside_special(i) |
|
if label == 1 and inside_special(i): |
|
ignored_boundary_count += 1 |
|
|
|
make_boundary = boundary_from_model or is_last_word |
|
if not make_boundary: |
|
continue |
|
|
|
|
|
text_parts: List[str] = [] |
|
|
|
|
|
if carry_over_tokens: |
|
text_parts.extend(carry_over_tokens) |
|
carry_over_tokens = [] |
|
|
|
for j, w in enumerate(current_words): |
|
global_word_idx = cunit_start_idx + j |
|
|
|
|
|
if global_word_idx == 0: |
|
text_parts.extend( |
|
[sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == -1] |
|
) |
|
|
|
text_parts.append(w) |
|
|
|
|
|
text_parts.extend( |
|
[sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == global_word_idx] |
|
) |
|
|
|
|
|
while text_parts and text_parts[-1].upper() == '<PAUSE>': |
|
carry_over_tokens.insert(0, text_parts.pop()) |
|
|
|
|
|
while text_parts and text_parts[-1].upper() in {'<REPSTART>', '<REVSTART>'}: |
|
carry_over_tokens.insert(0, text_parts.pop()) |
|
|
|
|
|
text_token = ' '.join(text_parts) |
|
text_words_only = ' '.join(current_words) |
|
|
|
cunit_data.append({ |
|
"start": cunit_start_time, |
|
"end": current_meta[-1]["end"], |
|
"speaker": "", |
|
"text_token": text_token, |
|
"text": text_words_only, |
|
"words": [ |
|
{ |
|
"word": word["word"], |
|
"start": word["start"], |
|
"end": word["end"] |
|
} for word in current_meta |
|
] |
|
}) |
|
|
|
|
|
cunit_start_idx = i + 1 |
|
current_words, current_meta = [], [] |
|
if cunit_start_idx < len(word_data): |
|
cunit_start_time = word_data[cunit_start_idx]["start"] |
|
|
|
|
|
|
|
output_data = { |
|
"segments": cunit_data |
|
} |
|
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump(output_data, f, indent=2, ensure_ascii=False) |
|
|
|
print(f"C-unit segmentation done → {output_file}") |
|
return len(cunit_data), ignored_boundary_count |
|
|
|
|
|
|
|
|
|
|
|
def _build_special_ranges(special_token_map: List[Dict[str, Any]]): |
|
rep_ranges, rev_ranges = [], [] |
|
rep_start, rev_start = None, None |
|
for sp in special_token_map: |
|
tok = sp["token"].upper() |
|
idx = sp["insert_after_word_idx"] |
|
if tok == '<REPSTART>': |
|
rep_start = idx + 1 |
|
elif tok == '<REPEND>' and rep_start is not None: |
|
rep_ranges.append((rep_start, idx)) |
|
rep_start = None |
|
elif tok == '<REVSTART>': |
|
rev_start = idx + 1 |
|
elif tok == '<REVEND>' and rev_start is not None: |
|
rev_ranges.append((rev_start, idx)) |
|
rev_start = None |
|
return rep_ranges, rev_ranges |
|
|