SATEv1.5 / segmentation /segment.py
Shuwei Hou
initial_for_hf
5806e12
import json
import os
import re
from typing import List, Dict, Any, Callable, Tuple
def map_special_tokens_to_word_positions(text: str, word_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
special_token_map: List[Dict[str, Any]] = []
for m in re.finditer(r'<[^>]*?>', text):
special_token_map.append({
"token": m.group(),
"char_start": m.start(), # index in original text
})
if not special_token_map:
return []
visible_offset_map = {}
visible_idx = 0
i = 0
while i < len(text):
if text[i] == '<':
j = text.find('>', i) + 1
i = j
continue
visible_offset_map[i] = visible_idx
visible_idx += 1
i += 1
clean_text = re.sub(r'<[^>]*?>', '', text)
# locate each word in clean_text
word_positions = []
cur = 0
for w in word_list:
pos = clean_text.find(w["word"], cur)
if pos != -1:
word_positions.append({
"word": w["word"],
"start": pos,
"end": pos + len(w["word"])
})
cur = pos + len(w["word"])
# map each token
for sp in special_token_map:
# how many visible chars are before this token?
raw_idx = sp["char_start"]
visible_before = 0
# find largest key <= raw_idx in visible_offset_map
keys = [k for k in visible_offset_map.keys() if k < raw_idx]
if keys:
visible_before = visible_offset_map[max(keys)] + 1 # +1 because map stores idx of char at k
insert_after = -1
for i, wp in enumerate(word_positions):
if visible_before >= wp["end"]:
insert_after = i
else:
break
sp["insert_after_word_idx"] = insert_after
return special_token_map
def reorganize_transcription_c_unit(
session_id: str,
segment_func: Callable[[str], List[int]],
base_dir: str = "session_data",
device: str = "cuda"
) -> Tuple[int, int]:
"""Segment utterances into C-units with rules:
1. Boundaries inside <REPSTART>…<REPEND> or <REVSTART>…<REVEND> are ignored.
2. Trailing <PAUSE> <REVSTART> <REPSTART> moves to next C-unit prefix.
Returns (total_cunit_count, ignored_boundary_count).
"""
session_dir = os.path.join(base_dir, session_id)
input_file = os.path.join(session_dir, "transcription.json")
output_file = os.path.join(session_dir, "transcription_cunit.json")
if not os.path.exists(input_file):
raise FileNotFoundError(input_file)
with open(input_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Handle both old and new format
if "segments" in data:
transcription_data = data["segments"]
else:
transcription_data = data
cunit_data: List[Dict[str, Any]] = []
ignored_boundary_count = 0
for utt in transcription_data:
original_text = utt["text"]
words_meta = utt.get("words", [])
clean_text = re.sub(r'<[^>]*?>', '', original_text).strip()
if not clean_text:
continue
# build word list
if words_meta:
word_data = [w for w in words_meta if w["word"] not in {"?", ",", ".", "!"}]
word_texts = [w["word"] for w in word_data]
else:
word_texts = re.sub(r'[\?\.,!]', '', clean_text).split()
word_data = [{"word": w, "start": utt["start"], "end": utt["end"]} for w in word_texts]
if not word_texts:
continue
# token positions & special ranges
special_token_map = map_special_tokens_to_word_positions(original_text, word_data)
rep_ranges, rev_ranges = _build_special_ranges(special_token_map)
def inside_special(idx: int) -> bool:
return any(s <= idx <= e for s, e in rep_ranges) or any(s <= idx <= e for s, e in rev_ranges)
# segmentation labels
labels = segment_func(' '.join(word_texts))
if len(labels) != len(word_texts):
raise ValueError(
f"Segmentation length mismatch: {len(word_texts)} words vs {len(labels)} labels"
)
current_words: List[str] = []
current_meta: List[Dict[str, Any]] = []
cunit_start_idx = 0 # global word idx of first word in current c‑unit
cunit_start_time = word_data[0]["start"]
carry_over_tokens: List[str] = []
for i, (word, label) in enumerate(zip(word_texts, labels)):
current_words.append(word)
current_meta.append(word_data[i])
is_last_word = i == len(word_texts) - 1
boundary_from_model = label == 1 and not inside_special(i)
if label == 1 and inside_special(i):
ignored_boundary_count += 1
make_boundary = boundary_from_model or is_last_word
if not make_boundary:
continue
# -------- assemble C‑unit --------
text_parts: List[str] = []
# 2a. prefix: carried‑over <PAUSE>
if carry_over_tokens:
text_parts.extend(carry_over_tokens)
carry_over_tokens = []
for j, w in enumerate(current_words):
global_word_idx = cunit_start_idx + j
# sentence‑initial tokens & ‑1 insertion
if global_word_idx == 0:
text_parts.extend(
[sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == -1]
)
text_parts.append(w)
# tokens that follow this word
text_parts.extend(
[sp["token"] for sp in special_token_map if sp["insert_after_word_idx"] == global_word_idx]
)
# 2b. move trailing <PAUSE> to next c‑unit
while text_parts and text_parts[-1].upper() == '<PAUSE>':
carry_over_tokens.insert(0, text_parts.pop())
# 2c. move trailing <REPSTART> or <REVSTART> to next c‑unit
while text_parts and text_parts[-1].upper() in {'<REPSTART>', '<REVSTART>'}:
carry_over_tokens.insert(0, text_parts.pop())
# Create text_token (with special tokens) and text (only words)
text_token = ' '.join(text_parts)
text_words_only = ' '.join(current_words)
cunit_data.append({
"start": cunit_start_time,
"end": current_meta[-1]["end"],
"speaker": "", # Initialize as empty
"text_token": text_token,
"text": text_words_only,
"words": [
{
"word": word["word"],
"start": word["start"],
"end": word["end"]
} for word in current_meta
]
})
# reset for next C‑unit
cunit_start_idx = i + 1
current_words, current_meta = [], []
if cunit_start_idx < len(word_data):
cunit_start_time = word_data[cunit_start_idx]["start"]
# Wrap in segments structure to match original format
output_data = {
"segments": cunit_data
}
with open(output_file, "w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"C-unit segmentation done → {output_file}")
return len(cunit_data), ignored_boundary_count
def _build_special_ranges(special_token_map: List[Dict[str, Any]]):
rep_ranges, rev_ranges = [], []
rep_start, rev_start = None, None
for sp in special_token_map:
tok = sp["token"].upper()
idx = sp["insert_after_word_idx"]
if tok == '<REPSTART>':
rep_start = idx + 1
elif tok == '<REPEND>' and rep_start is not None:
rep_ranges.append((rep_start, idx))
rep_start = None
elif tok == '<REVSTART>':
rev_start = idx + 1
elif tok == '<REVEND>' and rev_start is not None:
rev_ranges.append((rev_start, idx))
rev_start = None
return rep_ranges, rev_ranges