|
|
|
import json
|
|
from pathlib import Path
|
|
from flask import request, send_file, current_app, make_response
|
|
from flask_restful import Resource
|
|
from flask_jwt_extended import jwt_required, get_jwt_identity
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
import zipfile
|
|
import os
|
|
|
|
from app import db, Setting
|
|
from app.models import Customer
|
|
from app.models.translate import Translate
|
|
from app.resources.task.translate_service import TranslateEngine
|
|
from app.utils.response import APIResponse
|
|
from app.utils.check_utils import AIChecker
|
|
|
|
|
|
TRANSLATE_SETTINGS = {
|
|
"models": ["gpt-3.5-turbo", "gpt-4"],
|
|
"default_model": "gpt-3.5-turbo",
|
|
"max_threads": 5,
|
|
"prompt_template": "请将以下内容翻译为{target_lang}"
|
|
}
|
|
|
|
|
|
class TranslateStartResource1(Resource):
|
|
@jwt_required()
|
|
def post(self):
|
|
"""启动翻译任务(支持绝对路径和多参数)[^1]"""
|
|
data = request.form
|
|
required_fields = [
|
|
'server', 'model', 'lang', 'uuid',
|
|
'prompt', 'threads', 'file_name'
|
|
]
|
|
|
|
|
|
if not all(field in data for field in required_fields):
|
|
return APIResponse.error("缺少必要参数", 400)
|
|
|
|
|
|
if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']):
|
|
return APIResponse.error("OpenAI服务需要API地址和密钥", 400)
|
|
|
|
try:
|
|
|
|
user_id = get_jwt_identity()
|
|
customer = Customer.query.get(user_id)
|
|
|
|
|
|
def get_absolute_storage_path(filename):
|
|
|
|
base_dir = Path(current_app.root_path).parent.absolute()
|
|
|
|
date_str = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
target_dir = base_dir / "storage" / "translate" / date_str
|
|
target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
return str(target_dir / filename)
|
|
|
|
origin_filename = data['file_name']
|
|
|
|
|
|
target_abs_path = get_absolute_storage_path(origin_filename)
|
|
|
|
|
|
translate_type = data.get('type[2]', 'trans_all_only_inherit')
|
|
|
|
|
|
translate = Translate.query.filter_by(uuid=data['uuid']).first()
|
|
if not translate:
|
|
return APIResponse.error("未找到对应的翻译记录", 404)
|
|
|
|
|
|
translate.origin_filename = data['file_name']
|
|
translate.target_filepath = target_abs_path
|
|
translate.lang = data['lang']
|
|
translate.model = data['model']
|
|
translate.backup_model = data['backup_model']
|
|
translate.type = translate_type
|
|
translate.prompt = data['prompt']
|
|
translate.threads = int(data['threads'])
|
|
translate.api_url = data.get('api_url', '')
|
|
translate.api_key = data.get('api_key', '')
|
|
translate.backup_model = data.get('backup_model', '')
|
|
translate.origin_lang = data.get('origin_lang', '')
|
|
|
|
comparison_id = data.get('comparison_id', '0')
|
|
translate.comparison_id = int(comparison_id) if comparison_id else None
|
|
prompt_id = data.get('prompt_id', '0')
|
|
translate.prompt_id = int(prompt_id) if prompt_id else None
|
|
translate.doc2x_flag = data.get('doc2x_flag', 'N')
|
|
translate.doc2x_secret_key = data.get('doc2x_secret_key', '')
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
|
TranslateEngine(translate.id).execute()
|
|
|
|
return APIResponse.success({
|
|
"task_id": translate.id,
|
|
"uuid": translate.uuid,
|
|
"target_path": target_abs_path
|
|
})
|
|
|
|
except Exception as e:
|
|
db.session.rollback()
|
|
current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True)
|
|
return APIResponse.error("任务启动失败", 500)
|
|
|
|
|
|
|
|
class TranslateStartResource(Resource):
|
|
@jwt_required()
|
|
def post(self):
|
|
"""启动翻译任务(支持绝对路径和多参数)[^1]"""
|
|
data = request.form
|
|
required_fields = [
|
|
'server', 'model', 'lang', 'uuid',
|
|
'prompt', 'threads', 'file_name'
|
|
]
|
|
|
|
|
|
if not all(field in data for field in required_fields):
|
|
return APIResponse.error("缺少必要参数", 400)
|
|
|
|
|
|
if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']):
|
|
return APIResponse.error("OpenAI服务需要API地址和密钥", 400)
|
|
|
|
try:
|
|
|
|
user_id = get_jwt_identity()
|
|
customer = Customer.query.get(user_id)
|
|
|
|
|
|
def get_absolute_storage_path(filename):
|
|
|
|
base_dir = Path(current_app.root_path).parent.absolute()
|
|
|
|
date_str = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
target_dir = base_dir / "storage" / "translate" / date_str
|
|
target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
return target_dir / filename
|
|
|
|
origin_filename = data['file_name']
|
|
|
|
|
|
target_abs_path = get_absolute_storage_path(origin_filename)
|
|
|
|
|
|
translate_type = data.get('type[2]', 'trans_all_only_inherit')
|
|
|
|
|
|
translate = Translate.query.filter_by(uuid=data['uuid']).first()
|
|
if not translate:
|
|
return APIResponse.error("未找到对应的翻译记录", 404)
|
|
|
|
|
|
translate.origin_filename = origin_filename
|
|
translate.target_filepath = str(target_abs_path)
|
|
translate.lang = data['lang']
|
|
translate.model = data['model']
|
|
translate.backup_model = data['backup_model']
|
|
translate.type = translate_type
|
|
translate.prompt = data['prompt']
|
|
translate.threads = int(data['threads'])
|
|
translate.api_url = data.get('api_url', '')
|
|
translate.api_key = data.get('api_key', '')
|
|
translate.backup_model = data.get('backup_model', '')
|
|
translate.origin_lang = data.get('origin_lang', '')
|
|
|
|
comparison_id = data.get('comparison_id', '0')
|
|
translate.comparison_id = int(comparison_id) if comparison_id else None
|
|
prompt_id = data.get('prompt_id', '0')
|
|
translate.prompt_id = int(prompt_id) if prompt_id else None
|
|
translate.doc2x_flag = data.get('doc2x_flag', 'N')
|
|
translate.doc2x_secret_key = data.get('doc2x_secret_key', '')
|
|
|
|
|
|
db.session.commit()
|
|
|
|
TranslateEngine(translate.id).execute()
|
|
|
|
return APIResponse.success({
|
|
"task_id": translate.id,
|
|
"uuid": translate.uuid,
|
|
"target_path": str(target_abs_path)
|
|
})
|
|
|
|
except Exception as e:
|
|
db.session.rollback()
|
|
current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True)
|
|
return APIResponse.error("任务启动失败", 500)
|
|
|
|
|
|
|
|
class TranslateListResource(Resource):
|
|
@jwt_required()
|
|
def get(self):
|
|
"""获取翻译记录列表"""
|
|
|
|
page = request.args.get('page', '1')
|
|
limit = request.args.get('limit', '100')
|
|
status_filter = request.args.get('status')
|
|
|
|
try:
|
|
page = int(page)
|
|
limit = int(limit)
|
|
except ValueError:
|
|
return APIResponse.error("Invalid page or limit value"), 400
|
|
|
|
query = Translate.query.filter_by(
|
|
customer_id=get_jwt_identity(),
|
|
deleted_flag='N'
|
|
)
|
|
|
|
|
|
if status_filter:
|
|
valid_statuses = {'none', 'process', 'done', 'failed'}
|
|
if status_filter not in valid_statuses:
|
|
return APIResponse.error(f"Invalid status value: {status_filter}"), 400
|
|
query = query.filter_by(status=status_filter)
|
|
|
|
|
|
pagination = query.paginate(page=page, per_page=limit, error_out=False)
|
|
|
|
|
|
data = []
|
|
for t in pagination.items:
|
|
|
|
if t.created_at and t.end_at:
|
|
spend_time = t.end_at - t.created_at
|
|
spend_time_minutes = int(spend_time.total_seconds() // 60)
|
|
spend_time_seconds = int(spend_time.total_seconds() % 60)
|
|
spend_time_str = f"{spend_time_minutes}分{spend_time_seconds}秒"
|
|
else:
|
|
spend_time_str = "--"
|
|
|
|
|
|
status_name_map = {
|
|
'none': '未开始',
|
|
'process': '进行中',
|
|
'done': '已完成',
|
|
'failed': '失败'
|
|
}
|
|
status_name = status_name_map.get(t.status, '未知状态')
|
|
|
|
|
|
file_type = self.get_file_type(t.origin_filename)
|
|
|
|
|
|
end_at_str = t.end_at.strftime('%Y-%m-%d %H:%M:%S') if t.end_at else "--"
|
|
|
|
data.append({
|
|
'id': t.id,
|
|
'file_type': file_type,
|
|
'origin_filename': t.origin_filename,
|
|
'status': t.status,
|
|
'status_name': status_name,
|
|
'process': float(t.process),
|
|
'spend_time': spend_time_str,
|
|
'end_at': end_at_str,
|
|
'start_at': t.start_at.strftime('%Y-%m-%d %H:%M:%S') if t.start_at else "--",
|
|
|
|
'lang': t.lang,
|
|
'target_filepath': t.target_filepath
|
|
})
|
|
|
|
|
|
return APIResponse.success({
|
|
'data': data,
|
|
'total': pagination.total,
|
|
'current_page': pagination.page
|
|
})
|
|
|
|
@staticmethod
|
|
def get_file_type(filename):
|
|
"""根据文件名获取文件类型"""
|
|
if not filename:
|
|
return "未知"
|
|
ext = filename.split('.')[-1].lower()
|
|
if ext in {'docx', 'doc'}:
|
|
return "Word"
|
|
elif ext in {'xlsx', 'xls'}:
|
|
return "Excel"
|
|
elif ext == 'pptx':
|
|
return "PPT"
|
|
elif ext == 'pdf':
|
|
return "PDF"
|
|
elif ext in {'txt', 'md'}:
|
|
return "文本"
|
|
else:
|
|
return "其他"
|
|
|
|
|
|
class TranslateSettingResource(Resource):
|
|
@jwt_required()
|
|
def get(self):
|
|
"""获取翻译配置"""
|
|
try:
|
|
|
|
settings = self._load_settings_from_db()
|
|
return APIResponse.success(settings)
|
|
except Exception as e:
|
|
return APIResponse.error(f"获取配置失败: {str(e)}", 500)
|
|
|
|
@staticmethod
|
|
def _load_settings_from_db():
|
|
"""
|
|
从数据库加载翻译配置
|
|
"""
|
|
|
|
settings = Setting.query.filter(
|
|
Setting.group.in_(['api_setting', 'other_setting']),
|
|
Setting.deleted_flag == 'N'
|
|
).all()
|
|
|
|
|
|
config = {}
|
|
for setting in settings:
|
|
|
|
value = json.loads(setting.value) if setting.serialized else setting.value
|
|
|
|
|
|
if setting.alias == 'models':
|
|
config['models'] = value.split(',') if isinstance(value, str) else value
|
|
elif setting.alias == 'default_model':
|
|
config['default_model'] = value
|
|
elif setting.alias == 'default_backup':
|
|
config['default_backup'] = value
|
|
elif setting.alias == 'api_url':
|
|
config['api_url'] = value
|
|
elif setting.alias == 'api_key':
|
|
config['api_key'] = value
|
|
elif setting.alias == 'prompt':
|
|
config['prompt_template'] = value
|
|
elif setting.alias == 'threads':
|
|
config['max_threads'] = int(value) if value.isdigit() else 10
|
|
|
|
|
|
config.setdefault('models', ['gpt-3.5-turbo', 'gpt-4'])
|
|
config.setdefault('default_model', 'gpt-3.5-turbo')
|
|
config.setdefault('default_backup', 'gpt-3.5-turbo')
|
|
config.setdefault('api_url', '')
|
|
config.setdefault('api_key', '')
|
|
config.setdefault('prompt_template', '请将以下内容翻译为{target_lang}')
|
|
config.setdefault('max_threads', 10)
|
|
|
|
return config
|
|
|
|
|
|
class TranslateProcessResource(Resource):
|
|
@jwt_required()
|
|
def post(self):
|
|
"""查询翻译进度[^3]"""
|
|
uuid = request.form.get('uuid')
|
|
translate = Translate.query.filter_by(
|
|
uuid=uuid,
|
|
customer_id=get_jwt_identity()
|
|
).first_or_404()
|
|
|
|
return APIResponse.success({
|
|
'status': translate.status,
|
|
'progress': float(translate.process),
|
|
'download_url': translate.target_filepath if translate.status == 'done' else None
|
|
})
|
|
|
|
|
|
class TranslateDeleteResource(Resource):
|
|
@jwt_required()
|
|
def delete(self, id):
|
|
"""软删除翻译记录[^4]"""
|
|
|
|
translate = Translate.query.filter_by(
|
|
id=id,
|
|
customer_id=get_jwt_identity()
|
|
).first_or_404()
|
|
|
|
|
|
translate.deleted_flag = 'Y'
|
|
db.session.commit()
|
|
|
|
return APIResponse.success(message='记录已标记为删除')
|
|
|
|
|
|
|
|
class TranslateDownloadResource(Resource):
|
|
|
|
def get(self, id):
|
|
"""通过 ID 下载单个翻译结果文件[^5]"""
|
|
|
|
translate = Translate.query.filter_by(
|
|
id=id,
|
|
|
|
).first_or_404()
|
|
|
|
|
|
if not translate.target_filepath or not os.path.exists(translate.target_filepath):
|
|
return APIResponse.error('文件不存在', 404)
|
|
|
|
|
|
response = make_response(send_file(
|
|
translate.target_filepath,
|
|
as_attachment=True,
|
|
download_name=os.path.basename(translate.target_filepath)
|
|
))
|
|
|
|
|
|
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0'
|
|
response.headers['Pragma'] = 'no-cache'
|
|
response.headers['Expires'] = '0'
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranslateDownloadAllResource(Resource):
|
|
@jwt_required()
|
|
def get(self):
|
|
"""批量下载所有翻译结果文件[^6]"""
|
|
|
|
records = Translate.query.filter_by(
|
|
customer_id=get_jwt_identity(),
|
|
deleted_flag='N'
|
|
).all()
|
|
|
|
|
|
zip_buffer = BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
|
|
for record in records:
|
|
if record.target_filepath and os.path.exists(record.target_filepath):
|
|
|
|
zip_file.write(
|
|
record.target_filepath,
|
|
os.path.basename(record.target_filepath)
|
|
)
|
|
|
|
|
|
zip_buffer.seek(0)
|
|
|
|
|
|
return send_file(
|
|
zip_buffer,
|
|
mimetype='application/zip',
|
|
as_attachment=True,
|
|
download_name=f"translations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"
|
|
)
|
|
|
|
|
|
class OpenAICheckResource(Resource):
|
|
@jwt_required()
|
|
def post(self):
|
|
"""OpenAI接口检测[^6]"""
|
|
data = request.form
|
|
required = ['api_url', 'api_key', 'model']
|
|
if not all(k in data for k in required):
|
|
return APIResponse.error('缺少必要参数', 400)
|
|
|
|
is_valid, msg = AIChecker.check_openai_connection(
|
|
data['api_url'],
|
|
data['api_key'],
|
|
data['model']
|
|
)
|
|
|
|
return APIResponse.success({'valid': is_valid, 'message': msg})
|
|
|
|
|
|
class PDFCheckResource(Resource):
|
|
@jwt_required()
|
|
def post(self):
|
|
"""PDF扫描件检测[^7]"""
|
|
if 'file' not in request.files:
|
|
return APIResponse.error('请选择PDF文件', 400)
|
|
|
|
file = request.files['file']
|
|
if not file.filename.lower().endswith('.pdf'):
|
|
return APIResponse.error('仅支持PDF文件', 400)
|
|
|
|
try:
|
|
file_stream = file.stream
|
|
is_scanned = AIChecker.check_pdf_scanned(file_stream)
|
|
return APIResponse.success({'scanned': is_scanned})
|
|
except Exception as e:
|
|
return APIResponse.error(f'检测失败: {str(e)}', 500)
|
|
|
|
|
|
|
|
class TranslateTestResource(Resource):
|
|
def get(self):
|
|
"""测试翻译服务[^1]"""
|
|
return APIResponse.success(message="测试服务正常")
|
|
|
|
|
|
class TranslateDeleteAllResource(Resource):
|
|
@jwt_required()
|
|
def delete(self):
|
|
"""删除用户所有翻译记录[^2]"""
|
|
Translate.query.filter_by(
|
|
customer_id=get_jwt_identity(),
|
|
deleted_flag='N'
|
|
).delete()
|
|
db.session.commit()
|
|
return APIResponse.success(message="删除成功")
|
|
|
|
|
|
class TranslateFinishCountResource(Resource):
|
|
@jwt_required()
|
|
def get(self):
|
|
"""获取已完成翻译数量[^3]"""
|
|
count = Translate.query.filter_by(
|
|
customer_id=get_jwt_identity(),
|
|
status='done',
|
|
deleted_flag='N'
|
|
).count()
|
|
return APIResponse.success({'total': count})
|
|
|
|
|
|
class TranslateRandDeleteAllResource(Resource):
|
|
def delete(self):
|
|
"""删除临时用户所有记录[^4]"""
|
|
rand_user_id = request.json.get('rand_user_id')
|
|
if not rand_user_id:
|
|
return APIResponse.error('需要临时用户ID', 400)
|
|
|
|
Translate.query.filter_by(
|
|
rand_user_id=rand_user_id,
|
|
deleted_flag='N'
|
|
).delete()
|
|
db.session.commit()
|
|
return APIResponse.success(message="删除成功")
|
|
|
|
|
|
class TranslateRandDeleteResource(Resource):
|
|
def delete(self, id):
|
|
"""删除临时用户单条记录[^5]"""
|
|
rand_user_id = request.json.get('rand_user_id')
|
|
translate = Translate.query.filter_by(
|
|
id=id,
|
|
rand_user_id=rand_user_id
|
|
).first_or_404()
|
|
|
|
db.session.delete(translate)
|
|
db.session.commit()
|
|
return APIResponse.success(message="删除成功")
|
|
|
|
|
|
class TranslateRandDownloadResource(Resource):
|
|
def get(self):
|
|
"""下载临时用户翻译文件[^6]"""
|
|
rand_user_id = request.args.get('rand_user_id')
|
|
records = Translate.query.filter_by(
|
|
rand_user_id=rand_user_id,
|
|
status='done'
|
|
).all()
|
|
|
|
zip_buffer = BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, 'w') as zip_file:
|
|
for record in records:
|
|
if os.path.exists(record.target_filepath):
|
|
zip_file.write(
|
|
record.target_filepath,
|
|
os.path.basename(record.target_filepath)
|
|
)
|
|
|
|
zip_buffer.seek(0)
|
|
return send_file(
|
|
zip_buffer,
|
|
mimetype='application/zip',
|
|
as_attachment=True,
|
|
download_name=f"temp_translations_{datetime.now().strftime('%Y%m%d')}.zip"
|
|
)
|
|
|
|
|
|
class Doc2xCheckResource(Resource):
|
|
def post(self):
|
|
"""检查Doc2x接口[^7]"""
|
|
secret_key = request.json.get('doc2x_secret_key')
|
|
|
|
if secret_key == "valid_key_123":
|
|
return APIResponse.success(message="接口正常")
|
|
return APIResponse.error("无效密钥", 400)
|
|
|