Spaces:
Build error
Build error
| import uvicorn | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| import hashlib | |
| import os | |
| from enum import Enum | |
| from paddleocr import PaddleOCR | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| from typing import Optional | |
| app = FastAPI(docs_url='/') | |
| # 确保输出目录存在 | |
| output_dir = 'output' | |
| os.makedirs(output_dir, exist_ok=True) | |
| class LangEnum(str, Enum): | |
| ch = "ch" | |
| en = "en" | |
| japan = "japan" | |
| korean = "korean" | |
| chinese_cht = "chinese_cht" | |
| fr = "fr" | |
| de = "de" | |
| # OCR 实例缓存 | |
| ocr_cache = {} | |
| def get_ocr_instance(lang: str = "ch", use_gpu: bool = False): | |
| """获取OCR实例,使用PP-OCRv5模型""" | |
| cache_key = f"v5_{lang}_{use_gpu}" | |
| if cache_key not in ocr_cache: | |
| # 使用PaddleOCR 3.0的新API + PP-OCRv5模型 | |
| ocr_cache[cache_key] = PaddleOCR( | |
| ocr_version="PP-OCRv5", # 指定使用PP-OCRv5版本 | |
| lang=lang, | |
| text_detection_model_name="PP-OCRv5_server_det", # 使用server版本检测模型 | |
| text_recognition_model_name="PP-OCRv5_server_rec", # 使用server版本识别模型 | |
| use_doc_orientation_classify=False, # 关闭文档方向分类 | |
| use_doc_unwarping=False, # 关闭文档矫正 | |
| use_textline_orientation=False, # 关闭文本行方向分类 | |
| device="gpu" if use_gpu else "cpu" | |
| ) | |
| return ocr_cache[cache_key] | |
| def validate_image(file: UploadFile): | |
| """验证上传的文件""" | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="文件必须是图片格式") | |
| # 检查文件大小 (最大10MB) | |
| if hasattr(file, 'size') and file.size and file.size > 10 * 1024 * 1024: | |
| raise HTTPException(status_code=400, detail="图片文件大小不能超过10MB") | |
| async def ocr_recognition( | |
| file: UploadFile = File(...), | |
| lang: LangEnum = LangEnum.ch, | |
| use_gpu: bool = False | |
| ): | |
| """PP-OCRv5文字识别 - 支持5种文字类型的单模型""" | |
| try: | |
| validate_image(file) | |
| contents = await file.read() | |
| if not contents: | |
| raise HTTPException(status_code=400, detail="文件内容为空") | |
| # 转换图片格式 | |
| image = Image.open(io.BytesIO(contents)) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # 获取OCR实例 | |
| ocr = get_ocr_instance(lang=lang, use_gpu=use_gpu) | |
| # 转换为numpy数组进行识别 | |
| img_array = np.array(image) | |
| # 使用PP-OCRv5进行识别 | |
| results = ocr.predict(img_array) | |
| if not results or len(results) == 0: | |
| return { | |
| "success": True, | |
| "message": "未检测到文字", | |
| "model_version": "PP-OCRv5", | |
| "language": lang, | |
| "count": 0, | |
| "results": [] | |
| } | |
| # 处理识别结果 | |
| result = results[0] # 取第一个结果 | |
| # 提取结果信息 | |
| ocr_results = [] | |
| if hasattr(result, 'json') and result.json: | |
| # 从result.json中提取信息 | |
| result_data = result.json | |
| rec_texts = result_data.get('rec_texts', []) | |
| rec_scores = result_data.get('rec_scores', []) | |
| dt_polys = result_data.get('dt_polys', []) | |
| for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)): | |
| ocr_results.append({ | |
| "id": i, | |
| "text": text, | |
| "confidence": round(float(score), 4), | |
| "bbox": poly.tolist() if hasattr(poly, 'tolist') else poly | |
| }) | |
| return { | |
| "success": True, | |
| "model_version": "PP-OCRv5", | |
| "language": lang, | |
| "count": len(ocr_results), | |
| "results": ocr_results | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"OCR识别失败: {str(e)}") | |
| async def table_recognition( | |
| file: UploadFile = File(...), | |
| lang: LangEnum = LangEnum.ch, | |
| use_gpu: bool = False | |
| ): | |
| """PP-StructureV3表格识别""" | |
| try: | |
| validate_image(file) | |
| contents = await file.read() | |
| if not contents: | |
| raise HTTPException(status_code=400, detail="文件内容为空") | |
| # 计算文件哈希 | |
| file_hash = hashlib.sha256(contents).hexdigest()[:12] | |
| # 转换图片格式 | |
| image = Image.open(io.BytesIO(contents)) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # 使用PP-StructureV3进行表格识别 | |
| # 这里需要单独的表格识别产线 | |
| from paddleocr import PPStructure | |
| # 获取表格识别实例 | |
| table_key = f"table_v3_{lang}_{use_gpu}" | |
| if table_key not in ocr_cache: | |
| ocr_cache[table_key] = PPStructure( | |
| table=True, | |
| lang=lang, | |
| device="gpu" if use_gpu else "cpu", | |
| show_log=True | |
| ) | |
| table_engine = ocr_cache[table_key] | |
| img_array = np.array(image) | |
| result = table_engine(img_array) | |
| # 保存结果 | |
| try: | |
| from paddleocr import save_structure_res | |
| save_structure_res(result, output_dir, file_hash) | |
| except Exception as save_error: | |
| print(f"保存结果文件失败: {save_error}") | |
| # 处理结果 | |
| tables = [] | |
| images = [] | |
| texts = [] | |
| for item in result: | |
| item_type = item.get('type', '') | |
| bbox = item.get('bbox', []) | |
| res = item.get('res', {}) | |
| if item_type == 'table': | |
| tables.append({ | |
| "type": item_type, | |
| "bbox": bbox, | |
| "html": res.get('html', ''), | |
| "confidence": res.get('confidence', 0.0) | |
| }) | |
| elif item_type == 'figure': | |
| images.append({ | |
| "type": item_type, | |
| "bbox": bbox | |
| }) | |
| else: | |
| texts.append({ | |
| "type": item_type, | |
| "bbox": bbox, | |
| "text": res.get('text', '') if isinstance(res, dict) else str(res) | |
| }) | |
| return { | |
| "success": True, | |
| "model_version": "PP-StructureV3", | |
| "language": lang, | |
| "hash": file_hash, | |
| "summary": { | |
| "total_elements": len(result), | |
| "tables": len(tables), | |
| "images": len(images), | |
| "texts": len(texts) | |
| }, | |
| "tables": tables, | |
| "images": images, | |
| "texts": texts | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"表格识别失败: {str(e)}") | |
| async def health_check(): | |
| """健康检查接口""" | |
| return { | |
| "status": "healthy", | |
| "ocr_version": "PP-OCRv5", | |
| "structure_version": "PP-StructureV3", | |
| "cache_instances": len(ocr_cache), | |
| "supported_languages": [lang.value for lang in LangEnum] | |
| } | |
| async def get_model_info(): | |
| """获取模型信息""" | |
| return { | |
| "ocr_models": { | |
| "PP-OCRv5_server_det": "高精度文本检测模型", | |
| "PP-OCRv5_server_rec": "高精度文本识别模型 - 支持中英日韩繁5种文字类型" | |
| }, | |
| "structure_models": { | |
| "PP-StructureV3": "通用文档解析方案 - 支持表格、图像、文本混合识别" | |
| }, | |
| "features": { | |
| "multi_language": "单模型支持5种文字类型", | |
| "handwriting": "显著提升手写体识别能力", | |
| "accuracy_improvement": "相比PP-OCRv4提升13个百分点" | |
| } | |
| } | |
| async def root(): | |
| """根路径""" | |
| return { | |
| "message": "PP-OCRv5 OCR API 服务正常运行", | |
| "version": "3.0", | |
| "models": "PP-OCRv5 + PP-StructureV3", | |
| "docs": "/docs" | |
| } | |
| # 挂载静态文件服务 | |
| app.mount("/output", StaticFiles(directory=output_dir, follow_symlink=True, html=True), name="output") | |
| if __name__ == '__main__': | |
| uvicorn.run(app=app, host="0.0.0.0", port=7860) |