Spaces:
Running
Running
| from dataclasses import dataclass, field | |
| from typing import Set | |
| from graphgen.models.evaluate.base_evaluator import BaseEvaluator | |
| from graphgen.models.text.text_pair import TextPair | |
| from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop | |
| nltk_helper = NLTKHelper() | |
| class MTLDEvaluator(BaseEvaluator): | |
| """ | |
| 衡量文本词汇多样性的指标 | |
| """ | |
| stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english"))) | |
| stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))) | |
| async def evaluate_single(self, pair: TextPair) -> float: | |
| loop = create_event_loop() | |
| return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) | |
| def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: | |
| """ | |
| 计算MTLD (向前和向后的平均值) | |
| min is 1.0 | |
| higher is better | |
| """ | |
| if not text or not text.strip(): | |
| return 0.0 | |
| lang = detect_main_language(text) | |
| tokens = nltk_helper.word_tokenize(text, lang) | |
| stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en | |
| filtered_tokens = [word for word in tokens if word not in stopwords] | |
| filtered_tokens = [word for word in filtered_tokens if word.isalnum()] | |
| if not filtered_tokens: | |
| return 0 | |
| # 计算向前的MTLD | |
| forward_factors = self._compute_factors(filtered_tokens, threshold) | |
| # 计算向后的MTLD | |
| backward_factors = self._compute_factors(filtered_tokens[::-1], threshold) | |
| # 取平均值 | |
| return (forward_factors + backward_factors) / 2 | |
| def _compute_factors(tokens: list, threshold: float) -> float: | |
| factors = 0 | |
| current_segment = [] | |
| unique_words = set() | |
| for token in tokens: | |
| current_segment.append(token) | |
| unique_words.add(token) | |
| ttr = len(unique_words) / len(current_segment) | |
| if ttr <= threshold: | |
| factors += 1 | |
| current_segment = [] | |
| unique_words = set() | |
| # 处理最后一个不完整片段 | |
| if current_segment: | |
| ttr = len(unique_words) / len(current_segment) | |
| if ttr <= threshold: | |
| factors += 1 | |
| else: | |
| factors += (1 - (ttr - threshold) / (1 - threshold)) | |
| return len(tokens) / factors if factors > 0 else len(tokens) | |