import re import time import torch import streamlit as st from transformers import T5ForConditionalGeneration, T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer from bert_score import score import tempfile # 模型加载(使用缓存加速) @st.cache_resource def load_models(): device = 'cuda' if torch.cuda.is_available() else 'cpu' # 加载微调模型 finetuned_model_path = "finetuned_model_v2/best_model" finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path) finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path).to(device) # 加载困惑度模型 perplexity_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) perplexity_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") return finetuned_model, finetuned_tokenizer, perplexity_model, perplexity_tokenizer # 初始化session_state if 'processed' not in st.session_state: st.session_state.processed = False if 'translated_code' not in st.session_state: st.session_state.translated_code = [] # 常量定义 CUSTOM_TERMS = { "写入 CSV": "Запись в CSV", "CSV 表头": "Заголовок таблицы CSV", } prefix = 'translate to ru: ' # 工具函数 def calculate_perplexity(text): tokens = st.session_state.perplexity_tokenizer.encode(text, return_tensors='pt').to('cpu') with torch.no_grad(): loss = st.session_state.perplexity_model(tokens, labels=tokens).loss return torch.exp(loss).item() def evaluate_translation(original, translated, scores): P, R, F1 = score([translated], [original], model_type="xlm-roberta-large", lang="ru") ppl = calculate_perplexity(translated) scores.append((F1.item(), ppl)) # 翻译核心函数 def translate_text(text, term_dict=None): preserved_paths = re.findall(r'[a-zA-Z]:\\[^ \u4e00-\u9fff]+', text) for i, path in enumerate(preserved_paths): text = text.replace(path, f"||PATH_{i}||") if term_dict: sorted_terms = sorted(term_dict.keys(), key=lambda x: len(x), reverse=True) pattern = re.compile('|'.join(map(re.escape, sorted_terms))) text = pattern.sub(lambda x: term_dict[x.group()], text) src_text = prefix + text input_ids = st.session_state.finetuned_tokenizer(src_text, return_tensors="pt", max_length=512, truncation=True) generated_tokens = st.session_state.finetuned_model.generate(**input_ids.to('cpu')) result = st.session_state.finetuned_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) translated = result[0] for i, path in enumerate(preserved_paths): translated = translated.replace(f"||PATH_{i}||", path) translated = re.sub(r'(\b[а-яА-ЯёЁ]+)(\.py\b)', lambda m: f"{m.group(1)} {m.group(2)}", translated) translated = re.sub(r'(?<=[а-яА-ЯёЁ])([.,!?])(?=\S)', r' \1', translated) return translated # 界面布局 st.set_page_config(layout="wide", page_icon="📝", page_title="Python C2R Code Comment Translator") # 标题部分 st.title("Python Chinese to Russian Code Comment Translator") st.subheader("Upload a Python file with Chinese comments", divider='rainbow') # 文件上传 uploaded_file = st.file_uploader("Choose .py file", type=['py'], label_visibility='collapsed') # 添加开始翻译按钮 (修改点1) start_translation = st.button("开始翻译 / Start Translation") # 初始化模型 if 'models_loaded' not in st.session_state: with st.spinner('Loading models...'): (finetuned_model, finetuned_tokenizer, perplexity_model, perplexity_tokenizer) = load_models() st.session_state.update({ 'finetuned_model': finetuned_model, 'finetuned_tokenizer': finetuned_tokenizer, 'perplexity_model': perplexity_model, 'perplexity_tokenizer': perplexity_tokenizer, 'models_loaded': True }) # 处理上传文件 (修改点2:添加按钮触发逻辑) if uploaded_file and start_translation: st.session_state.processed = False # 重置处理状态 st.session_state.translated_code = [] # 清空已翻译内容 with st.spinner('Processing file...'): code_lines = [line.decode('utf-8-sig') if isinstance(line, bytes) else line for line in uploaded_file.readlines()] # 添加行号(去掉冒号) numbered_original = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(code_lines)]) numbered_translated = [] # 创建两列布局 col1, col2 = st.columns(2) # 原始代码框 with col1: st.subheader("Original Python Code") original_content = st.session_state.original_content = numbered_original st.code(original_content, language='python') # 翻译代码框 with col2: st.subheader("Real-time Translation") translated_box = st.empty() progress_bar = st.progress(0) status_text = st.empty() # 处理变量初始化 translated_lines = [] detected_count = 0 translated_count = 0 scores = [] total_lines = len(code_lines) # 正则表达式模式 pure_comment_pattern = re.compile(r'^(\s*)#.*?([\u4e00-\u9fff]+.*)') inline_comment_pattern = re.compile(r'(\s+#)\s*([^#]*[\u4e00-\u9fff]+[^#]*)') multi_comment_pattern = re.compile(r'^(\s*)(["\']{3})(.*?)\2', re.DOTALL) # 逐行处理 for idx, line in enumerate(code_lines): current_line = line.rstrip('\n') # 更新进度 progress = (idx + 1) / total_lines progress_bar.progress(progress) status_text.markdown(f"**Processing line {idx+1}/{total_lines}** | Content: `{current_line[:50]}...`") # 注释处理逻辑 processed = False if pure_comment_pattern.search(line): detected_count += 1 if match := pure_comment_pattern.match(line): indent, comment = match.groups() translated = translate_text(comment.strip(), CUSTOM_TERMS) evaluate_translation(comment, translated, scores) translated_lines.append(f"{indent}# {translated}\n") translated_count += 1 processed = True if not processed and inline_comment_pattern.search(line): detected_count += 1 if match := inline_comment_pattern.search(line): code_part = line[:match.start()] symbol, comment = match.groups() translated = translate_text(comment.strip(), CUSTOM_TERMS) evaluate_translation(comment, translated, scores) translated_lines.append(f"{code_part}{symbol} {translated}\n") translated_count += 1 processed = True if not processed and (multi_match := multi_comment_pattern.match(line)): detected_count += 1 if re.search(r'[\u4e00-\u9fff]', multi_match.group(3)): translated = translate_text(multi_match.group(3), CUSTOM_TERMS) evaluate_translation(multi_match.group(3), translated, scores) translated_lines.append(f"{multi_match.group(1)}{multi_match.group(2)}{translated}{multi_match.group(2)}\n") translated_count += 1 processed = True if not processed: translated_lines.append(line) # 更新带行号的翻译结果(去掉冒号) numbered_translation = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(translated_lines)]) translated_box.code(numbered_translation, language='python') time.sleep(0.1) # 处理完成 st.session_state.translated_code = translated_lines st.session_state.detected_count = detected_count st.session_state.translated_count = translated_count st.session_state.scores = scores st.session_state.processed = True # 清空进度状态 progress_bar.empty() status_text.empty() # 显示统计信息 if st.session_state.processed: st.divider() # 右侧统计布局 with st.container(): col_right = st.columns([1, 3])[1] with col_right: # 第一行指标 col1, col2 = st.columns(2) with col1: st.metric("Detected Comments", st.session_state.detected_count) with col2: st.metric("Translated Comments", st.session_state.translated_count) # 第二行指标 col3, col4 = st.columns(2) with col3: if st.session_state.scores: avg_bert = sum(f1 for f1, _ in st.session_state.scores) / len(st.session_state.scores) st.metric("Average BERTScore", f"{avg_bert:.4f}", help="Higher is better (0-1)") with col4: if st.session_state.scores: avg_ppl = sum(ppl for _, ppl in st.session_state.scores) / len(st.session_state.scores) st.metric("Average Perplexity", f"{avg_ppl:.4f}", help="Lower is better (Typical range: 1~100+, lower means better translation)") # 下载按钮(修改点3:调整位置到指标下方) cols = st.columns([1, 2, 1]) with cols[1]: with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as tmp: tmp.write("".join(st.session_state.translated_code).encode('utf-8')) with open(tmp.name, 'rb') as f: st.download_button( label="⬇️ Download Translated File", data=f, file_name=f"translated_{uploaded_file.name}", mime='text/x-python', use_container_width=False )