File size: 2,724 Bytes
8fb7841 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import Union
import pysrt
from mllm_tools.litellm import LiteLLMWrapper
from mllm_tools.gemini import GeminiWrapper
from mllm_tools.utils import _prepare_text_inputs
from eval_suite.prompts_raw import _fix_transcript, _text_eval_new
from eval_suite.utils import extract_json, convert_score_fields
def parse_srt_to_text(srt_path) -> str:
"""
Parse an SRT subtitle file into plain text.
Args:
srt_path: Path to the SRT subtitle file.
Returns:
str: The subtitle text with duplicates removed and ellipses replaced.
"""
subs = pysrt.open(srt_path)
full_text = []
for sub in subs:
sub.text = sub.text.replace("...", ".")
for line in sub.text.splitlines():
# .srt can contain repeated lines
if full_text and full_text[-1] == line:
continue
full_text.append(line)
return "\n".join(full_text)
def fix_transcript(text_eval_model: Union[LiteLLMWrapper, GeminiWrapper], transcript: str) -> str:
"""
Fix and clean up a transcript using an LLM model.
Args:
text_eval_model: The LLM model wrapper to use for fixing the transcript.
transcript: The input transcript text to fix.
Returns:
str: The fixed and cleaned transcript text.
"""
print("Fixing transcript...")
prompt = _fix_transcript.format(transcript=transcript)
response = text_eval_model(_prepare_text_inputs(prompt))
fixed_script = response.split("<SCRIPT>", maxsplit=1)[1].split("</SCRIPT>")[0]
return fixed_script
def evaluate_text(text_eval_model: LiteLLMWrapper, transcript: str, retry_limit: int) -> dict:
"""
Evaluate transcript text using an LLM model with retry logic.
Args:
text_eval_model: The LLM model wrapper to use for evaluation.
transcript: The transcript text to evaluate.
retry_limit: Maximum number of retry attempts on failure.
Returns:
dict: The evaluation results as a JSON object.
Raises:
ValueError: If all retry attempts fail.
"""
# prompt = _text_eval.format(transcript=transcript)
prompt = _text_eval_new.format(transcript=transcript)
for attempt in range(retry_limit):
try:
evaluation = text_eval_model(_prepare_text_inputs(prompt))
evaluation_json = extract_json(evaluation)
evaluation_json = convert_score_fields(evaluation_json)
return evaluation_json
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e.__class__.__name__}: {e}")
if attempt + 1 == retry_limit:
raise ValueError("Reached maximum retry limit. Evaluation failed.") from None
|