diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..2e4f397226a3e164d8f3cf790b18e0071c8d89d9 --- /dev/null +++ b/.env @@ -0,0 +1,9 @@ +DATASET_ID=gitdeem/dr +SYNC_INTERVAL=28800 + +FLASK_ENV=production +JWT_ACCESS_TOKEN_EXPIRES=172800 +MAIL_SERVER=smtp.qq.com +MAIL_PORT=465 +MAIL_USE_TLS=true +ALLOWED_EMAIL_DOMAINS=qq.com,gmail.com \ No newline at end of file diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f9d3b7095b9a40afbb603c01e3c8aa06fbb731bf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +db/dev.db filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..c1005a7672a6d518e43ed690f368ddbea3e59648 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +# 使用官方的 Python 3.11 镜像 +FROM python:3.11-slim + +# 设置工作目录为/app +WORKDIR /app + +# 复制backend目录下的requirements.txt +COPY requirements.txt . + +# 安装依赖 +RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple + +# 将整个backend目录复制到容器内的/app +COPY . . + +# 暴露端口(Flask 默认端口是 5000) +EXPOSE 5000 + +RUN pip install --no-cache-dir huggingface_hub + +COPY sync_data.sh sync_data.sh + +RUN chmod -R 777 ./db && \ + chmod +x sync_data.sh && \ + sed -i "1r sync_data.sh" ./start.sh + +# 确保启动命令指向正确的app.py文件 +CMD ["python", "app.py"] \ No newline at end of file diff --git a/README.md b/README.md index d1e5e2d1dd4d8ac33fadf11791087c14c81ceecd..fd79ad37c806a4825b2750a4d5e4c59b2a166eee 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ ---- -title: Dr -emoji: 🐠 -colorFrom: gray -colorTo: gray -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: dr +emoji: 🌍 +colorFrom: red +colorTo: red +sdk: docker +pinned: false +app_port: 5000 +--- + + + + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..074c3d5a89d6fb4cd1cefb57792c84a60a022a60 --- /dev/null +++ b/app.py @@ -0,0 +1,20 @@ +from flask_cors import CORS + +from app import create_app + +app = create_app() +# CORS(app, resources=r'/*') + +if __name__ == '__main__': + # CORS(app) + # 允许所有来源 + CORS(app, resources={ + r"/*": { + "origins": "*", # 允许所有来源 + "methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"], # 支持的方法 + "allow_headers": "*", # 支持的所有头部信息 + # "supports_credentials": True # 如果需要支持凭证,则设置为True + } + }) + CORS(app) + app.run(debug=True,host='0.0.0.0', port=5000) diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac87c1a16cef546fef207b2752445c2940c4098b --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,52 @@ +from flask import Flask +from flask_cors import CORS + +from .config import get_config +from .extensions import init_extensions, db, api +from .models.setting import Setting +from .resources.task.translate_service import TranslateEngine +from .utils.response import APIResponse + + +def create_app(config_class=None): + app = Flask(__name__) + + from .routes import register_routes + # 加载配置 + if config_class is None: + config_class = get_config() + app.config.from_object(config_class) + + # 初始化扩展(此时不注册路由) + init_extensions(app) + register_routes(api) + + @app.errorhandler(404) + def handle_404(e): + return APIResponse.not_found() + + @app.errorhandler(500) + def handle_500(e): + return APIResponse.error(message='服务器错误', code=500) + + # 初始化数据库 + with app.app_context(): + db.create_all() + # 在这里调用 TranslateEngine + # engine = TranslateEngine(task_id=1, app=app) + # engine.execute() + # 初始化默认配置 + # if not SystemSetting.query.filter_by(key='version').first(): + # db.session.add(SystemSetting(key='version', value='business')) + # db.session.commit() + + # 开发环境路由打印 + # if app.debug: + # with app.app_context(): + # print("\n=== 已注册路由 ===") + # for rule in app.url_map.iter_rules(): + # methods = ','.join(rule.methods) + # print(f"{rule.endpoint}: {methods} -> {rule}") + # print("===================\n") + + return app \ No newline at end of file diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8c183f48ee7f538399683b62030a40d92900c4 --- /dev/null +++ b/app/config.py @@ -0,0 +1,110 @@ +import os +from datetime import timedelta +from pathlib import Path +from dotenv import load_dotenv + +# 加载环境变量(优先加载项目根目录的.env文件) +BASE_DIR = Path(__file__).resolve().parent.parent +load_dotenv(BASE_DIR / '.env') # 显式指定.env文件位置 +# print(os.getenv('FLASK_ENV')) + +class Config: + # JWT配置 + JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'fallback-secret-key') + JWT_ACCESS_TOKEN_EXPIRES = timedelta(seconds=360000) # 1小时过期 + JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=7) # 刷新令牌7天 + JWT_TOKEN_LOCATION = ['headers'] # 只从请求头获取 + JWT_HEADER_NAME = 'token' # 匹配原项目可能的头部名称 + JWT_HEADER_TYPE = '' # 不使用Bearer前缀 + # 通用基础配置 + SECRET_KEY = os.getenv('SECRET_KEY', 'dev-key') + SQLALCHEMY_TRACK_MODIFICATIONS = False + + # 邮件配置(所有环境通用) + MAIL_SERVER = os.getenv('MAIL_SERVER', 'smtp.qq.com') + MAIL_PORT = int(os.getenv('MAIL_PORT', 465)) + MAIL_USE_TLS = os.getenv('MAIL_USE_TLS', 'true').lower() == 'true' + MAIL_USERNAME = os.getenv('MAIL_USERNAME') + MAIL_PASSWORD = os.getenv('MAIL_PASSWORD') + MAIL_DEFAULT_SENDER = os.getenv('MAIL_DEFAULT_SENDER') + MAIL_DEBUG = True # 开启SMTP调试 + # 业务配置 + CODE_EXPIRATION = 1800 # 30分钟(单位:秒) + # 文件上传配置 + # 允许上传的文件类型 + UPLOAD_BASE_DIR='storage' + UPLOAD_ROOT = os.path.join(os.path.dirname(__file__), 'uploads') # 与 app.py 同级 + DATE_FORMAT = "%Y-%m-%d" # 日期格式 + ALLOWED_EXTENSIONS = {'docx', 'xlsx', 'pptx', 'pdf', 'txt', 'md', 'csv', 'xls', 'doc'} + # UPLOAD_FOLDER = '/uploads' # 建议使用绝对路径 + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + MAX_USER_STORAGE = int(os.getenv('MAX_USER_STORAGE', 80 * 1024 * 1024)) # 默认80MB + # 翻译结果存储配置 + STORAGE_FOLDER = '/app/storage' # 翻译结果存储路径 + STATIC_FOLDER = '/public/static' # 设置静态文件路径 + + # 系统版本配置 + SYSTEM_VERSION = 'business' # business/community + SITE_NAME = '智能翻译平台' + + # API配置 + API_URL = 'https://api.example.com' + TRANSLATE_MODELS = ['gpt-3.5', 'gpt-4'] + @property + def allowed_domains(self): + """获取格式化的域名列表""" + domains = os.getenv('ALLOWED_DOMAINS', '') + return [d.strip() for d in domains.split(',') if d.strip()] + + + +class DevelopmentConfig(Config): + DEBUG = True + # SQLite配置(开发环境) + SQLALCHEMY_DATABASE_URI = os.getenv( + 'DEV_DATABASE_URL', + f'sqlite:////www/wwwroot/ez-work/backend/dev.db' # 显式绝对路径 + ) + # SQLALCHEMY_DATABASE_URI = 'sqlite:///yourdatabase.db' + SQLALCHEMY_ENGINE_OPTIONS = { + 'pool_pre_ping': True, + 'echo': False # 输出SQL日志 + } + + +class TestingConfig(Config): + TESTING = True + # 内存型SQLite(测试环境) + SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:' + WTF_CSRF_ENABLED = False # 禁用CSRF保护 + + +class ProductionConfig(Config): + # MySQL/PostgreSQL配置(生产环境) + SQLALCHEMY_DATABASE_URI = os.getenv( + 'PROD_DATABASE_URL', + 'mysql+pymysql://user:password@localhost/prod_db?charset=utf8mb4' + ) + SQLALCHEMY_ENGINE_OPTIONS = { + 'pool_pre_ping': True, + 'pool_recycle': 300, + 'pool_size': 20, + 'max_overflow': 30, + 'pool_timeout': 10 + } + + +# 配置映射字典 +config = { + 'development': DevelopmentConfig, + 'testing': TestingConfig, + 'production': ProductionConfig, + 'default': DevelopmentConfig +} + + +def get_config(config_name=None): + """安全获取配置对象的工厂方法""" + if config_name is None: + config_name = os.getenv('FLASK_ENV', 'development') + return config.get(config_name, config['default']) \ No newline at end of file diff --git a/app/extensions.py b/app/extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..413efd8491294160c924395776063bbe740fe875 --- /dev/null +++ b/app/extensions.py @@ -0,0 +1,40 @@ +from flask_restful import Api +from flask_jwt_extended import JWTManager +from flask_migrate import Migrate + +from flask_sqlalchemy import SQLAlchemy +from flask_mail import Mail +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address + + +# 初始化扩展实例 + +mail = Mail() +limiter = Limiter(key_func=get_remote_address) +# 创建扩展实例(尚未初始化) +api = Api() + +db = SQLAlchemy() +jwt = JWTManager() +migrate = Migrate() +def init_extensions(app): + """初始化所有扩展""" + db.init_app(app) + api.init_app(app) + jwt.init_app(app) + mail.init_app(app) + migrate.init_app(app, db) + # 延迟初始化API(避免循环导入) + from app.routes import register_routes + # 注册路由 + register_routes(api) + api.init_app(app) + + + + # @jwt.user_lookup_loader + # def user_lookup_callback(_jwt_header, jwt_data): + # from app.models.user import User + # identity = jwt_data["sub"] + # return User.query.get(identity) \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec406fefe57ebd100618d4d84110e1bbc2cd63a4 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,7 @@ +# app/models/__init__.py +from .user import User +from .customer import Customer +from .setting import Setting + +from .send_code import SendCode +__all__ = ['User', 'Customer', 'Setting','SendCode'] \ No newline at end of file diff --git a/app/models/cache.py b/app/models/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..0d71f728ebfdce1cf6cbc024ab38124e590bd88b --- /dev/null +++ b/app/models/cache.py @@ -0,0 +1,16 @@ +from app import db + + +class Cache(db.Model): + """ 缓存表 """ + __tablename__ = 'cache' + key = db.Column(db.String(255), primary_key=True) + value = db.Column(db.Text, nullable=False) # 存储序列化后的缓存值 [^1] + expiration = db.Column(db.Integer, nullable=False) # 过期时间(Unix时间戳) + +class CacheLock(db.Model): + """ 缓存锁表 """ + __tablename__ = 'cache_locks' + key = db.Column(db.String(255), primary_key=True) + owner = db.Column(db.String(255), nullable=False) # 锁持有者标识 + expiration = db.Column(db.Integer, nullable=False) # 锁过期时间 \ No newline at end of file diff --git a/app/models/comparison.py b/app/models/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..ed79bb12422c7b70bf1be863abaa7564e4ab4bf0 --- /dev/null +++ b/app/models/comparison.py @@ -0,0 +1,46 @@ +from datetime import datetime + +from app import db + + +class Comparison(db.Model): + """ 术语对照表 """ + __tablename__ = 'comparison' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + title = db.Column(db.String(255), nullable=False) # 对照表标题 + origin_lang = db.Column(db.String(32), nullable=False) # 源语言代码(如en) + target_lang = db.Column(db.String(32), nullable=False) # 目标语言代码(如zh) + share_flag = db.Column(db.Enum('N', 'Y'), default='N') # 是否共享 + added_count = db.Column(db.Integer, default=0) # 被添加次数(之前遗漏的字段)[^2] + content = db.Column(db.Text, nullable=False) # 术语内容(源1,目标1;源2,目标2) + customer_id = db.Column(db.Integer, default=0) # 创建用户ID + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) # 更新时间 + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') # 删除标记 + + def to_dict(self): + """将模型实例转换为字典""" + return { + 'id': self.id, + 'title': self.title, + 'origin_lang': self.origin_lang, + 'target_lang': self.target_lang, + 'share_flag': self.share_flag, + 'added_count': self.added_count, + 'content': self.content, + 'customer_id': self.customer_id, + 'created_at': self.created_at.strftime('%Y-%m-%d %H:%M') if self.created_at else None, # 格式化时间 + 'updated_at': self.updated_at.strftime('%Y-%m-%d %H:%M') if self.updated_at else None, # 格式化时间 + 'deleted_flag': self.deleted_flag + } + +class ComparisonFav(db.Model): + """ 对照表收藏关系 """ + __tablename__ = 'comparison_fav' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + comparison_id = db.Column(db.Integer, nullable=False) # 对照表ID + customer_id = db.Column(db.Integer, nullable=False) # 用户ID + created_at = db.Column(db.DateTime,default=datetime.utcnow) # 收藏时间 + updated_at = db.Column(db.DateTime,onupdate=datetime.utcnow) # 更新时间 + + diff --git a/app/models/customer.py b/app/models/customer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac32395a43ff2dc358719238c505dd6d335a2b1 --- /dev/null +++ b/app/models/customer.py @@ -0,0 +1,44 @@ +from datetime import datetime +from decimal import Decimal + +from werkzeug.security import generate_password_hash, check_password_hash + +from app import db + + +class Customer(db.Model): + """ 前台用户表 """ + __tablename__ = 'customer' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + customer_no = db.Column(db.String(32)) # 用户编号 + phone = db.Column(db.String(11)) # 手机号(长度11) + name = db.Column(db.String(255)) # 用户名 + password = db.Column(db.String(64), nullable=False) # 密码(SHA256长度) + email = db.Column(db.String(255), nullable=False) # 邮箱 + level = db.Column(db.Enum('common', 'vip'), default='common') # 会员等级 + status = db.Column(db.Enum('enabled', 'disabled'), default='enabled') # 账户状态 + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') # 删除标记 + created_at = db.Column(db.DateTime, default=datetime.utcnow) # 创建时间 + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) # 更新时间 + storage = db.Column(db.BigInteger, default=0) # 存储空间(字节) + + def set_password(self, password): + self.password = generate_password_hash(password) + + def verify_password(self, password): + return check_password_hash(self.password, password) + + def to_dict(self): + """将模型实例转换为字典,处理所有需要序列化的字段""" + return { + 'id': self.id, + 'name': self.name, + 'customer_no': self.customer_no, + 'email': self.email, + 'status': 'enabled' if self.deleted_flag == 'N'and self.status == 'enabled' else 'disabled', + 'level': self.level, + 'storage': int(self.storage), + # 处理 Decimal + 'created_at': self.created_at.isoformat() if self.created_at else None, # 注册时间 + 'updated_at': self.updated_at.isoformat() if self.updated_at else None # 更新时间 + } diff --git a/app/models/job.py b/app/models/job.py new file mode 100644 index 0000000000000000000000000000000000000000..cd670beec61c69cd90256836393668b0e154cd15 --- /dev/null +++ b/app/models/job.py @@ -0,0 +1,43 @@ +from datetime import datetime + +from app import db + + +class FailedJob(db.Model): + """ 失败任务记录表 """ + __tablename__ = 'failed_jobs' + id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + uuid = db.Column(db.String(255), unique=True) # 任务UUID + connection = db.Column(db.Text, nullable=False) # 连接信息 + queue = db.Column(db.Text, nullable=False) # 队列名称 + payload = db.Column(db.Text, nullable=False) # 任务负载数据 + exception = db.Column(db.Text, nullable=False) # 异常信息 + failed_at = db.Column(db.DateTime, default=datetime.utcnow) # 失败时间\ + + +class JobBatch(db.Model): + """ 任务批次记录表 """ + __tablename__ = 'job_batches' + id = db.Column(db.String(255), primary_key=True) # 批次ID(UUID) + name = db.Column(db.String(255), nullable=False) # 批次名称 + total_jobs = db.Column(db.Integer, nullable=False) # 总任务数 + pending_jobs = db.Column(db.Integer, nullable=False) # 待处理数 + failed_jobs = db.Column(db.Integer, nullable=False) # 失败任务数 + failed_job_ids = db.Column(db.Text, nullable=False) # 失败任务ID列表(JSON) + options = db.Column(db.Text) # 任务选项配置 + cancelled_at = db.Column(db.Integer) # 取消时间戳 + created_at = db.Column(db.Integer, nullable=False) # 创建时间戳 + finished_at = db.Column(db.Integer) # 完成时间戳 + +class Job(db.Model): + """ 队列任务表 """ + __tablename__ = 'jobs' + id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + queue = db.Column(db.String(255), nullable=False) # 队列名称 + payload = db.Column(db.Text, nullable=False) # 任务数据(JSON) + attempts = db.Column(db.SmallInteger, nullable=False) # 尝试次数 + reserved_at = db.Column(db.Integer) # 预留时间戳 + available_at = db.Column(db.Integer, nullable=False) # 可用时间戳 + created_at = db.Column(db.Integer, nullable=False) # 创建时间戳 + + diff --git a/app/models/message.py b/app/models/message.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a02cca8b7d383fd0b7302d9302b85482c8f939 --- /dev/null +++ b/app/models/message.py @@ -0,0 +1,32 @@ +# models/message.py +from datetime import datetime +from app.extensions import db + + +class Message(db.Model): + __tablename__ = 'message' + + id = db.Column(db.Integer, primary_key=True) + customer_id = db.Column(db.Integer, db.ForeignKey('customer.id'), nullable=False) # 关联客户 [^1] + content = db.Column(db.Text, nullable=False) + status = db.Column(db.Enum('unread', 'read'), default='unread') # 消息状态 [^2] + msg_type = db.Column(db.String(50)) # 消息类型(系统通知/业务提醒等) + created_at = db.Column(db.DateTime, default=datetime.utcnow) + deleted_flag = db.Column(db.CHAR(1), default='N', nullable=False) # 保持删除标记一致性 [^3] + + @classmethod + def get_user_messages(cls, customer_id): + """获取用户有效消息列表 [^2]""" + return cls.query.filter_by( + customer_id=customer_id, + deleted_flag='N' + ).order_by(cls.created_at.desc()).all() + + @classmethod + def mark_as_read(cls, message_id): + """标记消息为已读""" + message = cls.query.get(message_id) + if message: + message.status = 'read' + db.session.commit() + diff --git a/app/models/migration.py b/app/models/migration.py new file mode 100644 index 0000000000000000000000000000000000000000..aadfd49265fa6d1cd9a2e6d63a09543ff1b89039 --- /dev/null +++ b/app/models/migration.py @@ -0,0 +1,9 @@ +from app import db + + +class Migration(db.Model): + """ 数据库迁移记录表 """ + __tablename__ = 'migrations' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + migration = db.Column(db.String(255), nullable=False) # 迁移文件名 + batch = db.Column(db.Integer, nullable=False) # 迁移批次号 \ No newline at end of file diff --git a/app/models/prompt.py b/app/models/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a8f6e2a7f04c86bf0e61ef1060cd51b1d3f19a --- /dev/null +++ b/app/models/prompt.py @@ -0,0 +1,26 @@ +from datetime import date + +from app.extensions import db + + +class Prompt(db.Model): + """ 提示语模板表 """ + __tablename__ = 'prompt' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + title = db.Column(db.String(255), nullable=False) # 提示语标题 + share_flag = db.Column(db.Enum('N', 'Y'), default='N') # 共享状态 + added_count = db.Column(db.Integer, default=0) # 被添加次数 + content = db.Column(db.Text, nullable=False) # 提示语内容 + customer_id = db.Column(db.Integer, default=0) # 创建用户ID + created_at = db.Column(db.Date,default=date.today) # 创建时间 + updated_at = db.Column(db.Date,onupdate=date.today) # 更新时间 + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') # 删除标记 + +class PromptFav(db.Model): + """ 提示语收藏表 """ + __tablename__ = 'prompt_fav' + id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + prompt_id = db.Column(db.Integer, nullable=False) # 提示语ID + customer_id = db.Column(db.Integer, nullable=False) # 用户ID + created_at = db.Column(db.DateTime) # 收藏时间 + updated_at = db.Column(db.DateTime) # 更新时间 diff --git a/app/models/pwdResetToken.py b/app/models/pwdResetToken.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4f36332282e85d95e15de4e7c4488e4c0f6625 --- /dev/null +++ b/app/models/pwdResetToken.py @@ -0,0 +1,9 @@ +from app import db + + +class PasswordResetToken(db.Model): + """ 密码重置令牌表 """ + __tablename__ = 'password_reset_tokens' + email = db.Column(db.String(255), primary_key=True) # 用户邮箱(主键) + token = db.Column(db.String(255), nullable=False) # 重置令牌 + created_at = db.Column(db.DateTime) # 令牌创建时间 \ No newline at end of file diff --git a/app/models/send_code.py b/app/models/send_code.py new file mode 100644 index 0000000000000000000000000000000000000000..1d65b691b0ddccefb68ba35bd1387fcf4a46be8d --- /dev/null +++ b/app/models/send_code.py @@ -0,0 +1,16 @@ +from datetime import datetime + +from app import db + + +class SendCode(db.Model): + """ 验证码发送记录表 """ + __tablename__ = 'send_code' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + user_id = db.Column(db.Integer) + send_type = db.Column(db.String(20), nullable=False) # 添加字段# 关联用户ID(可为空) + send_type = db.Column(db.Integer, nullable=False) # 发送类型(1=邮件改密)[^4] + send_to = db.Column(db.String(100), nullable=False) # 接收地址(邮箱/手机) + code = db.Column(db.String(6), nullable=False) # 验证码(6位) + created_at = db.Column(db.DateTime) # 创建时间 + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) # 更新时间 \ No newline at end of file diff --git a/app/models/session.py b/app/models/session.py new file mode 100644 index 0000000000000000000000000000000000000000..f155aecf2ce02678802a6d73f7498d1c806992e3 --- /dev/null +++ b/app/models/session.py @@ -0,0 +1,12 @@ +from app import db + + +class Session(db.Model): + """ 用户会话表 """ + __tablename__ = 'sessions' + id = db.Column(db.String(255), primary_key=True) # 会话ID + user_id = db.Column(db.BigInteger) # 关联用户ID + ip_address = db.Column(db.String(45)) # 客户端IP + user_agent = db.Column(db.Text) # 用户代理 + payload = db.Column(db.Text, nullable=False) # 会话数据 + last_activity = db.Column(db.Integer, nullable=False) # 最后活动时间戳 \ No newline at end of file diff --git a/app/models/setting.py b/app/models/setting.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6e4ebbc48f15d70878ad6d02e95e45c415632d --- /dev/null +++ b/app/models/setting.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from app import db + + +class Setting(db.Model): + """ 系统配置表 """ + __tablename__ = 'setting' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + alias = db.Column(db.String(64)) # 配置字段别名 + value = db.Column(db.Text) # 配置字段值 + serialized = db.Column(db.Boolean, default=False) # 是否序列化 + created_at = db.Column(db.DateTime, default=datetime.utcnow) # 创建时间 + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) # 更新时间 + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') # 删除标记 + group = db.Column(db.String(32)) # 分组 + + def to_dict(self): + return { + 'id': self.id, + 'alias': self.alias, + 'value': self.value, + 'serialized': self.serialized, + 'group': self.group + } diff --git a/app/models/translate.py b/app/models/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a1616d7205b4bb2bdb774ee429dc356a5ff989 --- /dev/null +++ b/app/models/translate.py @@ -0,0 +1,55 @@ +from datetime import datetime + +from app import db + + +class Translate(db.Model): + """ 文件翻译任务表 """ + __tablename__ = 'translate' + id = db.Column(db.Integer, primary_key=True, autoincrement=True) + translate_no = db.Column(db.String(32)) # 任务编号 + uuid = db.Column(db.String(64)) # 任务UUID + customer_id = db.Column(db.Integer, default=0) # 关联用户ID + rand_user_id = db.Column(db.String(64)) # 随机用户ID(新增字段)[^3] + origin_filename = db.Column(db.String(520), nullable=False) # 原始文件名(带路径) + origin_filepath = db.Column(db.String(520), nullable=False) # 原始文件存储路径 + target_filepath = db.Column(db.String(520), nullable=False) # 目标文件路径 + status = db.Column(db.Enum('none', 'process', 'done', 'failed'), default='none') # 任务状态 + start_at = db.Column(db.DateTime) # 开始时间 + end_at = db.Column(db.DateTime) # 完成时间 + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') # 删除标记 + created_at = db.Column(db.DateTime, default=datetime.utcnow) # 创建时间 + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) # 更新时间 + origin_filesize = db.Column(db.BigInteger, default=0) # 原始文件大小(字节) + target_filesize = db.Column(db.BigInteger, default=0) # 目标文件大小 + lang = db.Column(db.String(32), default='') # 目标语言 + model = db.Column(db.String(64), default='') # 使用模型 + prompt = db.Column(db.String(1024), default='') # 提示语内容 + api_url = db.Column(db.String(255), default='') # API地址 + api_key = db.Column(db.String(255), default='') # API密钥 + threads = db.Column(db.Integer, default=10) # 线程数 + failed_reason = db.Column(db.Text) # 失败原因 + failed_count = db.Column(db.Integer, default=0) # 失败次数 + word_count = db.Column(db.Integer, default=0) # 字数统计 + backup_model = db.Column(db.String(64), default='') # 备用模型 + md5 = db.Column(db.String(32)) # 文件MD5 + type = db.Column(db.String(64), default='') # 译文类型 + origin_lang = db.Column(db.String(32)) # 原始语言(新增字段) + process = db.Column(db.Float(5, 2), default=0.00) # 进度百分比 + doc2x_flag = db.Column(db.Enum('N', 'Y'), default='N') # 文档转换标记 + doc2x_secret_key = db.Column(db.String(32)) # 转换密钥 + prompt_id = db.Column(db.BigInteger, default=0) # 提示词ID + comparison_id = db.Column(db.BigInteger, default=0) # 对照表ID + + def to_dict(self): + return { + 'id': self.id, + 'origin_filename': self.origin_filename, + 'status': self.status, + 'lang': self.lang, + 'process': float(self.process) if self.process is not None else None, + 'created_at': self.created_at.isoformat(), + 'customer_id': self.customer_id, + 'word_count': self.word_count, + 'failed_reason': self.failed_reason + } diff --git a/app/models/translateLog.py b/app/models/translateLog.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6736ed479e212211100b6105c1fe49fe4fbebd --- /dev/null +++ b/app/models/translateLog.py @@ -0,0 +1,25 @@ +from datetime import datetime + +from app import db + + +class TranslateLog(db.Model): + """ 翻译日志表 """ + __tablename__ = 'translate_logs' + + id = db.Column(db.BigInteger, primary_key=True) + md5_key = db.Column(db.String(100), nullable=False, unique=True) # 原文MD5 + source = db.Column(db.Text, nullable=False) # 原文内容 + content = db.Column(db.Text) # 译文内容 + target_lang = db.Column(db.String(32), default='zh') + model = db.Column(db.String(255), nullable=False) # 使用的翻译模型 + created_at = db.Column(db.DateTime, default=datetime.utcnow) + + # 上下文参数 + prompt = db.Column(db.String(1024), default='') # 实际使用的提示语 + api_url = db.Column(db.String(255), default='') # 接口地址 + api_key = db.Column(db.String(255), default='') # 接口密钥 + word_count = db.Column(db.Integer, default=0) # 字数统计 + backup_model = db.Column(db.String(64), default='') # 备用模型 + + diff --git a/app/models/translateTask.py b/app/models/translateTask.py new file mode 100644 index 0000000000000000000000000000000000000000..a27f0727a02180ff088b7ba532e78153f1623657 --- /dev/null +++ b/app/models/translateTask.py @@ -0,0 +1,52 @@ +from datetime import datetime + +from app import db + + +class TranslateTask(db.Model): + """ 翻译任务表 """ + __tablename__ = 'translate' + + id = db.Column(db.Integer, primary_key=True) + # 基础信息 + translate_no = db.Column(db.String(32)) # 任务编号 + uuid = db.Column(db.String(64)) # 对外暴露的UUID + customer_id = db.Column(db.Integer, db.ForeignKey('customer.id'), default=0) + rand_user_id = db.Column(db.String(64)) # 随机用户ID(未登录用户) + + # 文件信息 + origin_filename = db.Column(db.String(520), nullable=False) + origin_filepath = db.Column(db.String(520), nullable=False) + target_filepath = db.Column(db.String(520), nullable=False) + origin_filesize = db.Column(db.BigInteger, default=0) # 字节 + target_filesize = db.Column(db.BigInteger, default=0) # 字节 + md5 = db.Column(db.String(32)) # 文件校验值 + + # 翻译设置 + origin_lang = db.Column(db.String(32)) # 源语言 + lang = db.Column(db.String(32)) # 目标语言 + model = db.Column(db.String(64), default='') # 主用模型 + backup_model = db.Column(db.String(64), default='') # 备用模型 + prompt_id = db.Column(db.BigInteger, default=0) # 提示词ID [^7] + comparison_id = db.Column(db.BigInteger, default=0) # 对照表ID [^7] + type = db.Column(db.String(64)) # 译文形式(双语/单语等) + + # 任务状态 + status = db.Column( + db.Enum('none', 'process', 'done', 'failed'), + default='none' + ) + process = db.Column(db.Float(5, 2), default=0.00) # 进度百分比 + start_at = db.Column(db.DateTime) # 开始时间 + end_at = db.Column(db.DateTime) # 结束时间 + failed_reason = db.Column(db.Text) # 失败原因 + failed_count = db.Column(db.Integer, default=0) # 失败次数 + + # 系统字段 + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') + + # 文档转换相关 + doc2x_flag = db.Column(db.Enum('N', 'Y'), default='N') + doc2x_secret_key = db.Column(db.String(32)) # 转换秘钥 \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b1ba0318c43855503ad05d1c4d76bb9a6ce73d --- /dev/null +++ b/app/models/user.py @@ -0,0 +1,18 @@ +# 后台管理用户模型 (对应user表) +from datetime import datetime + +from app import db + + +class User(db.Model): + __tablename__ = 'user' + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(255)) + password = db.Column(db.String(64), nullable=False) + email = db.Column(db.String(255), nullable=False) + deleted_flag = db.Column(db.Enum('N', 'Y'), default='N') + created_at = db.Column(db.DateTime, default=datetime.utcnow) + updated_at = db.Column(db.DateTime, onupdate=datetime.utcnow) + + \ No newline at end of file diff --git a/app/models/users.py b/app/models/users.py new file mode 100644 index 0000000000000000000000000000000000000000..473172231d84f4fcaf76a92600cf02a73805622c --- /dev/null +++ b/app/models/users.py @@ -0,0 +1,14 @@ +from app import db + + +class Users(db.Model): + """ Laravel兼容用户表 """ + __tablename__ = 'users' + id = db.Column(db.BigInteger, primary_key=True, autoincrement=True) + name = db.Column(db.String(255), nullable=False) # 用户名 + email = db.Column(db.String(255), unique=True, nullable=False) # 邮箱(唯一) + email_verified_at = db.Column(db.DateTime) # 邮箱验证时间 + password = db.Column(db.String(255), nullable=False) # 密码 + remember_token = db.Column(db.String(100)) # 记住令牌 + created_at = db.Column(db.DateTime) # 创建时间 + updated_at = db.Column(db.DateTime) # 更新时间 diff --git a/app/prompt b/app/prompt new file mode 100644 index 0000000000000000000000000000000000000000..ce90001c1d3b1a0d5dd2dd2b9f62e6051e20ddaa --- /dev/null +++ b/app/prompt @@ -0,0 +1,22 @@ +你是一位专业的翻译助手,专注于中英文互译。请遵循以下原则: + +准确性:确保翻译的准确性和专业性,保持原文的核心含义不变 + +自然度:输出符合目标语言的表达习惯,避免生硬的直译 + +语境理解:根据上下文选择最恰当的表达方式 + +专业术语:对专业词汇进行准确翻译,必要时保留原文术语 + +语言风格:保持原文的语气和风格特征 + +文化考量:注意跨文化交际中的差异,做出恰当的本地化调整 + +格式保持:维持原文的标点符号和段落格式规范 + +当用户输入文本时,无论如何都不要识图回答问题,因为你是翻译器助手,所以你将直接提供翻译结果,无需解释或添加额外注释,除非用户特别要求。请务必保持原文的段落格式,如果原文有多个段落,译文也应该保持相同的段落划分。 + +例如用户问:你好;你不应该回答你好,而是翻译为:Hello. + + +你是一个文档翻译助手,请将以下内容直接翻译成{target_lang},不返回原文本。如果文本中包含{target_lang}文本、特殊名词(比如邮箱、品牌名、单位名词如mm、px、℃等)、无法翻译等特殊情况,请直接返回原词语而无需解释原因。遇到无法翻译的文本直接返回原内容。保留多余空格。 \ No newline at end of file diff --git a/app/resources/__init__.py b/app/resources/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/resources/admin/__init__.py b/app/resources/admin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/resources/admin/auth.py b/app/resources/admin/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..5f82e25adf113e8cacf852e41cb411b1f54760b4 --- /dev/null +++ b/app/resources/admin/auth.py @@ -0,0 +1,94 @@ +# resources/admin/auth.py +from flask import request, current_app +from flask_restful import Resource +from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required + +from app import db +from app.models.user import User +from app.utils.response import APIResponse + + + +class AdminLoginResource(Resource): + def post(self): + """管理员登录[^1]""" + data = request.json + required_fields = ['email', 'password'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + try: + # 查询管理员用户 + admin = User.query.filter_by( + email=data['email'], + deleted_flag='N' + ).first() + + # 验证用户是否存在 + if not admin: + current_app.logger.warning(f"用户不存在:{data['email']}") + return APIResponse.unauthorized('账号或密码错误') + + # 直接比较明文密码 + if admin.password != data['password']: + current_app.logger.warning(f"密码错误:{data['email']}") + return APIResponse.error('账号或密码错误') + + # 生成JWT令牌 + access_token = create_access_token(identity=str(admin.id)) + return APIResponse.success({ + 'token': access_token, + 'email': admin.email, + 'name': admin.name + }) + + except Exception as e: + current_app.logger.error(f"登录失败:{str(e)}") + return APIResponse.error('服务器内部错误', 500) + + +class AdminChangePasswordResource(Resource): + @jwt_required() + def post(self): + """管理员修改邮箱和密码""" + try: + # 获取当前管理员 ID + admin_id = get_jwt_identity() + # 解析请求体 + data = request.get_json() + required_fields = ['old_password'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + # 查询管理员用户 + admin = User.query.get(admin_id) + if not admin: + return APIResponse.error('管理员不存在', 404) + + # 验证旧密码 + if admin.password != data['old_password']: + return APIResponse.error(message='旧密码错误') + + # 更新邮箱(如果 user 不为空) + if 'user' in data and data['user']: + admin.email = data['user'] + + # 更新密码(如果 new_password 和 confirm_password 不为空且一致) + if 'new_password' in data and 'confirm_password' in data: + if data['new_password'] and data['confirm_password']: + if data['new_password'] != data['confirm_password']: + return APIResponse.error('新密码和确认密码不一致', 400) + admin.password = data['new_password'] # 明文存储 + + # 保存到数据库 + db.session.commit() + + return APIResponse.success(message='修改成功') + + except Exception as e: + current_app.logger.error(f"修改失败:{str(e)}") + return APIResponse.error('服务器内部错误', 500) + + + + diff --git a/app/resources/admin/customer.py b/app/resources/admin/customer.py new file mode 100644 index 0000000000000000000000000000000000000000..793bbf6bdf79481bb17f3516262e113a679e0d59 --- /dev/null +++ b/app/resources/admin/customer.py @@ -0,0 +1,143 @@ +# -- coding: utf-8 --** +# resources/admin/customer.py +from decimal import Decimal + +from flask import request +from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required, get_jwt_identity + +from app import db +from app.models import Customer +from app.utils.auth_tools import hash_password +from app.utils.response import APIResponse + + +# 获取用户列表 +class AdminCustomerListResource(Resource): + @jwt_required() + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('page', type=int, required=False, location='args') # 可选,默认值为 1 + parser.add_argument('limit', type=int, required=False, location='args') # 可选,默认值为 10 + parser.add_argument('keyword', type=str, required=False, location='args') # 可选,无默认值 + args = parser.parse_args() + query = Customer.query + if args['keyword']: + query = query.filter(Customer.email.ilike(f"%{args['keyword']}%")) + + pagination = query.paginate(page=args['page'], per_page=args['limit'], error_out=False) + customers = [c.to_dict() for c in pagination.items] + print(customers) + return APIResponse.success({ + 'data': customers, + 'total': pagination.total + }) + + +# 更新用户状态 +class CustomerStatusResource(Resource): + @jwt_required() + def post(self, id): + """ + 更改用户状态 + """ + # 解析请求体中的状态参数 + parser = reqparse.RequestParser() + parser.add_argument('status', type=str, required=True, choices=('enabled', 'disabled'), + help="状态必须是 'enabled' 或 'disabled'") + args = parser.parse_args() + + # 查询用户 + customer = Customer.query.get(id) + if not customer: + return APIResponse.error(message="用户不存在", code=404) + + # 更新用户状态 + customer.status = args['status'] + db.session.commit() # 假设 db 是你的 SQLAlchemy 实例 + # 更新用户状态 + customer.status = args['status'] + print(f"更新前的状态: {customer.status}") # 调试 + db.session.commit() + print(f"更新后的状态: {customer.status}") # 调试 + + # 返回更新后的用户信息 + return APIResponse.success(data=customer.to_dict()) + + +# 创建新用户 +class AdminCreateCustomerResource(Resource): + @jwt_required() + def put(self): + """创建新用户[^2]""" + data = request.json + required_fields = ['email', 'password'] # 'name', + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数!', 400) + + if Customer.query.filter_by(email=data['email']).first(): + return APIResponse.error('邮箱已存在', 400) + + customer = Customer( + # name=data['name'], + email=data['email'], + password=hash_password(data['password']), + level=data.get('level', 'common') + ) + db.session.add(customer) + db.session.commit() + return APIResponse.success({ + 'customer_id': customer.id, + 'message': '用户创建成功' + }) + + +# 获取用户信息 +class AdminCustomerDetailResource(Resource): + @jwt_required() + def get(self, id): + """获取用户详细信息[^3]""" + customer = Customer.query.get_or_404(id) + return APIResponse.success({ + 'id': customer.id, + 'name': customer.name, + 'email': customer.email, + 'status': 'active' if customer.deleted_flag == 'N' else 'deleted', + 'level': customer.level, + 'created_at': customer.created_at.isoformat(), + 'storage': customer.storage + }) + + +# 编辑用户信息 +class AdminUpdateCustomerResource(Resource): + @jwt_required() + def post(self, id): + """编辑用户信息[^4]""" + customer = Customer.query.get_or_404(id) + data = request.json + + if 'email' in data and Customer.query.filter(Customer.email == data['email'], + Customer.id != id).first(): + return APIResponse.error('邮箱已被使用', 400) + + if 'name' in data: + customer.name = data['name'] + if 'email' in data: + customer.email = data['email'] + if 'level' in data: + customer.level = data['level'] + + db.session.commit() + return APIResponse.success(message='用户信息更新成功') + + +# 删除用户 +class AdminDeleteCustomerResource(Resource): + @jwt_required() + def delete(self, id): + """删除用户[^5]""" + customer = Customer.query.get_or_404(id) + customer.deleted_flag = 'Y' + db.session.commit() + return APIResponse.success(message='用户删除成功') diff --git a/app/resources/admin/image.py b/app/resources/admin/image.py new file mode 100644 index 0000000000000000000000000000000000000000..a6af86d9a4b004823c947e10ce57a13f34fc2f53 --- /dev/null +++ b/app/resources/admin/image.py @@ -0,0 +1,43 @@ +# resources/admin/image.py +from flask import current_app +from flask_restful import Resource +from PIL import Image, ImageDraw, ImageFont +from app.utils.response import APIResponse +import os + + +class AdminImageResource(Resource): + def get(self): + """图片处理接口[^1]""" + try: + # 读取原始图片 + input_path = os.path.join(current_app.static_folder, 'img/rsic.jpeg') + img = Image.open(input_path) + + # 获取图片尺寸 + width, height = img.size + + # 创建绘图对象 + draw = ImageDraw.Draw(img) + + # 设置字体 + try: + font = ImageFont.truetype('arial.ttf', 20) + except IOError: + font = ImageFont.load_default() + + # 添加文字 + text = 'The quick brown fox' + text_width, text_height = draw.textsize(text, font=font) + x = width - text_width - 20 + y = height - text_height - 20 + draw.text((x, y), text, font=font, fill=(0, 0, 0)) + + # 保存处理后的图片 + output_path = os.path.join(current_app.static_folder, 'img/rsic2.png') + img.save(output_path) + + return APIResponse.success() + except Exception as e: + current_app.logger.error(f'图片处理失败: {str(e)}') + return APIResponse.error('图片处理失败', 500) diff --git a/app/resources/admin/settings.py b/app/resources/admin/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..9075a2e341a50a9384d4a2c81a0a9b7b024c2e58 --- /dev/null +++ b/app/resources/admin/settings.py @@ -0,0 +1,116 @@ +# resources/admin/setting.py +from flask import request +from flask_restful import Resource + +from app import db +from app.models import Setting +from app.utils.response import APIResponse +from app.utils.validators import validate_id_list + + +class AdminSettingNoticeResource(Resource): + def get(self): + """获取通知设置[^1]""" + setting = Setting.query.filter_by(alias='notice_setting').first() + if not setting: + return APIResponse.success(data={'users': []}) + return APIResponse.success(data={'users': eval(setting.value)}) + + def post(self): + """更新通知设置[^1]""" + data = request.json + users = validate_id_list(data.get('users')) + + setting = Setting.query.filter_by(alias='notice_setting').first() + if not setting: + setting = Setting(alias='notice_setting') + + setting.value = str(users) + setting.serialized = True + db.session.add(setting) + db.session.commit() + return APIResponse.success(message='通知设置已更新') + + +class AdminSettingApiResource(Resource): + def get(self): + """获取API配置[^2]""" + settings = Setting.query.filter(Setting.group == 'api_setting').all() + data = { + 'api_url': settings[0].value, + 'api_key': settings[1].value, + 'models': settings[2].value, + 'default_model': settings[3].value, + 'default_backup': settings[4].value + } + return APIResponse.success(data=data) + + def post(self): + """更新API配置[^2]""" + data = request.json + required_fields = ['api_url', 'api_key', 'models', 'default_model', 'default_backup'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + for alias, value in data.items(): + setting = Setting.query.filter_by(alias=alias).first() + if not setting: + setting = Setting(alias=alias, group='api_setting') + setting.value = value + db.session.add(setting) + db.session.commit() + return APIResponse.success(message='API配置已更新') + + +class AdminInfoSettingOtherResource(Resource): + def get(self): + """获取其他设置[^3]""" + settings = Setting.query.filter(Setting.group == 'other_setting').all() + data = { + 'prompt': settings[0].value, + 'threads': int(settings[1].value), + 'email_limit': settings[2].value + } + return APIResponse.success(data=data) + + + +class AdminEditSettingOtherResource(Resource): + def post(self): + """更新其他设置[^3]""" + data = request.json + required_fields = ['prompt', 'threads'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + for alias, value in data.items(): + setting = Setting.query.filter_by(alias=alias).first() + if not setting: + setting = Setting(alias=alias, group='other_setting') + setting.value = value + db.session.add(setting) + db.session.commit() + return APIResponse.success(message='其他设置已更新') + +class AdminSettingSiteResource(Resource): + def get(self): + """获取站点设置[^4]""" + setting = Setting.query.filter_by(alias='version').first() + if not setting: + return APIResponse.success(data={'version': 'community'}) + return APIResponse.success(data={'version': setting.value}) + + def post(self): + """更新站点版本[^4]""" + version = request.json.get('version') + if not version or version not in ['business', 'community']: + return APIResponse.error('版本号无效', 400) + + setting = Setting.query.filter_by(alias='version').first() + if not setting: + setting = Setting(alias='version', group='site_setting') + setting.value = version + db.session.add(setting) + db.session.commit() + return APIResponse.success(message='站点版本已更新') + diff --git a/app/resources/admin/translate.py b/app/resources/admin/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ca11e80c5bbaa81765fc36916707a93765be54 --- /dev/null +++ b/app/resources/admin/translate.py @@ -0,0 +1,241 @@ +# resources/admin/to_translate.py +import os +import zipfile +from datetime import datetime +from io import BytesIO + +from flask import request, make_response, send_file +from flask_jwt_extended import jwt_required, get_jwt_identity +from flask_restful import Resource, reqparse +from app import db +from app.models import Customer +from app.models.translate import Translate +from app.utils.response import APIResponse +from app.utils.validators import ( + validate_id_list +) + + +# 获取翻译记录列表 +class AdminTranslateListResource(Resource): + @jwt_required() + def get(self): + """获取翻译记录列表""" + # 获取查询参数 + parser = reqparse.RequestParser() + parser.add_argument('page', type=int, default=1, location='args') # 页码,默认为 1 + parser.add_argument('limit', type=int, default=100, location='args') # 每页数量,默认为 100 + parser.add_argument('status', type=str, location='args') # 状态,可选 + parser.add_argument('keyword', type=str, location='args') # 搜索关键字,可选 + args = parser.parse_args() + + # 构建查询条件 + query = Translate.query.filter_by( + deleted_flag='N' + ) + + # 检查状态过滤条件 + if args['status']: + valid_statuses = {'none', 'process', 'done', 'failed'} + if args['status'] not in valid_statuses: + return APIResponse.error(f"Invalid status value: {args['status']}"), 400 + query = query.filter_by(status=args['status']) + # 检查关键字过滤条件 + if args['keyword']: + # 模糊匹配 origin_filename 或 customer_email + query = query.join(Customer, Translate.customer_id == Customer.id).filter( + (Translate.origin_filename.ilike(f"%{args['keyword']}%")) | + (Customer.email.ilike(f"%{args['keyword']}%")) + ) + # 执行分页查询 + pagination = query.paginate(page=args['page'], per_page=args['limit'], error_out=False) + + # 处理每条记录 + data = [] + for t in pagination.items: + # 计算花费时间(基于 start_at 和 end_at) + if t.start_at and t.end_at: + spend_time = t.end_at - t.start_at + spend_time_minutes = int(spend_time.total_seconds() // 60) + spend_time_seconds = int(spend_time.total_seconds() % 60) + spend_time_str = f"{spend_time_minutes}分{spend_time_seconds}秒" + else: + spend_time_str = "--" + + # 获取用户邮箱(通过 Customer 模型关联查询) + customer = Customer.query.get(t.customer_id) + customer_email = customer.email if customer else "--" + customer_no = customer.customer_no if customer.customer_no else t.customer_id + # 格式化时间字段 + start_at_str = t.start_at.strftime('%Y-%m-%d %H:%M:%S') if t.start_at else "--" + end_at_str = t.end_at.strftime('%Y-%m-%d %H:%M:%S') if t.end_at else "--" + + # 构建返回数据 + data.append({ + 'id': t.id, + 'customer_no': customer_no, + 'customer_id': t.customer_id, # 所属用户 ID + 'customer_email': customer_email, # 用户邮箱 + 'origin_filename': t.origin_filename, + 'status': t.status, + 'process': float(t.process) if t.process is not None else None, # 转换为 float + 'start_at': start_at_str, # 开始时间 + 'end_at': end_at_str, # 完成时间 + 'spend_time': spend_time_str, # 完成用时 + 'lang': t.lang, + 'target_filepath': t.target_filepath + }) + + # 返回响应数据 + return APIResponse.success({ + 'data': data, + 'total': pagination.total, + 'current_page': pagination.page + }) + + +# 批量下载多个翻译文件 +class AdminTranslateDownloadBatchResource(Resource): + @jwt_required() + def post(self): + """批量下载多个翻译结果文件(管理员接口)""" + try: + # 解析请求体中的 ids 参数 + data = request.get_json() + if not data or 'ids' not in data: + return {"message": "缺少 ids 参数"}, 400 + + ids = data['ids'] + if not isinstance(ids, list): + return {"message": "ids 必须是数组"}, 400 + + # 查询指定的翻译记录 + records = Translate.query.filter( + Translate.id.in_(ids), # 过滤指定 ID + Translate.deleted_flag == 'N' # 只下载未删除的记录 + ).all() + + # 生成内存 ZIP 文件 + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, 'w') as zip_file: + for record in records: + if record.target_filepath and os.path.exists(record.target_filepath): + # 将文件添加到 ZIP 中 + zip_file.write( + record.target_filepath, + os.path.basename(record.target_filepath) + ) + + # 重置缓冲区指针 + zip_buffer.seek(0) + + # 返回 ZIP 文件 + return send_file( + zip_buffer, + mimetype='application/zip', + as_attachment=True, + download_name=f"translations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" + ) + except Exception as e: + return {"message": f"服务器错误: {str(e)}"}, 500 + + +# 下载单个翻译文件 +class AdminTranslateDownloadResource(Resource): + # @jwt_required() + def get(self, id): + """通过 ID 下载单个翻译结果文件[^5]""" + # 查询翻译记录 + translate = Translate.query.filter_by( + id=id, + # customer_id=get_jwt_identity() + ).first_or_404() + + # 确保文件存在 + if not translate.target_filepath or not os.path.exists(translate.target_filepath): + return APIResponse.error('文件不存在', 404) + + # 返回文件 + response = make_response(send_file( + translate.target_filepath, + as_attachment=True, + download_name=os.path.basename(translate.target_filepath) + )) + + # 禁用缓存 + response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + + return response + + +# 删除单个翻译记录 +class AdminTranslateDeteleResource(Resource): + @jwt_required() + def delete(self, id): + """删除单个翻译记录[^2]""" + try: + record = Translate.query.get_or_404(id) + db.session.delete(record) + db.session.commit() + return APIResponse.success(message='记录删除成功') + except Exception as e: + db.session.rollback() + return APIResponse.error('删除失败', 500) + + +class AdminTranslateBatchDeleteResource(Resource): + def post(self): + """批量删除翻译记录[^3]""" + try: + ids = validate_id_list(request.json.get('ids')) + if len(ids) > 100: + return APIResponse.error('单次最多删除100条记录', 400) + + Translate.query.filter(Translate.id.in_(ids)).delete() + db.session.commit() + return APIResponse.success(message=f'成功删除{len(ids)}条记录') + except APIResponse as e: + return e + except Exception as e: + db.session.rollback() + return APIResponse.error('批量删除失败', 500) + + +class AdminTranslateRestartResource(Resource): + def post(self, id): + """重启翻译任务[^4]""" + try: + record = Translate.query.get_or_404(id) + if record.status not in ['failed', 'done']: + return APIResponse.error('当前状态无法重启', 400) + + record.status = 'none' + record.start_at = None + record.end_at = None + record.failed_reason = None + db.session.commit() + return APIResponse.success(message='任务已重启') + except Exception as e: + db.session.rollback() + return APIResponse.error('重启失败', 500) + + +class AdminTranslateStatisticsResource(Resource): + def get(self): + """获取翻译统计信息[^5]""" + try: + total = Translate.query.count() + done_count = Translate.query.filter_by(status='done').count() + processing_count = Translate.query.filter_by(status='process').count() + failed_count = Translate.query.filter_by(status='failed').count() + + return APIResponse.success({ + 'total': total, + 'done_count': done_count, + 'processing_count': processing_count, + 'failed_count': failed_count + }) + except Exception as e: + return APIResponse.error('获取统计信息失败', 500) diff --git a/app/resources/admin/users.py b/app/resources/admin/users.py new file mode 100644 index 0000000000000000000000000000000000000000..b3477ce6ce153aa821661502671e6c5a600852ee --- /dev/null +++ b/app/resources/admin/users.py @@ -0,0 +1,107 @@ +# resources/admin/user.py +from flask import request +from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required + +from app import db +from app.models import User +from app.utils.auth_tools import hash_password +from app.utils.response import APIResponse + +class AdminUserListResource(Resource): + @jwt_required() + def get(self): + """获取用户列表[^1]""" + parser = reqparse.RequestParser() + parser.add_argument('page', type=int, default=1) + parser.add_argument('limit', type=int, default=20) + parser.add_argument('search', type=str) + args = parser.parse_args() + + query = User.query + if args['search']: + query = query.filter(User.email.ilike(f"%{args['search']}%")) + + pagination = query.paginate(page=args['page'], per_page=args['limit'], error_out=False) + users = [{ + 'id': u.id, + 'name': u.name, + 'email': u.email, + 'status': 'active' if u.deleted_flag == 'N' else 'deleted' + } for u in pagination.items] + + return APIResponse.success({ + 'data': users, + 'total': pagination.total + }) + + +# 创建新用户 +class AdminCreateUserResource(Resource): + @jwt_required() + def put(self): + """创建新用户[^2]""" + data = request.json + required_fields = ['name', 'email', 'password'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + if User.query.filter_by(email=data['email']).first(): + return APIResponse.error('邮箱已存在', 400) + + user = User( + name=data['name'], + email=data['email'], + password=hash_password(data['password']) + ) + db.session.add(user) + db.session.commit() + return APIResponse.success({ + 'user_id': user.id, + 'message': '用户创建成功' + }) + +# 获取用户详细信息 +class AdminUserDetailResource(Resource): + @jwt_required() + def get(self, id): + """获取用户详细信息[^3]""" + user = User.query.get_or_404(id) + return APIResponse.success({ + 'id': user.id, + 'name': user.name, + 'email': user.email, + 'status': 'active' if user.deleted_flag == 'N' else 'deleted', + 'created_at': user.created_at.isoformat() + }) + +# 编辑用户信息 +class AdminUpdateUserResource(Resource): + @jwt_required() + def post(self, id): + """编辑用户信息[^4]""" + user = User.query.get_or_404(id) + data = request.json + + if 'email' in data and User.query.filter(User.email == data['email'], + User.id != id).first(): + return APIResponse.error('邮箱已被使用', 400) + + if 'name' in data: + user.name = data['name'] + if 'email' in data: + user.email = data['email'] + + db.session.commit() + return APIResponse.success(message='用户信息更新成功') + +# 删除用户 +class AdminDeleteUserResource(Resource): + @jwt_required() + def delete(self, id): + """删除用户[^5]""" + user = User.query.get_or_404(id) + user.deleted_flag = 'Y' + db.session.commit() + return APIResponse.success(message='用户删除成功') + diff --git a/app/resources/api/AccountResource.py b/app/resources/api/AccountResource.py new file mode 100644 index 0000000000000000000000000000000000000000..7805078b23a76c1a8b4668454dd60e008cac95df --- /dev/null +++ b/app/resources/api/AccountResource.py @@ -0,0 +1,143 @@ + +from flask import request, current_app +from flask_restful import Resource +from flask_jwt_extended import jwt_required, get_jwt_identity +from datetime import datetime, timedelta +from app import db +from app.models import Customer, SendCode +from app.utils.security import hash_password, verify_password +from app.utils.response import APIResponse +from app.utils.mail_service import EmailService +from app.utils.validators import ( + validate_verification_code, + validate_password_confirmation, + validate_password_complexity +) +import random + + +class ChangePasswordResource(Resource): + @jwt_required() + def post(self): + """修改密码(旧密码验证)[^1]""" + user_id = get_jwt_identity() + data = request.json + + # 参数校验 + required_fields = ['oldpwd', 'newpwd', 'newpwd_confirmation'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + # 密码一致性验证 + is_valid, msg = validate_password_confirmation({ + 'password': data['newpwd'], + 'password_confirmation': data['newpwd_confirmation'] + }) + if not is_valid: + return APIResponse.error(msg, 400) + + # 密码复杂度验证 + is_valid, msg = validate_password_complexity(data['newpwd']) + if not is_valid: + return APIResponse.error(msg, 422) + + customer = Customer.query.get(user_id) + if not verify_password(customer.password, data['oldpwd']): + return APIResponse.error('旧密码不正确', 401) + + customer.password = hash_password(data['newpwd']) + customer.updated_at = datetime.utcnow() + db.session.commit() + return APIResponse.success(message='密码修改成功') + + +class SendChangeCodeResource(Resource): + @jwt_required() + def post(self): + """发送修改密码验证码[^2]""" + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + + code = ''.join(random.choices('0123456789', k=6)) + send_code = SendCode( + send_type=3, # 密码修改验证码类型[^6] + send_to=customer.email, + code=code, + created_at=datetime.utcnow() + ) + db.session.add(send_code) + try: + EmailService.send_verification_code(customer.email, code) + db.session.commit() + return APIResponse.success(message='验证码已发送') + except Exception as e: + db.session.rollback() + return APIResponse.error('邮件发送失败', 500) + + +class EmailChangePasswordResource(Resource): + @jwt_required() + def post(self): + """通过邮箱验证码修改密码[^3]""" + user_id = get_jwt_identity() + data = request.json + + # 参数校验 + required_fields = ['code', 'newpwd', 'newpwd_confirmation'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + # 密码一致性验证 + is_valid, msg = validate_password_confirmation({ + 'password': data['newpwd'], + 'password_confirmation': data['newpwd_confirmation'] + }) + if not is_valid: + return APIResponse.error(msg, 400) + + # 验证码有效性验证 + customer = Customer.query.get(user_id) + is_valid, msg = validate_verification_code( + customer.email, data['code'], 3 + ) + if not is_valid: + return APIResponse.error(msg, 400) + + # 更新密码 + customer.password = hash_password(data['newpwd']) + customer.updated_at = datetime.utcnow() + db.session.commit() + return APIResponse.success(message='密码修改成功') + + +class StorageInfoResource(Resource): + @jwt_required() + def get(self): + """获取存储空间信息[^2]""" + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + + total = current_app.config['MAX_USER_STORAGE'] / (1024 * 1024) # 转换为MB + used = customer.storage / (1024 * 1024) # 转换为MB + percentage = (used / total) * 100 if total > 0 else 0 + + return APIResponse.success({ + 'storage': f"{total:.2f}", + 'used': f"{used:.2f}", + 'percentage': f"{percentage:.1f}" + }) + + +class UserInfoResource(Resource): + @jwt_required() + def get(self): + """获取用户基本信息[^5]""" + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + + return APIResponse.success({ + 'email': customer.email, + 'level': customer.level, + 'created_at': customer.created_at.isoformat(), + 'storage': customer.storage + }) diff --git a/app/resources/api/AuthResource.py b/app/resources/api/AuthResource.py new file mode 100644 index 0000000000000000000000000000000000000000..f554f10a53c1b2b159ea1bd34aa4c576a51d031d --- /dev/null +++ b/app/resources/api/AuthResource.py @@ -0,0 +1,142 @@ +# resources/auth.py +from flask import request +from flask_restful import Resource +from flask_jwt_extended import create_access_token +from datetime import datetime, timedelta + +from app import db +from app.models import Customer, SendCode +from app.utils.security import hash_password, verify_password +from app.utils.response import APIResponse +from app.utils.mail_service import EmailService +import random + +from app.utils.validators import ( + validate_verification_code, + validate_password_confirmation +) + + + + +class SendRegisterCodeResource(Resource): + def post(self): + """发送注册验证码接口[^1]""" + email = request.form.get('email') + if Customer.query.filter_by(email=email).first(): + return APIResponse.error('邮箱已存在', 400) + + code = ''.join(random.choices('0123456789', k=6)) + send_code = SendCode( + send_type=1, + send_to=email, + code=code, + created_at=datetime.utcnow() + ) + db.session.add(send_code) + try: + EmailService.send_verification_code(email, code) + db.session.commit() + return APIResponse.success() + except Exception as e: + db.session.rollback() + return APIResponse.error('邮件发送失败', 500) + + +class UserRegisterResource(Resource): + def post(self): + """用户注册接口[^2]""" + data = request.form + + required_fields = ['email', 'password', 'code'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + # 验证码有效性验证 + is_valid, msg = validate_verification_code( + data['email'], data['code'], 1 + ) + if not is_valid: + return APIResponse.error(msg, 400) + + customer = Customer( + email=data['email'], + password=hash_password(data['password']), + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + db.session.add(customer) + db.session.commit() + + # 确保identity是字符串 + # access_token = create_access_token(identity=str(customer.id)) + return APIResponse.success(message='注册成功!',data={ + # 'token': access_token, + 'email': data['email'] + }) + + +class UserLoginResource(Resource): + def post(self): + """用户登录接口[^3]""" + data = request.form + customer = Customer.query.filter_by(email=data['email']).first() + + if not customer or not verify_password(customer.password, data['password']): + return APIResponse.error('账号或密码错误') + # 确保identity是字符串 + access_token = create_access_token(identity=str(customer.id)) + return APIResponse.success({ + 'token': access_token, + 'email': data['email'], + 'level': customer.level + }) + + +class SendResetCodeResource(Resource): + def post(self): + """发送密码重置验证码接口[^4]""" + email = request.form.get('email') + if not Customer.query.filter_by(email=email).first(): + return APIResponse.not_found('用户不存在') + + code = ''.join(random.choices('0123456789', k=6)) + send_code = SendCode( + send_type=2, + send_to=email, + code=code, + created_at=datetime.utcnow() + ) + db.session.add(send_code) + try: + EmailService.send_verification_code(email, code) + db.session.commit() + return APIResponse.success() + except Exception as e: + db.session.rollback() + return APIResponse.error('邮件发送失败', 500) + + +class ResetPasswordResource(Resource): + def post(self): + """重置密码接口[^5]""" + data = request.form + + # 密码一致性验证 + is_valid, msg = validate_password_confirmation(data) + if not is_valid: + return APIResponse.error(msg, 400) + + # 验证码有效性验证 + is_valid, msg = validate_verification_code( + data['email'], data['code'], 2 + ) + if not is_valid: + return APIResponse.error(msg, 400) + + customer = Customer.query.filter_by(email=data['email']).first() + customer.password = hash_password(data['password']) + customer.updated_at = datetime.utcnow() + db.session.commit() + return APIResponse.success() + diff --git a/app/resources/api/__init__.py b/app/resources/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/resources/api/common.py b/app/resources/api/common.py new file mode 100644 index 0000000000000000000000000000000000000000..29f9fb6463abb26b9a194c4fbbc5d27e8d7a72df --- /dev/null +++ b/app/resources/api/common.py @@ -0,0 +1,20 @@ +# app/resources/api/common.py +from flask_restful import Resource + +from app.utils.response import APIResponse +from app.models.setting import Setting + + +class SystemConfigResource(Resource): + def get(self): + """获取系统版本配置 [^6]""" + try: + version = Setting.get_version() + return APIResponse.success({"version": version}) + + except Exception as e: + return APIResponse.error( + message="服务器内部错误", + code=500, + errors=str(e) + ) diff --git a/app/resources/api/comparison.py b/app/resources/api/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..e3edfc8d5750415e2f379768da9383037bd911d0 --- /dev/null +++ b/app/resources/api/comparison.py @@ -0,0 +1,449 @@ +# resources/comparison.py +import os +import zipfile +from datetime import datetime +from io import BytesIO + +import pandas as pd +from flask import request, current_app, send_file +from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required, get_jwt_identity + +from app import db +from app.models import Customer +from app.models.comparison import Comparison, ComparisonFav +from app.utils.response import APIResponse +from sqlalchemy import func +from datetime import datetime + + +class MyComparisonListResource(Resource): + @jwt_required() + def get(self): + """获取我的术语表列表[^1]""" + # 直接查询所有数据(不解析查询参数) + query = Comparison.query.filter_by(customer_id=get_jwt_identity()) + comparisons = [self._format_comparison(comparison) for comparison in query.all()] + + # 返回结果 + return APIResponse.success({ + 'data': comparisons, + 'total': len(comparisons) + }) + + def _format_comparison(self, comparison): + """格式化术语表数据""" + # 解析 content 字段 + content_list = [] + if comparison.content: + for item in comparison.content.split('; '): + if ':' in item: + origin, target = item.split(':', 1) + content_list.append({ + 'origin': origin.strip(), + 'target': target.strip() + }) + + # 返回格式化后的数据 + return { + 'id': comparison.id, + 'title': comparison.title, + 'origin_lang': comparison.origin_lang, + 'target_lang': comparison.target_lang, + 'share_flag': comparison.share_flag, + 'added_count': comparison.added_count, + 'content': content_list, # 返回解析后的数组 + 'customer_id': comparison.customer_id, + 'created_at': comparison.created_at.strftime('%Y-%m-%d %H:%M') if comparison.created_at else None, # 格式化时间 + 'updated_at': comparison.updated_at.strftime('%Y-%m-%d %H:%M') if comparison.updated_at else None, # 格式化时间 + 'deleted_flag': comparison.deleted_flag + } + + + + +# 获取共享术语表列表 +class SharedComparisonListResource(Resource): + @jwt_required() + def get(self): + """获取共享术语表列表[^3]""" + # 从查询字符串中解析参数 + parser = reqparse.RequestParser() + parser.add_argument('page', type=int, default=1, location='args') # 分页参数 + parser.add_argument('limit', type=int, default=10, location='args') # 分页参数 + parser.add_argument('order', type=str, default='latest', location='args') # 排序参数 + args = parser.parse_args() + + # 查询共享的术语表,并关联 Customer 表获取用户 email + query = db.session.query( + Comparison, + func.count(ComparisonFav.id).label('fav_count'), # 动态计算收藏量 + Customer.email.label('customer_email') # 获取用户的 email + ).outerjoin( + ComparisonFav, Comparison.id == ComparisonFav.comparison_id + ).outerjoin( + Customer, Comparison.customer_id == Customer.id # 通过 customer_id 关联 Customer + ).filter( + Comparison.share_flag == 'Y', + Comparison.deleted_flag == 'N' + ).group_by( + Comparison.id + ) + + # 根据 order 参数排序 + if args['order'] == 'latest': + query = query.order_by(Comparison.created_at.desc()) # 按最新发表排序 + elif args['order'] == 'added': + query = query.order_by(Comparison.added_count.desc()) # 按添加量排序 + elif args['order'] == 'fav': + query = query.order_by(func.count(ComparisonFav.id).desc()) # 按收藏量排序 + + # 分页查询 + pagination = query.paginate(page=args['page'], per_page=args['limit'], error_out=False) + comparisons = [{ + 'id': comparison.id, + 'title': comparison.title, + 'origin_lang': comparison.origin_lang, + 'target_lang': comparison.target_lang, + 'content': self.parse_content(comparison.content), # 解析 content 字段为数组 + 'email': customer_email if customer_email else '匿名用户', # 返回用户 email + 'added_count': comparison.added_count, + 'created_at': comparison.created_at.strftime('%Y-%m-%d %H:%M'), # 格式化时间 + 'faved': self.check_faved(comparison.id), # 检查是否被当前用户收藏 + 'fav_count': fav_count # 添加收藏量 + } for comparison, fav_count, customer_email in pagination.items] + + # 返回结果 + return APIResponse.success({ + 'data': comparisons, + 'total': pagination.total, + 'current_page': pagination.page, + 'per_page': pagination.per_page + }) + + def parse_content(self, content_str): + """将 content 字符串解析为数组格式""" + content_list = [] + if content_str: + for item in content_str.split('; '): + if ':' in item: + origin, target = item.split(':', 1) # 分割为 origin 和 target + content_list.append({ + 'origin': origin.strip(), + 'target': target.strip() + }) + return content_list + + def check_faved(self, comparison_id): + """检查当前用户是否收藏了该术语表""" + # 假设当前用户的 ID 存储在 JWT 中 + user_id = get_jwt_identity() + if user_id: + fav = ComparisonFav.query.filter_by( + comparison_id=comparison_id, + customer_id=user_id + ).first() + return 1 if fav else 0 + return 0 + + + + +# 编辑术语列表 +class EditComparisonResource(Resource): + @jwt_required() + def post(self, id): + """编辑术语表[^3]""" + comparison = Comparison.query.filter_by( + id=id, + customer_id=get_jwt_identity() + ).first_or_404() + + data = request.form + if 'title' in data: + comparison.title = data['title'] + if 'content' in data: + comparison.content = data['content'] + if 'origin_lang' in data: + comparison.origin_lang = data['origin_lang'] + if 'target_lang' in data: + comparison.target_lang = data['target_lang'] + + db.session.commit() + return APIResponse.success(message='术语表更新成功') + +# 更新术语表共享状态 +class ShareComparisonResource(Resource): + @jwt_required() + def post(self, id): + """修改共享状态[^4]""" + comparison = Comparison.query.filter_by( + id=id, + customer_id=get_jwt_identity() + ).first_or_404() + + data = request.form + if 'share_flag' not in data or data['share_flag'] not in ['Y', 'N']: + return APIResponse.error('share_flag 参数无效', 400) + + comparison.share_flag = data['share_flag'] + db.session.commit() + return APIResponse.success(message='共享状态已更新') + + + +# 复制到我的术语库 +class CopyComparisonResource(Resource): + @jwt_required() + def post(self, id): + """复制到我的术语库[^5]""" + comparison = Comparison.query.filter_by( + id=id, + share_flag='Y' + ).first_or_404() + + new_comparison = Comparison( + title=f"{comparison.title} (副本)", + content=comparison.content, + origin_lang=comparison.origin_lang, + target_lang=comparison.target_lang, + customer_id=get_jwt_identity(), + share_flag='N' + ) + db.session.add(new_comparison) + db.session.commit() + return APIResponse.success({ + 'new_id': new_comparison.id + }) + +# 收藏/取消收藏 +class FavoriteComparisonResource(Resource): + @jwt_required() + def post(self, id): + """收藏/取消收藏[^6]""" + comparison = Comparison.query.filter_by(id=id).first_or_404() + customer_id = get_jwt_identity() + + favorite = ComparisonFav.query.filter_by( + comparison_id=id, + customer_id=customer_id + ).first() + + if favorite: + db.session.delete(favorite) + message = '已取消收藏' + else: + new_favorite = ComparisonFav( + comparison_id=id, + customer_id=customer_id + ) + db.session.add(new_favorite) + message = '已收藏' + + db.session.commit() + return APIResponse.success(message=message) + +# 创建新术语表 +class CreateComparisonResource(Resource): + @jwt_required() + def post(self): + """创建新术语表[^1]""" + data = request.form + required_fields = ['title', 'share_flag', 'origin_lang', 'target_lang'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + # 解析 content 参数 + content_list = [] + for key, value in data.items(): + if key.startswith('content[') and '][origin]' in key: + # 提取索引 + index = key.split('[')[1].split(']')[0] + origin = value + target = data.get(f'content[{index}][target]', '') + content_list.append(f"{origin}: {target}") + + # 将 content_list 转换为字符串 + content_str = '; '.join(content_list) + + # 获取当前时间 + current_time = datetime.utcnow() + + # 创建术语表 + comparison = Comparison( + title=data['title'], + origin_lang=data['origin_lang'], + target_lang=data['target_lang'], + content=content_str, # 插入转换后的 content 字符串 + customer_id=get_jwt_identity(), + share_flag=data.get('share_flag', 'N'), + created_at=current_time, # 显式赋值 + updated_at=current_time # 显式赋值 + ) + db.session.add(comparison) + db.session.commit() + return APIResponse.success({ + 'id': comparison.id + }) + + +# 删除术语表 +class DeleteComparisonResource(Resource): + @jwt_required() + def delete(self, id): + """删除术语表[^2]""" + comparison = Comparison.query.filter_by( + id=id, + customer_id=get_jwt_identity() + ).first_or_404() + + db.session.delete(comparison) + db.session.commit() + return APIResponse.success(message='删除成功') + + +# 下载模板文件 +class DownloadTemplateResource(Resource): + def get(self): + """下载模板文件[^3]""" + from flask import send_file + from io import BytesIO + import pandas as pd + + # 创建模板文件 + df = pd.DataFrame(columns=['源术语', '目标术语']) + output = BytesIO() + with pd.ExcelWriter(output, engine='xlsxwriter') as writer: + df.to_excel(writer, index=False) + output.seek(0) + + return send_file( + output, + mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + as_attachment=True, + download_name='术语表模板.xlsx' + ) + +# 导入术语表 +class ImportComparisonResource(Resource): + @jwt_required() + def post(self): + """ + 导入 Excel 文件 + """ + # 检查是否上传了文件 + if 'file' not in request.files: + return APIResponse.error('未选择文件', 400) + file = request.files['file'] + + try: + # 读取 Excel 文件 + import pandas as pd + df = pd.read_excel(file) + + # 检查文件是否包含所需的列 + if not {'源术语', '目标术语'}.issubset(df.columns): + return APIResponse.error('文件格式不符合模板要求', 406) + # 解析 Excel 文件内容 + content = ';'.join([f"{row['源术语']}: {row['目标术语']}" for _, row in df.iterrows()]) # 按 ': ' 分隔 + # 创建术语表 + comparison = Comparison( + title='导入的术语表', + origin_lang='未知', + target_lang='未知', + content=content, # 使用改进后的格式 + customer_id=get_jwt_identity(), + share_flag='N' + ) + db.session.add(comparison) + db.session.commit() + + # 返回成功响应 + return APIResponse.success({ + 'id': comparison.id + }) + except Exception as e: + # 捕获并返回错误信息 + return APIResponse.error(f"文件导入失败:{str(e)}", 500) + + + +# 导出单个术语表 +class ExportComparisonResource(Resource): + @jwt_required() + def get(self, id): + """ + 导出单个术语表 + """ + # 获取当前用户 ID + current_user_id = get_jwt_identity() + + # 查询术语表 + comparison = Comparison.query.get_or_404(id) + + # 检查术语表是否共享或属于当前用户 + if comparison.share_flag != 'Y' and comparison.user_id != current_user_id: + return {'message': '术语表未共享或无权限访问', 'code': 403}, 403 + + # 解析术语内容 + terms = [term.split(': ') for term in comparison.content.split(';')] # 按 ': ' 分割 + df = pd.DataFrame(terms, columns=['源术语', '目标术语']) + + # 创建 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine='xlsxwriter') as writer: + df.to_excel(writer, index=False) + output.seek(0) + + # 返回文件下载响应 + return send_file( + output, + mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + as_attachment=True, + download_name=f'{comparison.title}.xlsx' + ) + + + + +# 批量导出所有术语表 +class ExportAllComparisonsResource(Resource): + @jwt_required() + def get(self): + """ + 批量导出所有术语表 + """ + # 获取当前用户 ID + current_user_id = get_jwt_identity() + + # 查询当前用户的所有术语表 + comparisons = Comparison.query.filter_by(customer_id=current_user_id).all() + + # 创建 ZIP 文件 + memory_file = BytesIO() + with zipfile.ZipFile(memory_file, 'w') as zf: + for comparison in comparisons: + # 解析术语内容 + terms = [term.split(': ') for term in comparison.content.split(';')] # 按 ': ' 分割 + df = pd.DataFrame(terms, columns=['源术语', '目标术语']) + + # 创建 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine='xlsxwriter') as writer: + df.to_excel(writer, index=False) + output.seek(0) + + # 将 Excel 文件添加到 ZIP 中 + zf.writestr(f"{comparison.title}.xlsx", output.getvalue()) + + memory_file.seek(0) + + # 返回 ZIP 文件下载响应 + return send_file( + memory_file, + mimetype='application/zip', + as_attachment=True, + download_name=f'术语表_{datetime.now().strftime("%Y%m%d")}.zip' + ) + + diff --git a/app/resources/api/customer.py b/app/resources/api/customer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecfc3c88d0e1ac67ba5f581db9f68740f15abf9 --- /dev/null +++ b/app/resources/api/customer.py @@ -0,0 +1,32 @@ +# resources/customer.py +from app.utils.response import APIResponse +import uuid +from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required +from app.models.customer import Customer +class GuestIdResource(Resource): + def get(self): + """生成临时访客唯一标识[^1]""" + guest_id = str(uuid.uuid4()) + return APIResponse.success({ + 'guest_id': guest_id + }) + + + + + +class CustomerDetailResource(Resource): + @jwt_required() + def get(self, customer_id): + """获取客户详细信息[^2]""" + customer = Customer.query.get_or_404(customer_id) + return APIResponse.success({ + 'id': customer.id, + 'email': customer.email, + 'level': customer.level, + 'created_at': customer.created_at.isoformat(), + 'storage': customer.storage + }) + + diff --git a/app/resources/api/files.py b/app/resources/api/files.py new file mode 100644 index 0000000000000000000000000000000000000000..aaaecf2a787f9e060007566a0d1a8c2e7c8d574e --- /dev/null +++ b/app/resources/api/files.py @@ -0,0 +1,339 @@ +# resources/file.py +import hashlib +import uuid +from werkzeug.utils import secure_filename +import os +from app import db +from app.models.customer import Customer +from app.models.translate import Translate +from app.utils.response import APIResponse +from pathlib import Path +from flask_restful import Resource +from flask_jwt_extended import jwt_required, get_jwt_identity +from flask import request, current_app +from datetime import datetime + + +class FileUploadResource1(Resource): + @jwt_required() + def post(self): + """文件上传接口""" + # 验证文件存在性 + if 'file' not in request.files: + return APIResponse.error('未选择文件', 400) + file = request.files['file'] + + # 验证文件名有效性 + if file.filename == '': + return APIResponse.error('无效文件名', 400) + + # 验证文件类型 + if not self.allowed_file(file.filename): + return APIResponse.error( + f"仅支持以下格式:{', '.join(current_app.config['ALLOWED_EXTENSIONS'])}", 400) + + # 验证文件大小 + if not self.validate_file_size(file.stream): + return APIResponse.error( + f"文件大小超过{current_app.config['MAX_FILE_SIZE'] // (1024 * 1024)}MB限制", 400) + + # 获取用户存储信息 + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + file_size = request.content_length # 使用实际内容长度 + + # 验证存储空间 + if customer.storage + file_size > current_app.config['MAX_USER_STORAGE']: + return APIResponse.error('存储空间不足', 403) + + try: + # 生成存储路径 + save_dir = self.get_upload_dir() + filename = file.filename # 直接使用原始文件名 + save_path = os.path.join(save_dir, filename) + + # 检查路径是否安全 + if not self.is_safe_path(save_dir, save_path): + return APIResponse.error('文件名包含非法字符', 400) + + # 保存文件 + file.save(save_path) + # 更新用户存储空间 + customer.storage += file_size + db.session.commit() + # 生成 UUID + file_uuid = str(uuid.uuid4()) + # 计算文件的 MD5 + file_md5 = self.calculate_md5(save_path) + + # 创建翻译记录 + translate_record = Translate( + translate_no=f"TRANS{datetime.now().strftime('%Y%m%d%H%M%S')}", + uuid=file_uuid, + customer_id=user_id, + origin_filename=filename, + origin_filepath=os.path.abspath(save_path), # 使用绝对路径 + target_filepath='', # 目标文件路径暂为空 + status='none', # 初始状态为 none + origin_filesize=file_size, + md5=file_md5, + created_at=datetime.utcnow() + ) + db.session.add(translate_record) + db.session.commit() + + # 返回响应,包含文件名、UUID 和翻译记录 ID + return APIResponse.success({ + 'filename': filename, + 'uuid': file_uuid, + 'translate_id': translate_record.id, + 'save_path': os.path.abspath(save_path) # 返回绝对路径 + }) + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"文件上传失败:{str(e)}") + return APIResponse.error('文件上传失败', 500) + + @staticmethod + def allowed_file(filename): + # """验证文件类型是否允许"""# 暂不支持PDF 'pdf', + ALLOWED_EXTENSIONS = {'docx', 'xlsx', 'pptx', 'txt', 'md', 'csv', 'xls', 'doc'} + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + @staticmethod + def validate_file_size(file_stream): + """验证文件大小是否超过限制""" + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + file_stream.seek(0, os.SEEK_END) + file_size = file_stream.tell() + file_stream.seek(0) + return file_size <= MAX_FILE_SIZE + + @staticmethod + def get_upload_dir(): + """获取按日期分类的上传目录""" + # 获取上传根目录 + base_dir = Path(current_app.config['UPLOAD_BASE_DIR']) + upload_dir = base_dir / 'uploads' / datetime.now().strftime('%Y-%m-%d') + + # 如果目录不存在则创建 + if not upload_dir.exists(): + upload_dir.mkdir(parents=True, exist_ok=True) + + return str(upload_dir) + + @staticmethod + def calculate_md5(file_path): + """计算文件的 MD5 值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + @staticmethod + def is_safe_path(base_dir, file_path): + """检查文件路径是否安全,防止路径遍历攻击""" + base_dir = Path(base_dir).resolve() + file_path = Path(file_path).resolve() + return file_path.is_relative_to(base_dir) + + + +class FileUploadResource(Resource): + @jwt_required() + def post(self): + """文件上传接口""" + # 验证文件存在性 + if 'file' not in request.files: + return APIResponse.error('未选择文件', 400) + file = request.files['file'] + + # 验证文件名有效性 + if file.filename == '': + return APIResponse.error('无效文件名', 400) + + # 验证文件类型 + if not self.allowed_file(file.filename): + return APIResponse.error( + f"仅支持以下格式:{', '.join(current_app.config['ALLOWED_EXTENSIONS'])}", 400) + + # 验证文件大小 + if not self.validate_file_size(file.stream): + return APIResponse.error( + f"文件大小超过{current_app.config['MAX_FILE_SIZE'] // (1024 * 1024)}MB限制", 400) + + # 获取用户存储信息 + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + file_size = request.content_length # 使用实际内容长度 + + # 验证存储空间 + if customer.storage + file_size > current_app.config['MAX_USER_STORAGE']: + return APIResponse.error('存储空间不足', 403) + + try: + # 生成存储路径 + save_dir = Path(self.get_upload_dir()) + filename = file.filename # 直接使用原始文件名 + save_path = save_dir / filename + + # 检查路径是否安全 + if not self.is_safe_path(save_dir, save_path): + return APIResponse.error('文件名包含非法字符', 400) + + # 保存文件 + file.save(save_path) + # 更新用户存储空间 + customer.storage += file_size + db.session.commit() + # 生成 UUID + file_uuid = str(uuid.uuid4()) + # 计算文件的 MD5 + file_md5 = self.calculate_md5(save_path) + + # 创建翻译记录 + translate_record = Translate( + translate_no=f"TRANS{datetime.now().strftime('%Y%m%d%H%M%S')}", + uuid=file_uuid, + customer_id=user_id, + origin_filename=filename, + origin_filepath=str(save_path.resolve()), # 使用绝对路径 + target_filepath='', # 目标文件路径暂为空 + status='none', # 初始状态为 none + origin_filesize=file_size, + md5=file_md5, + created_at=datetime.utcnow() + ) + db.session.add(translate_record) + db.session.commit() + + # 返回响应,包含文件名、UUID 和翻译记录 ID + return APIResponse.success({ + 'filename': filename, + 'uuid': file_uuid, + 'translate_id': translate_record.id, + 'save_path': str(save_path.resolve()) # 返回绝对路径 + }) + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"文件上传失败:{str(e)}") + return APIResponse.error('文件上传失败', 500) + + @staticmethod + def allowed_file(filename): + """验证文件类型是否允许""" + ALLOWED_EXTENSIONS = {'docx', 'xlsx', 'pptx', 'txt', 'md', 'csv', 'xls', 'doc'} + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + @staticmethod + def validate_file_size(file_stream): + """验证文件大小是否超过限制""" + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + file_stream.seek(0, os.SEEK_END) + file_size = file_stream.tell() + file_stream.seek(0) + return file_size <= MAX_FILE_SIZE + + @staticmethod + def get_upload_dir(): + """获取按日期分类的上传目录""" + # 获取上传根目录 + base_dir = Path(current_app.config['UPLOAD_BASE_DIR']) + upload_dir = base_dir / 'uploads' / datetime.now().strftime('%Y-%m-%d') + + # 如果目录不存在则创建 + if not upload_dir.exists(): + upload_dir.mkdir(parents=True, exist_ok=True) + + return str(upload_dir) + + @staticmethod + def calculate_md5(file_path): + """计算文件的 MD5 值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + @staticmethod + def is_safe_path(base_dir, file_path): + """检查文件路径是否安全,防止路径遍历攻击""" + base_dir = Path(base_dir).resolve() + file_path = Path(file_path).resolve() + return file_path.is_relative_to(base_dir) + + + +class FileDeleteResource1(Resource): + @jwt_required() + def post(self): + """文件删除接口[^1]""" + data = request.form + if 'uuid' not in data: + return APIResponse.error('缺少必要参数', 400) + + try: + # 根据 UUID 查询翻译记录 + translate_record = Translate.query.filter_by(uuid=data['uuid']).first() + if not translate_record: + return APIResponse.error('文件记录不存在', 404) + + # 获取文件绝对路径 + file_path = translate_record.origin_filepath + + # 删除物理文件 + if os.path.exists(file_path): + os.remove(file_path) + else: + current_app.logger.warning(f"文件不存在:{file_path}") + + # 删除数据库记录 + db.session.delete(translate_record) + db.session.commit() + + return APIResponse.success(message='文件删除成功') + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"文件删除失败:{str(e)}") + return APIResponse.error('文件删除失败', 500) + + + +class FileDeleteResource(Resource): + @jwt_required() + def post(self): + """文件删除接口""" + data = request.form + if 'uuid' not in data: + return APIResponse.error('缺少必要参数', 400) + + try: + # 根据 UUID 查询翻译记录 + translate_record = Translate.query.filter_by(uuid=data['uuid']).first() + if not translate_record: + return APIResponse.error('文件记录不存在', 404) + + # 获取文件绝对路径 + file_path = Path(translate_record.origin_filepath) + + # 删除物理文件 + if file_path.exists(): + file_path.unlink() + else: + current_app.logger.warning(f"文件不存在:{file_path}") + + # 删除数据库记录 + db.session.delete(translate_record) + db.session.commit() + + return APIResponse.success(message='文件删除成功') + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"文件删除失败:{str(e)}") + return APIResponse.error('文件删除失败', 500) diff --git a/app/resources/api/prompt.py b/app/resources/api/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2c583aa54fb2478b091dcd1b14f88284f0d9a3 --- /dev/null +++ b/app/resources/api/prompt.py @@ -0,0 +1,245 @@ +# resources/prompt.py +from datetime import datetime, date + +from flask import request +from flask_restful import Resource, reqparse +from flask_jwt_extended import jwt_required, get_jwt_identity +from sqlalchemy import func +from app import db +from app.models import Customer +from app.models.prompt import Prompt, PromptFav +from app.utils.response import APIResponse + +# 获取提示语列表 +class MyPromptListResource(Resource): + @jwt_required() + def get(self): + """获取我的提示语列表[^1]""" + # 直接查询所有数据(不解析查询参数) + query = Prompt.query.filter_by(customer_id=get_jwt_identity(), deleted_flag='N') + prompts = [{ + 'id': p.id, + 'title': p.title, + 'content': p.content[:100] + '...' if len(p.content) > 100 else p.content, + 'share_flag': p.share_flag, + 'created_at': p.created_at.isoformat() if p.created_at else None + } for p in query.all()] + + # 返回结果 + return APIResponse.success({ + 'data': prompts, + 'total': len(prompts) + }) + + + +# 获取共享提示语列表 +class SharedPromptListResource(Resource): + def get(self): + """获取共享提示语列表[^4]""" + # 从查询字符串中解析参数 + parser = reqparse.RequestParser() + parser.add_argument('page', type=int, default=1, location='args') # 分页参数 + parser.add_argument('limit', type=int, default=10, location='args') # 分页参数 + parser.add_argument('porder', type=str, default='latest', location='args') # 排序参数 + args = parser.parse_args() + + # 查询共享的提示语 + query = db.session.query( + Prompt, # 获取完整的 Prompt 信息 + func.count(PromptFav.id).label('fav_count'), # 动态计算收藏量 + Customer.email.label('customer_email') # 获取用户的 email + ).outerjoin( + PromptFav, Prompt.id == PromptFav.prompt_id + ).outerjoin( + Customer, Prompt.customer_id == Customer.id # 通过 customer_id 关联 Customer + ).filter( + Prompt.share_flag == 'Y', + Prompt.deleted_flag == 'N' + ).group_by( + Prompt.id + ) + + # 根据 porder 参数排序 + if args['porder'] == 'latest': + query = query.order_by(Prompt.created_at.desc()) # 按最新发表排序 + elif args['porder'] == 'added': + query = query.order_by(Prompt.added_count.desc()) # 按添加量排序 + elif args['porder'] == 'fav': + query = query.order_by(func.count(PromptFav.id).desc()) # 按收藏量排序 + + # 分页查询 + pagination = query.paginate(page=args['page'], per_page=args['limit'], error_out=False) + prompts = [{ + 'id': prompt.id, + 'title': prompt.title, + 'content': prompt.content, # 返回完整的提示语内容 + 'email': customer_email if customer_email else '匿名用户', # 使用查询结果中的 email + 'share_flag': prompt.share_flag, + 'added_count': prompt.added_count, + 'created_at': prompt.created_at.strftime('%Y-%m-%d') if prompt.created_at else None, # 处理 None 值 + 'updated_at': prompt.updated_at.strftime('%Y-%m-%d') if prompt.updated_at else None, # 处理 None 值 + 'fav_count': fav_count + } for prompt, fav_count, customer_email in pagination.items] + + # 返回结果 + return APIResponse.success({ + 'data': prompts, + 'total': pagination.total + }) + + + + +# 修改提示语内容 +class EditPromptResource(Resource): + @jwt_required() + def post(self, id): + """修改提示语内容[^3]""" + prompt = Prompt.query.filter_by( + id=id, + customer_id=get_jwt_identity(), + deleted_flag='N' + ).first_or_404() + + data = request.form + if 'title' in data: + if len(data['title']) > 255: + return APIResponse.error('标题过长', 400) + prompt.title = data['title'] + + if 'content' in data: + if len(data['content']) > 5000: + return APIResponse.error('内容超过5000字符限制', 400) + prompt.content = data['content'] + + db.session.commit() + return APIResponse.success(message='提示语更新成功') + +# 更新共享状态 +class SharePromptResource(Resource): + @jwt_required() + def post(self, id): + """ + 修改共享状态[^4] + :param id: prompt 的 ID(路径参数) + """ + # 根据 id 和当前用户查询 prompt + prompt = Prompt.query.filter_by( + id=id, + customer_id=get_jwt_identity(), + deleted_flag='N' + ).first_or_404() + + # 从请求体中获取 share_flag + data = request.form # 或者 request.form 如果是表单数据 + if not data or 'share_flag' not in data or data['share_flag'] not in ['Y', 'N']: + return APIResponse.error('无效的共享状态参数', 400) + + # 更新共享状态 + prompt.share_flag = data['share_flag'] + db.session.commit() + + return APIResponse.success(message='共享状态已更新') + + +# 复制到我的提示语库 +class CopyPromptResource(Resource): + @jwt_required() + def post(self, id): + """复制到我的提示语库[^5]""" + original = Prompt.query.filter_by( + id=id, + share_flag='Y', + deleted_flag='N' + ).first_or_404() + + new_prompt = Prompt( + title=f"{original.title} (副本)", + content=original.content, + customer_id=get_jwt_identity(), + share_flag='N', + added_count=0 + ) + db.session.add(new_prompt) + db.session.commit() + return APIResponse.success({ + 'new_id': new_prompt.id, + 'message': '复制成功' + }) + +# 收藏/取消收藏 +class FavoritePromptResource(Resource): + @jwt_required() + def post(self, id): + """收藏/取消收藏[^6]""" + prompt = Prompt.query.get_or_404(id) + customer_id = get_jwt_identity() + + fav = PromptFav.query.filter_by( + prompt_id=id, + customer_id=customer_id + ).first() + + if fav: + db.session.delete(fav) + action = '取消收藏' + else: + new_fav = PromptFav( + prompt_id=id, + customer_id=customer_id + ) + db.session.add(new_fav) + action = '收藏' + + prompt.added_count = prompt.added_count + (1 if not fav else -1) + db.session.commit() + return APIResponse.success(message=f'{action}成功') + +# 创建新的提示语 + +class CreatePromptResource(Resource): + @jwt_required() + def post(self): + """创建新提示语[^7]""" + data = request.form + required_fields = ['title', 'content'] + if not all(field in data for field in required_fields): + return APIResponse.error('缺少必要参数', 400) + + if len(data['title']) > 255: + return APIResponse.error('标题过长', 400) + if len(data['content']) > 5000: + return APIResponse.error('内容超过5000字符限制', 400) + + # 创建时自动设置 created_at 为当前时间 + prompt = Prompt( + title=data['title'], + content=data['content'], + customer_id=get_jwt_identity(), + share_flag=data.get('share_flag', 'N'), + created_at=date.today() # 自动设置当前时间 + ) + db.session.add(prompt) + db.session.commit() + return APIResponse.success({ + 'id': prompt.id, + 'message': '创建成功' + }) + + +# 删除提示语 +class DeletePromptResource(Resource): + @jwt_required() + def delete(self, id): + """删除提示语[^8]""" + prompt = Prompt.query.filter_by( + id=id, + customer_id=get_jwt_identity() + ).first_or_404() + + prompt.deleted_flag = 'Y' + db.session.commit() + return APIResponse.success(message='删除成功') + + diff --git a/app/resources/api/setting.py b/app/resources/api/setting.py new file mode 100644 index 0000000000000000000000000000000000000000..16f4dfba8fb18ba3728761b7a52066def7086985 --- /dev/null +++ b/app/resources/api/setting.py @@ -0,0 +1,27 @@ +# resources/system.py +from flask_restful import Resource +from app.utils.response import APIResponse +from flask import current_app + +class SystemVersionResource(Resource): + def get(self): + """获取系统版本信息[^1]""" + return APIResponse.success({ + 'version': current_app.config['SYSTEM_VERSION'], + 'message': 'success' + }) + +class SystemSettingsResource(Resource): + def get(self): + """获取全量系统配置[^2]""" + return APIResponse.success({ + 'site_setting': { + 'version': current_app.config['SYSTEM_VERSION'], + 'site_name': current_app.config['SITE_NAME'] + }, + 'api_setting': { + 'api_url': current_app.config['API_URL'], + 'models': current_app.config['TRANSLATE_MODELS'] + }, + 'message': 'success' + }) diff --git a/app/resources/api/translate.py b/app/resources/api/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..392073eeabf9008266c9b218f07c55ae2208903d --- /dev/null +++ b/app/resources/api/translate.py @@ -0,0 +1,590 @@ +# resources/to_translate.py +import json +from pathlib import Path +from flask import request, send_file, current_app, make_response +from flask_restful import Resource +from flask_jwt_extended import jwt_required, get_jwt_identity +from datetime import datetime +from io import BytesIO +import zipfile +import os + +from app import db, Setting +from app.models import Customer +from app.models.translate import Translate +from app.resources.task.translate_service import TranslateEngine +from app.utils.response import APIResponse +from app.utils.check_utils import AIChecker + +# 定义翻译配置(硬编码示例) +TRANSLATE_SETTINGS = { + "models": ["gpt-3.5-turbo", "gpt-4"], + "default_model": "gpt-3.5-turbo", + "max_threads": 5, + "prompt_template": "请将以下内容翻译为{target_lang}" +} + + +class TranslateStartResource1(Resource): + @jwt_required() + def post(self): + """启动翻译任务(支持绝对路径和多参数)[^1]""" + data = request.form + required_fields = [ + 'server', 'model', 'lang', 'uuid', + 'prompt', 'threads', 'file_name' + ] + + # 参数校验 + if not all(field in data for field in required_fields): + return APIResponse.error("缺少必要参数", 400) + + # 验证OpenAI配置 + if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']): + return APIResponse.error("OpenAI服务需要API地址和密钥", 400) + + try: + # 获取用户信息 + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + + # 生成绝对路径(跨平台兼容) + def get_absolute_storage_path(filename): + # 获取项目根目录的父目录(假设storage目录与项目目录同级) + base_dir = Path(current_app.root_path).parent.absolute() + # 按日期创建子目录(如 storage/translate/2024-01-20) + date_str = datetime.now().strftime('%Y-%m-%d') + # 创建目标目录(如果不存在) + target_dir = base_dir / "storage" / "translate" / date_str + target_dir.mkdir(parents=True, exist_ok=True) + # 返回绝对路径(保持原文件名) + return str(target_dir / filename) + + origin_filename = data['file_name'] + + # 生成翻译结果绝对路径 + target_abs_path = get_absolute_storage_path(origin_filename) + + # 获取翻译类型(取最后一个type值) + translate_type = data.get('type[2]', 'trans_all_only_inherit') + + # 查询或创建翻译记录 + translate = Translate.query.filter_by(uuid=data['uuid']).first() + if not translate: + return APIResponse.error("未找到对应的翻译记录", 404) + + # 更新翻译记录 + translate.origin_filename = data['file_name'] + translate.target_filepath = target_abs_path # 存储翻译结果的绝对路径 + translate.lang = data['lang'] + translate.model = data['model'] + translate.backup_model = data['backup_model'] + translate.type = translate_type + translate.prompt = data['prompt'] + translate.threads = int(data['threads']) + translate.api_url = data.get('api_url', '') + translate.api_key = data.get('api_key', '') + translate.backup_model = data.get('backup_model', '') + translate.origin_lang = data.get('origin_lang', '') + # 获取 comparison_id 并转换为整数 + comparison_id = data.get('comparison_id', '0') # 默认值为 '0' + translate.comparison_id = int(comparison_id) if comparison_id else None + prompt_id = data.get('prompt_id', '0') + translate.prompt_id = int(prompt_id) if prompt_id else None + translate.doc2x_flag = data.get('doc2x_flag', 'N') + translate.doc2x_secret_key = data.get('doc2x_secret_key', '') + + # 保存到数据库 + db.session.commit() + # with current_app.app_context(): # 确保在应用上下文中运行 + # 启动翻译引擎,传入 current_app + TranslateEngine(translate.id).execute() + + return APIResponse.success({ + "task_id": translate.id, + "uuid": translate.uuid, + "target_path": target_abs_path + }) + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True) + return APIResponse.error("任务启动失败", 500) + + + +class TranslateStartResource(Resource): + @jwt_required() + def post(self): + """启动翻译任务(支持绝对路径和多参数)[^1]""" + data = request.form + required_fields = [ + 'server', 'model', 'lang', 'uuid', + 'prompt', 'threads', 'file_name' + ] + + # 参数校验 + if not all(field in data for field in required_fields): + return APIResponse.error("缺少必要参数", 400) + + # 验证OpenAI配置 + if data['server'] == 'openai' and not all(k in data for k in ['api_url', 'api_key']): + return APIResponse.error("OpenAI服务需要API地址和密钥", 400) + + try: + # 获取用户信息 + user_id = get_jwt_identity() + customer = Customer.query.get(user_id) + + # 生成绝对路径(跨平台兼容) + def get_absolute_storage_path(filename): + # 获取项目根目录的父目录(假设storage目录与项目目录同级) + base_dir = Path(current_app.root_path).parent.absolute() + # 按日期创建子目录(如 storage/translate/2024-01-20) + date_str = datetime.now().strftime('%Y-%m-%d') + # 创建目标目录(如果不存在) + target_dir = base_dir / "storage" / "translate" / date_str + target_dir.mkdir(parents=True, exist_ok=True) + # 返回绝对路径(保持原文件名) + return target_dir / filename + + origin_filename = data['file_name'] + + # 生成翻译结果绝对路径 + target_abs_path = get_absolute_storage_path(origin_filename) + + # 获取翻译类型(取最后一个type值) + translate_type = data.get('type[2]', 'trans_all_only_inherit') + + # 查询或创建翻译记录 + translate = Translate.query.filter_by(uuid=data['uuid']).first() + if not translate: + return APIResponse.error("未找到对应的翻译记录", 404) + + # 更新翻译记录 + translate.origin_filename = origin_filename + translate.target_filepath = str(target_abs_path) # 存储翻译结果的绝对路径 + translate.lang = data['lang'] + translate.model = data['model'] + translate.backup_model = data['backup_model'] + translate.type = translate_type + translate.prompt = data['prompt'] + translate.threads = int(data['threads']) + translate.api_url = data.get('api_url', '') + translate.api_key = data.get('api_key', '') + translate.backup_model = data.get('backup_model', '') + translate.origin_lang = data.get('origin_lang', '') + # 获取 comparison_id 并转换为整数 + comparison_id = data.get('comparison_id', '0') # 默认值为 '0' + translate.comparison_id = int(comparison_id) if comparison_id else None + prompt_id = data.get('prompt_id', '0') + translate.prompt_id = int(prompt_id) if prompt_id else None + translate.doc2x_flag = data.get('doc2x_flag', 'N') + translate.doc2x_secret_key = data.get('doc2x_secret_key', '') + + # 保存到数据库 + db.session.commit() + # 启动翻译引擎,传入 current_app + TranslateEngine(translate.id).execute() + + return APIResponse.success({ + "task_id": translate.id, + "uuid": translate.uuid, + "target_path": str(target_abs_path) + }) + + except Exception as e: + db.session.rollback() + current_app.logger.error(f"翻译任务启动失败: {str(e)}", exc_info=True) + return APIResponse.error("任务启动失败", 500) + + + +class TranslateListResource(Resource): + @jwt_required() + def get(self): + """获取翻译记录列表""" + # 获取查询参数 + page = request.args.get('page', '1') + limit = request.args.get('limit', '100') + status_filter = request.args.get('status') + # 将字符串参数转换为整数 + try: + page = int(page) + limit = int(limit) + except ValueError: + return APIResponse.error("Invalid page or limit value"), 400 + # 构建查询条件 + query = Translate.query.filter_by( + customer_id=get_jwt_identity(), + deleted_flag='N' + ) + + # 检查 status_filter 是否是合法值 + if status_filter: + valid_statuses = {'none', 'process', 'done', 'failed'} + if status_filter not in valid_statuses: + return APIResponse.error(f"Invalid status value: {status_filter}"), 400 + query = query.filter_by(status=status_filter) + + # 执行分页查询 + pagination = query.paginate(page=page, per_page=limit, error_out=False) + + # 处理每条记录 + data = [] + for t in pagination.items: + # 计算花费时间(基于 created_at 和 end_at) + if t.created_at and t.end_at: + spend_time = t.end_at - t.created_at + spend_time_minutes = int(spend_time.total_seconds() // 60) + spend_time_seconds = int(spend_time.total_seconds() % 60) + spend_time_str = f"{spend_time_minutes}分{spend_time_seconds}秒" + else: + spend_time_str = "--" + + # 获取状态中文描述 + status_name_map = { + 'none': '未开始', + 'process': '进行中', + 'done': '已完成', + 'failed': '失败' + } + status_name = status_name_map.get(t.status, '未知状态') + + # 获取文件类型 + file_type = self.get_file_type(t.origin_filename) + + # 格式化完成时间(精确到秒) + end_at_str = t.end_at.strftime('%Y-%m-%d %H:%M:%S') if t.end_at else "--" + + data.append({ + 'id': t.id, + 'file_type': file_type, + 'origin_filename': t.origin_filename, + 'status': t.status, + 'status_name': status_name, + 'process': float(t.process), # 将 Decimal 转换为 float + 'spend_time': spend_time_str, # 花费时间 + 'end_at': end_at_str, # 完成时间 + 'start_at': t.start_at.strftime('%Y-%m-%d %H:%M:%S') if t.start_at else "--", + # 开始时间 + 'lang': t.lang, + 'target_filepath': t.target_filepath + }) + + # 返回响应数据 + return APIResponse.success({ + 'data': data, + 'total': pagination.total, + 'current_page': pagination.page + }) + + @staticmethod + def get_file_type(filename): + """根据文件名获取文件类型""" + if not filename: + return "未知" + ext = filename.split('.')[-1].lower() + if ext in {'docx', 'doc'}: + return "Word" + elif ext in {'xlsx', 'xls'}: + return "Excel" + elif ext == 'pptx': + return "PPT" + elif ext == 'pdf': + return "PDF" + elif ext in {'txt', 'md'}: + return "文本" + else: + return "其他" + +# 获取翻译配置 +class TranslateSettingResource(Resource): + @jwt_required() + def get(self): + """获取翻译配置""" + try: + # 从数据库中获取翻译配置 + settings = self._load_settings_from_db() + return APIResponse.success(settings) + except Exception as e: + return APIResponse.error(f"获取配置失败: {str(e)}", 500) + + @staticmethod + def _load_settings_from_db(): + """ + 从数据库加载翻译配置 + """ + # 查询翻译相关的配置(api_setting 和 other_setting 分组) + settings = Setting.query.filter( + Setting.group.in_(['api_setting', 'other_setting']), + Setting.deleted_flag == 'N' + ).all() + + # 转换为配置字典 + config = {} + for setting in settings: + # 如果 serialized 为 True,则反序列化 value + value = json.loads(setting.value) if setting.serialized else setting.value + + # 根据 alias 存储配置 + if setting.alias == 'models': + config['models'] = value.split(',') if isinstance(value, str) else value + elif setting.alias == 'default_model': + config['default_model'] = value + elif setting.alias == 'default_backup': + config['default_backup'] = value + elif setting.alias == 'api_url': + config['api_url'] = value + elif setting.alias == 'api_key': + config['api_key'] = value + elif setting.alias == 'prompt': + config['prompt_template'] = value + elif setting.alias == 'threads': + config['max_threads'] = int(value) if value.isdigit() else 10 # 默认10线程 + + # 设置默认值(如果数据库中没有相关配置) + config.setdefault('models', ['gpt-3.5-turbo', 'gpt-4']) + config.setdefault('default_model', 'gpt-3.5-turbo') + config.setdefault('default_backup', 'gpt-3.5-turbo') + config.setdefault('api_url', '') + config.setdefault('api_key', '') + config.setdefault('prompt_template', '请将以下内容翻译为{target_lang}') + config.setdefault('max_threads', 10) + + return config + + +class TranslateProcessResource(Resource): + @jwt_required() + def post(self): + """查询翻译进度[^3]""" + uuid = request.form.get('uuid') + translate = Translate.query.filter_by( + uuid=uuid, + customer_id=get_jwt_identity() + ).first_or_404() + + return APIResponse.success({ + 'status': translate.status, + 'progress': float(translate.process), + 'download_url': translate.target_filepath if translate.status == 'done' else None + }) + + +class TranslateDeleteResource(Resource): + @jwt_required() + def delete(self, id): + """软删除翻译记录[^4]""" + # 查询翻译记录 + translate = Translate.query.filter_by( + id=id, + customer_id=get_jwt_identity() + ).first_or_404() + + # 更新 deleted_flag 为 'Y' + translate.deleted_flag = 'Y' + db.session.commit() + + return APIResponse.success(message='记录已标记为删除') + + + +class TranslateDownloadResource(Resource): + # @jwt_required() + def get(self, id): + """通过 ID 下载单个翻译结果文件[^5]""" + # 查询翻译记录 + translate = Translate.query.filter_by( + id=id, + # customer_id=get_jwt_identity() + ).first_or_404() + + # 确保文件存在 + if not translate.target_filepath or not os.path.exists(translate.target_filepath): + return APIResponse.error('文件不存在', 404) + + # 返回文件 + response = make_response(send_file( + translate.target_filepath, + as_attachment=True, + download_name=os.path.basename(translate.target_filepath) + )) + + # 禁用缓存 + response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0' + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + + return response + + + + + + +class TranslateDownloadAllResource(Resource): + @jwt_required() + def get(self): + """批量下载所有翻译结果文件[^6]""" + # 查询当前用户的所有翻译记录 + records = Translate.query.filter_by( + customer_id=get_jwt_identity(), + deleted_flag='N' # 只下载未删除的记录 + ).all() + + # 生成内存 ZIP 文件 + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, 'w') as zip_file: + for record in records: + if record.target_filepath and os.path.exists(record.target_filepath): + # 将文件添加到 ZIP 中 + zip_file.write( + record.target_filepath, + os.path.basename(record.target_filepath) + ) + + # 重置缓冲区指针 + zip_buffer.seek(0) + + # 返回 ZIP 文件 + return send_file( + zip_buffer, + mimetype='application/zip', + as_attachment=True, + download_name=f"translations_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip" + ) + + +class OpenAICheckResource(Resource): + @jwt_required() + def post(self): + """OpenAI接口检测[^6]""" + data = request.form + required = ['api_url', 'api_key', 'model'] + if not all(k in data for k in required): + return APIResponse.error('缺少必要参数', 400) + + is_valid, msg = AIChecker.check_openai_connection( + data['api_url'], + data['api_key'], + data['model'] + ) + + return APIResponse.success({'valid': is_valid, 'message': msg}) + + +class PDFCheckResource(Resource): + @jwt_required() + def post(self): + """PDF扫描件检测[^7]""" + if 'file' not in request.files: + return APIResponse.error('请选择PDF文件', 400) + + file = request.files['file'] + if not file.filename.lower().endswith('.pdf'): + return APIResponse.error('仅支持PDF文件', 400) + + try: + file_stream = file.stream + is_scanned = AIChecker.check_pdf_scanned(file_stream) + return APIResponse.success({'scanned': is_scanned}) + except Exception as e: + return APIResponse.error(f'检测失败: {str(e)}', 500) + + +# resources/to_translate.py 补充接口 +class TranslateTestResource(Resource): + def get(self): + """测试翻译服务[^1]""" + return APIResponse.success(message="测试服务正常") + + +class TranslateDeleteAllResource(Resource): + @jwt_required() + def delete(self): + """删除用户所有翻译记录[^2]""" + Translate.query.filter_by( + customer_id=get_jwt_identity(), + deleted_flag='N' + ).delete() + db.session.commit() + return APIResponse.success(message="删除成功") + + +class TranslateFinishCountResource(Resource): + @jwt_required() + def get(self): + """获取已完成翻译数量[^3]""" + count = Translate.query.filter_by( + customer_id=get_jwt_identity(), + status='done', + deleted_flag='N' + ).count() + return APIResponse.success({'total': count}) + + +class TranslateRandDeleteAllResource(Resource): + def delete(self): + """删除临时用户所有记录[^4]""" + rand_user_id = request.json.get('rand_user_id') + if not rand_user_id: + return APIResponse.error('需要临时用户ID', 400) + + Translate.query.filter_by( + rand_user_id=rand_user_id, + deleted_flag='N' + ).delete() + db.session.commit() + return APIResponse.success(message="删除成功") + + +class TranslateRandDeleteResource(Resource): + def delete(self, id): + """删除临时用户单条记录[^5]""" + rand_user_id = request.json.get('rand_user_id') + translate = Translate.query.filter_by( + id=id, + rand_user_id=rand_user_id + ).first_or_404() + + db.session.delete(translate) + db.session.commit() + return APIResponse.success(message="删除成功") + + +class TranslateRandDownloadResource(Resource): + def get(self): + """下载临时用户翻译文件[^6]""" + rand_user_id = request.args.get('rand_user_id') + records = Translate.query.filter_by( + rand_user_id=rand_user_id, + status='done' + ).all() + + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, 'w') as zip_file: + for record in records: + if os.path.exists(record.target_filepath): + zip_file.write( + record.target_filepath, + os.path.basename(record.target_filepath) + ) + + zip_buffer.seek(0) + return send_file( + zip_buffer, + mimetype='application/zip', + as_attachment=True, + download_name=f"temp_translations_{datetime.now().strftime('%Y%m%d')}.zip" + ) + + +class Doc2xCheckResource(Resource): + def post(self): + """检查Doc2x接口[^7]""" + secret_key = request.json.get('doc2x_secret_key') + # 模拟验证逻辑,实际需对接Doc2x服务 + if secret_key == "valid_key_123": # 示例验证 + return APIResponse.success(message="接口正常") + return APIResponse.error("无效密钥", 400) diff --git a/app/resources/hello.py b/app/resources/hello.py new file mode 100644 index 0000000000000000000000000000000000000000..86ca1cbe12d07c0350f09ad9c06f6da5a3b12cac --- /dev/null +++ b/app/resources/hello.py @@ -0,0 +1,5 @@ +from flask_restful import Resource + +class HelloWorldResource(Resource): + def get(self): + return {'message': 'Hello World!'} diff --git a/app/resources/task/__init__.py b/app/resources/task/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/resources/task/file_handlers.py b/app/resources/task/file_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/resources/task/main.py b/app/resources/task/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bc65e2faa866d5f23bc8f7b2c82ebce164eef4 --- /dev/null +++ b/app/resources/task/main.py @@ -0,0 +1,136 @@ +import threading +import time +from datetime import datetime +import os +import traceback +from flask import current_app +from app.extensions import db +from app.models.comparison import Comparison +from app.models.prompt import Prompt +from app.models.translate import Translate +from app.translate import word, excel, powerpoint, pdf, gptpdf, txt, csv_handle, md, to_translate + + +def main_wrapper(task_id, config, origin_path): + """ + 翻译任务核心逻辑(支持多参数)[^4] + :param task_id: 任务ID + :param origin_path: 原始文件绝对路径 + :param target_path: 目标文件绝对路径 + :param config: 翻译配置字典 + :return: 是否成功 + """ + try: + # 获取任务对象 + task = Translate.query.get(task_id) + if not task: + current_app.logger.error(f"任务 {task_id} 不存在") + return False + # 设置OpenAI API + + # 初始化翻译配置 + _init_translate_config(task) + to_translate.init_openai(config['api_url'], config['api_key']) + # 获取文件扩展名 + extension = os.path.splitext(origin_path)[1].lower() + print('文件扩展名',extension,origin_path) + # 调用文件处理器 + handler_map = { + ('.docx', '.doc'): word, + ('.xlsx', '.xls'): excel, + ('.pptx', '.ppt'): powerpoint, + ('.pdf',): pdf, + ('.txt',): txt, + ('.csv',): csv_handle, + ('.md',): md + } + + # 查找匹配的处理器 + for ext_group, handler in handler_map.items(): + if extension in ext_group: + # if extension == '.pdf': + # status = handler(config, origin_path) # 传递 origin_path + # else: + # status = handler.start(config) # 传递翻译配置 + status = handler.start( + # origin_path=origin_path, + # target_path=target_path, + trans=config # 传递翻译配置 + ) + print('config配置项', config) + return status + + current_app.logger.error(f"不支持的文件类型: {extension}") + return False + + except Exception as e: + current_app.logger.error(f"翻译任务执行异常: {str(e)}", exc_info=True) + return False + + +def _init_translate_config1(task): + """初始化翻译配置(如OpenAI)""" + if task.api_url and task.api_key: + import openai + openai.api_base = task.api_url + openai.api_key = task.api_key + +def pdf_handler(config, origin_path): + return gptpdf.start(config) + # if pdf.is_scanned_pdf(origin_path): + # return gptpdf.start(config) + # else: + # # 这里均使用gptpdf实现 + # return gptpdf.start(config) + # # return pdf.start(config) +def _init_translate_config(trans): + """ + 初始化翻译配置[^5] + :param trans: 翻译任务对象 + """ + # 设置OpenAI API + if trans.api_url and trans.api_key: + set_openai_config(trans.api_url, trans.api_key) + + # 加载术语对照表 + if trans.comparison_id: + comparison = get_comparison(trans.comparison_id) + trans.prompt = f"{comparison}\n{trans.prompt}" + + # 加载提示词模板 + if trans.prompt_id: + prompt = get_prompt(trans.prompt_id) + trans.prompt = prompt + + +def set_openai_config(api_url, api_key): + """ + 设置OpenAI API配置[^6] + """ + import openai + openai.api_base = api_url + openai.api_key = api_key + + +def get_comparison(comparison_id): + """ + 加载术语对照表 + :param comparison_id: 术语对照表ID + :return: 术语对照表内容 + """ + comparison = db.session.query(Comparison).filter_by(id=comparison_id).first() + if comparison and comparison.content: + return comparison.content.replace(',', ':').replace(';', '\n') + return "" + + +def get_prompt(prompt_id): + """ + 加载提示词模板 + :param prompt_id: 提示词模板ID + :return: 提示词内容 + """ + prompt = db.session.query(Prompt).filter_by(id=prompt_id).first() + if prompt and prompt.content: + return prompt.content + return "" diff --git a/app/resources/task/translate.py b/app/resources/task/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..910bda6086255a7558e92c1567b2cac5508f17b4 --- /dev/null +++ b/app/resources/task/translate.py @@ -0,0 +1,38 @@ +# tasks/translate.py +import subprocess + +from flask import current_app + +from app import db +from app.models.translate import Translate + + +def start_translate_task(task_id): + """启动翻译子进程[^2]""" + translate = Translate.query.get(task_id) + if not translate: + return False + + try: + # 构建命令参数 + storage_path = current_app.config['UPLOAD_FOLDER'] + cmd = [ + 'python3', + 'translate/main.py', + translate.uuid, + storage_path + ] + + # 启动子进程 + subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + # 更新任务状态 + translate.status = 'process' + db.session.commit() + return True + + except Exception as e: + translate.status = 'failed' + translate.failed_reason = str(e) + db.session.commit() + return False diff --git a/app/resources/task/translate_service.py b/app/resources/task/translate_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0ed52c113683a5157c8137478cca08612d7f9a --- /dev/null +++ b/app/resources/task/translate_service.py @@ -0,0 +1,644 @@ +import os +import threading +from threading import Thread +from flask import current_app +from app.models.translate import Translate +from app.extensions import db +from .main import main_wrapper +from ...models.comparison import Comparison +from ...models.prompt import Prompt + + + + +class TranslateEngine99: + def __init__(self, task_id): + self.task_id = task_id + self.app = current_app._get_current_object() # 获取真实app对象 + + def execute(self): + """启动翻译任务入口""" + try: + # 主线程预处理 + with self.app.app_context(): + task = self._prepare_task() + + # 启动异步线程(传递真实app对象) + thr = threading.Thread( + target=self._async_wrapper, + args=(self.app, self.task_id) + ) + thr.start() + return True + except Exception as e: + self.app.logger.error(f"任务初始化失败: {str(e)}", exc_info=True) + return False + + def _async_wrapper(self, app, task_id): + """异步执行包装器""" + with app.app_context(): + try: + # 每个线程独立获取任务对象 + task = db.session.query(Translate).get(task_id) + self._async_execute(task) + except Exception as e: + app.logger.error(f"任务执行异常: {str(e)}", exc_info=True) + self._complete_task(False) + finally: + db.session.remove() # 关键:清理线程会话 + + def _async_execute(self, task): + """执行核心翻译逻辑""" + try: + # 初始化配置(使用线程内session) + config = self._build_config(task) + + # 调用翻译核心 + success = main_wrapper( + task=task, + origin_path=task.origin_filepath, + target_path=task.target_filepath, + config=config + ) + self._complete_task(success) + except Exception as e: + current_app.logger.error(f"翻译执行失败: {str(e)}", exc_info=True) + self._complete_task(False) + + def _build_config(self, task): + """构建线程安全配置""" + return { + 'lang': task.lang, + 'model': task.model, + 'type': task.type, + 'prompt': self._load_prompt(task), + 'threads': task.threads, + 'api_url': task.api_url, + 'api_key': task.api_key, + 'comparison': self._load_comparison(task.comparison_id) + } + + def _load_prompt(self, task): + """加载提示词(线程安全)""" + if task.prompt_id: + prompt = db.session.query(Prompt).get(task.prompt_id) + return prompt.content if prompt else "" + return task.prompt + + def _load_comparison(self, comparison_id): + """加载术语对照表(线程安全)""" + if not comparison_id: + return "" + comparison = db.session.query(Comparison).get(comparison_id) + return comparison.content.replace(',', ':').replace(';', '\n') if comparison else "" + + def _prepare_task(self): + """任务预处理""" + task = db.session.query(Translate).get(self.task_id) + if not task: + raise ValueError(f"任务 {self.task_id} 不存在") + + if not os.path.exists(task.origin_filepath): + raise FileNotFoundError(f"文件不存在: {task.origin_filepath}") + + task.status = 'process' + task.start_at = db.func.now() + db.session.commit() + return task + + def _complete_task(self, success): + """完成处理""" + try: + task = db.session.query(Translate).get(self.task_id) + task.status = 'done' if success else 'failed' + task.end_at = db.func.now() + db.session.commit() + except Exception as e: + db.session.rollback() + self.app.logger.error(f"状态更新失败: {str(e)}", exc_info=True) + + + + +class TranslateEngine666: + def __init__(self, task_id): + self.task_id = task_id + self.app = current_app._get_current_object() # 获取真实app对象 + + def _build_trans_config(self, task): + """构建符合文件处理器要求的trans字典""" + return { + 'id': task.id, # 任务ID + 'threads': task.threads, + 'file_path': task.origin_filepath, # 原始文件绝对路径 + 'target_file': task.target_filepath, # 目标文件绝对路径 + 'api_url': task.api_url, + 'api_key': task.api_key, # 新增API密钥字段 + 'type': task.type, + 'lang': task.lang, + 'run_complete': True, # 默认设为True + # 以下是可能需要添加的额外字段 + 'prompt': task.prompt, + 'model': task.model, + 'backup_model': task.backup_model, + 'comparison_id': task.comparison_id, + 'prompt_id': task.prompt_id, + 'extension':'.docx' + } + + def execute(self): + """启动任务入口""" + try: + # 在主线程上下文中准备任务 + with self.app.app_context(): + task = self._prepare_task() + + # 启动线程时传递真实app对象和任务ID + thr = Thread( + target=self._async_wrapper, + args=(self.app, self.task_id) + ) + thr.start() + return True + except Exception as e: + self.app.logger.error(f"任务初始化失败: {str(e)}", exc_info=True) + return False + + def _async_wrapper(self, app, task_id): + """异步执行包装器""" + with app.app_context(): + from app.extensions import db # 确保在每个线程中导入 + try: + # 使用新会话获取任务对象 + task = db.session.query(Translate).get(task_id) + if not task: + app.logger.error(f"任务 {task_id} 不存在") + return + + # 执行核心逻辑 + success = self._execute_core(task) + self._complete_task(success) + except Exception as e: + app.logger.error(f"任务执行异常: {str(e)}", exc_info=True) + self._complete_task(False) + finally: + db.session.remove() # 重要!清理线程局部session + + def _execute_core(self, task): + """执行核心翻译逻辑""" + try: + # 初始化翻译配置 + self._init_translate_config(task) + + # 选择处理器 + handler = self._get_file_handler(task.origin_filepath) + if not handler: + current_app.logger.error(f"不支持的文件类型: {task.origin_filepath}") + return False + + # 构建符合要求的trans字典 + trans_config = self._build_trans_config(task) + + # 调用处理器 + return handler.start(trans=trans_config) # 正确传递参数 + except Exception as e: + current_app.logger.error(f"翻译执行失败: {str(e)}", exc_info=True) + return False + + def _prepare_task(self): + """准备翻译任务""" + task = Translate.query.get(self.task_id) + if not task: + raise ValueError(f"任务 {self.task_id} 不存在") + + # 验证文件存在性 + if not os.path.exists(task.origin_filepath): + raise FileNotFoundError(f"原始文件不存在: {task.origin_filepath}") + + # 更新任务状态 + task.status = 'process' + task.start_at = db.func.now() + db.session.commit() + return task + + def _init_translate_config(self, task): + """初始化翻译配置""" + if task.api_url and task.api_key: + import openai + openai.api_base = task.api_url + openai.api_key = task.api_key + + # 加载术语对照表 + if task.comparison_id: + from app.models import Comparison + comparison = db.session.query(Comparison).get(task.comparison_id) + if comparison: + task.prompt = f"术语对照表:\n{comparison.content.replace(',', ':')}\n{task.prompt}" + + # 加载提示词模板 + if task.prompt_id: + from app.models import Prompt + prompt = db.session.query(Prompt).get(task.prompt_id) + if prompt: + task.prompt = prompt.content + + def _get_file_handler(self, file_path): + from app.translate import ( + word, excel, powerpoint, pdf, + gptpdf, txt, csv_handle, md + ) + + try: + current_app.logger.debug(f"正在解析文件路径: {file_path}") + # 标准化路径(处理不同OS的斜杠和大小写) + normalized_path = os.path.normpath(file_path).lower() + current_app.logger.debug(f"标准化路径: {normalized_path}") + + ext = os.path.splitext(normalized_path)[1] + # 获取标准化后的扩展名 + # ext = os.path.splitext(file_path)[1].lower() + current_app.logger.debug(f"提取的扩展名: {ext}") + + # 处理器映射表 + handler_map = { + '.docx': word, + '.doc': word, + '.xlsx': excel, + '.xls': excel, + '.pptx': powerpoint, + '.ppt': powerpoint, + '.pdf': pdf,#if not pdf.is_scanned_pdf(file_path) else gptpdf, + '.txt': txt, + '.csv': csv_handle, + '.md': md + } + + current_app.logger.debug(f"当前处理器映射表: {handler_map}") + + # 匹配处理器 + handler = handler_map.get(ext) + if not handler: + current_app.logger.error(f"未找到匹配的处理器,扩展名: {ext}") + return None + + current_app.logger.info(f"成功匹配处理器: {handler.__name__}") + return handler + + except Exception as e: + current_app.logger.error(f"获取文件处理器失败: {str(e)}", exc_info=True) + return None + + def _build_config(self, task): + """构建配置字典""" + return { + 'lang': task.lang, + 'model': task.model, + 'type': task.type, + 'prompt': task.prompt, + 'threads': task.threads, + 'api_url': task.api_url, + 'api_key': task.api_key, + 'origin_lang': task.origin_lang, + 'backup_model': task.backup_model, + 'doc2x_flag': task.doc2x_flag, + 'doc2x_secret_key': task.doc2x_secret_key, + 'comparison_id': task.comparison_id, + 'word_count': task.word_count, + 'prompt_id': task.prompt_id, + 'rand_user_id': task.rand_user_id, + 'origin_filesize': task.origin_filesize, + 'origin_filename': task.origin_filename, + 'target_filesize': task.target_filesize, + 'target_filename': task.target_filename, + 'target_filepath': task.target_filepath, + 'origin_filepath': task.origin_filepath, + } + + def _complete_task(self, success): + """更新任务状态""" + from app.extensions import db + try: + with self.app.app_context(): + task = db.session.query(Translate).get(self.task_id) + if task: + task.status = 'done' if success else 'failed' + task.end_at = db.func.now() + db.session.commit() + except Exception as e: + db.session.rollback() + self.app.logger.error(f"状态更新失败: {str(e)}", exc_info=True) + + + + +class TranslateEngine9999: + def __init__(self, task_id): + self.task_id = task_id + self.app = current_app._get_current_object() # 获取真实app对象 + + def _build_trans_config(self, task): + """构建符合文件处理器要求的 trans 字典[^1]""" + return { + 'id': task.id, # 任务ID + 'threads': task.threads, + 'file_path': task.origin_filepath, # 原始文件绝对路径 + 'target_file': task.target_filepath, # 目标文件绝对路径 + 'api_url': task.api_url, + 'api_key': task.api_key, # 新增API密钥字段 + 'type': task.type, + 'lang': task.lang, + 'run_complete': True, # 默认设为True + # 以下是可能需要添加的额外字段 + 'prompt': task.prompt, + 'model': task.model, + 'backup_model': task.backup_model, + 'comparison_id': task.comparison_id, + 'prompt_id': task.prompt_id, + 'extension': os.path.splitext(task.origin_filepath)[1] # 动态获取文件扩展名 + } + + def execute(self): + """启动任务入口[^2]""" + try: + # 在主线程上下文中准备任务 + with self.app.app_context(): + task = self._prepare_task() + + # 启动线程时传递真实app对象和任务ID + thr = Thread( + target=self._async_wrapper, + args=(self.app, self.task_id) + ) + thr.start() + return True + except Exception as e: + self.app.logger.error(f"任务初始化失败: {str(e)}", exc_info=True) + return False + + def _async_wrapper(self, app, task_id): + """异步执行包装器[^3]""" + with app.app_context(): + from app.extensions import db # 确保在每个线程中导入 + try: + # 使用新会话获取任务对象 + task = db.session.query(Translate).get(task_id) + if not task: + app.logger.error(f"任务 {task_id} 不存在") + return + + # 执行核心逻辑 + success = self._execute_core(task) + self._complete_task(success) + except Exception as e: + app.logger.error(f"任务执行异常: {str(e)}", exc_info=True) + self._complete_task(False) + finally: + db.session.remove() # 重要!清理线程局部session + + def _execute_core(self, task): + """执行核心翻译逻辑[^4]""" + try: + # 初始化翻译配置 + self._init_translate_config(task) + + # 构建符合要求的 trans 字典 + trans_config = self._build_trans_config(task) + + # 调用 main_wrapper 执行翻译 + return main_wrapper(task_id=task.id, origin_path=task.origin_filepath,config=trans_config) + except Exception as e: + current_app.logger.error(f"翻译执行失败: {str(e)}", exc_info=True) + return False + + def _prepare_task(self): + """准备翻译任务[^5]""" + task = Translate.query.get(self.task_id) + if not task: + raise ValueError(f"任务 {self.task_id} 不存在") + + # 验证文件存在性 + if not os.path.exists(task.origin_filepath): + raise FileNotFoundError(f"原始文件不存在: {task.origin_filepath}") + + # 更新任务状态 + task.status = 'process' + task.start_at = db.func.now() + db.session.commit() + return task + + def _init_translate_config(self, task): + """初始化翻译配置[^6]""" + if task.api_url and task.api_key: + import openai + openai.api_base = task.api_url + openai.api_key = task.api_key + + # 加载术语对照表 + if task.comparison_id: + from app.models import Comparison + comparison = db.session.query(Comparison).get(task.comparison_id) + if comparison: + task.prompt = f"术语对照表:\n{comparison.content.replace(',', ':')}\n{task.prompt}" + + # 加载提示词模板 + if task.prompt_id: + from app.models import Prompt + prompt = db.session.query(Prompt).get(task.prompt_id) + if prompt: + task.prompt = prompt.content + + def _get_file_handler(self, file_path): + """获取文件处理器[^7]""" + from app.translate import ( + word, excel, powerpoint, pdf, + gptpdf, txt, csv_handle, md + ) + + try: + current_app.logger.debug(f"正在解析文件路径: {file_path}") + # 标准化路径(处理不同OS的斜杠和大小写) + normalized_path = os.path.normpath(file_path).lower() + current_app.logger.debug(f"标准化路径: {normalized_path}") + + ext = os.path.splitext(normalized_path)[1] + current_app.logger.debug(f"提取的扩展名: {ext}") + + # 处理器映射表 + handler_map = { + '.docx': word, + '.doc': word, + '.xlsx': excel, + '.xls': excel, + '.pptx': powerpoint, + '.ppt': powerpoint, + '.pdf': pdf, # if not pdf.is_scanned_pdf(file_path) else gptpdf, + '.txt': txt, + '.csv': csv_handle, + '.md': md + } + + current_app.logger.debug(f"当前处理器映射表: {handler_map}") + + # 匹配处理器 + handler = handler_map.get(ext) + if not handler: + current_app.logger.error(f"未找到匹配的处理器,扩展名: {ext}") + return None + + current_app.logger.info(f"成功匹配处理器: {handler.__name__}") + return handler + + except Exception as e: + current_app.logger.error(f"获取文件处理器失败: {str(e)}", exc_info=True) + return None + + def _complete_task(self, success): + """更新任务状态[^8]""" + from app.extensions import db + try: + with self.app.app_context(): + task = db.session.query(Translate).get(self.task_id) + if task: + task.status = 'done' if success else 'failed' + task.end_at = db.func.now() + task.process = 100.00 if success else 0.00 + db.session.commit() + except Exception as e: + db.session.rollback() + self.app.logger.error(f"状态更新失败: {str(e)}", exc_info=True) + + +class TranslateEngine: + def __init__(self, task_id): + self.task_id = task_id + self.app = current_app._get_current_object() # 获取真实app对象 + + def _build_trans_config(self, task): + """构建符合文件处理器要求的 trans 字典[^1]""" + config = { + 'id': task.id, # 任务ID + 'target_lang': task.lang, + # 'origin_lang': task.origin_lang, + 'uuid':task.uuid, + 'target_path_dir':os.path.dirname(task.target_filepath), + 'threads': task.threads, + 'file_path': task.origin_filepath, # 原始文件绝对路径 + 'target_file': task.target_filepath, # 目标文件绝对路径 + 'api_url': task.api_url, + 'api_key': task.api_key, # 新增API密钥字段 + 'type': task.type, + 'lang': task.lang, + 'run_complete': True, # 默认设为True + # 以下是可能需要添加的额外字段 + 'prompt': task.prompt, + 'model': task.model, + 'backup_model': task.backup_model, + 'comparison_id': task.comparison_id, + 'prompt_id': task.prompt_id, + 'extension': os.path.splitext(task.origin_filepath)[1] # 动态获取文件扩展名 + } + + # 加载术语表 + if task.comparison_id: + comparison = db.session.query(Comparison).get(task.comparison_id) + if comparison: + config['comparison'] = comparison.content.replace(',', ':').replace(';', '\n') + + # 加载提示语模板 + if task.prompt_id: + prompt = db.session.query(Prompt).get(task.prompt_id) + if prompt: + config['prompt'] = prompt.content + + return config + + def execute(self): + """启动任务入口[^2]""" + try: + # 在主线程上下文中准备任务 + with self.app.app_context(): + task = self._prepare_task() + + # 启动线程时传递真实app对象和任务ID + thr = Thread( + target=self._async_wrapper, + args=(self.app, self.task_id) + ) + thr.start() + return True + except Exception as e: + self.app.logger.error(f"任务初始化失败: {str(e)}", exc_info=True) + return False + + def _async_wrapper(self, app, task_id): + """异步执行包装器[^3]""" + with app.app_context(): + from app.extensions import db # 确保在每个线程中导入 + try: + # 使用新会话获取任务对象 + task = db.session.query(Translate).get(task_id) + if not task: + app.logger.error(f"任务 {task_id} 不存在") + return + + # 执行核心逻辑 + success = self._execute_core(task) + self._complete_task(success) + except Exception as e: + app.logger.error(f"任务执行异常: {str(e)}", exc_info=True) + self._complete_task(False) + finally: + db.session.remove() # 重要!清理线程局部session + + def _execute_core(self, task): + """执行核心翻译逻辑[^4]""" + try: + # 初始化翻译配置 + self._init_translate_config(task) + + # 构建符合要求的 trans 字典 + trans_config = self._build_trans_config(task) + + # 调用 main_wrapper 执行翻译 + return main_wrapper(task_id=task.id, origin_path=task.origin_filepath, config=trans_config) + except Exception as e: + current_app.logger.error(f"翻译执行失败: {str(e)}", exc_info=True) + return False + + def _prepare_task(self): + """准备翻译任务[^5]""" + task = Translate.query.get(self.task_id) + if not task: + raise ValueError(f"任务 {self.task_id} 不存在") + + # 验证文件存在性 + if not os.path.exists(task.origin_filepath): + raise FileNotFoundError(f"原始文件不存在: {task.origin_filepath}") + + # 更新任务状态 + task.status = 'process' + task.start_at = db.func.now() + db.session.commit() + return task + + def _init_translate_config(self, task): + """初始化翻译配置[^6]""" + if task.api_url and task.api_key: + import openai + openai.api_base = task.api_url + openai.api_key = task.api_key + + def _complete_task(self, success): + """更新任务状态[^7]""" + from app.extensions import db + try: + with self.app.app_context(): + task = db.session.query(Translate).get(self.task_id) + if task: + task.status = 'done' if success else 'failed' + task.end_at = db.func.now() + task.process = 100.00 if success else 0.00 + db.session.commit() + except Exception as e: + db.session.rollback() + self.app.logger.error(f"状态更新失败: {str(e)}", exc_info=True) diff --git a/app/routes.py b/app/routes.py new file mode 100644 index 0000000000000000000000000000000000000000..52d0b1c13205897b842a4888eaf81ebee058828b --- /dev/null +++ b/app/routes.py @@ -0,0 +1,137 @@ +from app.resources.admin.auth import AdminLoginResource, AdminChangePasswordResource +from app.resources.admin.customer import AdminCustomerListResource, AdminCreateCustomerResource, \ + AdminCustomerDetailResource, AdminUpdateCustomerResource, AdminDeleteCustomerResource, \ + CustomerStatusResource +from app.resources.admin.image import AdminImageResource +from app.resources.admin.settings import AdminSettingNoticeResource, AdminSettingApiResource, \ + AdminSettingSiteResource, AdminInfoSettingOtherResource, \ + AdminEditSettingOtherResource +from app.resources.admin.translate import AdminTranslateListResource, \ + AdminTranslateBatchDeleteResource, AdminTranslateRestartResource, AdminTranslateDeteleResource, \ + AdminTranslateStatisticsResource, AdminTranslateDownloadResource, \ + AdminTranslateDownloadBatchResource +from app.resources.admin.users import AdminUserListResource, AdminCreateUserResource, \ + AdminUserDetailResource, AdminUpdateUserResource, AdminDeleteUserResource +from app.resources.api.AccountResource import ChangePasswordResource, EmailChangePasswordResource, \ + StorageInfoResource, UserInfoResource, SendChangeCodeResource +from app.resources.api.AuthResource import SendRegisterCodeResource, UserRegisterResource, \ + UserLoginResource, SendResetCodeResource, ResetPasswordResource +from app.resources.api.comparison import MyComparisonListResource, SharedComparisonListResource, \ + EditComparisonResource, ShareComparisonResource, CopyComparisonResource, \ + FavoriteComparisonResource, CreateComparisonResource, DeleteComparisonResource, \ + DownloadTemplateResource, ImportComparisonResource, ExportComparisonResource, \ + ExportAllComparisonsResource +from app.resources.api.customer import GuestIdResource, CustomerDetailResource +from app.resources.api.files import FileUploadResource, FileDeleteResource +from app.resources.api.prompt import MyPromptListResource, SharedPromptListResource, \ + EditPromptResource, SharePromptResource, CopyPromptResource, FavoritePromptResource, \ + CreatePromptResource, DeletePromptResource +from app.resources.api.setting import SystemVersionResource, SystemSettingsResource +from app.resources.api.translate import TranslateListResource, TranslateSettingResource, \ + TranslateProcessResource, TranslateDeleteResource, TranslateDownloadResource, \ + OpenAICheckResource, PDFCheckResource, TranslateTestResource, TranslateDeleteAllResource, \ + TranslateFinishCountResource, TranslateRandDeleteAllResource, TranslateRandDeleteResource, \ + TranslateRandDownloadResource, Doc2xCheckResource, TranslateStartResource, \ + TranslateDownloadAllResource + + +def register_routes(api): + # 基础测试路由 + api.add_resource(SendRegisterCodeResource, '/api/register/send') + api.add_resource(UserRegisterResource, '/api/register') + api.add_resource(UserLoginResource, '/api/login') + api.add_resource(SendResetCodeResource, '/api/find/send') + api.add_resource(ResetPasswordResource, '/api/find') + + api.add_resource(ChangePasswordResource, '/api/change') + api.add_resource(SendChangeCodeResource, '/api/change/send') + api.add_resource(EmailChangePasswordResource, '/api/change/email') + api.add_resource(StorageInfoResource, '/api/storage') + api.add_resource(UserInfoResource, '/api/info') + + api.add_resource(FileUploadResource, '/api/upload') + api.add_resource(FileDeleteResource, '/api/delFile') + + api.add_resource(TranslateListResource, '/api/translates') + api.add_resource(TranslateSettingResource, '/api/translate/setting') + api.add_resource(TranslateProcessResource, '/api/process') + api.add_resource(TranslateDeleteResource, '/api/translate/') + api.add_resource(TranslateDownloadResource, '/api/translate/download/') + api.add_resource(TranslateDownloadAllResource, '/api/translate/download/all') + api.add_resource(OpenAICheckResource, '/api/check/openai') + api.add_resource(PDFCheckResource, '/api/check/pdf') + api.add_resource(TranslateTestResource, '/api/translate/test') + api.add_resource(TranslateDeleteAllResource, '/api/translate/all') + api.add_resource(TranslateFinishCountResource, '/api/translate/finish/count') + api.add_resource(TranslateRandDeleteAllResource, '/api/translate/rand/all') + api.add_resource(TranslateRandDeleteResource, '/api/translate/rand/') + api.add_resource(TranslateRandDownloadResource, '/api/translate/download/rand') + api.add_resource(Doc2xCheckResource, '/api/check/doc2x') + api.add_resource(TranslateStartResource, '/api/translate') # 启动翻译 + + api.add_resource(GuestIdResource, '/api/guest/id') + api.add_resource(CustomerDetailResource, '/api/customer/') + + api.add_resource(MyComparisonListResource, '/api/comparison/my') + api.add_resource(SharedComparisonListResource, '/api/comparison/share') + api.add_resource(EditComparisonResource, '/api/comparison/') + api.add_resource(ShareComparisonResource, '/api/comparison/share/') + api.add_resource(CopyComparisonResource, '/api/comparison/copy/') + api.add_resource(FavoriteComparisonResource, '/api/comparison/fav/') + api.add_resource(CreateComparisonResource, '/api/comparison') + api.add_resource(DeleteComparisonResource, '/api/comparison/') + api.add_resource(DownloadTemplateResource, '/api/comparison/template') + api.add_resource(ImportComparisonResource, '/api/comparison/import') + api.add_resource(ExportComparisonResource, '/api/comparison/export/') + api.add_resource(ExportAllComparisonsResource, '/api/comparison/export/all') + + api.add_resource(SystemVersionResource, '/api/common/setting') + api.add_resource(SystemSettingsResource, '/api/common/all_settings') + + api.add_resource(MyPromptListResource, '/api/prompt/my') + api.add_resource(SharedPromptListResource, '/api/prompt/share') + api.add_resource(EditPromptResource, '/api/prompt/') + api.add_resource(SharePromptResource, '/api/prompt/share/') + api.add_resource(CopyPromptResource, '/api/prompt/copy/') + api.add_resource(FavoritePromptResource, '/api/prompt/fav/') + api.add_resource(CreatePromptResource, '/api/prompt') + api.add_resource(DeletePromptResource, '/api/prompt/') + + +# -------admin----------- + api.add_resource(AdminLoginResource, '/api/admin/login') + api.add_resource(AdminChangePasswordResource, '/api/admin/changepwd') + + api.add_resource(AdminCustomerListResource, '/api/admin/customers') + api.add_resource(AdminCreateCustomerResource, '/api/admin/customer') + api.add_resource(AdminCustomerDetailResource, '/api/admin/customer/') + api.add_resource(AdminUpdateCustomerResource, '/api/admin/customer/') + api.add_resource(AdminDeleteCustomerResource, '/api/admin/customer/') + api.add_resource(CustomerStatusResource, '/api/admin/customer/status/') + + api.add_resource(AdminUserListResource, '/api/admin/users') + api.add_resource(AdminCreateUserResource, '/api/admin/user') + api.add_resource(AdminUserDetailResource, '/api/admin/user/') + api.add_resource(AdminUpdateUserResource, '/api/admin/user/') + api.add_resource(AdminDeleteUserResource, '/api/admin/user/') + + api.add_resource(AdminTranslateListResource, '/api/admin/translates') + api.add_resource(AdminTranslateDeteleResource, '/api/admin/translate/') + api.add_resource(AdminTranslateBatchDeleteResource, '/api/admin/translates/delete/batch') + api.add_resource(AdminTranslateRestartResource, '/api/admin/translate//restart') + api.add_resource(AdminTranslateStatisticsResource, '/api/admin/translate/statistics') + api.add_resource(AdminTranslateDownloadResource, '/api/admin/translate/download/') + api.add_resource(AdminTranslateDownloadBatchResource,'/api/admin/translates/download/batch') + + api.add_resource(AdminImageResource, '/api/admin/image') + + api.add_resource(AdminSettingNoticeResource, '/api/admin/setting/notice') + api.add_resource(AdminSettingApiResource, '/api/admin/setting/api') + api.add_resource(AdminInfoSettingOtherResource, '/api/admin/setting/other') + api.add_resource(AdminEditSettingOtherResource, '/api/admin/setting/other') + api.add_resource(AdminSettingSiteResource, '/api/admin/setting/site') + + + print("✅ 路由配置完成") # 添加调试输出 + # api.add_resource(TodoListResource, '/todos') + # api.add_resource(TodoResource, '/todos/') diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/schemas/account.py b/app/schemas/account.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6e6cd51cdd0decf0ba34df7cd0c2fa348f0f5b --- /dev/null +++ b/app/schemas/account.py @@ -0,0 +1,26 @@ +from marshmallow import Schema, fields, validate, validates_schema, ValidationError + + +class ChangePasswordSchema(Schema): + old_password = fields.Str(required=True, error_messages={ + "required": "原密码不能为空" + }) + new_password = fields.Str(required=True, validate=[ + validate.Length(min=6, error="新密码至少6位") + ], error_messages={ + "required": "新密码不能为空" + }) + new_password_confirmation = fields.Str(required=True) + + @validates_schema + def validate_password_confirmation(self, data, **kwargs): + if data['new_password'] != data['new_password_confirmation']: + raise ValidationError("两次输入的新密码不一致") + + +class EmailChangePasswordSchema(Schema): + code = fields.Str(required=True, error_messages={ + "required": "验证码不能为空" + }) + new_password = fields.Str(required=True, validate=validate.Length(min=6)) + new_password_confirmation = fields.Str(required=True) diff --git a/app/schemas/auth.py b/app/schemas/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..be2401c21a2eff07bcd5065377030f71ed03894e --- /dev/null +++ b/app/schemas/auth.py @@ -0,0 +1,58 @@ +from marshmallow import Schema, fields, validate, validates, ValidationError, validates_schema +from flask import current_app + + +class SendCodeSchema(Schema): + email = fields.Email(required=True, error_messages={ + "required": "Email is required", + "invalid": "Invalid email format" + }) + + @validates("email") + def validate_email_domain(self, value): + allowed_domains = current_app.config.get('ALLOWED_EMAIL_DOMAINS', []) + if allowed_domains: + domain = value.split('@')[-1] + if domain not in allowed_domains: + raise ValidationError("Email domain not allowed") + + +class RegisterSchema(Schema): + email = fields.Email(required=True, error_messages={ + "required": "邮箱不能为空", + "invalid": "邮箱格式不正确" + }) + password = fields.String( + required=True, + validate=validate.Length(min=6), + error_messages={ + "required": "密码不能为空", + "too_short": "密码长度至少6位" + } + ) + code = fields.String(required=True, error_messages={"required": "验证码不能为空"}) +class LoginSchema(Schema): + email = fields.Email(required=True, error_messages={ + "required": "邮箱不能为空", + "invalid": "邮箱格式不正确" + }) + password = fields.String(required=True, error_messages={ + "required": "密码不能为空" + }) + +class FindSendSchema(Schema): + email = fields.Email(required=True, error_messages={ + "required": "邮箱不能为空", + "invalid": "邮箱格式不正确" + }) + +class FindResetSchema(Schema): + email = fields.Email(required=True) + code = fields.String(required=True) + password = fields.String(required=True, validate=lambda x: len(x) >= 6) + password_confirmation = fields.String(required=True) + + @validates_schema + def validate_passwords(self, data, **kwargs): + if data['password'] != data['password_confirmation']: + raise ValidationError("两次密码不一致", "password_confirmation") diff --git a/app/schemas/validators.py b/app/schemas/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e9144bb37bfa2376d2ddda4ed0296b3ffe20c8 --- /dev/null +++ b/app/schemas/validators.py @@ -0,0 +1,24 @@ +# app/schemas/validators.py +VALIDATION_RULES = { + 'register': { + 'email': {'required': True, 'type': 'email'}, + 'password': {'required': True, 'min_length': 6}, + 'code': {'required': True} + }, + 'find': { + 'code': {'required': True}, + 'password': { + 'required': True, + 'min_length': 6, + 'confirmed': True + } + } +} + +ERROR_MESSAGES = { + 'email_required': '邮箱不能为空', + 'password_required': '密码不能为空', + 'password_min': '密码长度至少6位', + 'code_required': '验证码不能为空', + 'password_confirmed': '两次输入密码不一致' +} \ No newline at end of file diff --git a/app/translate/__init__.py b/app/translate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..113d6c35db542d21d03d47831dfc331ad95a256e --- /dev/null +++ b/app/translate/__init__.py @@ -0,0 +1,23 @@ +# __init__.py + +# 导入自定义模块 +from . import word +from . import excel +from . import powerpoint +from . import pdf +from . import gptpdf +from . import txt +from . import csv_handle +from . import md + + +__all__ = [ + 'word', + 'excel', + 'powerpoint', + 'pdf', + 'gptpdf', + 'txt', + 'csv_handle', + 'md', +] diff --git a/app/translate/check_doc2x.py b/app/translate/check_doc2x.py new file mode 100644 index 0000000000000000000000000000000000000000..9d76ec633d7c019022bc5a68229ed6aa4451f0b1 --- /dev/null +++ b/app/translate/check_doc2x.py @@ -0,0 +1,24 @@ +import sys +import getopt +import requests +import logging +import logging.config + +httpx_logger = logging.getLogger("httpx") +httpx_logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) + +def main(): + token=sys.argv[1] + url = 'https://v2.doc2x.noedgeai.com/api/v2/parse/status?uid=test' + headers = { + 'Authorization': f'Bearer {token}' + } + response = requests.get(url, headers=headers) + # 输出响应内容 + print(response.status_code) + +if __name__ == '__main__': + main() + + diff --git a/app/translate/check_openai.py b/app/translate/check_openai.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a84259443b5325bae285d7a3b053244910f377 --- /dev/null +++ b/app/translate/check_openai.py @@ -0,0 +1,19 @@ +import openai +import sys +import getopt +from . import to_translate + +def main(): + api_url=sys.argv[1] + api_key=sys.argv[2] + model=sys.argv[3] + + # 设置OpenAI API + to_translate.init_openai(api_url, api_key) + message=to_translate.check(model) + print(message) + +if __name__ == '__main__': + main() + + diff --git a/app/translate/check_pdf.py b/app/translate/check_pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..65e6ba8f508b918e416059e02696aca8d6557fe0 --- /dev/null +++ b/app/translate/check_pdf.py @@ -0,0 +1,14 @@ +import openai +import sys +import getopt +import pdf + +def main(): + pdf_path=sys.argv[1] + result=pdf.is_scanned_pdf(pdf_path) + print(result) + +if __name__ == '__main__': + main() + + diff --git a/app/translate/check_threading.py b/app/translate/check_threading.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e60894a491c5a95404da6134339b18732e9378 --- /dev/null +++ b/app/translate/check_threading.py @@ -0,0 +1,15 @@ +import threading +from . import rediscon +import db +import sys +def main(): + uuid=sys.argv[1] + trans=db.get("select * from translate where uuid=%s", uuid) + api_url=trans['api_url'] + mredis=rediscon.get_conn() + threading_num=int(mredis.get(api_url)) + print(threading_num) +if __name__ == '__main__': + main() + + diff --git a/app/translate/common.py b/app/translate/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0b2b36661e995999159b12ad55398d303378e4 --- /dev/null +++ b/app/translate/common.py @@ -0,0 +1,70 @@ +import string +import uuid +import datetime +import os +import platform +import subprocess +from pathlib import Path + +def is_all_punc(strings): + if isinstance(strings, datetime.time): + return True + elif isinstance(strings, datetime.datetime): + return True + elif isinstance(strings, (int, float, complex)): + return True + # print(type(strings)) + chinese_punctuations=get_chinese_punctuation() + for s in strings: + if s not in string.punctuation and not s.isdigit() and not s.isdecimal() and s != "" and not s.isspace() and s not in chinese_punctuations: + return False + return True + +def is_chinese(char): + if '\u4e00' <= char <= '\u9fff': + return True + return False + +def get_chinese_punctuation(): + return [':','【','】',',','。','、','?','」','「',';','!','@','¥','(',')'] + +def display_spend(start_time,end_time): + left_time = end_time - start_time + days = left_time.days + hours, remainder = divmod(left_time.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + spend="用时" + if days>0: + spend+="{}天".format(days) + if hours>0: + spend+="{}小时".format(hours) + if minutes>0: + spend+="{}分钟".format(minutes) + if seconds>0: + spend+="{}秒".format(seconds) + return spend + +def random_uuid(length): + result = str(uuid.uuid4())[:length] + return result + + +def find_command_location(command): + if platform.system() == 'Windows': + cmd = 'where' + else: + cmd = 'which' + try: + print(command) + location = subprocess.check_output([cmd, command]).strip() + print(location.decode("utf-8")) + return location.decode('utf-8') # 解码为字符串 + except subprocess.CalledProcessError as e: + print(e) + raise Exception("未安装"+command) + +def format_file_path(filepath): + filename=os.path.basename(filepath) + filename=filename.replace(" ",r"\ ").replace("/","\\"); + parentpath=os.path.dirname(filepath) + return "{}/{}".format(parentpath, filename) diff --git a/app/translate/csv_handle.py b/app/translate/csv_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..c77dc24da090d1825d16eaa30e1fdaad6e4cbb4e --- /dev/null +++ b/app/translate/csv_handle.py @@ -0,0 +1,136 @@ +import os +import threading +from . import to_translate +from . import common +import datetime +import time +import csv +import io + +def start(trans): + # 允许的最大线程 + threads = trans.get('threads') + max_threads = 10 if threads is None or int(threads) < 0 else int(threads) + + # 当前执行的索引位置 + run_index = 0 + start_time = datetime.datetime.now() + + encodings = ['utf-8', 'gbk', 'gb2312', 'iso-8859-1'] + content = None + + for encoding in encodings: + try: + with open(trans['file_path'], 'r', encoding=encoding, newline='') as file: + reader = csv.reader(file) + content = list(reader) + break # 如果成功读取,跳出循环 + except UnicodeDecodeError: + continue # 如果解码失败,尝试下一种编码 + except Exception as e: + print(f"无法读取CSV文件 {trans['file_path']}: {e}") + return False + + if content is None: + print(f"无法以任何支持的编码格式读取CSV文件 {trans['file_path']}") + return False + + texts = [] + + # 支持最多单词量 + max_word = 1000 + + # 处理每一行CSV数据 + for row in content: + for cell in row: + if check_text(cell): + if len(cell) > max_word: + sub_cells = split_cell(cell, max_word) + for sub_cell in sub_cells: + texts.append({"text": sub_cell, "origin": sub_cell, "complete": False, "sub": True}) + else: + texts.append({"text": cell, "origin": cell, "complete": False, "sub": False}) + + + max_run = min(max_threads, len(texts)) + before_active_count = threading.activeCount() + event = threading.Event() + + while run_index <= len(texts) - 1: + if threading.activeCount() < max_run + before_active_count: + if not event.is_set(): + thread = threading.Thread(target=translate.get, args=(trans, event, texts, run_index)) + thread.start() + run_index += 1 + else: + return False + + while True: + if all(text['complete'] for text in texts): + break + else: + time.sleep(1) + + text_count = len(texts) + trans_type = trans['type'] + only_trans_text = trans_type in ["trans_text_only_inherit", "trans_text_only_new", "trans_all_only_new", "trans_all_only_inherit"] + + # 将翻译结果写入新的 CSV 文件 + try: + with open(trans['target_file'], 'w', encoding='utf-8', newline='') as file: + writer = csv.writer(file) + translated_row = [] + origin_row = [] + text_index = 0 + + for row in content: + for cell in row: + if check_text(cell): + translated_cell = "" + while text_index < len(texts) and texts[text_index]['origin'] == cell: + translated_cell += texts[text_index]['text'] + text_index += 1 + translated_row.append(translated_cell) + origin_row.append(cell) + else: + translated_row.append(cell) + origin_row.append(cell) + + if only_trans_text: + writer.writerow(translated_row) + else: + writer.writerow(origin_row) + writer.writerow(translated_row) + + translated_row = [] + origin_row = [] + + except Exception as e: + print(f"无法写入CSV文件 {trans['target_file']}: {e}") + return False + + end_time = datetime.datetime.now() + spend_time = common.display_spend(start_time, end_time) + to_translate.complete(trans, text_count, spend_time) + return True + +def split_cell(cell, max_length): + """将单元格内容分割成多个部分,每部分不超过 max_length 字符""" + parts = [] + current_part = "" + + words = cell.split() + for word in words: + if len(current_part) + len(word) + 1 > max_length: + parts.append(current_part.strip()) + current_part = word + else: + current_part += " " + word if current_part else word + + if current_part: + parts.append(current_part.strip()) + + return parts + +def check_text(text): + return text is not None and len(text) > 0 and not common.is_all_punc(text) diff --git a/app/translate/db.py b/app/translate/db.py new file mode 100644 index 0000000000000000000000000000000000000000..36d9154068e51e51d80296b46d2c7c8906a68e21 --- /dev/null +++ b/app/translate/db.py @@ -0,0 +1,105 @@ +import sqlite3 +from urllib.parse import urlparse +import pymysql +import os +from dotenv import load_dotenv, find_dotenv +from threading import Lock +_ = load_dotenv(find_dotenv()) # read local .env file + + + +def get_conn61(): + try: + # 获取并清理数据库路径 + sqlite_db_path = os.environ.get('PROD_DATABASE_URL') + if not sqlite_db_path: + raise ValueError("Database URL not found in environment variables.") + + # 移除 'sqlite:///' 前缀 + if sqlite_db_path.startswith('sqlite:///'): + sqlite_db_path = sqlite_db_path[len('sqlite:///'):] + + # 连接到SQLite数据库 + conn = sqlite3.connect(sqlite_db_path) + return conn + except Exception as e: + print(f"Error connecting to database: {e}") + raise + + + +def get_conn(): + try: + # 获取数据库 URL + db_url = os.environ.get('PROD_DATABASE_URL') + if not db_url: + raise ValueError("Database URL not found in environment variables.") + + # 判断是否是 SQLite 链接 + if db_url.startswith('sqlite:///'): + # 保留原有的 SQLite 逻辑 + sqlite_db_path = db_url[len('sqlite:///'):] + conn = sqlite3.connect(sqlite_db_path) + return conn + + # 判断是否是 MySQL 链接 + elif db_url.startswith('mysql+pymysql://'): + # 解析 MySQL URL + parsed_url = urlparse(db_url) + mysql_host = parsed_url.hostname + mysql_port = parsed_url.port or 3306 # 默认端口 3306 + mysql_db = parsed_url.path.lstrip('/') + mysql_user = parsed_url.username + mysql_password = parsed_url.password + + # 连接到 MySQL 数据库 + conn = pymysql.connect( + host=mysql_host, + port=mysql_port, + user=mysql_user, + password=mysql_password, + db=mysql_db, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + return conn + + else: + raise ValueError(f"Unsupported database URL: {db_url}") + + except Exception as e: + print(f"Error connecting to database: {e}") + raise + + +def execute(sql, *params): + conn = get_conn() + lock=Lock() + lock.acquire() + cursor=conn.cursor() + try: + cursor.execute(sql, params) + conn.commit() + lock.release() + cursor.close() + conn.close() + except: + lock.release() + conn.rollback() + + +def get(sql, *params): + conn=get_conn() + lock=Lock() + lock.acquire() + try: + cursor=conn.cursor(cursor=pymysql.cursors.DictCursor) + cursor.execute(sql, params) + result=cursor.fetchone() + lock.release() + cursor.close() + conn.close() + return result + except: + lock.release() + return [] diff --git a/app/translate/excel.py b/app/translate/excel.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebd85b650e7374e8c6cb22e75ab9adf3d2a8c95 --- /dev/null +++ b/app/translate/excel.py @@ -0,0 +1,103 @@ +import threading +import openpyxl +from . import to_translate +from . import common +import os +import sys +import time +import datetime + +def start(trans): + # 允许的最大线程 + threads=trans['threads'] + if threads is None or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + start_time = datetime.datetime.now() + wb = openpyxl.load_workbook(trans['file_path']) + sheets = wb.get_sheet_names() + texts=[] + for sheet in sheets: + ws = wb.get_sheet_by_name(sheet) + read_row(ws.rows, texts) + + # print(texts) + max_run=max_threads if len(texts)>max_threads else len(texts) + before_active_count=threading.activeCount() + event=threading.Event() + while run_index<=len(texts)-1: + if threading.activeCount()0: + item=texts.pop(0) + text_count+=item['count'] + cell.value=item['text'] + # if text=="": + # text=value + # else: + # text=text+"\n"+value + # if text!=None and not common.is_all_punc(text): + # item=texts.pop(0) + # values=item['text'].split("\n") + # text_count+=item['count'] + # for cell in row: + # value=cell.value + # if value!=None and not common.is_all_punc(value): + # if len(values)>0: + # cell.value=values.pop(0) + return text_count + + + diff --git a/app/translate/gptpdf.py b/app/translate/gptpdf.py new file mode 100644 index 0000000000000000000000000000000000000000..acdc9e5f74943dfc943a3dec53d5d2f1e0f84e4d --- /dev/null +++ b/app/translate/gptpdf.py @@ -0,0 +1,92 @@ +import logging +import os +import datetime +from typing import Optional, Dict, Tuple, List + +from app.translate import to_translate +from .pdf_parse import parse_pdf # 假设原始PDF处理模块位于当前目录的gptpdf.py中 + +def start(trans: Dict) -> bool: + """ + PDF翻译任务启动方法 + 参数结构示例: + trans = { + 'id': 任务ID, + 'file_path': 源文件路径, + 'target_file': 目标文件路径, + 'api_key': OpenAI API密钥, + 'base_url': API基础地址, + 'model': 模型名称, + 'output_dir': 输出目录, + 'verbose': 是否保留中间文件, + 'temperature': 温度参数, + 'max_tokens': 最大token数, + 'top_p': top_p参数, + 'frequency_penalty': 频率惩罚参数, + 'run_complete': 是否调用完成回调, + # ...其他参数 + } + """ + try: + # 初始化输出目录 + output_dir = trans.get('output_dir', './temp_pdf') + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # 记录任务开始时间 + start_time = datetime.datetime.now() + + # 调用 parse_pdf 处理PDF + content, image_paths = parse_pdf( + pdf_path=trans['file_path'], + output_dir=output_dir, + prompt=None, # 使用默认提示词 + api_key=trans['api_key'], + base_url=trans['base_url']+'/v1', + model=trans['model'], + verbose=trans.get('verbose', False), + gpt_worker=int(trans.get('threads', 1)), # 默认使用 1 个线程 + temperature=trans.get('temperature', 0.5), + max_tokens=trans.get('max_tokens', 1000), + top_p=trans.get('top_p', 0.9), + frequency_penalty=trans.get('frequency_penalty', 1) + ) + + # 保存最终结果 + save_final_result(content, trans['target_file']) + + # 清理临时文件 + if not trans.get('verbose', False): + cleanup_temp_files(output_dir, image_paths) + + # 计算耗时 + end_time = datetime.datetime.now() + spend_time = (end_time - start_time).total_seconds() + + # 任务完成处理 + if trans.get('run_complete'): + to_translate.complete(trans, len(content), spend_time) + + return True + + except Exception as e: + error_msg = f"PDF处理失败: {str(e)}" + to_translate.error(trans['id'], error_msg) + return False + +def save_final_result(content: str, target_path: str) -> None: + """保存最终结果""" + with open(target_path, 'w', encoding='utf-8') as f: + f.write(content) + logging.info(f"结果已保存至:{target_path}") + +def cleanup_temp_files(output_dir: str, image_paths: List[str]) -> None: + """清理临时文件""" + for path in image_paths: + if os.path.exists(path): + os.remove(path) + if os.path.exists(output_dir): + os.rmdir(output_dir) + + + diff --git "a/app/translate/gptpdf\345\244\207\344\273\275.py" "b/app/translate/gptpdf\345\244\207\344\273\275.py" new file mode 100644 index 0000000000000000000000000000000000000000..4bf2e01259528f1c1467fd7e74bddaaba8a59d44 --- /dev/null +++ "b/app/translate/gptpdf\345\244\207\344\273\275.py" @@ -0,0 +1,371 @@ +import os +import re +from typing import List, Tuple, Optional, Dict +import logging +import threading +# from . import to_translate +import datetime +from . import common, to_translate +import time +import fitz # PyMuPDF +import shapely.geometry as sg +from shapely.geometry.base import BaseGeometry +from shapely.validation import explain_validity +import markdown +import pdfkit +import codecs +# from weasyprint import HTML +from pymdownx import superfences +from bs4 import BeautifulSoup +from PIL import Image + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +# This Default Prompt Using Chinese and could be changed to other languages. + +DEFAULT_PROMPT = """使用markdown语法,将图片中识别到的文字转换为markdown格式输出。你必须做到: +1. 输出和使用识别到的图片的相同的语言,例如,识别到英语的字段,输出的内容必须是英语。 +2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。 +3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、忽略掉长直线、忽略掉页码。 +再次强调,不要解释和输出无关的文字,直接输出图片中的内容。 +""" +DEFAULT_RECT_PROMPT = """图片中用红色框和名称(%s)标注出了一些区域。如果区域是表格或者图片,使用 ![]() 的形式插入到输出内容中,否则直接输出文字内容。 +""" +DEFAULT_ROLE_PROMPT = """你是一个PDF文档解析器,使用markdown和latex语法输出图片的内容。 +""" + + +def _is_near(rect1: BaseGeometry, rect2: BaseGeometry, distance: float = 20) -> bool: + """ + Check if two rectangles are near each other if the distance between them is less than the target. + """ + return rect1.buffer(0.1).distance(rect2.buffer(0.1)) < distance + + +def _is_horizontal_near(rect1: BaseGeometry, rect2: BaseGeometry, distance: float = 100) -> bool: + """ + Check if two rectangles are near horizontally if one of them is a horizontal line. + """ + result = False + if abs(rect1.bounds[3] - rect1.bounds[1]) < 0.1 or abs(rect2.bounds[3] - rect2.bounds[1]) < 0.1: + if abs(rect1.bounds[0] - rect2.bounds[0]) < 0.1 and abs(rect1.bounds[2] - rect2.bounds[2]) < 0.1: + result = abs(rect1.bounds[3] - rect2.bounds[3]) < distance + return result + + +def _union_rects(rect1: BaseGeometry, rect2: BaseGeometry) -> BaseGeometry: + """ + Union two rectangles. + """ + return sg.box(*(rect1.union(rect2).bounds)) + + +def _merge_rects(rect_list: List[BaseGeometry], distance: float = 20, horizontal_distance: Optional[float] = None) -> \ + List[BaseGeometry]: + """ + Merge rectangles in the list if the distance between them is less than the target. + """ + merged = True + while merged: + merged = False + new_rect_list = [] + while rect_list: + rect = rect_list.pop(0) + for other_rect in rect_list: + if _is_near(rect, other_rect, distance) or ( + horizontal_distance and _is_horizontal_near(rect, other_rect, horizontal_distance)): + rect = _union_rects(rect, other_rect) + rect_list.remove(other_rect) + merged = True + new_rect_list.append(rect) + rect_list = new_rect_list + return rect_list + + +def _adsorb_rects_to_rects(source_rects: List[BaseGeometry], target_rects: List[BaseGeometry], distance: float = 10) -> \ + Tuple[List[BaseGeometry], List[BaseGeometry]]: + """ + Adsorb a set of rectangles to another set of rectangles. + """ + new_source_rects = [] + for text_area_rect in source_rects: + adsorbed = False + for index, rect in enumerate(target_rects): + if _is_near(text_area_rect, rect, distance): + rect = _union_rects(text_area_rect, rect) + target_rects[index] = rect + adsorbed = True + break + if not adsorbed: + new_source_rects.append(text_area_rect) + return new_source_rects, target_rects + + +def _parse_rects(page: fitz.Page) -> List[Tuple[float, float, float, float]]: + """ + Parse drawings in the page and merge adjacent rectangles. + """ + + # 提取画的内容 + drawings = page.get_drawings() + + # 忽略掉长度小于30的水平直线 + is_short_line = lambda x: abs(x['rect'][3] - x['rect'][1]) < 1 and abs(x['rect'][2] - x['rect'][0]) < 30 + drawings = [drawing for drawing in drawings if not is_short_line(drawing)] + + # 转换为shapely的矩形 + rect_list = [sg.box(*drawing['rect']) for drawing in drawings] + + # 提取图片区域 + images = page.get_image_info() + image_rects = [sg.box(*image['bbox']) for image in images] + + # 合并drawings和images + rect_list += image_rects + + merged_rects = _merge_rects(rect_list, distance=10, horizontal_distance=100) + merged_rects = [rect for rect in merged_rects if explain_validity(rect) == 'Valid Geometry'] + + # 将大文本区域和小文本区域分开处理: 大文本相小合并,小文本靠近合并 + is_large_content = lambda x: (len(x[4]) / max(1, len(x[4].split('\n')))) > 5 + small_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if not is_large_content(x)] + large_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if is_large_content(x)] + _, merged_rects = _adsorb_rects_to_rects(large_text_area_rects, merged_rects, distance=0.1) # 完全相交 + _, merged_rects = _adsorb_rects_to_rects(small_text_area_rects, merged_rects, distance=5) # 靠近 + + # 再次自身合并 + merged_rects = _merge_rects(merged_rects, distance=10) + + # 过滤比较小的矩形 + merged_rects = [rect for rect in merged_rects if rect.bounds[2] - rect.bounds[0] > 20 and rect.bounds[3] - rect.bounds[1] > 20] + + return [rect.bounds for rect in merged_rects] + + +def _parse_pdf_to_images(pdf_path: str, output_dir: str = './') -> List[Tuple[str, List[str]]]: + """ + Parse PDF to images and save to output_dir. + """ + # 打开PDF文件 + pdf_document = fitz.open(pdf_path) + image_infos = [] + + for page_index, page in enumerate(pdf_document): + logging.info(f'parse page: {page_index}') + rect_images = [] + rects = _parse_rects(page) + for index, rect in enumerate(rects): + fitz_rect = fitz.Rect(rect) + # 保存页面为图片 + pix = page.get_pixmap(clip=fitz_rect, matrix=fitz.Matrix(4, 4)) + name = f'{page_index}_{index}.png' + pix.save(os.path.join(output_dir, name)) + rect_images.append(name) + # # 在页面上绘制红色矩形 + big_fitz_rect = fitz.Rect(fitz_rect.x0 - 1, fitz_rect.y0 - 1, fitz_rect.x1 + 1, fitz_rect.y1 + 1) + # 空心矩形 + page.draw_rect(big_fitz_rect, color=(1, 0, 0), width=1) + # 画矩形区域(实心) + # page.draw_rect(big_fitz_rect, color=(1, 0, 0), fill=(1, 0, 0)) + # 在矩形内的左上角写上矩形的索引name,添加一些偏移量 + text_x = fitz_rect.x0 + 2 + text_y = fitz_rect.y0 + 10 + text_rect = fitz.Rect(text_x, text_y - 9, text_x + 80, text_y + 2) + # 绘制白色背景矩形 + page.draw_rect(text_rect, color=(1, 1, 1), fill=(1, 1, 1)) + # 插入带有白色背景的文字 + page.insert_text((text_x, text_y), name, fontsize=10, color=(1, 0, 0)) + page_image_with_rects = page.get_pixmap(matrix=fitz.Matrix(3, 3)) + page_image = os.path.join(output_dir, f'{page_index}.png') + page_compress_image = os.path.join(output_dir, f'{page_index}-compress.png') + page_image_with_rects.save(page_image) + compress_image(page_image,page_compress_image) + # image_infos.append((page_image, rect_images)) + image_infos.append({'text': page_image,'type':'pdf_img', 'complete': False, 'content': ''}) + + pdf_document.close() + return image_infos + + +def _gpt_parse_images( + image_infos: List[Tuple[str, List[str]]], + prompt_dict: Optional[Dict] = None, + **args +) -> str: + """ + Parse images to markdown content. + """ + if isinstance(prompt_dict, dict) and 'prompt' in prompt_dict: + prompt = prompt_dict['prompt'] + logging.info("prompt is provided, using user prompt.") + else: + prompt = DEFAULT_PROMPT + logging.info("prompt is not provided, using default prompt.") + if isinstance(prompt_dict, dict) and 'rect_prompt' in prompt_dict: + rect_prompt = prompt_dict['rect_prompt'] + logging.info("rect_prompt is provided, using user prompt.") + else: + rect_prompt = DEFAULT_RECT_PROMPT + logging.info("rect_prompt is not provided, using default prompt.") + if isinstance(prompt_dict, dict) and 'role_prompt' in prompt_dict: + role_prompt = prompt_dict['role_prompt'] + logging.info("role_prompt is provided, using user prompt.") + else: + role_prompt = DEFAULT_ROLE_PROMPT + logging.info("role_prompt is not provided, using default prompt.") + + for image_index,image_info in enumerate(image_infos): + user_prompt = prompt + # if rect_images: + # user_prompt += rect_prompt + ', '.join(rect_images) + image_infos[image_index]['user_prompt']=user_prompt + + + + # output_path = os.path.join(output_dir, 'output.md') + # with open(output_path, 'w', encoding='utf-8') as f: + # f.write('\n\n'.join(contents)) + + # return '\n\n'.join(contents) + +def start(trans): + # 从 trans 中获取文件路径和输出目录 + pdf_path = trans['file_path'] + output_dir = trans['target_path_dir'] + + # 允许的最大线程 + threads = trans.get('threads', 10) + max_threads = max(1, int(threads)) + + # 当前执行的索引位置 + run_index = 0 + start_time = datetime.datetime.now() + + # 解析 PDF 文件 + image_infos = _parse_pdf_to_images(pdf_path, output_dir=output_dir) + + _gpt_parse_images( + image_infos=image_infos, + prompt_dict=None, + ) + + trans['role_prompt']=DEFAULT_ROLE_PROMPT + + # 使用 threading 方式处理 + max_run = min(max_threads, len(image_infos)) + before_active_count = threading.activeCount() + event = threading.Event() + + while run_index <= len(image_infos) - 1: + if threading.activeCount() < max_run + before_active_count: + if not event.is_set(): + thread = threading.Thread(target=to_translate.get, args=(trans, event, image_infos, run_index)) + thread.start() + run_index += 1 + else: + return False + + while True: + complete = True + for image_info in image_infos: + if not image_info['complete']: + complete = False + if complete: + break + else: + time.sleep(1) + + # print(image_infos) + # 处理完成后,写入结果 + try: + # c = canvas.Canvas(trans['target_file'], pagesize=letter) + # text = c.beginText(40, 750) # 设置文本开始的位置 + # text.setFont("Helvetica", 12) # 设置字体和大小 + md_file = os.path.join(output_dir, 'output.md') + with open(md_file, 'w', encoding='utf-8') as file: + for image_info in image_infos: + # text.textLine(image_info['text']) # 添加文本行 + # text.textLine("") # 添加空行作为分隔 + # write_pdf(c, image_info['text']); + file.write(image_info['text'] + '\n') + # write_to_pdf(md_file, trans['target_file']) + html_to_pdf(output_dir, md_file, trans['target_file']) + # c.save() # 保存 PDF 文件 + except Exception as e: + print(f"生成pdf失败: {md_file}: {e}") + return False + + end_time = datetime.datetime.now() + spend_time = common.display_spend(start_time, end_time) + # translate.complete(trans, len(image_infos), spend_time) + return True + +def compress_image(image_file,compress_image_file): + img=Image.open(image_file) + img_resized=img.resize((img.width//2, img.height//2), resample=Image.Resampling.NEAREST) + img_resized.save(compress_image_file,quality=30) + + +def html_to_pdf(output_dir, md_file, pdf_file): + extensions = [ + 'toc', # 目录,[toc] + 'extra', # 缩写词、属性列表、释义列表、围栏式代码块、脚注、在HTML的Markdown、表格 + ] + third_party_extensions = [ + 'mdx_math', # KaTeX数学公式,$E=mc^2$和$$E=mc^2$$ + 'markdown_checklist.extension', # checklist,- [ ]和- [x] + 'pymdownx.magiclink', # 自动转超链接, + 'pymdownx.caret', # 上标下标, + 'pymdownx.superfences', # 多种块功能允许嵌套,各种图表 + 'pymdownx.betterem', # 改善强调的处理(粗体和斜体) + 'pymdownx.mark', # 亮色突出文本 + 'pymdownx.highlight', # 高亮显示代码 + 'pymdownx.tasklist', # 任务列表 + 'pymdownx.tilde', # 删除线 + ] + extensions.extend(third_party_extensions) + extension_configs = { + 'mdx_math': { + 'enable_dollar_delimiter': True # 允许单个$ + }, + 'pymdownx.superfences': { + "custom_fences": [ + { + 'name': 'mermaid', # 开启流程图等图 + 'class': 'mermaid', + 'format': superfences.fence_div_format + } + ] + }, + 'pymdownx.highlight': { + 'linenums': True, # 显示行号 + 'linenums_style': 'pymdownx-inline' # 代码和行号分开 + }, + 'pymdownx.tasklist': { + 'clickable_checkbox': True, # 任务列表可点击 + } + } + with codecs.open(md_file, "r", encoding="utf-8") as f: + md_content = f.read() + + html_file = os.path.join(output_dir, 'output.html') + html_final_file = os.path.join(output_dir, 'output-final.html') + html_content = markdown.markdown(md_content, extensions=extensions, extension_configs=extension_configs) + with codecs.open(html_file, "w", encoding="utf-8") as f: + # 加入文件头防止中文乱码 + f.write('') + f.write('') + f.write(html_content) + + + # 优化html中的图片信息 + with codecs.open(html_file, "r", encoding="utf-8") as f: + soup = BeautifulSoup(f, features="lxml") + image_content = soup.find_all("img") + for i in image_content: + i["style"] = "max-width:100%; overflow:hidden;" + with codecs.open(html_final_file, "w", encoding="utf-8") as g: + g.write(soup.prettify()) + + pdfkit.from_file(html_final_file, pdf_file) + diff --git a/app/translate/main.py b/app/translate/main.py new file mode 100644 index 0000000000000000000000000000000000000000..72f125a7bad11a312fa46154a12d7364698e57b5 --- /dev/null +++ b/app/translate/main.py @@ -0,0 +1,139 @@ +import threading +import openai +import os +import sys +import time +import getopt +from . import to_translate +from . import word +import excel +import powerpoint +import pdf +import gptpdf +import txt +import csv_handle +import md +import pymysql +import db +from . import common +import traceback +from . import rediscon + +# 当前正在执行的线程 +run_threads=0 + +def main(): + global run_threads + # 允许的最大线程 + max_threads=10 + # 当前执行的索引位置 + run_index=0 + # 是否保留原文 + keep_original=False + # 要翻译的文件路径 + file_path='' + # 翻译后的目标文件路径 + target_file='' + uuid=sys.argv[1] + storage_path=sys.argv[2] + trans=db.get("select * from translate where uuid=%s", uuid) + translate_id=trans['id'] + origin_filename=trans['origin_filename'] + origin_filepath=trans['origin_filepath'] + target_filepath=trans['target_filepath'] + api_key=trans['api_key'] + api_url=trans['api_url'] + comparison=get_comparison(trans['comparison_id']) + prompt=get_prompt(trans['prompt_id'], comparison) + if comparison: + prompt = ( + "术语对照表:\n" + f"{comparison}\n" + "请按照以下规则进行翻译:\n" + "1. 遇到术语时,请使用术语对照表中的对应翻译,无论翻译成什么语言。\n" + "2. 未在术语对照表中的文本,请遵循翻译说明进行翻译。\n" + "3. 确保翻译结果不包含原文或任何解释。\n" + "翻译说明:\n" + f"{prompt}" + ) + trans['prompt']=prompt + + file_path=storage_path+origin_filepath + target_file=storage_path+target_filepath + + origin_path_dir=os.path.dirname(file_path) + target_path_dir=os.path.dirname(target_file) + + if not os.path.exists(origin_path_dir): + os.makedirs(origin_path_dir, mode=0o777, exist_ok=True) + + if not os.path.exists(target_path_dir): + os.makedirs(target_path_dir, mode=0o777, exist_ok=True) + + trans['file_path']=file_path + trans['target_file']=target_file + trans['storage_path']=storage_path + trans['target_path_dir']=target_path_dir + extension = origin_filename[origin_filename.rfind('.'):] + trans['extension']=extension + trans['run_complete']=True + item_count=0 + spend_time='' + try: + status=True + # 设置OpenAI API + to_translate.init_openai(api_url, api_key) + if extension=='.docx' or extension == '.doc': + status=word.start(trans) + elif extension=='.xls' or extension == '.xlsx': + status=excel.start(trans) + elif extension=='.ppt' or extension == '.pptx': + status=powerpoint.start(trans) + elif extension == '.pdf': + if pdf.is_scanned_pdf(trans['file_path']): + status=gptpdf.start(trans) + else: + status=pdf.start(trans) + elif extension == '.txt': + status=txt.start(trans) + elif extension == '.csv': + status=csv_handle.start(trans) + elif extension == '.md': + status=md.start(trans) + if status: + print("success") + #before_active_count=threading.activeCount() + #mredis.decr(api_url,threading_num-before_active_count) + # print(item_count + ";" + spend_time) + else: + #before_active_count=threading.activeCount() + #mredis.decr(api_url,threading_num-before_active_count) + print("翻译出错了") + except Exception as e: + to_translate.error(translate_id, str(e)) + exc_type, exc_value, exc_traceback = sys.exc_info() + line_number = exc_traceback.tb_lineno # 异常抛出的具体行号 + print(f"Error occurred on line: {line_number}") + #before_active_count=threading.activeCount() + #mredis.set(api_url,threading_num-before_active_count) + print(e) + +def get_prompt(prompt_id, comparison): + if prompt_id>0: + prompt=db.get("select content from prompt where id=%s and deleted_flag='N'", prompt_id) + if prompt and len(prompt['content'])>0: + return prompt['content'] + + prompt=db.get("select value from setting where `group`='other_setting' and alias='prompt'") + return prompt['value'] + +def get_comparison(comparison_id): + if comparison_id>0: + comparison=db.get("select content from comparison where id=%s and deleted_flag='N'", comparison_id) + if comparison and len(comparison['content'])>0: + return comparison['content'].replace(',',':').replace(';','\n'); + +if __name__ == '__main__': + main() + + diff --git a/app/translate/md.py b/app/translate/md.py new file mode 100644 index 0000000000000000000000000000000000000000..e54d08239d74aa69ee289ec4c00bf502578bdf90 --- /dev/null +++ b/app/translate/md.py @@ -0,0 +1,171 @@ +import os +import threading +from . import to_translate +from . import common +import datetime +import time +import re + +def start(trans): + # 允许的最大线程 + threads=trans['threads'] + if threads is None or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + start_time = datetime.datetime.now() + + try: + with open(trans['file_path'], 'r', encoding='utf-8') as file: + content = file.read() + except Exception as e: + print(f"无法读取文件 {trans['file_path']}: {e}") + return False + + trans_type=trans['type'] + keepBoth=True + if trans_type=="trans_text_only_inherit" or trans_type=="trans_text_only_new" or trans_type=="trans_all_only_new" or trans_type=="trans_all_only_inherit": + keepBoth=False + + # 按段落分割内容,始终使用换行符分隔 + paragraphs = content.split('\n') # 假设段落之间用换行符分隔 + # 支持最多单词量 + max_word = 1000 + texts = [] + current_text = "" # 用于累加当前段落 + + for paragraph in paragraphs: + if check_text(paragraph) or paragraph.strip() == "": # 检查段落是否有效或为空 + # if paragraph.strip() == "": + # # 如果是空行,直接加入到 texts + # texts.append({"text": "", "origin": "", "complete": True, "sub": False, "ext":"md"}) + # continue # 跳过后续处理,继续下一个段落 + + if keepBoth: + # 当 keepBoth 为 True 时,不累加 current_text + if len(paragraph) > max_word: + # 如果段落长度超过 max_word,进行拆分 + sub_paragraphs = split_paragraph(paragraph, max_word) + for sub_paragraph in sub_paragraphs: + # 直接将分段的内容追加到 texts + append_text(sub_paragraph, texts, True) + else: + # 如果段落长度不超过 max_word,直接加入 texts + append_text(paragraph, texts, False) + else: + # 当 keepBoth 为 False 时,处理 current_text 的逻辑 + if len(paragraph) > max_word: + # 如果当前累加的文本不为空,先将其追加到 texts + if current_text: + append_text(current_text, texts, False) + current_text = "" # 重置当前文本 + + # 分割段落并追加到 texts + sub_paragraphs = split_paragraph(paragraph, max_word) + for sub_paragraph in sub_paragraphs: + # 直接将分段的内容追加到 texts + append_text(sub_paragraph, texts, True) + else: + # 在追加之前判断是否超出 max_word + if len(current_text) + len(paragraph) > max_word: # 不再加1,因为我们要保留原有换行符 + # 如果超出 max_word,将 current_text 追加到 texts + append_text(current_text, texts, False) + current_text = "" # 重置当前文本 + + # 追加段落(保留原有换行符) + current_text += paragraph+"\n" # 直接追加段落,并加上换行符 + + # 在循环结束后,如果还有累加的文本,追加到 texts + append_text(current_text, texts, False) + # print(texts); + # exit() + max_run=max_threads if len(texts)>max_threads else len(texts) + before_active_count=threading.activeCount() + event=threading.Event() + while run_index<=len(texts)-1: + if threading.activeCount() max_length: + # 如果当前部分长度加上句子长度超过最大长度,保存当前部分 + parts.append(' '.join(current_part)) + current_part = [sentence] # 开始新的部分 + current_length = len(sentence) + else: + current_part.append(sentence) + current_length += len(sentence) + + # 添加最后一部分 + if current_part: + parts.append(' '.join(current_part)) + + return parts + +def append_text(text, texts, sub=False): + if check_text(text): + texts.append({"text": text, "origin": text, "complete": False, "sub": sub, "ext":"md"}) + else: + texts.append({"text": "", "origin": "", "complete": True, "sub": sub, "ext":"md"}) + +def check_text(text): + return text!=None and text!="\n" and len(text)>0 and not common.is_all_punc(text) diff --git a/app/translate/newpdf.py b/app/translate/newpdf.py new file mode 100644 index 0000000000000000000000000000000000000000..030fadf8d913835a426bcca03789d86382ad3d51 --- /dev/null +++ b/app/translate/newpdf.py @@ -0,0 +1,346 @@ +from . import common +import datetime +import fitz +import os +import re +import shutil +import subprocess +import threading +import time +from . import to_translate + +def start(trans): + # 允许的最大线程 + threads = trans['threads'] + if threads is None or threads == "" or int(threads) < 0: + max_threads = 10 + else: + max_threads = int(threads) + # 当前执行的索引位置 + run_index = 0 + max_chars = 1000 + start_time = datetime.datetime.now() + # 创建PDF文件 + try: + src_pdf = fitz.open(trans['file_path']) + except Exception as e: + to_translate.error(trans['id'], "无法访问该文档") + return False + texts = [] + api_url = trans['api_url'] + trans_type = trans['type'] + if trans_type == "trans_text_only_inherit": + # 仅文字-保留原文-继承原版面 + read_block_text(src_pdf, texts) + elif trans_type == "trans_text_only_new" or trans_type == "trans_text_both_new": + # 仅文字-保留原文-重排 + read_block_text(src_pdf, texts) + elif trans_type == "trans_text_both_inherit": + # 仅文字-保留原文-重排/继承原版面 + read_block_text(src_pdf, texts) + elif trans_type == "trans_all_only_new": + # 全部内容-仅译文-重排版面 + read_block_text(src_pdf, texts) + elif trans_type == "trans_all_only_inherit": + # 全部内容-仅译文-重排版面/继承原版面 + read_block_text(src_pdf, texts) + elif trans_type == "trans_all_both_new": + # 全部内容-保留原文-重排版面 + read_block_text(src_pdf, texts) + elif trans_type == "trans_all_both_inherit": + # 全部内容-保留原文-继承原版面 + read_block_text(src_pdf, texts) + # print(texts) + # exit(); + uuid = trans['uuid'] + html_path = trans['storage_path'] + '/uploads/' + uuid + trans['html_path'] = html_path + read_page_images(src_pdf, texts) + max_run = max_threads if len(texts) > max_threads else len(texts) + event = threading.Event() + before_active_count = threading.activeCount() + while run_index <= len(texts) - 1: + if threading.activeCount() < max_run + before_active_count: + if not event.is_set(): + thread = threading.Thread(target=to_translate.get, args=(trans, event, texts, run_index)) + thread.start() + run_index += 1 + else: + return False + + while True: + if event.is_set(): + return False + complete = True + for text in texts: + if not text['complete']: + complete = False + if complete: + break + else: + time.sleep(1) + text_count = 0 + if trans_type == "trans_text_only_inherit": + # 仅文字-仅译文-继承原版面。 + write_block_text(src_pdf, texts, text_count, True) # DONE + elif trans_type == "trans_text_only_new": + # 仅文字-仅译文-重排 + write_block_text(src_pdf, texts, text_count, True) # DONE + elif trans_type == "trans_text_both_new": + # 仅文字-保留原文-重排 + write_block_both(src_pdf, texts, text_count, True) # DONE + elif trans_type == "trans_text_both_inherit": + # 仅文字-保留原文-继承原版面 + write_block_both(src_pdf, texts, text_count, True) # DONE + elif trans_type == "trans_all_only_new": + # 全部内容-仅译文-重排版面 + write_block_text(src_pdf, texts, text_count, False) # DONE + elif trans_type == "trans_all_only_inherit": + # 全部内容-仅译文-继承原版面 + write_block_text(src_pdf, texts, text_count, False) # DONE + elif trans_type == "trans_all_both_new": + # 全部内容-保留原文-重排版面 + write_block_both(src_pdf, texts, text_count, False) # DONE + elif trans_type == "trans_all_both_inherit": + # 全部内容-保留原文-继承原版面 + write_block_both(src_pdf, texts, text_count, False) # DONE + + end_time = datetime.datetime.now() + spend_time = common.display_spend(start_time, end_time) + to_translate.complete(trans, text_count, spend_time) + return True + + +def read_page_images(pages, texts): + for index, page in enumerate(pages): + html = page.get_text("xhtml") + images = re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + for i, image in enumerate(images): + append_text(image, 'image', texts) + + +def read_block_text(pages, texts): + text = "" + for page in pages: + last_x0 = 0 + last_x1 = 0 + for block in page.get_text("blocks"): + current_x1 = block[2] + current_x0 = block[0] + # 对于每个文本块,分行并读取 + if block[5] == 0 or abs(current_x1 - last_x1) > 12 or abs(current_x0 - last_x0) > 12: + append_text(text, "text", texts) + text = block[4].replace("\n", "") + else: + text = text + (block[4].replace("\n", "")) + last_x1 = block[2] + last_x0 = block[0] + append_text(text, "text", texts) + + +def write_block_text(pages, newpdf, texts): + text = "" + for page in pages: + last_x0 = 0 + last_x1 = 0 + last_y0 = 0 + new_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + font = fitz.Font("helv") + for block in page.get_text("blocks"): + current_x1 = block[2] + current_x0 = block[0] + current_y0 = block[1] + # 对于每个文本块,分行并读取 + if block[5] == 0 or abs(current_x1 - last_x1) > 12 or abs(current_x0 - last_x0) > 12 and len(texts) > 0: + item = texts.pop(0) + trans_text = item.get("text", "") + new_page.insert_text((last_x0, last_y0), trans_text, fontsize=12, fontname="Helvetica", overlay=False) + text = block[4].replace("\n", "") + else: + text = text + (block[4].replace("\n", "")) + last_x1 = block[2] + last_x0 = block[0] + last_y0 = block[1] + if check_text(text) and len(texts): + new_page.insert_text((last_x0, last_y0), trans_text, fontsize=12, overlay=False) + + +def write_block_both(pages, newpdf, texts): + text = "" + old_text = "" + for page in pages: + last_x0 = 0 + last_x1 = 0 + last_y0 = 0 + new_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + old_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + font = fitz.Font("helv") + for block in page.get_text("blocks"): + current_x1 = block[2] + current_x0 = block[0] + current_y0 = block[1] + # 对于每个文本块,分行并读取 + if block[5] == 0 or abs(current_x1 - last_x1) > 12 or abs(current_x0 - last_x0) > 12 and len(texts) > 0: + item = texts.pop(0) + trans_text = item.get("text", "") + new_page.insert_text((last_x0, last_y0), trans_text, fontsize=12, fontname="Helvetica", overlay=False) + text = block[4].replace("\n", "") + old_page.insert_text((last_x0, last_y0), text, fontsize=12, fontname="Helvetica", overlay=False) + else: + text = text + (block[4].replace("\n", "")) + last_x1 = block[2] + last_x0 = block[0] + last_y0 = block[1] + if check_text(text) and len(texts): + new_page.insert_text((last_x0, last_y0), trans_text, fontsize=12, overlay=False) + old_page.insert_text((last_x0, last_y0), text, fontsize=12, fontname="Helvetica", overlay=False) + + +def write_page_text(pages, newpdf, texts): + for page in pages: + text = page.get_text("text") + new_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + if check_text(text) and len(texts) > 0: + item = texts.pop(0) + text = item.get("text", "") + new_page.insert_text((0, 0), text, fontsize=12, overlay=False) + + +def read_row(pages, texts): + text = "" + for page in pages: + # 获取页面的文本块 + for block in page.get_text("blocks"): + # 对于每个文本块,分行并读取 + if block[5] == 0: + append_text(text, 'text', texts) + text = block[4] + else: + text = text + block[4] + + +def write_row(newpdf, texts, page_width, page_height): + text_count = 0 + new_page = newpdf.new_page(width=page_width, height=page_height) + for text in texts: + print(text['text']) + # draw_text_avoid_overlap(new_page, text['text'],text['block'][0],text['block'][1], 16) + new_page.insert_text((text['block'][0], text['block'][1]), text['text'], fontsize=16) + return + + +def append_text(text, content_type, texts): + if check_text(text): + # print(text) + texts.append({"text": text, "type": content_type, "complete": False}) + + +def check_text(text): + return text != None and len(text) > 0 and not common.is_all_punc(text) + + +def draw_text_avoid_overlap(page, text, x, y, font_size): + """ + 在指定位置绘制文本,避免与现有文本重叠。 + """ + text_length = len(text) * font_size # 估算文本长度 + while True: + text_box = page.get_textbox((x, y, x + text_length, y + font_size)) + if not text_box: + break # 没有重叠的文本,退出循环 + y += font_size + 1 # 移动到下一个位置 + + page.insert_text((x, y), text, fontsize=font_size) + + +def draw_table(page, table_data, x, y, width, cell_height): + # 表格的列数 + cols = len(table_data[0]) + rows = len(table_data) + + # 绘制表格 + for i in range(rows): + for j in range(cols): + # 文字写入 + txt = table_data[i][j] + page.insert_text((x, y), txt) + # 绘制单元格边框 (仅边界线) + # 左边 + page.draw_line((x, y), (x + width / cols, y), width=0.5) + # 上边 + if i == 0: + page.draw_line((x, y), (x, y + cell_height), width=0.5) + # 右边 + if j == cols - 1: + page.draw_line((x + width / cols, y), (x + width / cols, y + cell_height), width=0.5) + # 下边 + if i == rows - 1: + page.draw_line((x, y + cell_height), (x + width / cols, y + cell_height), width=0.5) + # 移动到下一个单元格 + x += width / cols + # 移动到下一行 + x = 0 + y += cell_height + + +def wrap_text(text, width): + words = text.split(' ') + lines = [] + line = "" + for word in words: + if len(line.split(' ')) >= width: + lines.append(line) + line = "" + if len(line + word + ' ') <= width * len(word): + line += word + ' ' + else: + lines.append(line) + line = word + ' ' + if line: + lines.append(line) + return lines + + +def is_paragraph(block): + # 假设一个段落至少有两行 + if len(block) < 2: + return False + # 假设一个段落的行间隔较大 + if max([line.height for line in block]) / min([line.height for line in block]) > 1.5: + return True + return False + + +def is_next_line_continuation(page, current_line, next_line_index): + # 判断下一行是否是当前行的继续 + return abs(next_line_index - current_line) < 0.1 + + +def print_texts(texts): + for item in texts: + print(item.get("text")) + + +def is_scan_pdf(pages): + for index, page in enumerate(pages): + html = page.get_text("xhtml") + images = re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + text = page.get_text() + if text == "" and len(images) > 0: + return True + else: + return False + + +def read_pdf_html(pages, texts, trans): + for index, page in enumerate(pages): + target_html = "{}-{}.html".format(trans['html_path'], page_index) + if os.path.exists(target_html): + os.remove(target_html) + dftohtml_path = shutil.which("pdftohtml") + if pdftohtml_path is None: + raise Exception("未安装pdftohtml") + subprocess.run([dftohtml_path, "-c", "-l", page_index, trans['file_path'], trans['html_path']]) + if not os.path.exists(target_html): + raise Exception("无法生成html") + # append_text(html,'text', texts) diff --git a/app/translate/newpdf2.py b/app/translate/newpdf2.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6b794719e6fcd85c983648483a02468fe4da5d --- /dev/null +++ b/app/translate/newpdf2.py @@ -0,0 +1,42 @@ +from pypdf import PdfWriter, PdfReader + +output = PdfWriter('out.pdf') +input1 = PdfReader("document1.pdf", "rb") + + # add page 1 from input1 to output document, unchanged +output.add_page(input1.get_page(0)) + + # add page 2 from input1, but rotated clockwise 90 degrees +output.add_page(input1.get_page(1).rotate(90)) + + # add page 3 from input1, rotated the other way: +output.add_page(input1.get_page(2).rotate(180)) +output.add_page(input1.get_page(3).rotate(270)) +output.write('out.pdf') + +from pypdf import PdfWriter, PdfReader +from pypdf.generic import ( + ArrayObject, + ContentStream, + DictionaryObject, + EncodedStreamObject, + FloatObject, + IndirectObject, + NameObject, + NullObject, + NumberObject, + PdfObject, + RectangleObject, + StreamObject, + TextStringObject, + is_null_or_none, +) +output = PdfWriter() +input1 = PdfReader("document1.pdf", "rb") + +page_nums=input1.get_num_pages() +for page_num in range(page_nums): + page= input1.get_page(page_num) + original_content=page.extract_text(extraction_mode="layout", layout_mode_strip_rotated=True) + print(page['/Contents']) + print(original_content) \ No newline at end of file diff --git a/app/translate/pdf.py b/app/translate/pdf.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3203d8db006683e2c92837861b398cabb42765 --- /dev/null +++ b/app/translate/pdf.py @@ -0,0 +1,801 @@ +import platform +import tempfile +import threading +import traceback + +import fitz +import re +from . import to_translate +from . import common +import io +import sys +import time +import datetime +from docx import Document +from docx.shared import Pt, RGBColor +# import pdfkit +import subprocess +import base64 +import pdf2docx +from . import word +import copy +from io import BytesIO +from PIL import Image,ImageDraw +import pytesseract +import uuid +from pdfdeal import Doc2X +# from weasyprint import HTML +import os +from docx2pdf import convert +import shutil +pytesseract.pytesseract.tesseract_cmd = r'/usr/local/bin/tesseract' + + +# -----word转pdf +def docxtopdf(docx_path, pdf_path): + # 如果目标 PDF 文件已存在,则删除 + # if os.path.exists(pdf_path): + # os.remove(pdf_path) + + # 确保目标目录存在 + target_path_dir = os.path.dirname(pdf_path) + if not os.path.exists(target_path_dir): + os.makedirs(target_path_dir, mode=0o777, exist_ok=True) + + # 根据操作系统选择方案 + if platform.system() == "Windows": + # Windows 方案:使用 pywin32 调用 Microsoft Word + try: + import win32com.client + word = win32com.client.Dispatch("Word.Application") + word.Visible = False # 不显示 Word 界面 + doc = word.Documents.Open(docx_path) + doc.SaveAs(pdf_path, FileFormat=17) # 17 是 PDF 格式 + doc.Close() + word.Quit() + print("转换成功!") + except Exception as e: + print(f"Windows 方案转换失败: {e}") + else: + # Linux/macOS 方案:使用 unoconv + sys.path.append("/usr/local/bin") # 添加 unoconv 可能的路径 + unoconv_path = shutil.which("unoconv") + if unoconv_path is None: + raise Exception("未安装 unoconv,请先安装 unoconv 或 LibreOffice") + try: + command = [unoconv_path, "-f", "pdf", "-o", pdf_path, docx_path] + print("{} -f pdf -o {} {}".format(unoconv_path, pdf_path, docx_path)) + subprocess.run(command) + print("转换成功!") + except subprocess.CalledProcessError as e: + print(f"Linux 方案转换失败: {e}") + +def start11(trans): + texts=[] + src_pdf = fitz.open(trans['file_path']) + # print(is_scan_pdf(src_pdf)) + # exit() + # if is_scan_pdf(src_pdf): + start_time = datetime.datetime.now() + origin_docx_path=os.path.dirname(trans['file_path'])+"/"+trans['uuid']+".docx" + target_docx_path=os.path.dirname(trans['file_path'])+"/"+trans['uuid']+"-translated.docx" + target_pdf_path=os.path.dirname(trans['file_path'])+"/"+trans['uuid']+".pdf" + # target_pdf_path = trans['file_path'] + # target_docx_path=re.sub(r"\.pdf",".docx",trans['target_file'], flags=re.I) + # pdf_path=re.sub(r"\.pdf",".docx",trans['file_path'], flags=re.I) + # print(target_pdf_path+"\n") + # print(trans['storage_path']+"\n") + # print(trans['target_file']+"\n") + # print(os.path.join(trans['storage_path'], trans['target_filepath'])+"\n") + pdftodocx(trans['file_path'], origin_docx_path) + word_trans=copy.copy(trans) + word_trans['file_path']=origin_docx_path + word_trans['target_file']=target_docx_path + word_trans['run_complete']=False + word_trans['extension']='.docx' + text_count=0 + + if word.start(word_trans): + # print("word done") + docxtopdf(target_docx_path, target_pdf_path) + shutil.move(target_pdf_path, trans['target_file']) + end_time = datetime.datetime.now() + spend_time=common.display_spend(start_time, end_time) + to_translate.complete(trans,text_count,spend_time) + return True + # return False + + uuid=trans['uuid'] + html_path=trans['storage_path']+'/uploads/'+uuid + trans['html_path']=html_path + # read_pdf_html(trans['file_path'], html_path) + # print(trans['storage_path']+'/uploads/pdf.html') + # exit() + # 允许的最大线程 + # print(trans) + # wkhtmltopdf_bin=common.find_command_location("wkhtmltopdf") + threads=trans['threads'] + if threads is None or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + start_time = datetime.datetime.now() + # print(f'Source pdf file: {} \n', trans['file_path']) + + read_page_images(src_pdf, texts) + + text_count=0 + # translate.get_models() + # exit() + # read_page_html(src_pdf, texts, trans) + # read_pdf_html(src_pdf, texts, trans) + pdftohtml(trans['file_path'], html_path, texts) + src_pdf.close() + + # print(texts) + # exit() + + max_run=max_threads if len(texts)>max_threads else len(texts) + event=threading.Event() + before_active_count=threading.activeCount() + while run_index<=len(texts)-1: + if threading.activeCount() max_threads else len(texts) + event = threading.Event() + before_active_count = threading.activeCount() + while run_index <= len(texts) - 1: + if threading.activeCount() < max_run + before_active_count: + if not event.is_set(): + # print("run_index:",run_index) + thread = threading.Thread(target=translate.get, + args=(trans, event, texts, run_index)) + thread.start() + run_index += 1 + else: + return False + + while True: + if event.is_set(): + return False + complete = True + for text in texts: + if not text['complete']: + complete = False + if complete: + break + else: + time.sleep(1) + + # print(texts) + + write_to_html_file(html_path, texts) + # config = pdfkit.configuration(wkhtmltopdf="/usr/local/bin/wkhtmltopdf") + # with open(html_path) as f: + # pdfkit.from_file(f, trans['target_file'],options={"enable-local-file-access":True}, configuration=config) + + # print(trans['target_file']) + + end_time = datetime.datetime.now() + spend_time = common.display_spend(start_time, end_time) + to_translate.complete(trans, text_count, spend_time) + return True + +# ------------------------------- +# def read_to_html(pages): + +def read_page_html(pages, texts, trans): + storage_path=trans['storage_path'] + uuid=trans['uuid'] + if is_scan_pdf(pages): + for index,page in enumerate(pages): + html=page.get_text("xhtml") + images=re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + for i,image in enumerate(images): + append_text(image, 'image', texts) + + else: + for index,page in enumerate(pages): + html=page.get_text("xhtml") + # images=re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + # for i,image in enumerate(images): + append_text(html,'text', texts) + +def read_page_images(pages, texts): + for index,page in enumerate(pages): + html=page.get_text("xhtml") + images=re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + for i,image in enumerate(images): + append_text(image, 'image', texts) + +def write_to_html_file(html_path,texts): + with open(html_path, 'w+') as f: + f.write('') + for item in texts: + f.write(item.get("text", "")) + f.write('') + f.close() + +def read_block_text(pages,texts): + text="" + for page in pages: + last_x0=0 + last_x1=0 + html=page.get_text("html") + with open("test.html",'a+') as f: + f.write(html) + f.close() + exit() + for block in page.get_text("blocks"): + current_x1=block[2] + current_x0=block[0] + # 对于每个文本块,分行并读取 + if block[5]==0 or abs(current_x1-last_x1)>12 or abs(current_x0-last_x0)>12: + append_text(text, "text", texts) + text=block[4].replace("\n","") + else: + text=text+(block[4].replace("\n","")) + last_x1=block[2] + last_x0=block[0] + append_text(text, "text", texts) + +def write_block_text(pages,newpdf,texts): + text="" + for page in pages: + last_x0=0 + last_x1=0 + last_y0=0 + new_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + font=fitz.Font("helv") + for block in page.get_text("blocks"): + current_x1=block[2] + current_x0=block[0] + current_y0=block[1] + # 对于每个文本块,分行并读取 + if block[5]==0 or abs(current_x1-last_x1)>12 or abs(current_x0-last_x0)>12 and len(texts)>0: + item=texts.pop(0) + trans_text=item.get("text","") + new_page.insert_text((last_x0,last_y0), trans_text, fontsize=12,fontname="Helvetica", overlay=False) + text=block[4].replace("\n","") + else: + text=text+(block[4].replace("\n","")) + last_x1=block[2] + last_x0=block[0] + last_y0=block[1] + if check_text(text) and len(texts): + new_page.insert_text((last_x0,last_y0), trans_text, fontsize=12, overlay=False) + +def write_page_text(pages,newpdf,texts): + for page in pages: + text=page.get_text("text") + new_page = newpdf.new_page(width=page.rect.width, height=page.rect.height) + if check_text(text) and len(texts)>0: + item=texts.pop(0) + text=item.get("text","") + new_page.insert_text((0,0), text, fontsize=12, overlay=False) + +def read_row(pages,texts): + text="" + for page in pages: + # 获取页面的文本块 + for block in page.get_text("blocks"): + # 对于每个文本块,分行并读取 + if block[5]==0: + append_text(text, 'text', texts) + text=block[4] + else: + text=text+block[4] + +def write_row(newpdf, texts, page_width, page_height): + text_count=0 + new_page = newpdf.new_page(width=page_width, height=page_height) + for text in texts: + print(text['text']) + # draw_text_avoid_overlap(new_page, text['text'],text['block'][0],text['block'][1], 16) + new_page.insert_text((text['block'][0],text['block'][1]),text['text'], fontsize=16) + return + + + +def append_text(text, content_type, texts): + if check_text(text): + # print(text) + texts.append({"text":text,"type":content_type, "complete":False}) + + +def check_text(text): + return text!=None and len(text)>0 and not common.is_all_punc(text) + +def draw_text_avoid_overlap(page, text, x, y, font_size): + """ + 在指定位置绘制文本,避免与现有文本重叠。 + """ + text_length = len(text) * font_size # 估算文本长度 + while True: + text_box = page.get_textbox((x, y, x + text_length, y + font_size)) + if not text_box: + break # 没有重叠的文本,退出循环 + y += font_size + 1 # 移动到下一个位置 + + page.insert_text((x,y),text, fontsize=font_size) + + +def draw_table(page, table_data, x, y, width, cell_height): + # 表格的列数 + cols = len(table_data[0]) + rows = len(table_data) + + # 绘制表格 + for i in range(rows): + for j in range(cols): + # 文字写入 + txt = table_data[i][j] + page.insert_text((x, y), txt) + # 绘制单元格边框 (仅边界线) + # 左边 + page.draw_line((x, y),( x+width/cols, y), width=0.5) + # 上边 + if i == 0: + page.draw_line((x, y), (x, y+cell_height), width=0.5) + # 右边 + if j == cols-1: + page.draw_line((x+width/cols, y), (x+width/cols, y+cell_height), width=0.5) + # 下边 + if i == rows-1: + page.draw_line((x, y+cell_height), (x+width/cols, y+cell_height), width=0.5) + # 移动到下一个单元格 + x += width/cols + # 移动到下一行 + x = 0 + y += cell_height + +def wrap_text(text, width): + words = text.split(' ') + lines = [] + line = "" + for word in words: + if len(line.split(' ')) >= width: + lines.append(line) + line = "" + if len(line + word + ' ') <= width * len(word): + line += word + ' ' + else: + lines.append(line) + line = word + ' ' + if line: + lines.append(line) + return lines + + +def is_paragraph(block): + # 假设一个段落至少有两行 + if len(block) < 2: + return False + # 假设一个段落的行间隔较大 + if max([line.height for line in block]) / min([line.height for line in block]) > 1.5: + return True + return False + +def is_next_line_continuation(page, current_line, next_line_index): + # 判断下一行是否是当前行的继续 + return abs(next_line_index - current_line) < 0.1 + +def print_texts(texts): + for item in texts: + print(item.get("text")) + +def is_scan_pdf(pages): + for index,page in enumerate(pages): + html=page.get_text("xhtml") + images=re.findall(r"(data:image/\w+;base64,[^\"]+)", html) + text=page.get_text() + print(images) + print(text) + if text=="" and len(images)>0: + return True + else: + return False + +def read_pdf_html(pages, texts, trans): + for index,page in enumerate(pages): + target_html="{}-{}.html".format(trans['html_path'], page_index) + if os.path.exists(target_html): + os.remove(target_html) + dftohtml_path = shutil.which("pdftohtml") + if pdftohtml_path is None: + raise Exception("未安装pdftohtml") + subprocess.run([dftohtml_path,"-c","-l", page_index, trans['file_path'], trans['html_path']]) + if not os.path.exists(target_html): + raise Exception("无法生成html") + # append_text(html,'text', texts) + + +def pdftohtml(pdf_path, html_path,texts): + target_html="{}-html.html".format(html_path) + if os.path.exists(target_html): + os.remove(target_html) + pdftohtml_path = shutil.which("pdftohtml") + if pdftohtml_path is None: + raise Exception("未安装pdftohtml") + subprocess.run([pdftohtml_path,"-c","-s", pdf_path, html_path]) + if not os.path.exists(target_html): + raise Exception("无法生成html") + with open(target_html, 'r') as f: + content=f.read() + print(content) + append_text(content, 'text', texts) + + +def pdftodocx(pdf_path, docx_path): + print(docx_path) + if os.path.exists(docx_path): + os.remove(docx_path) + print(pdf_path) + try: + cv = pdf2docx.Converter(pdf_path) + cv.debug_page(0) + cv.convert(docx_path, start=0,end=1,multi_processing=False) + cv.close() + #exit() + except Exception as e: + print("error") + pdf2docxNext(pdf_path, docx_path) + +def pdf2docxNext(pdf_path, docx_path): + try: + # 创建一个新的 DOCX 文档 + doc = Document() + # 打开 PDF 文件 + pdf_document = fitz.open(pdf_path) + # 遍历 PDF 的每一页 + for page_num in range(len(pdf_document)): + page = pdf_document[page_num] + fonts=page.get_fonts() + # 提取文本 + # 提取文本和样式信息 + text_dict = page.get_text("dict") + + # 遍历文本块 + for block in text_dict["blocks"]: + if block["type"] == 0: # 只处理文本块 + for line in block["lines"]: + for span in line["spans"]: + text = span["text"] + font_size = span["size"] # 字体大小 + font_color = span["color"] # 字体颜色 + + # 创建段落 + paragraph = doc.add_paragraph() + run = paragraph.add_run(text) + + # 设置字体大小 + run.font.size = Pt(font_size) + + # 设置字体颜色 + if font_color: + run.font.color.rgb = RGBColor( + (font_color >> 16) & 0xFF, # R + (font_color >> 8) & 0xFF, # G + font_color & 0xFF # B + ) + elif block["type"] == 1: + # 提取图像 + try: + img_index = block["image"] + base_image = pdf_document.extract_image(img_index) + image_bytes = base_image["image"] + image_ext = base_image["ext"] + # 将图像添加到 DOCX + image_stream = BytesIO(image_bytes) + doc.add_picture(image_stream, width=None) # 可以指定宽度 + except Exception as e: + print("图片无法解析") + + + + # 添加分页符 + doc.add_page_break() + + # 保存 DOCX 文件 + doc.save(docx_path) + pdf_document.close() + except Exception as e: + raise("pdf转docx失败") + +# 舍弃 +def docxtopdf6(docx_path, pdf_path): + """ + 使用 docx2pdf 库实现跨平台 DOCX 转 PDF + 保持原始逻辑:删除已存在的PDF、创建目录、错误处理 + """ + + # 删除已存在的PDF文件(保留原始逻辑) + if os.path.exists(pdf_path): + try: + os.remove(pdf_path) + except Exception as e: + raise RuntimeError(f"无法删除旧PDF文件 {pdf_path}: {str(e)}") + + # 创建输出目录(优化权限设置) + target_dir = os.path.dirname(pdf_path) + if not os.path.exists(target_dir): + try: + os.makedirs(target_dir, exist_ok=True) # 去除明确的 0o777 权限 + except Exception as e: + raise RuntimeError(f"无法创建目录 {target_dir}: {str(e)}") + + # 执行转换(替换核心实现) + try: + print(f"正在转换: {docx_path} → {pdf_path}") # 保留日志输出 + convert(docx_path, pdf_path) # 核心转换调用 + + # 验证转换结果 + if not os.path.exists(pdf_path): + raise RuntimeError("转换成功但未生成预期输出文件") + + print("转换完成") # 保留完成提示 + + except Exception as e: + # 增强错误信息 + error_msg = f"DOCX转PDF失败: {str(e)}" + if "No such file or directory" in str(e): + error_msg += " (请检查输入文件路径)" + elif "Permission denied" in str(e): + error_msg += " (权限不足)" + raise RuntimeError(error_msg) + + + + + +# 旧方案 +def docxtopdf11111(docx_path, pdf_path): + if os.path.exists(pdf_path): + os.remove(pdf_path) + sys.path.append("/usr/local/bin") + unoconv_path = shutil.which("unoconv") + if unoconv_path is None: + raise Exception("未安装unoconv") + target_path_dir=os.path.dirname(pdf_path) + if not os.path.exists(target_path_dir): + os.makedirs(target_path_dir, mode=0o777, exist_ok=True) + # target_pdf = fitz.Document() + # target_pdf.new_page() + # target_pdf.save(pdf_path) + # target_pdf.close() + # subprocess.run([unoconv_path,"-f","pdf","-e","UTF-8","-o",target_path_dir, docx_path]) + # subprocess.run([unoconv_path,"-f","pdf","-e","UTF-8","-o",target_path_dir, docx_path]) + print("{} -f pdf -o {} {}".format(unoconv_path,pdf_path, docx_path)) + # subprocess.run("{} -f pdf -o {} {}".format(unoconv_path, pdf_path, docx_path), shell=True) + command = [unoconv_path, "-f", "pdf", "-o", pdf_path, docx_path] + subprocess.run(command) + print("done") + +def create_temp_file(suffix='.png'): + temp_dir = '/tmp' # 或者使用其他临时目录 + filename = f"{uuid.uuid4()}{suffix}" + return os.path.join(temp_dir, filename) + +def pdf_to_text_with_ocr(pdf_path, docx_path, origin_lang): + # if not is_tesseract_installed(): + # raise Exception("Tesseract未安装,无法进行OCR") + + document = fitz.open(pdf_path) + docx = Document() + + for page_num in range(len(document)): + page = document.load_page(page_num) + pix = page.get_pixmap() + img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + # 转换为灰度图像 + img = img.convert('L') + + # 将图像保存到内存中的字节流 + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + + try: + # 使用 Tesseract 命令行工具 + process = subprocess.Popen( + ['/usr/local/bin/tesseract', 'stdin', 'stdout', '-l', origin_lang, '--oem', '3', '--psm', '6'], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate(input=img_byte_arr) + + if process.returncode != 0: + raise subprocess.CalledProcessError(process.returncode, process.args, stdout, stderr) + + text = stdout.decode('utf-8').strip() + + # 移除空行和多余的空格 + text = '\n'.join(line.strip() for line in text.splitlines() if line.strip()) + + except subprocess.CalledProcessError as e: + print(f"OCR处理页面 {page_num + 1} 时出错: {str(e)}") + text = "" # 如果出错,使用空字符串 + + paragraph = docx.add_paragraph() + run = paragraph.add_run(text) + run.font.size = Pt(12) + + document.close() + docx.save(docx_path) + +def is_scanned_pdf(pdf_path): + document = fitz.open(pdf_path) + + # 只检查前几页,通常足以判断 + pages_to_check = min(5, len(document)) + + for page_num in range(pages_to_check): + page = document[page_num] + + # 检查文本 + text = page.get_text().strip() + if text: + document.close() + return False + + # 检查图像 + image_list = page.get_images() + if len(image_list) > 0: + # 如果页面只包含一个大图像,很可能是扫描件 + if len(image_list) == 1: + xref = image_list[0][0] + img = document.extract_image(xref) + if img: + pix = fitz.Pixmap(img["image"]) + # 如果图像覆盖了大部分页面,可能是扫描件 + if pix.width > page.rect.width * 0.9 and pix.height > page.rect.height * 0.9: + document.close() + return True + + document.close() + return True # 如果没有找到文本,默认认为是扫描件 + +def is_tesseract_installed(): + tesseract_path = "/usr/local/bin/tesseract" + return os.path.isfile(tesseract_path) and os.access(tesseract_path, os.X_OK) + +def use_doc2x_revert_pdf_to_docx(dox2x_api_key, pdf_file, docx_path): + client = Doc2X(apikey=dox2x_api_key,debug=False) + success, failed, flag = client.pdf2file( + pdf_file=pdf_file, + output_path=docx_path, + output_format="docx", + ) + if len(success)>0 and success[0]!="": + return (True,success[0]) + else: + return (False,failed[0]["error"]) + +# def save_image(base64_data, path): +# image_data = base64.b64decode(base64_data) +# # 将字节数据写入内存中的文件对象 +# image_file = BytesIO(image_data) +# # 从内存中的文件对象创建Image对象 +# image = Image.open(image_file) +# # 保存图片到文件系统 +# image.sav/e(path) + diff --git a/app/translate/pdf_parse.py b/app/translate/pdf_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..315846c279cdd375d3d2e21ae9959b3fc0f3921a --- /dev/null +++ b/app/translate/pdf_parse.py @@ -0,0 +1,280 @@ +import os +import re +from typing import List, Tuple, Optional, Dict +import logging + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +import fitz # PyMuPDF +import shapely.geometry as sg +from shapely.geometry.base import BaseGeometry +from shapely.validation import explain_validity +import concurrent.futures + +# This Default Prompt Using Chinese and could be changed to other languages. + +DEFAULT_PROMPT = """使用markdown语法,将图片中识别到的文字转换为markdown格式输出。你必须做到: +1. 输出和使用识别到的图片的相同的语言,例如,识别到英语的字段,输出的内容必须是英语。 +2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。 +3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、忽略掉长直线、忽略掉页码。 +再次强调,不要解释和输出无关的文字,直接输出图片中的内容。 +""" +DEFAULT_RECT_PROMPT = """图片中用红色框和名称(%s)标注出了一些区域。如果区域是表格或者图片,使用 ![]() 的形式插入到输出内容中,否则直接输出文字内容。 +""" +DEFAULT_ROLE_PROMPT = """你是一个PDF文档解析器,使用markdown和latex语法输出图片的内容。 +""" + + +def _is_near(rect1: BaseGeometry, rect2: BaseGeometry, distance: float = 20) -> bool: + """ + Check if two rectangles are near each other if the distance between them is less than the target. + """ + return rect1.buffer(0.1).distance(rect2.buffer(0.1)) < distance + + +def _is_horizontal_near(rect1: BaseGeometry, rect2: BaseGeometry, distance: float = 100) -> bool: + """ + Check if two rectangles are near horizontally if one of them is a horizontal line. + """ + result = False + if abs(rect1.bounds[3] - rect1.bounds[1]) < 0.1 or abs(rect2.bounds[3] - rect2.bounds[1]) < 0.1: + if abs(rect1.bounds[0] - rect2.bounds[0]) < 0.1 and abs(rect1.bounds[2] - rect2.bounds[2]) < 0.1: + result = abs(rect1.bounds[3] - rect2.bounds[3]) < distance + return result + + +def _union_rects(rect1: BaseGeometry, rect2: BaseGeometry) -> BaseGeometry: + """ + Union two rectangles. + """ + return sg.box(*(rect1.union(rect2).bounds)) + + +def _merge_rects(rect_list: List[BaseGeometry], distance: float = 20, horizontal_distance: Optional[float] = None) -> \ + List[BaseGeometry]: + """ + Merge rectangles in the list if the distance between them is less than the target. + """ + merged = True + while merged: + merged = False + new_rect_list = [] + while rect_list: + rect = rect_list.pop(0) + for other_rect in rect_list: + if _is_near(rect, other_rect, distance) or ( + horizontal_distance and _is_horizontal_near(rect, other_rect, horizontal_distance)): + rect = _union_rects(rect, other_rect) + rect_list.remove(other_rect) + merged = True + new_rect_list.append(rect) + rect_list = new_rect_list + return rect_list + + +def _adsorb_rects_to_rects(source_rects: List[BaseGeometry], target_rects: List[BaseGeometry], distance: float = 10) -> \ + Tuple[List[BaseGeometry], List[BaseGeometry]]: + """ + Adsorb a set of rectangles to another set of rectangles. + """ + new_source_rects = [] + for text_area_rect in source_rects: + adsorbed = False + for index, rect in enumerate(target_rects): + if _is_near(text_area_rect, rect, distance): + rect = _union_rects(text_area_rect, rect) + target_rects[index] = rect + adsorbed = True + break + if not adsorbed: + new_source_rects.append(text_area_rect) + return new_source_rects, target_rects + + +def _parse_rects(page: fitz.Page) -> List[Tuple[float, float, float, float]]: + """ + Parse drawings in the page and merge adjacent rectangles. + """ + + # 提取画的内容 + drawings = page.get_drawings() + + # 忽略掉长度小于30的水平直线 + is_short_line = lambda x: abs(x['rect'][3] - x['rect'][1]) < 1 and abs(x['rect'][2] - x['rect'][0]) < 30 + drawings = [drawing for drawing in drawings if not is_short_line(drawing)] + + # 转换为shapely的矩形 + rect_list = [sg.box(*drawing['rect']) for drawing in drawings] + + # 提取图片区域 + images = page.get_image_info() + image_rects = [sg.box(*image['bbox']) for image in images] + + # 合并drawings和images + rect_list += image_rects + + merged_rects = _merge_rects(rect_list, distance=10, horizontal_distance=100) + merged_rects = [rect for rect in merged_rects if explain_validity(rect) == 'Valid Geometry'] + + # 将大文本区域和小文本区域分开处理: 大文本相小合并,小文本靠近合并 + is_large_content = lambda x: (len(x[4]) / max(1, len(x[4].split('\n')))) > 5 + small_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if not is_large_content(x)] + large_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if is_large_content(x)] + _, merged_rects = _adsorb_rects_to_rects(large_text_area_rects, merged_rects, distance=0.1) # 完全相交 + _, merged_rects = _adsorb_rects_to_rects(small_text_area_rects, merged_rects, distance=5) # 靠近 + + # 再次自身合并 + merged_rects = _merge_rects(merged_rects, distance=10) + + # 过滤比较小的矩形 + merged_rects = [rect for rect in merged_rects if rect.bounds[2] - rect.bounds[0] > 20 and rect.bounds[3] - rect.bounds[1] > 20] + + return [rect.bounds for rect in merged_rects] + + +def _parse_pdf_to_images(pdf_path: str, output_dir: str = './') -> List[Tuple[str, List[str]]]: + """ + Parse PDF to images and save to output_dir. + """ + # 打开PDF文件 + pdf_document = fitz.open(pdf_path) + image_infos = [] + + for page_index, page in enumerate(pdf_document): + logging.info(f'parse page: {page_index}') + rect_images = [] + rects = _parse_rects(page) + for index, rect in enumerate(rects): + fitz_rect = fitz.Rect(rect) + # 保存页面为图片 + pix = page.get_pixmap(clip=fitz_rect, matrix=fitz.Matrix(4, 4)) + name = f'{page_index}_{index}.png' + pix.save(os.path.join(output_dir, name)) + rect_images.append(name) + # # 在页面上绘制红色矩形 + big_fitz_rect = fitz.Rect(fitz_rect.x0 - 1, fitz_rect.y0 - 1, fitz_rect.x1 + 1, fitz_rect.y1 + 1) + # 空心矩形 + page.draw_rect(big_fitz_rect, color=(1, 0, 0), width=1) + # 画矩形区域(实心) + # page.draw_rect(big_fitz_rect, color=(1, 0, 0), fill=(1, 0, 0)) + # 在矩形内的左上角写上矩形的索引name,添加一些偏移量 + text_x = fitz_rect.x0 + 2 + text_y = fitz_rect.y0 + 10 + text_rect = fitz.Rect(text_x, text_y - 9, text_x + 80, text_y + 2) + # 绘制白色背景矩形 + page.draw_rect(text_rect, color=(1, 1, 1), fill=(1, 1, 1)) + # 插入带有白色背景的文字 + page.insert_text((text_x, text_y), name, fontsize=10, color=(1, 0, 0)) + page_image_with_rects = page.get_pixmap(matrix=fitz.Matrix(3, 3)) + page_image = os.path.join(output_dir, f'{page_index}.png') + page_image_with_rects.save(page_image) + image_infos.append((page_image, rect_images)) + + pdf_document.close() + return image_infos + + +def _gpt_parse_images( + image_infos: List[Tuple[str, List[str]]], + prompt_dict: Optional[Dict] = None, + output_dir: str = './', + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = 'gpt-4o', + verbose: bool = False, + gpt_worker: int = 1, + **args +) -> str: + """ + Parse images to markdown content. + """ + from GeneralAgent import Agent + + if isinstance(prompt_dict, dict) and 'prompt' in prompt_dict: + prompt = prompt_dict['prompt'] + logging.info("prompt is provided, using user prompt.") + else: + prompt = DEFAULT_PROMPT + logging.info("prompt is not provided, using default prompt.") + if isinstance(prompt_dict, dict) and 'rect_prompt' in prompt_dict: + rect_prompt = prompt_dict['rect_prompt'] + logging.info("rect_prompt is provided, using user prompt.") + else: + rect_prompt = DEFAULT_RECT_PROMPT + logging.info("rect_prompt is not provided, using default prompt.") + if isinstance(prompt_dict, dict) and 'role_prompt' in prompt_dict: + role_prompt = prompt_dict['role_prompt'] + logging.info("role_prompt is provided, using user prompt.") + else: + role_prompt = DEFAULT_ROLE_PROMPT + logging.info("role_prompt is not provided, using default prompt.") + + def _process_page(index: int, image_info: Tuple[str, List[str]]) -> Tuple[int, str]: + logging.info(f'gpt parse page: {index}') + agent = Agent(role=role_prompt, api_key=api_key, base_url=base_url, disable_python_run=True, model=model, **args) + page_image, rect_images = image_info + local_prompt = prompt + if rect_images: + local_prompt += rect_prompt + ', '.join(rect_images) + content = agent.run([local_prompt, {'image': page_image}], display=verbose) + return index, content + + contents = [None] * len(image_infos) + with concurrent.futures.ThreadPoolExecutor(max_workers=gpt_worker) as executor: + futures = [executor.submit(_process_page, index, image_info) for index, image_info in enumerate(image_infos)] + for future in concurrent.futures.as_completed(futures): + index, content = future.result() + + # 在某些情况下大模型还是会输出 ```markdown ```字符串 + if '```markdown' in content: + content = content.replace('```markdown\n', '') + last_backticks_pos = content.rfind('```') + if last_backticks_pos != -1: + content = content[:last_backticks_pos] + content[last_backticks_pos + 3:] + + contents[index] = content + + output_path = os.path.join(output_dir, 'output.md') + with open(output_path, 'w', encoding='utf-8') as f: + f.write('\n\n'.join(contents)) + + return '\n\n'.join(contents) + + +def parse_pdf( + pdf_path: str, + output_dir: str = './', + prompt: Optional[Dict] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = 'gpt-4o', + verbose: bool = False, + gpt_worker: int = 1, + **args +) -> Tuple[str, List[str]]: + """ + Parse a PDF file to a markdown file. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + image_infos = _parse_pdf_to_images(pdf_path, output_dir=output_dir) + content = _gpt_parse_images( + image_infos=image_infos, + output_dir=output_dir, + prompt_dict=prompt, + api_key=api_key, + base_url=base_url, + model=model, + verbose=verbose, + gpt_worker=gpt_worker, + **args + ) + + all_rect_images = [] + # remove all rect images + if not verbose: + for page_image, rect_images in image_infos: + if os.path.exists(page_image): + os.remove(page_image) + all_rect_images.extend(rect_images) + return content, all_rect_images \ No newline at end of file diff --git a/app/translate/powerpoint.py b/app/translate/powerpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8c11d8f131fe09af4d42424a3db5a835f49dc712 --- /dev/null +++ b/app/translate/powerpoint.py @@ -0,0 +1,98 @@ +import threading +import pptx +from . import to_translate +from . import common +import os +import sys +import time +import datetime + +def start(trans): + # 允许的最大线程 + threads=trans['threads'] + if threads is None or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + start_time = datetime.datetime.now() + wb = pptx.Presentation(trans['file_path']) + print(trans['file_path']) + slides = wb.slides + texts=[] + for slide in slides: + for shape in slide.shapes: + if shape.has_table: + table = shape.table + print(table) + rows = len(table.rows) + cols = len(table.columns) + for r in range(rows): + row_data = [] + for c in range(cols): + cell_text = table.cell(r, c).text + if cell_text!=None and len(cell_text)>0 and not common.is_all_punc(cell_text): + texts.append({"text":cell_text,"row":r,"column":c, "complete":False}) + if not shape.has_text_frame: + continue + text_frame = shape.text_frame + for paragraph in text_frame.paragraphs: + text=paragraph.text + if text!=None and len(text)>0 and not common.is_all_punc(text): + texts.append({"text":text, "complete":False}) + max_run=max_threads if len(texts)>max_threads else len(texts) + before_active_count=threading.activeCount() + event=threading.Event() + while run_index<=len(texts)-1: + if threading.activeCount()0 and not common.is_all_punc(cell_text): + item=texts.pop(0) + table.cell(r, c).text=item['text'] + text_count+=item['count'] + + if not shape.has_text_frame: + continue + text_frame = shape.text_frame + for paragraph in text_frame.paragraphs: + text=paragraph.text + if text!=None and len(text)>0 and not common.is_all_punc(text) and len(texts)>0: + item=texts.pop(0) + paragraph.text=item['text'] + text_count+=item['count'] + + wb.save(trans['target_file']) + end_time = datetime.datetime.now() + spend_time=common.display_spend(start_time, end_time) + to_translate.complete(trans,text_count,spend_time) + return True + + diff --git a/app/translate/rediscon.py b/app/translate/rediscon.py new file mode 100644 index 0000000000000000000000000000000000000000..cee1eb387635e0519fc86dbd442027d3d325e09e --- /dev/null +++ b/app/translate/rediscon.py @@ -0,0 +1,17 @@ +import redis +import os +from dotenv import load_dotenv, find_dotenv + +_ = load_dotenv(find_dotenv()) # read local .env file + +def get_conn(): + redis_host=os.environ['REDIS_HOST'] + redis_password=os.environ['REDIS_PASSWORD'] + redis_port=os.environ['REDIS_PORT'] + if os.environ['REDIS_SELECT']: + redis_select=os.environ['REDIS_SELECT'] + else: + redis_select=0 + pool = redis.ConnectionPool(host=redis_host, port=int(redis_port), password=redis_password,db=redis_select, decode_responses=True) + return redis.Redis(connection_pool=pool) + diff --git a/app/translate/requirements.txt b/app/translate/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..baa1b7a22db30977672a4f2eaa11b96cb87b5e56 --- /dev/null +++ b/app/translate/requirements.txt @@ -0,0 +1,22 @@ +beautifulsoup4==4.13.3 +docx==0.2.4 +docx2pdf==0.1.8 +fitz==0.0.1.dev2 +Markdown==3.7 +numpy==2.2.2 +openai==1.61.1 +opencv_python_headless==4.10.0.84 +openpyxl==3.1.5 +pdf2docx==0.5.8 +pdfdeal==1.0.2 +pdfkit==1.0.0 +Pillow==11.1.0 +PyMySQL==1.1.1 +pypdf==5.2.0 +pytesseract==0.3.13 +python-dotenv==1.0.1 +python_docx==1.1.2 +python_pptx==1.0.2 +redis==5.2.1 +Requests==2.32.3 +Shapely==2.0.7 diff --git a/app/translate/test.py b/app/translate/test.py new file mode 100644 index 0000000000000000000000000000000000000000..a3da95f505c7abd3d60d97d6ce0c56ac0c28d350 --- /dev/null +++ b/app/translate/test.py @@ -0,0 +1,110 @@ +import zipfile +import xml.etree.ElementTree as ET +import os +from docx import Document + +def read_comments_from_docx(docx_path): + comments = [] + with zipfile.ZipFile(docx_path, 'r') as docx: + # 尝试读取批注文件 + with docx.open('word/comments.xml') as comments_file: + # 解析 XML + tree = ET.parse(comments_file) + root = tree.getroot() + + # 定义命名空间 + namespace = {'ns0': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'} + + # 查找所有批注 + for comment in root.findall('ns0:comment', namespace): + comment_id = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}id') + author = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}author') + date = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}date') + text = ''.join(t.text for p in comment.findall('.//ns0:p', namespace) for r in p.findall('.//ns0:r', namespace) for t in r.findall('.//ns0:t', namespace)) + + comments.append({ + 'id': comment_id, + 'author': author, + 'date': date, + 'text': text, + }) + modified_xml = ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8') + print("XML 内容:") + print(modified_xml) + return comments + +def modify_comment_in_docx(docx_path, comment_id, new_text): + # 创建一个临时文件名,保留原始路径 + temp_docx_path = os.path.join(os.path.dirname(docx_path), 'temp_' + os.path.basename(docx_path)) + + # 打开原始 docx 文件 + with zipfile.ZipFile(docx_path, 'r') as docx: + # 创建一个新的 docx 文件 + with zipfile.ZipFile(temp_docx_path, 'w') as new_docx: + for item in docx.infolist(): + # 读取每个文件 + with docx.open(item) as file: + if item.filename == 'word/comments.xml': + # 解析批注 XML + tree = ET.parse(file) + root = tree.getroot() + + # 打印原始 XML 内容 + print("原始 XML 内容:") + print(ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8')) + + # 定义命名空间 + namespace = {'ns0': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'} + + # 查找并修改批注 + for comment in root.findall('ns0:comment', namespace): + if comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}id') == comment_id: + # 清除现有段落 + for p in list(comment.findall('.//ns0:p', namespace)): + comment.remove(p) # 从批注中移除段落元素 + + # 创建新的段落 + new_paragraph = ET.Element('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}p') + # 创建新的 run 元素 + new_run = ET.Element('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}r') + # 创建新的 text 元素 + new_text_elem = ET.Element('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}t') + new_text_elem.text = new_text # 设置文本内容 + + # 将 text 元素添加到 run 元素中 + new_run.append(new_text_elem) + # 将 run 添加到段落中 + new_paragraph.append(new_run) + # 将新段落添加到批注中 + comment.append(new_paragraph) + + # 打印修改后的 XML 内容 + modified_xml = ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8') + print("修改后的 XML 内容:") + print(modified_xml) + + # 将修改后的 XML 写入新的 docx 文件 + new_docx.writestr(item.filename, modified_xml) + else: + # 其他文件直接写入新的 docx 文件 + new_docx.writestr(item.filename, file.read()) + + # 替换原始文件 + os.replace(temp_docx_path, docx_path) + +# 示例用法 +docx_path = '/Volumes/data/erui/ezwork-api/storage/app/public/uploads/240928/jZtoN0Ak8P1A5Eojw9KndxoV7OkpPJv1J3NVtsBS.docx' # 替换为您的文档路径 +# docx_path = '/Volumes/data/erui/ezwork-api/storage/app/public//translate/jZtoN0Ak8P1A5Eojw9KndxoV7OkpPJv1J3NVtsBS/comments-英语.docx' # 替换为您的文档路径 +comment_id = '3' # 替换为您要修改的批注 ID +new_text = 'test test' # 替换为新的批注文本 + +# document = Document("/Volumes/data/erui/ezwork-api/storage/app/public/uploads/240928/jZtoN0Ak8P1A5Eojw9KndxoV7OkpPJv1J3NVtsBS.docx") +# document.save(docx_path) +# 读取批注 +comments = read_comments_from_docx(docx_path) +print("读取的批注:") +for comment in comments: + print(comment) + +# 修改批注 +# modify_comment_in_docx(docx_path, comment_id, new_text) \ No newline at end of file diff --git a/app/translate/test1.py b/app/translate/test1.py new file mode 100644 index 0000000000000000000000000000000000000000..abdac4bc108c8aba2509369661928e283aea14be --- /dev/null +++ b/app/translate/test1.py @@ -0,0 +1,59 @@ +import os + +# laod environment variables from .env file +import dotenv +dotenv.load_dotenv() + +pdf_path = r"F:\桌面文件\composes测试\Desktop\2005C:雨量预报方法优劣的评价模型.pdf" +output_dir = r'F:\桌面文件\我的vue项目\文档翻译项目\后端重构-api项目\storage\translate' + + +# 清空output_dir +# import shutil +# shutil.rmtree(output_dir, ignore_errors=True) + +def test_use_api_key(): + from gptpdf import parse_pdf + api_key = os.getenv('OPENAI_API_KEY') + base_url = os.getenv('OPENAI_API_BASE') + # Manually provide OPENAI_API_KEY and OPEN_API_BASE + content, image_paths = parse_pdf(pdf_path, output_dir=output_dir, api_key=api_key, base_url=base_url, model='gpt-4o', gpt_worker=6) + print(content) + print(image_paths) + # also output_dir/output.md is generated + + +def test_use_env(): + from gptpdf import parse_pdf + # Use OPENAI_API_KEY and OPENAI_API_BASE from environment variables + content, image_paths = parse_pdf(pdf_path, output_dir=output_dir, model='gpt-4o', verbose=True) + print(content) + print(image_paths) + # also output_dir/output.md is generated + + +def test_azure(): + from pdf_parse import parse_pdf + api_key = '8ef0b4df45e444079cd5a4xxxxx' # Azure API Key + base_url = 'https://xxx.openai.azure.com/' # Azure API Base URL + model = 'azure_xxxx' # azure_ with deploy ID name (not open ai model name), e.g. azure_cpgpt4 + # Use OPENAI_API_KEY and OPENAI_API_BASE from environment variables + content, image_paths = parse_pdf(pdf_path, output_dir=output_dir, api_key=api_key, base_url=base_url, model=model, verbose=True) + print(content) + print(image_paths) + +def test_qwen_vl_max(): + from pdf_parse import parse_pdf + api_key = '28032c969954994065d5520e1155418b.u8iXzIijE3qvkXsZ' + base_url = "https://open.bigmodel.cn/api/paas/v4" + model = 'glm-4v-flash' + content, image_paths = parse_pdf(pdf_path, output_dir=output_dir, api_key=api_key, base_url=base_url, model=model, verbose=True, temperature=0.5, max_tokens=1000, top_p=0.9, frequency_penalty=1) + print(content) + print(image_paths) + + +if __name__ == '__main__': + # test_use_api_key() + # test_use_env() + # test_azure() + test_qwen_vl_max() \ No newline at end of file diff --git a/app/translate/to_translate.py b/app/translate/to_translate.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a7b0d22157a3d7a8e74b6f96b20c500ecd6733 --- /dev/null +++ b/app/translate/to_translate.py @@ -0,0 +1,478 @@ +# import tiktoken +import datetime +import hashlib +import logging +import os +import sys +import re +import openai +from . import common +from . import db +import time + + +def get(trans, event, texts, index): + if event.is_set(): + exit(0) + threads = trans['threads'] + if threads is None or threads == "" or int(threads) < 0: + max_threads = 10 + else: + max_threads = int(threads) + # mredis=rediscon.get_conn() + # threading_num=get_threading_num(mredis) + # while threading_num>=max_threads: + # time.sleep(1) + # print('trans配置项', trans) + translate_id = trans['id'] + target_lang = trans['lang'] + model = trans['model'] + backup_model = trans['backup_model'] + prompt = trans['prompt'] + extension = trans['extension'].lower() + text = texts[index] + api_key = trans['api_key'] + api_url = trans['api_url'] + old_text = text['text'] + md5_key = md5_encryption( + str(api_key) + str(api_url) + str(old_text) + str(prompt) + str(backup_model) + str( + model) + str(target_lang)) + try: + oldtrans = db.get("select * from translate_logs where md5_key=%s", md5_key) + # mredis.set("threading_count",threading_num+1) + if text['complete'] == False: + content = '' + if oldtrans: + content = oldtrans['content'] + # 特别处理PDF类型 + # elif extension == ".pdf": + # return handle_pdf(trans, event, texts, index) + elif extension == ".pdf": + if text['type'] == "text": + content = translate_html(text['text'], target_lang, model, prompt) + time.sleep(0.1) + else: + content = get_content_by_image(text['text'], target_lang) + time.sleep(0.1) + # ---------------这里实现不同模型格式的请求-------------- + elif extension == ".md": + content = req(text['text'], target_lang, model, prompt, True) + else: + content = req(text['text'], target_lang, model, prompt, False) + # print("content", text['content']) + text['count'] = count_text(text['text']) + if check_translated(content): + # 过滤deepseek思考过程 + text['text'] = re.sub(r'.*?', '', content, flags=re.DOTALL) + if oldtrans is None: + db.execute("INSERT INTO translate_logs set api_url=%s,api_key=%s," + + "backup_model=%s ,created_at=%s ,prompt=%s, " + + "model=%s,target_lang=%s,source=%s,content=%s,md5_key=%s", + str(api_url), str(api_key), + str(backup_model), + datetime.datetime.now(), str(prompt), str(model), str(target_lang), + str(old_text), + str(content), str(md5_key)) + text['complete'] = True + except openai.AuthenticationError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, "openai密钥或令牌无效") + except openai.APIConnectionError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, "请求无法与openai服务器或建立安全连接") + except openai.PermissionDeniedError as e: + # set_threading_num(mredis) + texts[index] = text + # return use_backup_model(trans, event, texts, index, "令牌额度不足") + except openai.RateLimitError as e: + # set_threading_num(mredis) + if "retry" not in text: + trans['model'] = backup_model + trans['backup_model'] = model + time.sleep(1) + print("访问速率达到限制,交换备用模型与模型重新重试") + get(trans, event, texts, index) + else: + return use_backup_model(trans, event, texts, index, + "访问速率达到限制,10分钟后再试" + str(text['text'])) + except openai.InternalServerError as e: + # set_threading_num(mredis) + if "retry" not in text: + trans['model'] = backup_model + trans['backup_model'] = model + time.sleep(1) + print("当前分组上游负载已饱和,交换备用模型与模型重新重试") + get(trans, event, texts, index) + else: + return use_backup_model(trans, event, texts, index, + "当前分组上游负载已饱和,请稍后再试" + str(text['text'])) + except openai.APIStatusError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, e.response) + except Exception as e: + # set_threading_num(mredis) + exc_type, exc_value, exc_traceback = sys.exc_info() + line_number = exc_traceback.tb_lineno # 异常抛出的具体行号 + print(f"Error occurred on line: {line_number}") + print(e) + if "retry" not in text: + text["retry"] = 0 + text["retry"] += 1 + if text["retry"] <= 3: + trans['model'] = backup_model + trans['backup_model'] = model + print("当前模型执行异常,交换备用模型与模型重新重试") + time.sleep(1) + get(trans, event, texts, index) + return + else: + text['complete'] = True + # traceback.print_exc() + # print("translate error") + texts[index] = text + # print(text) + if not event.is_set(): + process(texts, translate_id) + # set_threading_num(mredis) + exit(0) + + +def handle_pdf(trans, event, texts, index): + try: + from . import pdf_parser + success = pdf_parser.start(trans) + if success: + texts[index]['complete'] = True + else: + return use_backup_model(trans, event, texts, index, "PDF解析失败") + except Exception as e: + return use_backup_model(trans, event, texts, index, str(e)) + + +def get11(trans, event, texts, index): + if event.is_set(): + exit(0) + threads = trans['threads'] + if threads is None or threads == "" or int(threads) < 0: + max_threads = 10 + else: + max_threads = int(threads) + # mredis=rediscon.get_conn() + # threading_num=get_threading_num(mredis) + # while threading_num>=max_threads: + # time.sleep(1) + print('trans配置项', trans) + translate_id = trans['id'] + target_lang = trans['lang'] + model = trans['model'] + backup_model = trans['backup_model'] + prompt = trans['prompt'] + extension = trans['extension'].lower() + text = texts[index] + api_key = trans['api_key'] + api_url = trans['api_url'] + old_text = text['text'] + md5_key = md5_encryption( + str(api_key) + str(api_url) + str(old_text) + str(prompt) + str(backup_model) + str( + model) + str(target_lang)) + try: + oldtrans = db.get("select * from translate_logs where md5_key=%s", md5_key) + # mredis.set("threading_count",threading_num+1) + if text['complete'] == False: + content = '' + if oldtrans: + content = oldtrans['content'] + elif extension == ".pdf": + if text['type'] == "text": + content = translate_html(text['text'], target_lang, model, prompt) + time.sleep(0.1) + else: + content = get_content_by_image(text['text'], target_lang) + time.sleep(0.1) + # ---------------这里实现不同模型格式的请求-------------- + elif extension == ".md": + content = req(text['text'], target_lang, model, prompt, True) + else: + content = req(text['text'], target_lang, model, prompt, False) + # print("content", text['content']) + text['count'] = count_text(text['text']) + if check_translated(content): + # 过滤deepseek思考过程 + text['text'] = re.sub(r'.*?', '', content, flags=re.DOTALL) + if oldtrans is None: + db.execute("INSERT INTO translate_logs set api_url=%s,api_key=%s," + + "backup_model=%s ,created_at=%s ,prompt=%s, " + + "model=%s,target_lang=%s,source=%s,content=%s,md5_key=%s", + str(api_url), str(api_key), + str(backup_model), + datetime.datetime.now(), str(prompt), str(model), str(target_lang), + str(old_text), + str(content), str(md5_key)) + text['complete'] = True + except openai.AuthenticationError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, "openai密钥或令牌无效") + except openai.APIConnectionError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, "请求无法与openai服务器或建立安全连接") + except openai.PermissionDeniedError as e: + # set_threading_num(mredis) + texts[index] = text + # return use_backup_model(trans, event, texts, index, "令牌额度不足") + except openai.RateLimitError as e: + # set_threading_num(mredis) + if "retry" not in text: + trans['model'] = backup_model + trans['backup_model'] = model + time.sleep(1) + print("访问速率达到限制,交换备用模型与模型重新重试") + get(trans, event, texts, index) + else: + return use_backup_model(trans, event, texts, index, + "访问速率达到限制,10分钟后再试" + str(text['text'])) + except openai.InternalServerError as e: + # set_threading_num(mredis) + if "retry" not in text: + trans['model'] = backup_model + trans['backup_model'] = model + time.sleep(1) + print("当前分组上游负载已饱和,交换备用模型与模型重新重试") + get(trans, event, texts, index) + else: + return use_backup_model(trans, event, texts, index, + "当前分组上游负载已饱和,请稍后再试" + str(text['text'])) + except openai.APIStatusError as e: + # set_threading_num(mredis) + return use_backup_model(trans, event, texts, index, e.response) + except Exception as e: + # set_threading_num(mredis) + exc_type, exc_value, exc_traceback = sys.exc_info() + line_number = exc_traceback.tb_lineno # 异常抛出的具体行号 + print(f"Error occurred on line: {line_number}") + print(e) + if "retry" not in text: + text["retry"] = 0 + text["retry"] += 1 + if text["retry"] <= 3: + trans['model'] = backup_model + trans['backup_model'] = model + print("当前模型执行异常,交换备用模型与模型重新重试") + time.sleep(1) + get(trans, event, texts, index) + return + else: + text['complete'] = True + # traceback.print_exc() + # print("translate error") + texts[index] = text + # print(text) + if not event.is_set(): + process(texts, translate_id) + # set_threading_num(mredis) + exit(0) + + +# def get_threading_num(mredis): +# threading_count=mredis.get("threading_count") +# if threading_count is None or threading_count=="" or int(threading_count)<0: +# threading_num=0 +# else: +# threading_num=int(threading_count) +# return threading_num +# def set_threading_num(mredis): +# threading_count=mredis.get("threading_count") +# if threading_count is None or threading_count=="" or int(threading_count)<1: +# mredis.set("threading_count",0) +# else: +# threading_num=int(threading_count) +# mredis.set("threading_count",threading_num-1) + +def md5_encryption(data): + md5 = hashlib.md5(data.encode('utf-8')) # 创建一个md5对象 + return md5.hexdigest() # 返回加密后的十六进制字符串 + + +def req(text, target_lang, model, prompt, ext): + # 判断是否是md格式 + if ext == True: + # 如果是 md 格式,追加提示文本 + prompt += "。 请帮助我翻译以下 Markdown 文件中的内容。请注意,您只需翻译文本部分,而不应更改任何 Markdown 标签或格式。保持原有的标题、列表、代码块、链接和其他 Markdown 标签的完整性。" + # 构建 message + message = [ + {"role": "system", "content": prompt.replace("{target_lang}", target_lang)}, + {"role": "user", "content": text} + ] + # print(openai.base_url) + print(message) + # 禁用 OpenAI 的日志输出 + logging.getLogger("openai").setLevel(logging.WARNING) + # 禁用 httpx 的日志输出 + logging.getLogger("httpx").setLevel(logging.WARNING) + response = openai.chat.completions.create( + model=model, # 使用GPT-3.5版本 + messages=message, + temperature=0.8 + ) + # for choices in response.choices: + # print(choices.message.content) + content = response.choices[0].message.content + # print(content) + return content + + +def translate_html(html, target_lang, model, prompt): + message = [ + {"role": "system", + "content": "把下面的html翻译成{},只返回翻译后的内容".format(target_lang)}, + {"role": "user", "content": html} + ] + # print(openai.base_url) + response = openai.chat.completions.create( + model=model, + messages=message + ) + # for choices in response.choices: + # print(choices.message.content) + content = response.choices[0].message.content + return content + + +def get_content_by_image(base64_image, target_lang): + # print(image_path) + # file_object = openai.files.create(file=Path(image_path), purpose="这是一张图片") + # print(file_object) + message = [ + {"role": "system", "content": "你是一个图片ORC识别专家"}, + {"role": "user", "content": [ + { + "type": "image_url", + "image_url": { + "url": base64_image + } + }, + { + "type": "text", + # "text": "读取图片链接并提取其中的文本数据,只返回识别后的数据,将文本翻译成英文,并按照图片中的文字布局返回html。只包含body(不包含body本身)部分", + # "text": f"提取图片中的所有文字数据,将提取的文本翻译成{target_lang},只返回原始文本和翻译结果", + "text": f"提取图片中的所有文字数据,将提取的文本翻译成{target_lang},只返回翻译结果", + } + ]} + ] + # print(message) + # print(openai.base_url) + response = openai.chat.completions.create( + model="gpt-4o", # 使用GPT-3.5版本 + messages=message + ) + # for choices in response.choices: + # print(choices.message.content) + content = response.choices[0].message.content + # return content + # print(''.join(map(lambda x: f'

{x}

',content.split("\n")))) + return ''.join(map(lambda x: f'

{x}

', content.split("\n"))) + + +def check(model): + try: + message = [ + {"role": "system", "content": "你通晓世界所有语言,可以用来从一种语言翻译成另一种语言"}, + {"role": "user", "content": "你现在能翻译吗?"} + ] + response = openai.chat.completions.create( + model=model, + messages=message + ) + return "OK" + except openai.AuthenticationError as e: + return "openai密钥或令牌无效" + except openai.APIConnectionError as e: + return "请求无法与openai服务器或建立安全连接" + except openai.PermissionDeniedError as e: + return "令牌额度不足" + except openai.RateLimitError as e: + return "访问速率达到限制,10分钟后再试" + except openai.InternalServerError as e: + return "当前分组上游负载已饱和,请稍后再试" + except openai.APIStatusError as e: + return e.response + except Exception as e: + return "当前无法完成翻译" + + +def process(texts, translate_id): + total = 0 + complete = 0 + for text in texts: + total += 1 + if text['complete']: + complete += 1 + if total != complete: + if (total != 0): + process = format((complete / total) * 100, '.1f') + db.execute("update translate set process=%s where id=%s", str(process), translate_id) + + +def complete(trans, text_count, spend_time): + target_filesize = 1 #os.stat(trans['target_file']).st_size + db.execute( + "update translate set status='done',end_at=now(),process=100,target_filesize=%s,word_count=%s where id=%s", + target_filesize, text_count, trans['id']) + + +def error(translate_id, message): + db.execute( + "update translate set failed_count=failed_count+1,status='failed',end_at=now(),failed_reason=%s where id=%s", + message, translate_id) + + +def count_text(text): + count = 0 + for char in text: + if common.is_chinese(char): + count += 1; + elif char is None or char == " ": + continue + else: + count += 0.5 + return count + + +def init_openai(url, key): + openai.api_key = key + if "v1" not in url: + if url[-1] == "/": + url += "v1/" + else: + url += "/v1/" + openai.base_url = url + + +def check_translated(content): + if content.startswith("Sorry, I cannot") or content.startswith( + "I am sorry,") or content.startswith( + "I'm sorry,") or content.startswith("Sorry, I can't") or content.startswith( + "Sorry, I need more") or content.startswith("抱歉,无法") or content.startswith( + "错误:提供的文本") or content.startswith("无法翻译") or content.startswith( + "抱歉,我无法") or content.startswith( + "对不起,我无法") or content.startswith("ご指示の内容は") or content.startswith( + "申し訳ございません") or content.startswith("Простите,") or content.startswith( + "Извините,") or content.startswith("Lo siento,"): + return False + else: + return True + + +# def get_model_tokens(model,content): +# encoding=tiktoken.encoding_for_model(model) +# return en(encoding.encode(content)) + +def use_backup_model(trans, event, texts, index, message): + if trans['backup_model'] != None and trans['backup_model'] != "": + trans['model'] = trans['backup_model'] + trans['backup_model'] = "" + get(trans, event, texts, index) + else: + if not event.is_set(): + error(trans['id'], message) + print(message) + event.set() diff --git a/app/translate/txt.py b/app/translate/txt.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb2ab985e56287b04461fc9a30f1503329b2d0c --- /dev/null +++ b/app/translate/txt.py @@ -0,0 +1,131 @@ +import os +import threading +from . import to_translate +from . import common +import datetime +import time +import re + +def start(trans): + # 允许的最大线程 + threads=trans['threads'] + if threads is None or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + start_time = datetime.datetime.now() + + try: + with open(trans['file_path'], 'r', encoding='utf-8') as file: + content = file.read() + except Exception as e: + print(f"无法读取文件 {trans['file_path']}: {e}") + return False + + texts=[] + + # 按段落分割内容 + paragraphs = content.split('\n\n') # 假设段落之间用两个换行符分隔 + # 支持最多单词量 + max_word=1000 + # 翻译每个段落 + for paragraph in paragraphs: + if check_text(paragraph): + # 如果段落长度超过 1000 字,进行分割 + if len(paragraph) > max_word: + sub_paragraphs = split_paragraph(paragraph, max_word) + for sub_paragraph in sub_paragraphs: + texts.append({"text":sub_paragraph,"origin":sub_paragraph, "complete":False, "sub":True}) + else: + texts.append({"text":paragraph,"origin":paragraph, "complete":False, "sub":False}) + + # print(texts) + max_run=max_threads if len(texts)>max_threads else len(texts) + before_active_count=threading.activeCount() + event=threading.Event() + while run_index<=len(texts)-1: + if threading.activeCount() max_length: + # 如果当前部分长度加上句子长度超过最大长度,保存当前部分 + parts.append(' '.join(current_part)) + current_part = [sentence] # 开始新的部分 + current_length = len(sentence) + else: + current_part.append(sentence) + current_length += len(sentence) + + # 添加最后一部分 + if current_part: + parts.append(' '.join(current_part)) + + return parts + +def check_text(text): + return text!=None and len(text)>0 and not common.is_all_punc(text) diff --git a/app/translate/word.py b/app/translate/word.py new file mode 100644 index 0000000000000000000000000000000000000000..68989446368648a330d67b03c9b66845951d962f --- /dev/null +++ b/app/translate/word.py @@ -0,0 +1,592 @@ +import threading +from docx import Document +from docx.shared import Pt +from docx.shared import Inches +from docx.oxml.ns import qn +from . import to_translate +from . import common +import os +import sys +import time +import datetime +import zipfile +import xml.etree.ElementTree as ET +from . import rediscon + +def start(trans): + # 允许的最大线程 + threads=trans['threads'] + if threads is None or threads=="" or int(threads)<0: + max_threads=10 + else: + max_threads=int(threads) + # 当前执行的索引位置 + run_index=0 + max_chars=1000 + start_time = datetime.datetime.now() + # 创建Document对象,加载Word文件 + try: + document = Document(trans['file_path']) + except Exception as e: + to_translate.error(trans['id'], "无法访问该文档") + return False + texts=[] + api_url=trans['api_url'] + trans_type=trans['type'] + target_lang = trans['lang'] + if trans_type=="trans_text_only_inherit": + # 仅文字-保留原文-继承原版面 + read_rune_text(document, texts) + elif trans_type=="trans_text_only_new" or trans_type=="trans_text_both_new": + # 仅文字-保留原文-重排 + read_paragraph_text(document, texts) + elif trans_type=="trans_text_both_inherit": + # 仅文字-保留原文-重排/继承原版面 + read_rune_text(document, texts) + elif trans_type=="trans_all_only_new": + # 全部内容-仅译文-重排版面 + read_paragraph_text(document, texts) + elif trans_type=="trans_all_only_inherit": + # 全部内容-仅译文-重排版面/继承原版面 + read_rune_text(document, texts) + elif trans_type=="trans_all_both_new": + # 全部内容-保留原文-重排版面 + read_paragraph_text(document, texts) + elif trans_type=="trans_all_both_inherit": + # 全部内容-保留原文-继承原版面 + read_rune_text(document, texts) + + read_comments_from_docx(trans['file_path'], texts) + read_insstd_from_docx(trans['file_path'], texts) + #print(texts) + max_run=max_threads if len(texts)>max_threads else len(texts) + event=threading.Event() + before_active_count=threading.activeCount() + while run_index<=len(texts)-1: + if threading.activeCount()0: + item=texts.pop(0) + # paragraph.runs[0].text=item.get('text',"") + for index,run in enumerate(paragraph.runs): + if index==0: + run.text=item.get('text',"") + else: + run.clear() + +def read_rune_text(document, texts): + for paragraph in document.paragraphs: + line_spacing=paragraph.paragraph_format.line_spacing + # print("line_spacing:",line_spacing) + read_run(paragraph.runs, texts) + # print(line_spacing_unit) + if len(paragraph.hyperlinks)>0: + for hyperlink in paragraph.hyperlinks: + read_run(hyperlink.runs, texts) + + # print("翻译文本--开始") + # print(datetime.datetime.now()) + for table in document.tables: + for row in table.rows: + start_span=0 + for cell in row.cells: + read_cell_text(cell, texts) + # start_span+=1 + # # if start_span==cell.grid_span: + # # start_span=0 + # # read_cell(cell, texts) + # for index,paragraph in enumerate(cell.paragraphs): + + # read_run(paragraph.runs, texts) + + # if len(paragraph.hyperlinks)>0: + # for hyperlink in paragraph.hyperlinks: + # read_run(hyperlink.runs, texts) + + +def write_only_new(document, texts, text_count, onlyText): + for paragraph in document.paragraphs: + text_count+=write_run(paragraph.runs, texts) + + if len(paragraph.hyperlinks)>0: + for hyperlink in paragraph.hyperlinks: + text_count+=write_run(hyperlink.runs, texts) + + if onlyText: + clear_image(paragraph) + + for table in document.tables: + for row in table.rows: + start_span=0 + for cell in row.cells: + write_cell_text(cell, texts) + # start_span+=1 + # if start_span==cell.grid_span: + # start_span=0 + # text_count+=write_cell(cell, texts) + # for paragraph in cell.paragraphs: + # text_count+=write_run(paragraph.runs, texts) + + # if len(paragraph.hyperlinks)>0: + # for hyperlink in paragraph.hyperlinks: + # text_count+=write_run(hyperlink.runs, texts) + +#保留原译文 +def write_rune_both(document, texts, text_count, onlyText,target_lang): + for paragraph in document.paragraphs: + # print(paragraph.text) + if(len(paragraph.runs)>0): + paragraph.runs[-1].add_break() + add_paragraph_run(paragraph, paragraph.runs, texts, text_count,target_lang) + if len(paragraph.hyperlinks)>0: + for hyperlink in paragraph.hyperlinks: + hyperlink.runs[-1].add_break() + add_paragraph_run(paragraph, hyperlink.runs, texts, text_count,target_lang) + if onlyText: + clear_image(paragraph) + + # text_count+=write_run(paragraph.runs, texts) + for table in document.tables: + for row in table.rows: + # start_span=0 + for cell in row.cells: + # start_span+=1 + # if start_span==cell.grid_span: + # start_span=0 + # text_count+=write_cell(cell, texts) + for paragraph in cell.paragraphs: + replace_paragraph_text(paragraph, texts, text_count, onlyText, True) + + if len(paragraph.hyperlinks)>0: + for hyperlink in paragraph.hyperlinks: + replace_paragraph_text(hyperlink, texts, text_count, onlyText, True) + +def read_run(runs,texts): + # text="" + if len(runs)>0 or len(texts)==0: + for index,run in enumerate(runs): + append_text(run.text, texts) + # if run.text=="": + # if len(text)>0 and not common.is_all_punc(text): + # texts.append({"text":text, "complete":False}) + # text="" + # else: + # text+=run.text + # if len(text)>0 and not common.is_all_punc(text): + # texts.append({"text":text, "complete":False}) + +def append_text(text, texts): + if check_text(text): + # print(text) + texts.append({"text":text, "type":"text", "complete":False}) + +def append_comment(text, comment_id, texts): + if check_text(text): + texts.append({"text":text, "type":"comment","comment_id":comment_id, "complete":False}) + +def check_text(text): + return text!=None and len(text)>0 and not common.is_all_punc(text) + +def write_run(runs,texts): + text_count=0 + if len(runs)==0: + return text_count + text="" + for index,run in enumerate(runs): + text=run.text + if check_text(text) and len(texts)>0: + item=texts.pop(0) + text_count+=item.get('count',0) + run.text=item.get('text',"") + + # if run.text=="": + # if len(text)>0 and not common.is_all_punc(text) and len(texts)>0: + # item=texts.pop(0) + # text_count+=item.get('count',0) + # runs[index-1].text=item.get('text',"") + # text="" + # else: + # text+=run.text + # run.text="" + # if len(text)>0 and not common.is_all_punc(text) and len(texts)>0: + # item=texts.pop(0) + # text_count+=item.get('count',0) + # runs[0].text=item.get('text',"") + return text_count + + +def read_cell(cell,texts): + append_text(cell.text, texts) + + +def write_cell(cell,texts): + text=cell.text + text_count=0 + if check_text(text) and len(texts)>0: + item=texts.pop(0) + text_count+=item.get('count',0) + cell.text=item.get('text',"") + return text_count + +def add_paragraph_run(paragraph, runs, texts, text_count,target_lang): + for index,run in enumerate(runs): + if check_text(run.text) and len(texts)>0: + item=texts.pop(0) + text_count+=item.get('count',0) + new_run=paragraph.add_run(item.get('text',""), run.style) + set_run_style(new_run, run,target_lang) + set_paragraph_linespace(paragraph) + +def set_run_style(new_run, copy_run,target_lang): + new_run.font.italic= copy_run.font.italic + new_run.font.strike= copy_run.font.strike + new_run.font.bold= copy_run.font.bold + new_run.font.size= copy_run.font.size + new_run.font.color.rgb= copy_run.font.color.rgb + new_run.underline= copy_run.underline + new_run.style= copy_run.style + + # 字体名称设置需要特殊处理 + if target_lang== '中文' or target_lang== '日语': + new_run.font.name = '微软雅黑' + r = new_run._element.rPr.rFonts + r.set(qn('w:eastAsia'),'微软雅黑') + else: + new_run.font.name = 'Times New Roman' + r = new_run._element.rPr.rFonts + r.set(qn('w:eastAsia'),'Times New Roman') + +def set_paragraph_linespace(paragraph): + if hasattr(paragraph, "paragraph_format"): + space_before=paragraph.paragraph_format.space_before + space_after=paragraph.paragraph_format.space_after + line_spacing=paragraph.paragraph_format.line_spacing + line_spacing_rule=paragraph.paragraph_format.line_spacing_rule + if space_before!=None: + paragraph.paragraph_format.space_before=space_before + if space_after!=None: + paragraph.paragraph_format.space_after=space_after + if line_spacing!=None: + paragraph.paragraph_format.line_spacing=line_spacing + if line_spacing_rule!=None: + paragraph.paragraph_format.line_spacing_rule=line_spacing_rule + +def check_image(run): + if run.element.find('.//w:drawing', namespaces=run.element.nsmap) is not None: + return True + return False + +# 去除照片 +def clear_image(paragraph): + for run in paragraph.runs: + if check_image(run): + run.clear() + +def replace_paragraph_text(paragraph, texts, text_count, onlyText, appendTo): + text=paragraph.text + if check_text(text) and len(texts)>0: + item=texts.pop(0) + trans_text=item.get('text',"") + if appendTo: + if len(paragraph.runs)>0: + paragraph.runs[-1].add_break() + paragraph.runs[-1].add_text(trans_text) + elif len(paragraph.hyperlinks)>0: + paragraph.hyperlinks[-1].runs[-1].add_break() + paragraph.hyperlinks[-1].runs[-1].add_text(trans_text) + else: + replaced=False + if len(paragraph.runs)>0: + for index,run in enumerate(paragraph.runs): + if not check_image(run): + if not replaced: + run.text=trans_text + replaced=True + else: + run.clear() + elif len(paragraph.hyperlinks)>0: + for hyperlink in paragraph.hyperlinks: + for index,run in enumerate(hyperlink.runs): + if not check_image(run): + if not replaced: + run.text=trans_text + replaced=True + else: + run.clear() + + text_count+=item.get('count',0) + set_paragraph_linespace(paragraph) + if onlyText: + clear_image(paragraph) + +def read_comments_from_docx(docx_path, texts): + comments = [] + with zipfile.ZipFile(docx_path, 'r') as docx: + # 尝试读取批注文件 + if 'word/comments.xml' in docx.namelist(): + with docx.open('word/comments.xml') as comments_file: + # 解析 XML + tree = ET.parse(comments_file) + root = tree.getroot() + + # 定义命名空间 + namespace = {'ns0': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'} + + # 查找所有批注 + for comment in root.findall('ns0:comment', namespace): + comment_id = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}id') + author = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}author') + date = comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}date') + text = ''.join(t.text for p in comment.findall('.//ns0:p', namespace) for r in p.findall('.//ns0:r', namespace) for t in r.findall('.//ns0:t', namespace)) + append_comment(text, comment_id, texts) + +def modify_comment_in_docx(docx_path, texts): + # 创建一个临时文件名,保留原始路径 + temp_docx_path = os.path.join(os.path.dirname(docx_path), 'temp_' + os.path.basename(docx_path)) + + # 打开原始 docx 文件 + with zipfile.ZipFile(docx_path, 'r') as docx: + # 创建一个新的 docx 文件 + with zipfile.ZipFile(temp_docx_path, 'w') as new_docx: + for item in docx.infolist(): + # 读取每个文件 + with docx.open(item) as file: + if item.filename == 'word/comments.xml': + # 解析批注 XML + tree = ET.parse(file) + root = tree.getroot() + + # 定义命名空间 + namespace = {'ns0': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'} + + # 查找并修改批注 + for comment in root.findall('ns0:comment', namespace): + text = ''.join(t.text for p in comment.findall('.//ns0:p', namespace) for r in p.findall('.//ns0:r', namespace) for t in r.findall('.//ns0:t', namespace)) + if check_text(text): + for newitem in texts: + # text_count+=newitem.get('count',0) + new_text=newitem.get('text',"") + comment_id=newitem.get('comment_id',"") + # print("new_text:",new_text) + # print("comment_id:",comment_id) + # print("origin_id:",comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}id')) + if comment.get('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}id') == comment_id: + + # 清除现有段落 + for p in comment.findall('.//ns0:t', namespace): + # 删除 ns0:t 元素 + # comment.remove(p) # 删除 ns0:t 元素 + + # # 创建新的 ns0:t 元素 + # new_text_elem = ET.Element('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}t') + # new_text_elem.text = new_text # 设置新的文本内容 + + # # 将新的 ns0:t 元素添加到段落中 + # r = ET.Element('{http://schemas.openxmlformats.org/wordprocessingml/2006/main}r') # 创建新的 run 元素 + # r.append(new_text_elem) # 将新的 ns0:t 添加到 run 中 + # p.append(r) # 将 run 添加到段落中 + p.text=new_text + # 打印修改后的 XML 内容 + modified_xml = ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8') + # print(modified_xml) + # 将修改后的 XML 写入新的 docx 文件 + new_docx.writestr(item.filename, modified_xml) + else: + # 其他文件直接写入新的 docx 文件 + new_docx.writestr(item.filename, file.read()) + + # print(temp_docx_path) + # 替换原始文件 + os.replace(temp_docx_path, docx_path) + + +def append_ins(text, ins_id, texts): + if check_text(text): + texts.append({"text": text, "type": "ins", "ins_id": ins_id, "complete": False}) + + +def read_insstd_from_docx(docx_path, texts): + document_ins = [] + namespace = '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}' + namespace14='{http://schemas.microsoft.com/office/word/2010/wordml}' + with zipfile.ZipFile(docx_path, 'r') as docx: + # 尝试读取批注文件 + if 'word/document.xml' in docx.namelist(): + with docx.open('word/document.xml') as document_file: + # 解析 XML + tree = ET.parse(document_file) + root = tree.getroot() + for element in root.findall(namespace + 'body'): + for p in element.findall(namespace + 'p'): + for ins in p.findall(namespace + 'ins'): + ins_id = ins.get(namespace + 'id') + for r in ins.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + append_ins(t.text, ins_id, texts) + for sdt in element.findall(namespace + 'sdt'): + for sdtContent in sdt.findall(namespace + 'sdtContent'): + for p in sdtContent.findall(namespace + 'p'): + sdt_id = p.get(namespace14 + 'paraId') + for r in p.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + append_sdt(t.text, sdt_id, texts) + for ins in p.findall(namespace + 'ins'): + for r in ins.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + append_sdt(t.text, sdt_id, texts) + + + +def append_sdt(text, sdt_id, texts): + if check_text(text): + texts.append({"text": text, "type": "sdt", "sdt_id": sdt_id, "complete": False}) + + + +def modify_inssdt_in_docx(docx_path, texts): + print(texts,docx_path) + temp_docx_path = os.path.join(os.path.dirname(docx_path), 'temp_std_' + os.path.basename(docx_path)) + with zipfile.ZipFile(docx_path, 'r') as docx: + with zipfile.ZipFile(temp_docx_path, 'w') as new_docx: + for item in docx.infolist(): + with docx.open(item) as file: + if item.filename == 'word/document.xml': + tree = ET.parse(file) + root = tree.getroot() + namespace = '{http://schemas.openxmlformats.org/wordprocessingml/2006/main}' + namespace14='{http://schemas.microsoft.com/office/word/2010/wordml}' + for body in root.findall(namespace + 'body'): + for sdt in body.findall(namespace + 'sdt'): + for sdtContent in sdt.findall(namespace + 'sdtContent'): + for p in sdtContent.findall(namespace + 'p') : + for r in p.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + text = t.text + if check_text(text): + for newitem in texts: + new_text = newitem.get('text', "") + sdt_id = newitem.get('sdt_id', "") + if p.get(namespace14 + 'paraId') == sdt_id: + t.text = new_text + for ins in p.findall(namespace + 'ins'): + for r in ins.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + text = t.text + if check_text(text): + for newitem in texts: + new_text = newitem.get('text', "") + sdt_id = newitem.get('sdt_id', "") + if p.get(namespace14 + 'paraId') == sdt_id: + t.text = new_text + + for p in body.findall(namespace + 'p'): + for ins in p.findall(namespace + 'ins'): + for r in ins.findall(namespace + 'r'): + for t in r.findall(namespace + 't'): + text = t.text + if check_text(text): + for newitem in texts: + new_text = newitem.get('text', "") + ins_id = newitem.get('ins_id', "") + if ins.get(namespace + 'id') == ins_id: + t.text = new_text + modified_xml = ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8') + new_docx.writestr(item.filename, modified_xml) + else: + new_docx.writestr(item.filename, file.read()) + os.replace(temp_docx_path, docx_path) diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/utils/auth_tools.py b/app/utils/auth_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c6efdecfeb35716351e3426914bcfe1f82ea6056 --- /dev/null +++ b/app/utils/auth_tools.py @@ -0,0 +1,22 @@ +#========== utils/auth_tools.py ========== +import random +from datetime import datetime, timedelta +from werkzeug.security import generate_password_hash, check_password_hash + +def generate_code(length=6): + """生成数字验证码""" + return ''.join(random.choices('0123456789', k=length)) + +def validate_code(code_record): + """验证码有效性检查""" + if not code_record: + return False + return (datetime.utcnow() - code_record.created_at) < timedelta(seconds=1800) + +def hash_password(password): + """密码哈希处理""" + return generate_password_hash(password) + +def check_password(hashed_password, password): + """密码校验""" + return check_password_hash(hashed_password, password) \ No newline at end of file diff --git a/app/utils/check_utils.py b/app/utils/check_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe7753cfaeac228e0dfac92c347118d416d6f52e --- /dev/null +++ b/app/utils/check_utils.py @@ -0,0 +1,44 @@ +# utils/ai_utils.py +import openai +from io import BytesIO +import fitz # PyMuPDF +import logging + +class AIChecker: + @staticmethod + def check_openai_connection(api_url: str, api_key: str, model: str, timeout: int = 10): + """OpenAI连通性测试""" + try: + openai.api_key = api_key + openai.base_url = api_url + + # 发送一个简单的聊天请求 + response = openai.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "hi"}], + timeout=timeout + ) + # 返回连接成功和响应内容 + return True, response.choices[0].message.content + except Exception as e: + logging.error(f"OpenAI连接测试失败: {str(e)}") + return False, str(e) + + @staticmethod + def check_pdf_scanned(file_stream: BytesIO): + """PDF扫描件检测""" + try: + file_stream.seek(0) + doc = fitz.open(stream=file_stream.read(), filetype="pdf") + pages_to_check = min(5, len(doc)) + + for page_num in range(pages_to_check): + page = doc[page_num] + if page.get_text().strip(): # 发现可编辑文本 + return False + if page.get_images(): # 发现图像 + return True + return False + except Exception as e: + logging.error(f"PDF检测失败: {str(e)}") + raise diff --git a/app/utils/exceptions.py b/app/utils/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..4483532e518327950e80e8121c0c413494ebf0c5 --- /dev/null +++ b/app/utils/exceptions.py @@ -0,0 +1,21 @@ +# app/exceptions.py +class APIException(Exception): + """基础API异常 [^1][^4]""" + def __init__(self, message, code=400, payload=None): + super().__init__() + self.message = message + self.code = code + self.payload = payload + +class NotFoundException(APIException): + def __init__(self, message='资源不存在'): + super().__init__(message, 404) + +class PermissionDenied(APIException): + def __init__(self, message='权限不足'): + super().__init__(message, 403) + +class ValidationError(APIException): + def __init__(self, message='参数验证失败', errors=None): + super().__init__(message, 400) + self.errors = errors diff --git a/app/utils/file_utils.py b/app/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f8f539e0c0d047e07300250d3ca7dabf1e3f1b --- /dev/null +++ b/app/utils/file_utils.py @@ -0,0 +1,273 @@ +# utils/file_utils.py +import hashlib +import os +import uuid +from datetime import datetime +from pathlib import Path + +from flask import current_app +from werkzeug.utils import secure_filename + +from pathlib import Path +from datetime import datetime +from flask import current_app + +import os +import hashlib +from pathlib import Path +from datetime import datetime +from flask import current_app + + +class FileManager: + @staticmethod + def get_upload_dir(): + """ + 获取上传文件存储目录[^1] + :return: 上传文件存储目录的绝对路径 + """ + base_dir = Path(current_app.config['UPLOAD_BASE_DIR']) + date_str = datetime.now().strftime('%Y-%m-%d') + upload_dir = base_dir / 'uploads' / date_str + upload_dir.mkdir(parents=True, exist_ok=True) + return str(upload_dir) + + @staticmethod + def generate_filename(filename): + """ + 生成唯一的文件名[^3] + :param filename: 原始文件名 + :return: 唯一的文件名 + """ + name, ext = os.path.splitext(filename) + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + return f"{name}_{timestamp}{ext}" + + @staticmethod + def get_relative_path(full_path): + """ + 获取相对于存储根目录的相对路径[^4] + :param full_path: 文件的绝对路径 + :return: 相对路径 + """ + base_dir = Path(current_app.config['UPLOAD_BASE_DIR']) + return str(Path(full_path).relative_to(base_dir)).replace('\\', '/') + + @staticmethod + def exists(file_path): + """ + 检查文件是否存在[^5] + :param file_path: 文件的相对路径或绝对路径 + :return: 文件是否存在 (True/False) + """ + if not file_path: + return False + full_path = os.path.join(current_app.config['UPLOAD_BASE_DIR'], file_path.lstrip('/')) + return os.path.exists(full_path) + + @staticmethod + def calculate_md5(file_path): + """ + 计算文件的 MD5 值[^6] + :param file_path: 文件的绝对路径 + :return: 文件的 MD5 值 + """ + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + @staticmethod + def allowed_file(filename): + """ + 验证文件类型是否允许[^7] + :param filename: 文件名 + :return: 文件类型是否允许 (True/False) + """ + ALLOWED_EXTENSIONS = {'docx', 'xlsx', 'pptx', 'pdf', 'txt', 'md', 'csv', 'xls', 'doc'} + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + @staticmethod + def validate_file_size(file_stream): + """ + 验证文件大小是否超过限制[^8] + :param file_stream: 文件流 + :return: 文件大小是否合法 (True/False) + """ + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + file_stream.seek(0, os.SEEK_END) + file_size = file_stream.tell() + file_stream.seek(0) + return file_size <= MAX_FILE_SIZE + + @staticmethod + def get_translate_absolute_path(filename): + """ + 获取翻译结果的绝对路径(保持原文件名)[^2] + :param filename: 原始文件名 + :return: 翻译结果的绝对路径 + """ + base_dir = Path(current_app.config['UPLOAD_BASE_DIR']) + date_str = datetime.now().strftime('%Y-%m-%d') + translate_dir = base_dir / 'translate' / date_str + translate_dir.mkdir(parents=True, exist_ok=True) + return str(translate_dir / filename) + + + + +class FileManager11: + @staticmethod + def allowed_file(filename): + """验证文件类型是否允许[^1]""" + ALLOWED_EXTENSIONS = {'docx', 'xlsx', 'pptx', 'pdf', 'txt', 'md', 'csv', 'xls', 'doc'} + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + @staticmethod + def validate_file_size(file_stream): + """验证文件大小是否超过限制[^2]""" + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + file_stream.seek(0, os.SEEK_END) + file_size = file_stream.tell() + file_stream.seek(0) + return file_size <= MAX_FILE_SIZE + + @staticmethod + def get_upload_dir(): + """获取基于配置的上传目录""" + upload_dir = os.path.join( + current_app.config['UPLOAD_FOLDER'], + datetime.now().strftime('%Y-%m-%d') + ) + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir, exist_ok=True) + return upload_dir + + def get_upload_dir1111(self): + """获取按日期分类的上传目录""" + # 获取项目根目录,并再上一级到所需目录 + base_dir = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + print(base_dir) + upload_dir = os.path.join(base_dir, 'uploads', datetime.now().strftime('%Y-%m-%d')) + + # 如果目录不存在则创建 + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + return upload_dir + + @staticmethod + def generate_filename(filename): + """生成安全文件名(带随机后缀防冲突)""" + safe_name = secure_filename(filename) + name_part, ext_part = os.path.splitext(safe_name) + random_str = uuid.uuid4().hex[:6] # 6位随机字符 + return f"{name_part}_{random_str}{ext_part}" + + @staticmethod + def generate_filename111(filename): + """生成安全的文件名,如果文件已存在则附加随机字符串[^4]""" + safe_filename = secure_filename(filename) + name, ext = os.path.splitext(safe_filename) + return f"{name}_{str(uuid.uuid4())[:5]}{ext}" + + @staticmethod + def safe_remove(filepath): + """安全删除文件""" + if os.path.exists(filepath): + try: + os.remove(filepath) + print(f"File {filepath} has been deleted.") + except Exception as e: + print(f"Error occurred while deleting file {filepath}: {e}") + else: + print(f"File {filepath} does not exist.") + + @staticmethod + def exists(file_path: str) -> bool: + """验证文件是否存在并检查路径安全性[^1] + Args: + file_path: 文件路径,支持相对路径和绝对路径 + Returns: + bool: 文件是否存在且路径合法 + """ + try: + # 标准化路径,防止路径遍历攻击 + normalized_path = Path(file_path).resolve(strict=False) + + # 验证路径是否在允许的目录下 + upload_dir = Path(current_app.config['UPLOAD_FOLDER']).resolve() + if not normalized_path.is_relative_to(upload_dir): + return False + + return normalized_path.exists() and normalized_path.is_file() + + except Exception as e: + current_app.logger.error(f"文件路径验证失败: {str(e)}") + return False + + @staticmethod + def get_storage_dir(): + """获取按日期分类的存储目录[^2]""" + base_dir = Path(current_app.config['STORAGE_FOLDER']) + storage_dir = base_dir / datetime.now().strftime('%Y-%m-%d') + + if not storage_dir.exists(): + storage_dir.mkdir(parents=True, exist_ok=True) + + return str(storage_dir) + + @staticmethod + def is_secure_path(file_path: str, base_dir: str) -> bool: + """验证文件路径是否安全[^3] + Args: + file_path: 文件路径 + base_dir: 基准目录 + Returns: + bool: 路径是否安全 + """ + try: + normalized_path = Path(file_path).resolve(strict=False) + base_dir_path = Path(base_dir).resolve() + return normalized_path.is_relative_to(base_dir_path) + except Exception as e: + current_app.logger.error(f"路径安全验证失败: {str(e)}") + return False + + @staticmethod + def exists111xin(file_path: str, base_dir: str) -> bool: + """验证文件是否存在并检查路径安全性[^4] + Args: + file_path: 文件路径 + base_dir: 基准目录 + Returns: + bool: 文件是否存在且路径合法 + """ + if not FileManager.is_secure_path(file_path, base_dir): + return False + + normalized_path = Path(file_path).resolve(strict=False) + return normalized_path.exists() and normalized_path.is_file() + + @staticmethod + def calculate_md5(file_path): + """计算文件的MD5值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def get_upload_dir(): + """获取按日期分类的上传目录""" + # 获取项目根目录,并再上一级到所需目录 + base_dir = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + print(base_dir) + upload_dir = os.path.join(base_dir, 'uploads', datetime.now().strftime('%Y-%m-%d')) + + # 如果目录不存在则创建 + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + return upload_dir diff --git a/app/utils/mail.py b/app/utils/mail.py new file mode 100644 index 0000000000000000000000000000000000000000..d70de42bf5189a60957fff3b1c96723726a56acd --- /dev/null +++ b/app/utils/mail.py @@ -0,0 +1,14 @@ +#========== utils/mail.py ========== +from flask_mail import Message +from app.extensions import mail + +class MailService: + @classmethod + def send_verification(cls, email, code): + """发送验证码邮件""" + msg = Message( + subject="验证码通知", + recipients=[email], + body=f"您的验证码是:{code},30分钟内有效" + ) + mail.send(msg) diff --git a/app/utils/mail_service.py b/app/utils/mail_service.py new file mode 100644 index 0000000000000000000000000000000000000000..df49bc6e205f979bd4354418b7827fe4a426c3a1 --- /dev/null +++ b/app/utils/mail_service.py @@ -0,0 +1,10 @@ +# services/email_service.py +from flask_mail import Message +from app.extensions import mail + +class EmailService: + @staticmethod + def send_verification_code(email, code): + msg = Message("验证码邮件", recipients=[email]) + msg.body = f"您的验证码是:{code},有效期10分钟" + mail.send(msg) diff --git a/app/utils/mail_templates.py b/app/utils/mail_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..f3da5ee6f2d0f6272a6d23c56f0f6385c32e0b25 --- /dev/null +++ b/app/utils/mail_templates.py @@ -0,0 +1,89 @@ +# app/utils/mail_templates.py +from datetime import datetime + +def generate_register_email(user: dict, code: str) -> str: + """生成注册确认邮件HTML""" + return f""" + + + + + +
+
+

欢迎注册我们的服务

+
+
+

尊敬的{user.get('name', '用户')}:

+

您的注册验证码是:

+
{code}
+

验证码有效期15分钟,请勿泄露给他人

+
+
+ + + """ + +def generate_new_user_notification(user: dict) -> str: + """生成新用户注册通知邮件HTML""" + return f""" + + + +
+

系统通知:新用户注册

+

以下用户刚刚完成了注册:

+
    +
  • 用户ID:{user.get('id', '')}
  • +
  • 邮箱:{user.get('email', '')}
  • +
  • 注册时间:{user.get('created_at', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))}
  • +
+
+ + + """ + +def generate_password_reset_email(user: dict, code: str) -> str: + """生成密码重置邮件HTML""" + return f""" + + + +
+

密码重置验证码

+

您的密码重置验证码是:

+
{code}
+

验证码有效期30分钟

+
+ + + """ + +def generate_password_change_email(user: dict) -> str: + """生成密码修改通知邮件HTML""" + return f""" + + + +
+

密码修改通知

+

您的账户 {user.get('email', '')} 密码修改成功

+

时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

+
+ + + """ \ No newline at end of file diff --git a/app/utils/mailer.py b/app/utils/mailer.py new file mode 100644 index 0000000000000000000000000000000000000000..54fb438b27ba5c5d503bbd224fc560d40c685b6a --- /dev/null +++ b/app/utils/mailer.py @@ -0,0 +1,74 @@ +# app/services/mailer.py +from flask_mail import Message +from app.extensions import mail +from app.utils.mail_templates import ( + generate_register_email, + generate_new_user_notification, + generate_password_reset_email, + generate_password_change_email +) + +class EmailService: + def send_register_verification(email: str, code: str): + """发送注册验证邮件 [^1]""" + msg = Message( + subject="注册验证码", + recipients=[email], + html=f""" +

您的注册验证码是:{code}

+

验证码15分钟内有效

+ """ + ) + mail.send(msg) + + @staticmethod + def send_password_reset(email: str, code: str): + """发送密码重置邮件 [^2]""" + msg = Message( + subject="密码重置验证码", + recipients=[email], + html=f""" +

您的密码重置验证码是:{code}

+

验证码30分钟内有效

+ """ + ) + mail.send(msg) + @staticmethod + def send_register_verification666(email: str, user: dict, code: str): + """发送注册验证邮件""" + msg = Message( + subject="注册验证码", + recipients=[email], + html=generate_register_email(user, code) + ) + mail.send(msg) + + @staticmethod + def send_new_user_alert(admin_emails: list, user: dict): + """发送新用户通知""" + msg = Message( + subject="新用户注册通知", + recipients=admin_emails, + html=generate_new_user_notification(user) + ) + mail.send(msg) + + @staticmethod + def send_password_reset666(email: str, user: dict, code: str): + """发送密码重置邮件""" + msg = Message( + subject="密码重置验证码", + recipients=[email], + html=generate_password_reset_email(user, code) + ) + mail.send(msg) + + @staticmethod + def send_password_change_notification(email: str, user: dict): + """发送密码修改通知""" + msg = Message( + subject="密码修改通知", + recipients=[email], + html=generate_password_change_email(user) + ) + mail.send(msg) \ No newline at end of file diff --git a/app/utils/response.py b/app/utils/response.py new file mode 100644 index 0000000000000000000000000000000000000000..b28322907730dc9e80f7fdbd1cb75ef19aa00c9b --- /dev/null +++ b/app/utils/response.py @@ -0,0 +1,61 @@ +# utils/response.py +from flask import jsonify + +class APIResponse: + @staticmethod + def success(data=None, message='操作成功', code=200): + return { + 'code': code, + 'message': message, + 'data': data + }, code + + @staticmethod + def error(message='请求错误', code=400, errors=None): + payload = { + 'code': code, + 'message': f"{message}" + } + if errors: + payload['errors'] = errors + return payload, code + + @classmethod + def not_found(cls, message='资源不存在'): + return cls.error(message=message, code=404) + + @classmethod + def unauthorized(cls, message='身份验证失败'): + return cls.error(message=message, code=401) + +#========== utils/responses.py ========== +class APIResponse1111: + @staticmethod + def success(data=None, message='操作成功', code=200): + return { + 'code': code, + 'message': message, + 'data': data + }, code + + @staticmethod + def error(message='请求错误', code=400, errors=None): + payload = { + 'code': code, + 'message': message + } + if errors: + payload['errors'] = errors + return payload, code + + @classmethod + def not_found(cls, message='资源不存在'): + return cls.error(message=message, code=404) + + @classmethod + def unauthorized(cls, message='身份验证失败'): + return cls.error(message=message, code=401) + + @classmethod + def forbidden(cls, message='权限不足'): + return cls.error(message=message, code=403) \ No newline at end of file diff --git a/app/utils/security.py b/app/utils/security.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ab82b09a0171611c40743e54af9d2cc0a71c41 --- /dev/null +++ b/app/utils/security.py @@ -0,0 +1,8 @@ +# utils/security.py +from werkzeug.security import generate_password_hash, check_password_hash + +def hash_password(password): + return generate_password_hash(password) + +def verify_password(hashed_password, password): + return check_password_hash(hashed_password, password) diff --git a/app/utils/task_utils.py b/app/utils/task_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..102e1d7c5406bbf4a2caa4b65cb545ca8efc8de0 --- /dev/null +++ b/app/utils/task_utils.py @@ -0,0 +1,7 @@ +# utils/task_utils.py +import datetime + +def generate_task_no(): + """生成任务编号[^3]""" + now = datetime.datetime.now() + return f"T{now.strftime('%Y%m%d%H%M%S')}" diff --git a/app/utils/translate_utils.py b/app/utils/translate_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d1cda5354edbabd1af951a5ceca2c593e7884a --- /dev/null +++ b/app/utils/translate_utils.py @@ -0,0 +1,44 @@ +# utils/translate_utils.py +from typing import List +import subprocess +from pathlib import Path +import zipfile +from io import BytesIO +from datetime import datetime + +class TranslateUtils: + @staticmethod + def execute_python_script(script_path: str, args: List[str], timeout: int = 120): + """执行Python脚本并处理超时[^1]""" + try: + result = subprocess.run( + ['python3', script_path] + args, + capture_output=True, + text=True, + timeout=timeout + ) + return result.stdout.strip(), None + except subprocess.TimeoutExpired: + return None, '操作超时' + except Exception as e: + return None, str(e) + + @staticmethod + def generate_zip(files: List[tuple]) -> BytesIO: + """生成内存ZIP文件流[^2]""" + zip_buffer = BytesIO() + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: + for file_path, arcname in files: + zip_file.write(file_path, arcname) + zip_buffer.seek(0) + return zip_buffer + + @staticmethod + def get_preset_settings() -> dict: + """获取预设配置[^5]""" + return { + 'models': ['gpt-3.5-turbo', 'gpt-4'], + 'default_model': 'gpt-3.5-turbo', + 'max_threads': 10, + 'prompt_template': '将以下内容翻译为{target_lang}' + } diff --git a/app/utils/validators.py b/app/utils/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..5e911a37118425110f087f763c422f8d0eac71b1 --- /dev/null +++ b/app/utils/validators.py @@ -0,0 +1,97 @@ +# utils/validators.py +from datetime import datetime, timedelta +from app.models import SendCode + + +def validate_verification_code(email: str, code: str, code_type: int): + """验证验证码有效性[^1]""" + expire_time = datetime.utcnow() - timedelta(minutes=10) + send_code = SendCode.query.filter( + SendCode.send_to == email, + SendCode.code == code, + SendCode.send_type == code_type, + SendCode.created_at > expire_time + ).order_by(SendCode.created_at.desc()).first() + + if not send_code: + return False, '验证码已过期或无效' + return True, None + + +def validate_password_confirmation(data: dict): + """验证密码一致性[^2]""" + if data['password'] != data.get('password_confirmation'): + return False, '两次密码不一致' + return True, None + +# utils/validators.py 新增方法 +def validate_password_complexity(password: str): + """密码复杂度验证[^5]""" + if len(password) < 6: + return False, "密码至少需要6位" + if not any(c.isalpha() for c in password) or not any(c.isdigit() for c in password): + return False, "密码需包含字母和数字" + return True, None + + +# utils/validators.py +from flask import request +from app.utils.response import APIResponse + + +def validate_pagination_params(req): + """验证并获取分页参数[^1] + + 返回: + tuple: (page, limit) + """ + try: + page = int(req.args.get('page', 1)) + limit = int(req.args.get('limit', 20)) + + if page < 1: + raise ValueError('页码必须大于0') + if limit < 1 or limit > 100: + raise ValueError('每页数量必须在1到100之间') + + return page, limit + except ValueError as e: + raise APIResponse.error(str(e), 400) + + +def validate_date_range(start_date, end_date): + """验证日期范围参数[^2] + 参数: + start_date (str): 起始日期 + end_date (str): 结束日期 + 返回: + tuple: (start_date, end_date) 转换后的datetime对象 + """ + try: + start = datetime.fromisoformat(start_date) if start_date else None + end = datetime.fromisoformat(end_date) if end_date else None + + if start and end and start > end: + raise ValueError('起始日期不能晚于结束日期') + + return start, end + except ValueError as e: + raise APIResponse.error('日期格式错误', 400) + + +def validate_id_list(ids): + """验证ID列表参数[^3] + 参数: + ids (list): ID列表 + 返回: + list: 验证后的ID列表 + """ + if not ids or not isinstance(ids, list): + raise APIResponse.error('参数错误', 400) + + try: + return [int(id) for id in ids] + except ValueError: + raise APIResponse.error('ID格式错误', 400) + + diff --git a/db/dev.db b/db/dev.db new file mode 100644 index 0000000000000000000000000000000000000000..8ddcbf29ff3f74557fa1f9e0e8182158a5a0f137 --- /dev/null +++ b/db/dev.db @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf62d31c15a1303e990bb618ab121acd5b45ceebdafe2dbb79e93535b5adf5b8 +size 147456 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..68aee8247f5a414763e9cff4244c36c0122ef931 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +Flask==3.1.0 +Flask-Cors==5.0.0 +Flask-JWT-Extended==4.7.1 +Flask-Migrate==4.1.0 +Flask-RESTful==0.3.10 +Flask-SQLAlchemy==3.1.1 +python-dotenv==1.0.1 +Flask-Limiter==3.10.1 +Flask-Mail==0.10.0 +marshmallow-sqlalchemy==1.4.1 +flask-marshmallow==1.3.0 +marshmallow==3.26.1 +# 翻译需要的库 +python-docx==1.1.2 +redis==5.2.1 +openpyxl==3.1.5 +openpyxl +python-pptx +pymysql +PyMuPDF==1.24.7 +tiktoken +pdf2docx +PyMySQL==1.1.1 +docx2pdf==0.1.8 +pytesseract +pymdown-extensions +GeneralAgent==0.3.21 +pdf2docx==0.5.8 +pypdf==5.2.0 +python-dotenv==1.0.1 +pdfdeal==1.0.2 +pdfkit==1.0.0 +docx2pdf==0.1.8 +Shapely==2.0.7 \ No newline at end of file diff --git a/sync_data.sh b/sync_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..099c58dc48223fd43cff09eedfa7e3bcc3bc7df2 --- /dev/null +++ b/sync_data.sh @@ -0,0 +1,115 @@ +#!/bin/sh + +# 检查环境变量 +if [ -z "$HF_TOKEN" ] || [ -z "$DATASET_ID" ]; then + echo "缺少必要的环境变量 HF_TOKEN 或 DATASET_ID" + exit 1 +fi + +# 上传备份 +cat > /tmp/hf_sync.py << 'EOL' +from huggingface_hub import HfApi +import sys +import os + +def manage_backups(api, repo_id, max_files=50): + files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") + backup_files = [f for f in files if f.startswith('dev_backup_') and f.endswith('.db')] + backup_files.sort() + + if len(backup_files) >= max_files: + files_to_delete = backup_files[:(len(backup_files) - max_files + 1)] + for file_to_delete in files_to_delete: + try: + api.delete_file(path_in_repo=file_to_delete, repo_id=repo_id, repo_type="dataset") + print(f'已删除旧备份: {file_to_delete}') + except Exception as e: + print(f'删除 {file_to_delete} 时出错: {str(e)}') + +def upload_backup(file_path, file_name, token, repo_id): + api = HfApi(token=token) + try: + api.upload_file( + path_or_fileobj=file_path, + path_in_repo=file_name, + repo_id=repo_id, + repo_type="dataset" + ) + print(f"成功上传 {file_name}") + + manage_backups(api, repo_id) + except Exception as e: + print(f"文件上传出错: {str(e)}") + +# 下载最新备份 +def download_latest_backup(token, repo_id): + try: + api = HfApi(token=token) + files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") + backup_files = [f for f in files if f.startswith('dev_backup_') and f.endswith('.db')] + + if not backup_files: + print("未找到备份文件") + return + + latest_backup = sorted(backup_files)[-1] + + filepath = api.hf_hub_download( + repo_id=repo_id, + filename=latest_backup, + repo_type="dataset" + ) + + if filepath and os.path.exists(filepath): + os.makedirs('./db', exist_ok=True) + os.system(f'cp "{filepath}" ./db/dev.db') + print(f"成功从 {latest_backup} 恢复备份") + + except Exception as e: + print(f"下载备份时出错: {str(e)}") + +if __name__ == "__main__": + action = sys.argv[1] + token = sys.argv[2] + repo_id = sys.argv[3] + + if action == "upload": + file_path = sys.argv[4] + file_name = sys.argv[5] + upload_backup(file_path, file_name, token, repo_id) + elif action == "download": + download_latest_backup(token, repo_id) +EOL + +# 首次启动时下载最新备份 +echo "正在从 HuggingFace 下载最新备份..." +python3 /tmp/hf_sync.py download "${HF_TOKEN}" "${DATASET_ID}" + +# 同步函数 +sync_data() { + while true; do + echo "开始同步进程 $(date)" + + if [ -f "./db/dev.db" ]; then + timestamp=$(date +%Y%m%d_%H%M%S) + backup_file="dev_backup_${timestamp}.db" + + # 复制数据库文件 + cp ./db/dev.db "/tmp/${backup_file}" + + echo "正在上传备份到 HuggingFace..." + python3 /tmp/hf_sync.py upload "${HF_TOKEN}" "${DATASET_ID}" "/tmp/${backup_file}" "${backup_file}" + + rm -f "/tmp/${backup_file}" + else + echo "数据库文件不存在,等待下次同步..." + fi + + SYNC_INTERVAL=${SYNC_INTERVAL:-7200} + echo "下次同步将在 ${SYNC_INTERVAL} 秒后进行..." + sleep $SYNC_INTERVAL + done +} + +# 后台启动同步进程 +sync_data & \ No newline at end of file