Spaces:
Running
Running
Commit
·
550ec39
1
Parent(s):
57d51b7
Add PDF conversion API endpoints
Browse files- Add PDF upload endpoint at /api/convert
- Add status checking endpoint at /api/status/{task_id}
- Add download endpoint at /api/download/{task_id}
- Add basic PDF processing dependency (PyMuPDF)
- Prepare structure for MinerU integration
- Update API info with new endpoints
This view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +118 -6
- config/magic-pdf.json +9 -0
- pdf_converter_mineru.py +272 -0
- requirements.txt +4 -1
- vendor/mineru/mineru/__init__.py +1 -0
- vendor/mineru/mineru/backend/__init__.py +1 -0
- vendor/mineru/mineru/backend/pipeline/__init__.py +1 -0
- vendor/mineru/mineru/backend/pipeline/batch_analyze.py +331 -0
- vendor/mineru/mineru/backend/pipeline/model_init.py +182 -0
- vendor/mineru/mineru/backend/pipeline/model_json_to_middle_json.py +249 -0
- vendor/mineru/mineru/backend/pipeline/model_list.py +6 -0
- vendor/mineru/mineru/backend/pipeline/para_split.py +381 -0
- vendor/mineru/mineru/backend/pipeline/pipeline_analyze.py +198 -0
- vendor/mineru/mineru/backend/pipeline/pipeline_magic_model.py +501 -0
- vendor/mineru/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +298 -0
- vendor/mineru/mineru/backend/vlm/__init__.py +1 -0
- vendor/mineru/mineru/backend/vlm/base_predictor.py +186 -0
- vendor/mineru/mineru/backend/vlm/hf_predictor.py +211 -0
- vendor/mineru/mineru/backend/vlm/predictor.py +111 -0
- vendor/mineru/mineru/backend/vlm/sglang_client_predictor.py +443 -0
- vendor/mineru/mineru/backend/vlm/sglang_engine_predictor.py +246 -0
- vendor/mineru/mineru/backend/vlm/token_to_middle_json.py +113 -0
- vendor/mineru/mineru/backend/vlm/utils.py +40 -0
- vendor/mineru/mineru/backend/vlm/vlm_analyze.py +93 -0
- vendor/mineru/mineru/backend/vlm/vlm_magic_model.py +521 -0
- vendor/mineru/mineru/backend/vlm/vlm_middle_json_mkcontent.py +221 -0
- vendor/mineru/mineru/cli/__init__.py +1 -0
- vendor/mineru/mineru/cli/client.py +212 -0
- vendor/mineru/mineru/cli/common.py +403 -0
- vendor/mineru/mineru/cli/fast_api.py +198 -0
- vendor/mineru/mineru/cli/gradio_app.py +343 -0
- vendor/mineru/mineru/cli/models_download.py +150 -0
- vendor/mineru/mineru/cli/vlm_sglang_server.py +4 -0
- vendor/mineru/mineru/data/__init__.py +1 -0
- vendor/mineru/mineru/data/data_reader_writer/__init__.py +17 -0
- vendor/mineru/mineru/data/data_reader_writer/base.py +63 -0
- vendor/mineru/mineru/data/data_reader_writer/dummy.py +11 -0
- vendor/mineru/mineru/data/data_reader_writer/filebase.py +62 -0
- vendor/mineru/mineru/data/data_reader_writer/multi_bucket_s3.py +144 -0
- vendor/mineru/mineru/data/data_reader_writer/s3.py +72 -0
- vendor/mineru/mineru/data/io/__init__.py +6 -0
- vendor/mineru/mineru/data/io/base.py +42 -0
- vendor/mineru/mineru/data/io/http.py +37 -0
- vendor/mineru/mineru/data/io/s3.py +114 -0
- vendor/mineru/mineru/data/utils/__init__.py +1 -0
- vendor/mineru/mineru/data/utils/exceptions.py +40 -0
- vendor/mineru/mineru/data/utils/path_utils.py +33 -0
- vendor/mineru/mineru/data/utils/schemas.py +20 -0
- vendor/mineru/mineru/model/__init__.py +1 -0
- vendor/mineru/mineru/model/layout/__init__.py +1 -0
app.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
-
from fastapi import FastAPI
|
2 |
-
from fastapi.responses import HTMLResponse
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
app = FastAPI()
|
6 |
|
7 |
@app.get("/")
|
8 |
async def root():
|
@@ -59,12 +65,118 @@ async def api_info():
|
|
59 |
"""API information endpoint"""
|
60 |
return {
|
61 |
"name": "PDF to Markdown Converter API",
|
62 |
-
"version": "0.
|
63 |
"endpoints": {
|
64 |
"/": "Main endpoint",
|
65 |
"/health": "Health check",
|
66 |
"/test": "Test HTML page",
|
67 |
"/docs": "FastAPI automatic documentation",
|
68 |
-
"/api/info": "This endpoint"
|
|
|
|
|
|
|
69 |
}
|
70 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
|
2 |
+
from fastapi.responses import HTMLResponse, FileResponse
|
3 |
import os
|
4 |
+
import tempfile
|
5 |
+
import shutil
|
6 |
+
from pathlib import Path
|
7 |
+
import asyncio
|
8 |
+
from typing import Dict, Optional
|
9 |
+
import uuid
|
10 |
|
11 |
+
app = FastAPI(title="MinerU PDF Converter", version="0.2.0")
|
12 |
|
13 |
@app.get("/")
|
14 |
async def root():
|
|
|
65 |
"""API information endpoint"""
|
66 |
return {
|
67 |
"name": "PDF to Markdown Converter API",
|
68 |
+
"version": "0.2.0",
|
69 |
"endpoints": {
|
70 |
"/": "Main endpoint",
|
71 |
"/health": "Health check",
|
72 |
"/test": "Test HTML page",
|
73 |
"/docs": "FastAPI automatic documentation",
|
74 |
+
"/api/info": "This endpoint",
|
75 |
+
"/api/convert": "Convert PDF to Markdown (POST)",
|
76 |
+
"/api/status/{task_id}": "Check conversion status",
|
77 |
+
"/api/download/{task_id}": "Download converted markdown"
|
78 |
}
|
79 |
+
}
|
80 |
+
|
81 |
+
# Store for conversion tasks
|
82 |
+
conversion_tasks: Dict[str, dict] = {}
|
83 |
+
|
84 |
+
@app.post("/api/convert")
|
85 |
+
async def convert_pdf(
|
86 |
+
background_tasks: BackgroundTasks,
|
87 |
+
file: UploadFile = File(...)
|
88 |
+
):
|
89 |
+
"""Convert PDF to Markdown"""
|
90 |
+
if not file.filename.endswith('.pdf'):
|
91 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported")
|
92 |
+
|
93 |
+
# Generate unique task ID
|
94 |
+
task_id = str(uuid.uuid4())
|
95 |
+
|
96 |
+
# Save uploaded file
|
97 |
+
temp_dir = Path(tempfile.mkdtemp())
|
98 |
+
pdf_path = temp_dir / file.filename
|
99 |
+
|
100 |
+
try:
|
101 |
+
with open(pdf_path, "wb") as buffer:
|
102 |
+
shutil.copyfileobj(file.file, buffer)
|
103 |
+
except Exception as e:
|
104 |
+
shutil.rmtree(temp_dir)
|
105 |
+
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
106 |
+
|
107 |
+
# Initialize task status
|
108 |
+
conversion_tasks[task_id] = {
|
109 |
+
"status": "processing",
|
110 |
+
"filename": file.filename,
|
111 |
+
"result": None,
|
112 |
+
"error": None,
|
113 |
+
"temp_dir": str(temp_dir)
|
114 |
+
}
|
115 |
+
|
116 |
+
# Start conversion in background
|
117 |
+
background_tasks.add_task(process_pdf_conversion, task_id, str(pdf_path))
|
118 |
+
|
119 |
+
return {
|
120 |
+
"task_id": task_id,
|
121 |
+
"status": "processing",
|
122 |
+
"message": "PDF conversion started",
|
123 |
+
"check_status_url": f"/api/status/{task_id}"
|
124 |
+
}
|
125 |
+
|
126 |
+
async def process_pdf_conversion(task_id: str, pdf_path: str):
|
127 |
+
"""Process PDF conversion in background"""
|
128 |
+
try:
|
129 |
+
# For now, just simulate conversion
|
130 |
+
await asyncio.sleep(2) # Simulate processing
|
131 |
+
|
132 |
+
# Create a dummy markdown file
|
133 |
+
output_path = Path(pdf_path).with_suffix('.md')
|
134 |
+
with open(output_path, 'w') as f:
|
135 |
+
f.write(f"# Converted from {Path(pdf_path).name}\n\n")
|
136 |
+
f.write("This is a placeholder conversion. Full MinerU integration coming soon.\n")
|
137 |
+
|
138 |
+
conversion_tasks[task_id]["status"] = "completed"
|
139 |
+
conversion_tasks[task_id]["result"] = str(output_path)
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
conversion_tasks[task_id]["status"] = "failed"
|
143 |
+
conversion_tasks[task_id]["error"] = str(e)
|
144 |
+
|
145 |
+
@app.get("/api/status/{task_id}")
|
146 |
+
async def get_conversion_status(task_id: str):
|
147 |
+
"""Check conversion status"""
|
148 |
+
if task_id not in conversion_tasks:
|
149 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
150 |
+
|
151 |
+
task = conversion_tasks[task_id]
|
152 |
+
response = {
|
153 |
+
"task_id": task_id,
|
154 |
+
"status": task["status"],
|
155 |
+
"filename": task["filename"]
|
156 |
+
}
|
157 |
+
|
158 |
+
if task["status"] == "completed":
|
159 |
+
response["download_url"] = f"/api/download/{task_id}"
|
160 |
+
elif task["status"] == "failed":
|
161 |
+
response["error"] = task["error"]
|
162 |
+
|
163 |
+
return response
|
164 |
+
|
165 |
+
@app.get("/api/download/{task_id}")
|
166 |
+
async def download_converted_file(task_id: str):
|
167 |
+
"""Download converted markdown file"""
|
168 |
+
if task_id not in conversion_tasks:
|
169 |
+
raise HTTPException(status_code=404, detail="Task not found")
|
170 |
+
|
171 |
+
task = conversion_tasks[task_id]
|
172 |
+
if task["status"] != "completed":
|
173 |
+
raise HTTPException(status_code=400, detail="Conversion not completed")
|
174 |
+
|
175 |
+
if not task["result"] or not Path(task["result"]).exists():
|
176 |
+
raise HTTPException(status_code=404, detail="Converted file not found")
|
177 |
+
|
178 |
+
return FileResponse(
|
179 |
+
task["result"],
|
180 |
+
media_type="text/markdown",
|
181 |
+
filename=Path(task["result"]).name
|
182 |
+
)
|
config/magic-pdf.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bucket_info":{
|
3 |
+
"bucket-name-1":["ak", "sk", "endpoint"],
|
4 |
+
"bucket-name-2":["ak", "sk", "endpoint"]
|
5 |
+
},
|
6 |
+
"temp-output-dir":"/tmp",
|
7 |
+
"models-dir":"/tmp/models",
|
8 |
+
"device-mode":"cpu"
|
9 |
+
}
|
pdf_converter_mineru.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
PDF to Markdown Converter using MinerU (vendor/mineru)
|
4 |
+
This is the main conversion script that uses the local MinerU installation
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import logging
|
10 |
+
import argparse
|
11 |
+
from pathlib import Path
|
12 |
+
import subprocess
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logging.basicConfig(
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
18 |
+
handlers=[
|
19 |
+
logging.StreamHandler(),
|
20 |
+
logging.FileHandler('pdf_converter.log')
|
21 |
+
]
|
22 |
+
)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class PdfConverterResult:
|
27 |
+
"""Class representing the result of a PDF conversion"""
|
28 |
+
|
29 |
+
def __init__(self, pdf_path: str, success: bool, md_path: str = None,
|
30 |
+
time_taken: float = 0, error: str = None):
|
31 |
+
self.pdf_path = pdf_path
|
32 |
+
self.success = success
|
33 |
+
self.md_path = md_path
|
34 |
+
self.time_taken = time_taken
|
35 |
+
self.error = error
|
36 |
+
|
37 |
+
def __str__(self):
|
38 |
+
if self.success:
|
39 |
+
return f"✅ Successfully converted {self.pdf_path} in {self.time_taken:.2f}s"
|
40 |
+
else:
|
41 |
+
return f"❌ Failed to convert {self.pdf_path}: {self.error}"
|
42 |
+
|
43 |
+
|
44 |
+
class MineruPdfConverter:
|
45 |
+
"""
|
46 |
+
PDF to Markdown converter using MinerU
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, output_dir: str = "output"):
|
50 |
+
self.output_dir = output_dir
|
51 |
+
os.makedirs(output_dir, exist_ok=True)
|
52 |
+
|
53 |
+
def convert_file(self, pdf_path: str, delete_after: bool = False) -> PdfConverterResult:
|
54 |
+
"""Convert a single PDF file to Markdown using MinerU"""
|
55 |
+
import time
|
56 |
+
start_time = time.time()
|
57 |
+
|
58 |
+
try:
|
59 |
+
pdf_path = Path(pdf_path)
|
60 |
+
if not pdf_path.exists():
|
61 |
+
return PdfConverterResult(
|
62 |
+
str(pdf_path), False, error=f"File not found: {pdf_path}"
|
63 |
+
)
|
64 |
+
|
65 |
+
logger.info(f"Processing: {pdf_path}")
|
66 |
+
|
67 |
+
# Prepare output directory
|
68 |
+
pdf_output_dir = os.path.join(self.output_dir, pdf_path.stem)
|
69 |
+
|
70 |
+
# Run MinerU command
|
71 |
+
cmd = [
|
72 |
+
"mineru",
|
73 |
+
"-p", str(pdf_path),
|
74 |
+
"-o", pdf_output_dir,
|
75 |
+
"-m", "txt", # Use text mode
|
76 |
+
"-f", "false", # Disable formula parsing for speed
|
77 |
+
"-t", "false", # Disable table parsing for speed
|
78 |
+
]
|
79 |
+
|
80 |
+
logger.info(f"Running command: {' '.join(cmd)}")
|
81 |
+
|
82 |
+
# Execute MinerU
|
83 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
84 |
+
|
85 |
+
if result.returncode != 0:
|
86 |
+
error_msg = result.stderr if result.stderr else "Unknown error"
|
87 |
+
return PdfConverterResult(
|
88 |
+
str(pdf_path), False, error=error_msg
|
89 |
+
)
|
90 |
+
|
91 |
+
# Find the generated markdown file
|
92 |
+
md_path = None
|
93 |
+
expected_md = Path(pdf_output_dir) / pdf_path.stem / "txt" / f"{pdf_path.stem}.md"
|
94 |
+
|
95 |
+
if expected_md.exists():
|
96 |
+
md_path = str(expected_md)
|
97 |
+
logger.info(f"✅ Markdown file created: {md_path}")
|
98 |
+
else:
|
99 |
+
# Search for any .md file in the output directory
|
100 |
+
for md_file in Path(pdf_output_dir).rglob("*.md"):
|
101 |
+
md_path = str(md_file)
|
102 |
+
logger.info(f"✅ Found markdown file: {md_path}")
|
103 |
+
break
|
104 |
+
|
105 |
+
if not md_path:
|
106 |
+
return PdfConverterResult(
|
107 |
+
str(pdf_path), False, error="No markdown file generated"
|
108 |
+
)
|
109 |
+
|
110 |
+
# Delete original PDF if requested
|
111 |
+
if delete_after and pdf_path.exists():
|
112 |
+
pdf_path.unlink()
|
113 |
+
logger.info(f"🗑️ Deleted original PDF: {pdf_path}")
|
114 |
+
|
115 |
+
elapsed_time = time.time() - start_time
|
116 |
+
|
117 |
+
return PdfConverterResult(
|
118 |
+
str(pdf_path), True, md_path=md_path, time_taken=elapsed_time
|
119 |
+
)
|
120 |
+
|
121 |
+
except Exception as e:
|
122 |
+
logger.error(f"Error processing {pdf_path}: {e}")
|
123 |
+
import traceback
|
124 |
+
traceback.print_exc()
|
125 |
+
|
126 |
+
return PdfConverterResult(
|
127 |
+
str(pdf_path), False, error=str(e)
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class BatchProcessor:
|
132 |
+
"""Process multiple PDF files in batch"""
|
133 |
+
|
134 |
+
def __init__(self, batch_dir: str = "batch-files", output_dir: str = "output",
|
135 |
+
workers: int = 1, delete_after: bool = False):
|
136 |
+
self.batch_dir = batch_dir
|
137 |
+
self.output_dir = output_dir
|
138 |
+
self.workers = workers
|
139 |
+
self.delete_after = delete_after
|
140 |
+
self.converter = MineruPdfConverter(output_dir)
|
141 |
+
|
142 |
+
def find_pdf_files(self) -> list[Path]:
|
143 |
+
"""Find all PDF files in the batch directory"""
|
144 |
+
pdf_files = []
|
145 |
+
batch_path = Path(self.batch_dir)
|
146 |
+
|
147 |
+
if not batch_path.exists():
|
148 |
+
logger.warning(f"Batch directory not found: {self.batch_dir}")
|
149 |
+
return pdf_files
|
150 |
+
|
151 |
+
# Find all PDFs recursively
|
152 |
+
pdf_files = list(batch_path.rglob("*.pdf"))
|
153 |
+
logger.info(f"Found {len(pdf_files)} PDF files in {self.batch_dir}")
|
154 |
+
|
155 |
+
return pdf_files
|
156 |
+
|
157 |
+
def process_batch(self) -> tuple[int, int]:
|
158 |
+
"""Process all PDFs in the batch directory"""
|
159 |
+
pdf_files = self.find_pdf_files()
|
160 |
+
|
161 |
+
if not pdf_files:
|
162 |
+
logger.info("No PDF files found to process")
|
163 |
+
return 0, 0
|
164 |
+
|
165 |
+
successful = 0
|
166 |
+
failed = 0
|
167 |
+
|
168 |
+
logger.info(f"Starting batch processing of {len(pdf_files)} files...")
|
169 |
+
|
170 |
+
# Process files sequentially (MinerU already handles parallelism internally)
|
171 |
+
for pdf_file in pdf_files:
|
172 |
+
result = self.converter.convert_file(str(pdf_file), self.delete_after)
|
173 |
+
|
174 |
+
if result.success:
|
175 |
+
successful += 1
|
176 |
+
logger.info(f"✅ {result}")
|
177 |
+
else:
|
178 |
+
failed += 1
|
179 |
+
logger.error(f"❌ {result}")
|
180 |
+
|
181 |
+
return successful, failed
|
182 |
+
|
183 |
+
|
184 |
+
def main():
|
185 |
+
"""Main entry point"""
|
186 |
+
parser = argparse.ArgumentParser(
|
187 |
+
description="Convert PDF files to Markdown using MinerU",
|
188 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
189 |
+
epilog="""
|
190 |
+
Examples:
|
191 |
+
# Convert a single PDF
|
192 |
+
%(prog)s convert path/to/file.pdf
|
193 |
+
|
194 |
+
# Batch convert all PDFs in batch-files directory
|
195 |
+
%(prog)s batch
|
196 |
+
|
197 |
+
# Batch convert with custom settings
|
198 |
+
%(prog)s batch --batch-dir /path/to/pdfs --output-dir /path/to/output --workers 4
|
199 |
+
|
200 |
+
# Delete PDFs after successful conversion
|
201 |
+
%(prog)s batch --delete-after
|
202 |
+
"""
|
203 |
+
)
|
204 |
+
|
205 |
+
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
206 |
+
|
207 |
+
# Convert command
|
208 |
+
convert_parser = subparsers.add_parser('convert', help='Convert a single PDF file')
|
209 |
+
convert_parser.add_argument('pdf_file', help='Path to PDF file')
|
210 |
+
convert_parser.add_argument('--output-dir', default='output', help='Output directory')
|
211 |
+
convert_parser.add_argument('--delete-after', action='store_true',
|
212 |
+
help='Delete PDF after successful conversion')
|
213 |
+
|
214 |
+
# Batch command
|
215 |
+
batch_parser = subparsers.add_parser('batch', help='Batch convert PDF files')
|
216 |
+
batch_parser.add_argument('--batch-dir', default='batch-files',
|
217 |
+
help='Directory containing PDF files')
|
218 |
+
batch_parser.add_argument('--output-dir', default='output',
|
219 |
+
help='Output directory')
|
220 |
+
batch_parser.add_argument('--workers', type=int, default=1,
|
221 |
+
help='Number of parallel workers')
|
222 |
+
batch_parser.add_argument('--delete-after', action='store_true',
|
223 |
+
help='Delete PDFs after successful conversion')
|
224 |
+
|
225 |
+
args = parser.parse_args()
|
226 |
+
|
227 |
+
# Auto-detect command if none specified
|
228 |
+
if not args.command:
|
229 |
+
# If first argument looks like a file, assume convert command
|
230 |
+
if len(sys.argv) > 1 and (sys.argv[1].endswith('.pdf') or Path(sys.argv[1]).exists()):
|
231 |
+
args.command = 'convert'
|
232 |
+
args.pdf_file = sys.argv[1]
|
233 |
+
args.output_dir = 'output'
|
234 |
+
args.delete_after = False
|
235 |
+
else:
|
236 |
+
# Default to batch mode
|
237 |
+
args.command = 'batch'
|
238 |
+
args.batch_dir = 'batch-files'
|
239 |
+
args.output_dir = 'output'
|
240 |
+
args.workers = 1
|
241 |
+
args.delete_after = False
|
242 |
+
|
243 |
+
# Execute command
|
244 |
+
if args.command == 'convert':
|
245 |
+
converter = MineruPdfConverter(args.output_dir)
|
246 |
+
result = converter.convert_file(args.pdf_file, args.delete_after)
|
247 |
+
print(result)
|
248 |
+
sys.exit(0 if result.success else 1)
|
249 |
+
|
250 |
+
elif args.command == 'batch':
|
251 |
+
processor = BatchProcessor(
|
252 |
+
args.batch_dir,
|
253 |
+
args.output_dir,
|
254 |
+
args.workers,
|
255 |
+
args.delete_after
|
256 |
+
)
|
257 |
+
successful, failed = processor.process_batch()
|
258 |
+
|
259 |
+
print(f"\n📊 Batch processing complete:")
|
260 |
+
print(f" ✅ Successful: {successful}")
|
261 |
+
print(f" ❌ Failed: {failed}")
|
262 |
+
print(f" 📁 Output directory: {args.output_dir}")
|
263 |
+
|
264 |
+
sys.exit(0 if failed == 0 else 1)
|
265 |
+
|
266 |
+
else:
|
267 |
+
parser.print_help()
|
268 |
+
sys.exit(1)
|
269 |
+
|
270 |
+
|
271 |
+
if __name__ == "__main__":
|
272 |
+
main()
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
fastapi==0.104.1
|
2 |
uvicorn==0.24.0
|
3 |
python-multipart==0.0.6
|
4 |
-
aiofiles==23.2.1
|
|
|
|
|
|
|
|
1 |
fastapi==0.104.1
|
2 |
uvicorn==0.24.0
|
3 |
python-multipart==0.0.6
|
4 |
+
aiofiles==23.2.1
|
5 |
+
|
6 |
+
# Basic PDF processing (will add MinerU later)
|
7 |
+
PyMuPDF>=1.18.16
|
vendor/mineru/mineru/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/backend/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/backend/pipeline/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/backend/pipeline/batch_analyze.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from loguru import logger
|
3 |
+
from tqdm import tqdm
|
4 |
+
from collections import defaultdict
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .model_init import AtomModelSingleton
|
8 |
+
from ...utils.config_reader import get_formula_enable, get_table_enable
|
9 |
+
from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
10 |
+
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
|
11 |
+
|
12 |
+
YOLO_LAYOUT_BASE_BATCH_SIZE = 8
|
13 |
+
MFD_BASE_BATCH_SIZE = 1
|
14 |
+
MFR_BASE_BATCH_SIZE = 16
|
15 |
+
|
16 |
+
|
17 |
+
class BatchAnalyze:
|
18 |
+
def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
|
19 |
+
self.batch_ratio = batch_ratio
|
20 |
+
self.formula_enable = get_formula_enable(formula_enable)
|
21 |
+
self.table_enable = get_table_enable(table_enable)
|
22 |
+
self.model_manager = model_manager
|
23 |
+
self.enable_ocr_det_batch = enable_ocr_det_batch
|
24 |
+
|
25 |
+
def __call__(self, images_with_extra_info: list) -> list:
|
26 |
+
if len(images_with_extra_info) == 0:
|
27 |
+
return []
|
28 |
+
|
29 |
+
images_layout_res = []
|
30 |
+
|
31 |
+
self.model = self.model_manager.get_model(
|
32 |
+
lang=None,
|
33 |
+
formula_enable=self.formula_enable,
|
34 |
+
table_enable=self.table_enable,
|
35 |
+
)
|
36 |
+
atom_model_manager = AtomModelSingleton()
|
37 |
+
|
38 |
+
images = [image for image, _, _ in images_with_extra_info]
|
39 |
+
|
40 |
+
# doclayout_yolo
|
41 |
+
layout_images = []
|
42 |
+
for image_index, image in enumerate(images):
|
43 |
+
layout_images.append(image)
|
44 |
+
|
45 |
+
|
46 |
+
images_layout_res += self.model.layout_model.batch_predict(
|
47 |
+
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
|
48 |
+
)
|
49 |
+
|
50 |
+
if self.formula_enable:
|
51 |
+
# 公式检测
|
52 |
+
images_mfd_res = self.model.mfd_model.batch_predict(
|
53 |
+
images, MFD_BASE_BATCH_SIZE
|
54 |
+
)
|
55 |
+
|
56 |
+
# 公式识别
|
57 |
+
images_formula_list = self.model.mfr_model.batch_predict(
|
58 |
+
images_mfd_res,
|
59 |
+
images,
|
60 |
+
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
|
61 |
+
)
|
62 |
+
mfr_count = 0
|
63 |
+
for image_index in range(len(images)):
|
64 |
+
images_layout_res[image_index] += images_formula_list[image_index]
|
65 |
+
mfr_count += len(images_formula_list[image_index])
|
66 |
+
|
67 |
+
# 清理显存
|
68 |
+
# clean_vram(self.model.device, vram_threshold=8)
|
69 |
+
|
70 |
+
ocr_res_list_all_page = []
|
71 |
+
table_res_list_all_page = []
|
72 |
+
for index in range(len(images)):
|
73 |
+
_, ocr_enable, _lang = images_with_extra_info[index]
|
74 |
+
layout_res = images_layout_res[index]
|
75 |
+
pil_img = images[index]
|
76 |
+
|
77 |
+
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
78 |
+
get_res_list_from_layout_res(layout_res)
|
79 |
+
)
|
80 |
+
|
81 |
+
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
|
82 |
+
'lang':_lang,
|
83 |
+
'ocr_enable':ocr_enable,
|
84 |
+
'pil_img':pil_img,
|
85 |
+
'single_page_mfdetrec_res':single_page_mfdetrec_res,
|
86 |
+
'layout_res':layout_res,
|
87 |
+
})
|
88 |
+
|
89 |
+
for table_res in table_res_list:
|
90 |
+
table_img, _ = crop_img(table_res, pil_img)
|
91 |
+
table_res_list_all_page.append({'table_res':table_res,
|
92 |
+
'lang':_lang,
|
93 |
+
'table_img':table_img,
|
94 |
+
})
|
95 |
+
|
96 |
+
# OCR检测处理
|
97 |
+
if self.enable_ocr_det_batch:
|
98 |
+
# 批处理模式 - 按语言和分辨率分组
|
99 |
+
# 收集所有需要OCR检测的裁剪图像
|
100 |
+
all_cropped_images_info = []
|
101 |
+
|
102 |
+
for ocr_res_list_dict in ocr_res_list_all_page:
|
103 |
+
_lang = ocr_res_list_dict['lang']
|
104 |
+
|
105 |
+
for res in ocr_res_list_dict['ocr_res_list']:
|
106 |
+
new_image, useful_list = crop_img(
|
107 |
+
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
108 |
+
)
|
109 |
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
110 |
+
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
111 |
+
)
|
112 |
+
|
113 |
+
# BGR转换
|
114 |
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
115 |
+
|
116 |
+
all_cropped_images_info.append((
|
117 |
+
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
|
118 |
+
))
|
119 |
+
|
120 |
+
# 按语言分组
|
121 |
+
lang_groups = defaultdict(list)
|
122 |
+
for crop_info in all_cropped_images_info:
|
123 |
+
lang = crop_info[5]
|
124 |
+
lang_groups[lang].append(crop_info)
|
125 |
+
|
126 |
+
# 对每种语言按分辨率分组并批处理
|
127 |
+
for lang, lang_crop_list in lang_groups.items():
|
128 |
+
if not lang_crop_list:
|
129 |
+
continue
|
130 |
+
|
131 |
+
# logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
|
132 |
+
|
133 |
+
# 获取OCR模型
|
134 |
+
ocr_model = atom_model_manager.get_atom_model(
|
135 |
+
atom_model_name='ocr',
|
136 |
+
det_db_box_thresh=0.3,
|
137 |
+
lang=lang
|
138 |
+
)
|
139 |
+
|
140 |
+
# 按分辨率分组并同时完成padding
|
141 |
+
resolution_groups = defaultdict(list)
|
142 |
+
for crop_info in lang_crop_list:
|
143 |
+
cropped_img = crop_info[0]
|
144 |
+
h, w = cropped_img.shape[:2]
|
145 |
+
# 使用更大的分组容差,减少分组数量
|
146 |
+
# 将尺寸标准化到32的倍数
|
147 |
+
normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
|
148 |
+
normalized_w = ((w + 32) // 32) * 32
|
149 |
+
group_key = (normalized_h, normalized_w)
|
150 |
+
resolution_groups[group_key].append(crop_info)
|
151 |
+
|
152 |
+
# 对每个分辨率组进行批处理
|
153 |
+
for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
|
154 |
+
|
155 |
+
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
|
156 |
+
max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
|
157 |
+
max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
|
158 |
+
target_h = ((max_h + 32 - 1) // 32) * 32
|
159 |
+
target_w = ((max_w + 32 - 1) // 32) * 32
|
160 |
+
|
161 |
+
# 对所有图像进行padding到统一尺寸
|
162 |
+
batch_images = []
|
163 |
+
for crop_info in group_crops:
|
164 |
+
img = crop_info[0]
|
165 |
+
h, w = img.shape[:2]
|
166 |
+
# 创建目标尺寸的白色背景
|
167 |
+
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
168 |
+
# 将原图像粘贴到左上角
|
169 |
+
padded_img[:h, :w] = img
|
170 |
+
batch_images.append(padded_img)
|
171 |
+
|
172 |
+
# 批处理检测
|
173 |
+
batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
|
174 |
+
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
|
175 |
+
batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
|
176 |
+
|
177 |
+
# 处理批处理结果
|
178 |
+
for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
|
179 |
+
new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
|
180 |
+
|
181 |
+
if dt_boxes is not None and len(dt_boxes) > 0:
|
182 |
+
# 直接应用原始OCR流程中的关键处理步骤
|
183 |
+
from mineru.utils.ocr_utils import (
|
184 |
+
merge_det_boxes, update_det_boxes, sorted_boxes
|
185 |
+
)
|
186 |
+
|
187 |
+
# 1. 排序检测框
|
188 |
+
if len(dt_boxes) > 0:
|
189 |
+
dt_boxes_sorted = sorted_boxes(dt_boxes)
|
190 |
+
else:
|
191 |
+
dt_boxes_sorted = []
|
192 |
+
|
193 |
+
# 2. 合并相邻检测框
|
194 |
+
if dt_boxes_sorted:
|
195 |
+
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
|
196 |
+
else:
|
197 |
+
dt_boxes_merged = []
|
198 |
+
|
199 |
+
# 3. 根据公式位置更新检测框(关键步骤!)
|
200 |
+
if dt_boxes_merged and adjusted_mfdetrec_res:
|
201 |
+
dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
|
202 |
+
else:
|
203 |
+
dt_boxes_final = dt_boxes_merged
|
204 |
+
|
205 |
+
# 构造OCR结果格式
|
206 |
+
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
|
207 |
+
|
208 |
+
if ocr_res:
|
209 |
+
ocr_result_list = get_ocr_result_list(
|
210 |
+
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
|
211 |
+
)
|
212 |
+
|
213 |
+
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
214 |
+
else:
|
215 |
+
# 原始单张处理模式
|
216 |
+
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
|
217 |
+
# Process each area that requires OCR processing
|
218 |
+
_lang = ocr_res_list_dict['lang']
|
219 |
+
# Get OCR results for this language's images
|
220 |
+
ocr_model = atom_model_manager.get_atom_model(
|
221 |
+
atom_model_name='ocr',
|
222 |
+
ocr_show_log=False,
|
223 |
+
det_db_box_thresh=0.3,
|
224 |
+
lang=_lang
|
225 |
+
)
|
226 |
+
for res in ocr_res_list_dict['ocr_res_list']:
|
227 |
+
new_image, useful_list = crop_img(
|
228 |
+
res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
|
229 |
+
)
|
230 |
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
231 |
+
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
|
232 |
+
)
|
233 |
+
# OCR-det
|
234 |
+
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
|
235 |
+
ocr_res = ocr_model.ocr(
|
236 |
+
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
237 |
+
)[0]
|
238 |
+
|
239 |
+
# Integration results
|
240 |
+
if ocr_res:
|
241 |
+
ocr_result_list = get_ocr_result_list(
|
242 |
+
ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
|
243 |
+
)
|
244 |
+
|
245 |
+
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
|
246 |
+
|
247 |
+
# 表格识别 table recognition
|
248 |
+
if self.table_enable:
|
249 |
+
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
|
250 |
+
_lang = table_res_dict['lang']
|
251 |
+
table_model = atom_model_manager.get_atom_model(
|
252 |
+
atom_model_name='table',
|
253 |
+
lang=_lang,
|
254 |
+
)
|
255 |
+
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
|
256 |
+
# 判断是否返回正常
|
257 |
+
if html_code:
|
258 |
+
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
|
259 |
+
if expected_ending:
|
260 |
+
table_res_dict['table_res']['html'] = html_code
|
261 |
+
else:
|
262 |
+
logger.warning(
|
263 |
+
'table recognition processing fails, not found expected HTML table end'
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
logger.warning(
|
267 |
+
'table recognition processing fails, not get html return'
|
268 |
+
)
|
269 |
+
|
270 |
+
# Create dictionaries to store items by language
|
271 |
+
need_ocr_lists_by_lang = {} # Dict of lists for each language
|
272 |
+
img_crop_lists_by_lang = {} # Dict of lists for each language
|
273 |
+
|
274 |
+
for layout_res in images_layout_res:
|
275 |
+
for layout_res_item in layout_res:
|
276 |
+
if layout_res_item['category_id'] in [15]:
|
277 |
+
if 'np_img' in layout_res_item and 'lang' in layout_res_item:
|
278 |
+
lang = layout_res_item['lang']
|
279 |
+
|
280 |
+
# Initialize lists for this language if not exist
|
281 |
+
if lang not in need_ocr_lists_by_lang:
|
282 |
+
need_ocr_lists_by_lang[lang] = []
|
283 |
+
img_crop_lists_by_lang[lang] = []
|
284 |
+
|
285 |
+
# Add to the appropriate language-specific lists
|
286 |
+
need_ocr_lists_by_lang[lang].append(layout_res_item)
|
287 |
+
img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
|
288 |
+
|
289 |
+
# Remove the fields after adding to lists
|
290 |
+
layout_res_item.pop('np_img')
|
291 |
+
layout_res_item.pop('lang')
|
292 |
+
|
293 |
+
if len(img_crop_lists_by_lang) > 0:
|
294 |
+
|
295 |
+
# Process OCR by language
|
296 |
+
total_processed = 0
|
297 |
+
|
298 |
+
# Process each language separately
|
299 |
+
for lang, img_crop_list in img_crop_lists_by_lang.items():
|
300 |
+
if len(img_crop_list) > 0:
|
301 |
+
# Get OCR results for this language's images
|
302 |
+
|
303 |
+
ocr_model = atom_model_manager.get_atom_model(
|
304 |
+
atom_model_name='ocr',
|
305 |
+
det_db_box_thresh=0.3,
|
306 |
+
lang=lang
|
307 |
+
)
|
308 |
+
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
|
309 |
+
|
310 |
+
# Verify we have matching counts
|
311 |
+
assert len(ocr_res_list) == len(
|
312 |
+
need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
|
313 |
+
|
314 |
+
# Process OCR results for this language
|
315 |
+
for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
|
316 |
+
ocr_text, ocr_score = ocr_res_list[index]
|
317 |
+
layout_res_item['text'] = ocr_text
|
318 |
+
layout_res_item['score'] = float(f"{ocr_score:.3f}")
|
319 |
+
if ocr_score < OcrConfidence.min_confidence:
|
320 |
+
layout_res_item['category_id'] = 16
|
321 |
+
else:
|
322 |
+
layout_res_bbox = [layout_res_item['poly'][0], layout_res_item['poly'][1],
|
323 |
+
layout_res_item['poly'][4], layout_res_item['poly'][5]]
|
324 |
+
layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
|
325 |
+
layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
|
326 |
+
if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
|
327 |
+
layout_res_item['category_id'] = 16
|
328 |
+
|
329 |
+
total_processed += len(img_crop_list)
|
330 |
+
|
331 |
+
return images_layout_res
|
vendor/mineru/mineru/backend/pipeline/model_init.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from .model_list import AtomicModel
|
7 |
+
from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
|
8 |
+
from ...model.mfd.yolo_v8 import YOLOv8MFDModel
|
9 |
+
from ...model.mfr.unimernet.Unimernet import UnimernetModel
|
10 |
+
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
|
11 |
+
from ...model.table.rapid_table import RapidTableModel
|
12 |
+
from ...utils.enum_class import ModelPath
|
13 |
+
from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
14 |
+
|
15 |
+
|
16 |
+
def table_model_init(lang=None):
|
17 |
+
atom_model_manager = AtomModelSingleton()
|
18 |
+
ocr_engine = atom_model_manager.get_atom_model(
|
19 |
+
atom_model_name='ocr',
|
20 |
+
det_db_box_thresh=0.5,
|
21 |
+
det_db_unclip_ratio=1.6,
|
22 |
+
lang=lang
|
23 |
+
)
|
24 |
+
table_model = RapidTableModel(ocr_engine)
|
25 |
+
return table_model
|
26 |
+
|
27 |
+
|
28 |
+
def mfd_model_init(weight, device='cpu'):
|
29 |
+
if str(device).startswith('npu'):
|
30 |
+
device = torch.device(device)
|
31 |
+
mfd_model = YOLOv8MFDModel(weight, device)
|
32 |
+
return mfd_model
|
33 |
+
|
34 |
+
|
35 |
+
def mfr_model_init(weight_dir, device='cpu'):
|
36 |
+
mfr_model = UnimernetModel(weight_dir, device)
|
37 |
+
return mfr_model
|
38 |
+
|
39 |
+
|
40 |
+
def doclayout_yolo_model_init(weight, device='cpu'):
|
41 |
+
if str(device).startswith('npu'):
|
42 |
+
device = torch.device(device)
|
43 |
+
model = DocLayoutYOLOModel(weight, device)
|
44 |
+
return model
|
45 |
+
|
46 |
+
def ocr_model_init(det_db_box_thresh=0.3,
|
47 |
+
lang=None,
|
48 |
+
use_dilation=True,
|
49 |
+
det_db_unclip_ratio=1.8,
|
50 |
+
):
|
51 |
+
if lang is not None and lang != '':
|
52 |
+
model = PytorchPaddleOCR(
|
53 |
+
det_db_box_thresh=det_db_box_thresh,
|
54 |
+
lang=lang,
|
55 |
+
use_dilation=use_dilation,
|
56 |
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
model = PytorchPaddleOCR(
|
60 |
+
det_db_box_thresh=det_db_box_thresh,
|
61 |
+
use_dilation=use_dilation,
|
62 |
+
det_db_unclip_ratio=det_db_unclip_ratio,
|
63 |
+
)
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
class AtomModelSingleton:
|
68 |
+
_instance = None
|
69 |
+
_models = {}
|
70 |
+
|
71 |
+
def __new__(cls, *args, **kwargs):
|
72 |
+
if cls._instance is None:
|
73 |
+
cls._instance = super().__new__(cls)
|
74 |
+
return cls._instance
|
75 |
+
|
76 |
+
def get_atom_model(self, atom_model_name: str, **kwargs):
|
77 |
+
|
78 |
+
lang = kwargs.get('lang', None)
|
79 |
+
table_model_name = kwargs.get('table_model_name', None)
|
80 |
+
|
81 |
+
if atom_model_name in [AtomicModel.OCR]:
|
82 |
+
key = (atom_model_name, lang)
|
83 |
+
elif atom_model_name in [AtomicModel.Table]:
|
84 |
+
key = (atom_model_name, table_model_name, lang)
|
85 |
+
else:
|
86 |
+
key = atom_model_name
|
87 |
+
|
88 |
+
if key not in self._models:
|
89 |
+
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
|
90 |
+
return self._models[key]
|
91 |
+
|
92 |
+
def atom_model_init(model_name: str, **kwargs):
|
93 |
+
atom_model = None
|
94 |
+
if model_name == AtomicModel.Layout:
|
95 |
+
atom_model = doclayout_yolo_model_init(
|
96 |
+
kwargs.get('doclayout_yolo_weights'),
|
97 |
+
kwargs.get('device')
|
98 |
+
)
|
99 |
+
elif model_name == AtomicModel.MFD:
|
100 |
+
atom_model = mfd_model_init(
|
101 |
+
kwargs.get('mfd_weights'),
|
102 |
+
kwargs.get('device')
|
103 |
+
)
|
104 |
+
elif model_name == AtomicModel.MFR:
|
105 |
+
atom_model = mfr_model_init(
|
106 |
+
kwargs.get('mfr_weight_dir'),
|
107 |
+
kwargs.get('device')
|
108 |
+
)
|
109 |
+
elif model_name == AtomicModel.OCR:
|
110 |
+
atom_model = ocr_model_init(
|
111 |
+
kwargs.get('det_db_box_thresh'),
|
112 |
+
kwargs.get('lang'),
|
113 |
+
)
|
114 |
+
elif model_name == AtomicModel.Table:
|
115 |
+
atom_model = table_model_init(
|
116 |
+
kwargs.get('lang'),
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
logger.error('model name not allow')
|
120 |
+
exit(1)
|
121 |
+
|
122 |
+
if atom_model is None:
|
123 |
+
logger.error('model init failed')
|
124 |
+
exit(1)
|
125 |
+
else:
|
126 |
+
return atom_model
|
127 |
+
|
128 |
+
|
129 |
+
class MineruPipelineModel:
|
130 |
+
def __init__(self, **kwargs):
|
131 |
+
self.formula_config = kwargs.get('formula_config')
|
132 |
+
self.apply_formula = self.formula_config.get('enable', True)
|
133 |
+
self.table_config = kwargs.get('table_config')
|
134 |
+
self.apply_table = self.table_config.get('enable', True)
|
135 |
+
self.lang = kwargs.get('lang', None)
|
136 |
+
self.device = kwargs.get('device', 'cpu')
|
137 |
+
logger.info(
|
138 |
+
'DocAnalysis init, this may take some times......'
|
139 |
+
)
|
140 |
+
atom_model_manager = AtomModelSingleton()
|
141 |
+
|
142 |
+
if self.apply_formula:
|
143 |
+
# 初始化公式检测模型
|
144 |
+
self.mfd_model = atom_model_manager.get_atom_model(
|
145 |
+
atom_model_name=AtomicModel.MFD,
|
146 |
+
mfd_weights=str(
|
147 |
+
os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
|
148 |
+
),
|
149 |
+
device=self.device,
|
150 |
+
)
|
151 |
+
|
152 |
+
# 初始化公式解析模型
|
153 |
+
mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
|
154 |
+
|
155 |
+
self.mfr_model = atom_model_manager.get_atom_model(
|
156 |
+
atom_model_name=AtomicModel.MFR,
|
157 |
+
mfr_weight_dir=mfr_weight_dir,
|
158 |
+
device=self.device,
|
159 |
+
)
|
160 |
+
|
161 |
+
# 初始化layout模型
|
162 |
+
self.layout_model = atom_model_manager.get_atom_model(
|
163 |
+
atom_model_name=AtomicModel.Layout,
|
164 |
+
doclayout_yolo_weights=str(
|
165 |
+
os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
|
166 |
+
),
|
167 |
+
device=self.device,
|
168 |
+
)
|
169 |
+
# 初始化ocr
|
170 |
+
self.ocr_model = atom_model_manager.get_atom_model(
|
171 |
+
atom_model_name=AtomicModel.OCR,
|
172 |
+
det_db_box_thresh=0.3,
|
173 |
+
lang=self.lang
|
174 |
+
)
|
175 |
+
# init table model
|
176 |
+
if self.apply_table:
|
177 |
+
self.table_model = atom_model_manager.get_atom_model(
|
178 |
+
atom_model_name=AtomicModel.Table,
|
179 |
+
lang=self.lang,
|
180 |
+
)
|
181 |
+
|
182 |
+
logger.info('DocAnalysis init done!')
|
vendor/mineru/mineru/backend/pipeline/model_json_to_middle_json.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
|
9 |
+
from mineru.backend.pipeline.model_init import AtomModelSingleton
|
10 |
+
from mineru.backend.pipeline.para_split import para_split
|
11 |
+
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
|
12 |
+
from mineru.utils.block_sort import sort_blocks_by_bbox
|
13 |
+
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
|
14 |
+
from mineru.utils.cut_image import cut_image_and_table
|
15 |
+
from mineru.utils.enum_class import ContentType
|
16 |
+
from mineru.utils.llm_aided import llm_aided_title
|
17 |
+
from mineru.utils.model_utils import clean_memory
|
18 |
+
from mineru.backend.pipeline.pipeline_magic_model import MagicModel
|
19 |
+
from mineru.utils.ocr_utils import OcrConfidence
|
20 |
+
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
|
21 |
+
from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
|
22 |
+
remove_overlaps_min_spans, txt_spans_extract
|
23 |
+
from mineru.version import __version__
|
24 |
+
from mineru.utils.hash_utils import str_md5
|
25 |
+
|
26 |
+
|
27 |
+
def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
|
28 |
+
scale = image_dict["scale"]
|
29 |
+
page_pil_img = image_dict["img_pil"]
|
30 |
+
page_img_md5 = str_md5(image_dict["img_base64"])
|
31 |
+
page_w, page_h = map(int, page.get_size())
|
32 |
+
magic_model = MagicModel(page_model_info, scale)
|
33 |
+
|
34 |
+
"""从magic_model对象中获取后面会用到的区块信息"""
|
35 |
+
discarded_blocks = magic_model.get_discarded()
|
36 |
+
text_blocks = magic_model.get_text_blocks()
|
37 |
+
title_blocks = magic_model.get_title_blocks()
|
38 |
+
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
|
39 |
+
|
40 |
+
img_groups = magic_model.get_imgs()
|
41 |
+
table_groups = magic_model.get_tables()
|
42 |
+
|
43 |
+
"""对image和table的区块分组"""
|
44 |
+
img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
|
45 |
+
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
|
46 |
+
)
|
47 |
+
|
48 |
+
table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
|
49 |
+
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
|
50 |
+
)
|
51 |
+
|
52 |
+
"""获取所有的spans信息"""
|
53 |
+
spans = magic_model.get_all_spans()
|
54 |
+
|
55 |
+
"""某些图可能是文本块,通过简单的规则判断一下"""
|
56 |
+
if len(maybe_text_image_blocks) > 0:
|
57 |
+
for block in maybe_text_image_blocks:
|
58 |
+
span_in_block_list = []
|
59 |
+
for span in spans:
|
60 |
+
if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
|
61 |
+
span_in_block_list.append(span)
|
62 |
+
if len(span_in_block_list) > 0:
|
63 |
+
# span_in_block_list中所有bbox的面积之和
|
64 |
+
spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
|
65 |
+
# 求ocr_res_area和res的面积的比值
|
66 |
+
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
|
67 |
+
if block_area > 0:
|
68 |
+
ratio = spans_area / block_area
|
69 |
+
if ratio > 0.25 and ocr_enable:
|
70 |
+
# 移除block的group_id
|
71 |
+
block.pop('group_id', None)
|
72 |
+
# 符合文本图的条件就把块加入到文本块列表中
|
73 |
+
text_blocks.append(block)
|
74 |
+
else:
|
75 |
+
# 如果不符合文本图的条件,就把块加回到图片块列表中
|
76 |
+
img_body_blocks.append(block)
|
77 |
+
else:
|
78 |
+
img_body_blocks.append(block)
|
79 |
+
|
80 |
+
|
81 |
+
"""将所有区块的bbox整理到一起"""
|
82 |
+
if formula_enabled:
|
83 |
+
interline_equation_blocks = []
|
84 |
+
|
85 |
+
if len(interline_equation_blocks) > 0:
|
86 |
+
|
87 |
+
for block in interline_equation_blocks:
|
88 |
+
spans.append({
|
89 |
+
"type": ContentType.INTERLINE_EQUATION,
|
90 |
+
'score': block['score'],
|
91 |
+
"bbox": block['bbox'],
|
92 |
+
})
|
93 |
+
|
94 |
+
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
|
95 |
+
img_body_blocks, img_caption_blocks, img_footnote_blocks,
|
96 |
+
table_body_blocks, table_caption_blocks, table_footnote_blocks,
|
97 |
+
discarded_blocks,
|
98 |
+
text_blocks,
|
99 |
+
title_blocks,
|
100 |
+
interline_equation_blocks,
|
101 |
+
page_w,
|
102 |
+
page_h,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
|
106 |
+
img_body_blocks, img_caption_blocks, img_footnote_blocks,
|
107 |
+
table_body_blocks, table_caption_blocks, table_footnote_blocks,
|
108 |
+
discarded_blocks,
|
109 |
+
text_blocks,
|
110 |
+
title_blocks,
|
111 |
+
interline_equations,
|
112 |
+
page_w,
|
113 |
+
page_h,
|
114 |
+
)
|
115 |
+
|
116 |
+
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
|
117 |
+
"""顺便删除大水印并保留abandon的span"""
|
118 |
+
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
|
119 |
+
|
120 |
+
"""删除重叠spans中置信度较低的那些"""
|
121 |
+
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
|
122 |
+
"""删除重叠spans中较小的那些"""
|
123 |
+
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
|
124 |
+
|
125 |
+
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
|
126 |
+
if ocr_enable:
|
127 |
+
pass
|
128 |
+
else:
|
129 |
+
"""使用新版本的混合ocr方案."""
|
130 |
+
spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
|
131 |
+
|
132 |
+
"""先处理不需要排版的discarded_blocks"""
|
133 |
+
discarded_block_with_spans, spans = fill_spans_in_blocks(
|
134 |
+
all_discarded_blocks, spans, 0.4
|
135 |
+
)
|
136 |
+
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
|
137 |
+
|
138 |
+
"""如果当前页面没有有效的bbox则跳过"""
|
139 |
+
if len(all_bboxes) == 0:
|
140 |
+
return None
|
141 |
+
|
142 |
+
"""对image/table/interline_equation截图"""
|
143 |
+
for span in spans:
|
144 |
+
if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
|
145 |
+
span = cut_image_and_table(
|
146 |
+
span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
|
147 |
+
)
|
148 |
+
|
149 |
+
"""span填充进block"""
|
150 |
+
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
|
151 |
+
|
152 |
+
"""对block进行fix操作"""
|
153 |
+
fix_blocks = fix_block_spans(block_with_spans)
|
154 |
+
|
155 |
+
"""对block进行排序"""
|
156 |
+
sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
|
157 |
+
|
158 |
+
"""构造page_info"""
|
159 |
+
page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
|
160 |
+
|
161 |
+
return page_info
|
162 |
+
|
163 |
+
|
164 |
+
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
|
165 |
+
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
|
166 |
+
formula_enabled = get_formula_enable(formula_enabled)
|
167 |
+
for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
|
168 |
+
page = pdf_doc[page_index]
|
169 |
+
image_dict = images_list[page_index]
|
170 |
+
page_info = page_model_info_to_page_info(
|
171 |
+
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
|
172 |
+
)
|
173 |
+
if page_info is None:
|
174 |
+
page_w, page_h = map(int, page.get_size())
|
175 |
+
page_info = make_page_info_dict([], page_index, page_w, page_h, [])
|
176 |
+
middle_json["pdf_info"].append(page_info)
|
177 |
+
|
178 |
+
"""后置ocr处理"""
|
179 |
+
need_ocr_list = []
|
180 |
+
img_crop_list = []
|
181 |
+
text_block_list = []
|
182 |
+
for page_info in middle_json["pdf_info"]:
|
183 |
+
for block in page_info['preproc_blocks']:
|
184 |
+
if block['type'] in ['table', 'image']:
|
185 |
+
for sub_block in block['blocks']:
|
186 |
+
if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
|
187 |
+
text_block_list.append(sub_block)
|
188 |
+
elif block['type'] in ['text', 'title']:
|
189 |
+
text_block_list.append(block)
|
190 |
+
for block in page_info['discarded_blocks']:
|
191 |
+
text_block_list.append(block)
|
192 |
+
for block in text_block_list:
|
193 |
+
for line in block['lines']:
|
194 |
+
for span in line['spans']:
|
195 |
+
if 'np_img' in span:
|
196 |
+
need_ocr_list.append(span)
|
197 |
+
img_crop_list.append(span['np_img'])
|
198 |
+
span.pop('np_img')
|
199 |
+
if len(img_crop_list) > 0:
|
200 |
+
atom_model_manager = AtomModelSingleton()
|
201 |
+
ocr_model = atom_model_manager.get_atom_model(
|
202 |
+
atom_model_name='ocr',
|
203 |
+
ocr_show_log=False,
|
204 |
+
det_db_box_thresh=0.3,
|
205 |
+
lang=lang
|
206 |
+
)
|
207 |
+
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
|
208 |
+
assert len(ocr_res_list) == len(
|
209 |
+
need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
|
210 |
+
for index, span in enumerate(need_ocr_list):
|
211 |
+
ocr_text, ocr_score = ocr_res_list[index]
|
212 |
+
if ocr_score > OcrConfidence.min_confidence:
|
213 |
+
span['content'] = ocr_text
|
214 |
+
span['score'] = float(f"{ocr_score:.3f}")
|
215 |
+
else:
|
216 |
+
span['content'] = ''
|
217 |
+
span['score'] = 0.0
|
218 |
+
|
219 |
+
"""分段"""
|
220 |
+
para_split(middle_json["pdf_info"])
|
221 |
+
|
222 |
+
"""llm优化"""
|
223 |
+
llm_aided_config = get_llm_aided_config()
|
224 |
+
|
225 |
+
if llm_aided_config is not None:
|
226 |
+
"""标题优化"""
|
227 |
+
title_aided_config = llm_aided_config.get('title_aided', None)
|
228 |
+
if title_aided_config is not None:
|
229 |
+
if title_aided_config.get('enable', False):
|
230 |
+
llm_aided_title_start_time = time.time()
|
231 |
+
llm_aided_title(middle_json["pdf_info"], title_aided_config)
|
232 |
+
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
|
233 |
+
|
234 |
+
"""清理内存"""
|
235 |
+
pdf_doc.close()
|
236 |
+
if os.getenv('MINERU_DONOT_CLEAN_MEM') is None and len(model_list) >= 10:
|
237 |
+
clean_memory(get_device())
|
238 |
+
|
239 |
+
return middle_json
|
240 |
+
|
241 |
+
|
242 |
+
def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
|
243 |
+
return_dict = {
|
244 |
+
'preproc_blocks': blocks,
|
245 |
+
'page_idx': page_id,
|
246 |
+
'page_size': [page_w, page_h],
|
247 |
+
'discarded_blocks': discarded_blocks,
|
248 |
+
}
|
249 |
+
return return_dict
|
vendor/mineru/mineru/backend/pipeline/model_list.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AtomicModel:
|
2 |
+
Layout = "layout"
|
3 |
+
MFD = "mfd"
|
4 |
+
MFR = "mfr"
|
5 |
+
OCR = "ocr"
|
6 |
+
Table = "table"
|
vendor/mineru/mineru/backend/pipeline/para_split.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from loguru import logger
|
3 |
+
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
|
4 |
+
from mineru.utils.language import detect_lang
|
5 |
+
|
6 |
+
|
7 |
+
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
|
8 |
+
LIST_END_FLAG = ('.', '。', ';', ';')
|
9 |
+
|
10 |
+
|
11 |
+
class ListLineTag:
|
12 |
+
IS_LIST_START_LINE = 'is_list_start_line'
|
13 |
+
IS_LIST_END_LINE = 'is_list_end_line'
|
14 |
+
|
15 |
+
|
16 |
+
def __process_blocks(blocks):
|
17 |
+
# 对所有block预处理
|
18 |
+
# 1.通过title和interline_equation将block分组
|
19 |
+
# 2.bbox边界根据line信息重置
|
20 |
+
|
21 |
+
result = []
|
22 |
+
current_group = []
|
23 |
+
|
24 |
+
for i in range(len(blocks)):
|
25 |
+
current_block = blocks[i]
|
26 |
+
|
27 |
+
# 如果当前块是 text 类型
|
28 |
+
if current_block['type'] == 'text':
|
29 |
+
current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
|
30 |
+
if 'lines' in current_block and len(current_block['lines']) > 0:
|
31 |
+
current_block['bbox_fs'] = [
|
32 |
+
min([line['bbox'][0] for line in current_block['lines']]),
|
33 |
+
min([line['bbox'][1] for line in current_block['lines']]),
|
34 |
+
max([line['bbox'][2] for line in current_block['lines']]),
|
35 |
+
max([line['bbox'][3] for line in current_block['lines']]),
|
36 |
+
]
|
37 |
+
current_group.append(current_block)
|
38 |
+
|
39 |
+
# 检查下一个块是否存在
|
40 |
+
if i + 1 < len(blocks):
|
41 |
+
next_block = blocks[i + 1]
|
42 |
+
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
|
43 |
+
if next_block['type'] in ['title', 'interline_equation']:
|
44 |
+
result.append(current_group)
|
45 |
+
current_group = []
|
46 |
+
|
47 |
+
# 处理最后一个 group
|
48 |
+
if current_group:
|
49 |
+
result.append(current_group)
|
50 |
+
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
def __is_list_or_index_block(block):
|
55 |
+
# 一个block如果是list block 应该同时满足以下特征
|
56 |
+
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
|
57 |
+
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
|
58 |
+
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
|
59 |
+
|
60 |
+
# index block 是一种特殊的list block
|
61 |
+
# 一个block如果是index block 应该同时满足以下特征
|
62 |
+
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
|
63 |
+
if len(block['lines']) >= 2:
|
64 |
+
first_line = block['lines'][0]
|
65 |
+
line_height = first_line['bbox'][3] - first_line['bbox'][1]
|
66 |
+
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
|
67 |
+
block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
|
68 |
+
page_weight, page_height = block['page_size']
|
69 |
+
|
70 |
+
left_close_num = 0
|
71 |
+
left_not_close_num = 0
|
72 |
+
right_not_close_num = 0
|
73 |
+
right_close_num = 0
|
74 |
+
lines_text_list = []
|
75 |
+
center_close_num = 0
|
76 |
+
external_sides_not_close_num = 0
|
77 |
+
multiple_para_flag = False
|
78 |
+
last_line = block['lines'][-1]
|
79 |
+
|
80 |
+
if page_weight == 0:
|
81 |
+
block_weight_radio = 0
|
82 |
+
else:
|
83 |
+
block_weight_radio = block_weight / page_weight
|
84 |
+
# logger.info(f"block_weight_radio: {block_weight_radio}")
|
85 |
+
|
86 |
+
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
|
87 |
+
if (
|
88 |
+
first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
|
89 |
+
and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
|
90 |
+
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
|
91 |
+
):
|
92 |
+
multiple_para_flag = True
|
93 |
+
|
94 |
+
block_text = ''
|
95 |
+
|
96 |
+
for line in block['lines']:
|
97 |
+
line_text = ''
|
98 |
+
|
99 |
+
for span in line['spans']:
|
100 |
+
span_type = span['type']
|
101 |
+
if span_type == ContentType.TEXT:
|
102 |
+
line_text += span['content'].strip()
|
103 |
+
# 添加所有文本,包括空行,保持与block['lines']长度一致
|
104 |
+
lines_text_list.append(line_text)
|
105 |
+
block_text = ''.join(lines_text_list)
|
106 |
+
|
107 |
+
block_lang = detect_lang(block_text)
|
108 |
+
# logger.info(f"block_lang: {block_lang}")
|
109 |
+
|
110 |
+
for line in block['lines']:
|
111 |
+
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
|
112 |
+
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
|
113 |
+
if (
|
114 |
+
line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
|
115 |
+
and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
|
116 |
+
):
|
117 |
+
external_sides_not_close_num += 1
|
118 |
+
if abs(line_mid_x - block_mid_x) < line_height / 2:
|
119 |
+
center_close_num += 1
|
120 |
+
|
121 |
+
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
|
122 |
+
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
|
123 |
+
left_close_num += 1
|
124 |
+
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
|
125 |
+
left_not_close_num += 1
|
126 |
+
|
127 |
+
# 计算右侧是否顶格
|
128 |
+
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
|
129 |
+
right_close_num += 1
|
130 |
+
else:
|
131 |
+
# 类中文没有超长单词的情况,可以用统一的阈值
|
132 |
+
if block_lang in ['zh', 'ja', 'ko']:
|
133 |
+
closed_area = 0.26 * block_weight
|
134 |
+
else:
|
135 |
+
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
|
136 |
+
# block宽的阈值可以小些,block窄的阈值要大
|
137 |
+
if block_weight_radio >= 0.5:
|
138 |
+
closed_area = 0.26 * block_weight
|
139 |
+
else:
|
140 |
+
closed_area = 0.36 * block_weight
|
141 |
+
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
|
142 |
+
right_not_close_num += 1
|
143 |
+
|
144 |
+
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
|
145 |
+
line_end_flag = False
|
146 |
+
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
|
147 |
+
line_num_flag = False
|
148 |
+
num_start_count = 0
|
149 |
+
num_end_count = 0
|
150 |
+
flag_end_count = 0
|
151 |
+
|
152 |
+
if len(lines_text_list) > 0:
|
153 |
+
for line_text in lines_text_list:
|
154 |
+
if len(line_text) > 0:
|
155 |
+
if line_text[-1] in LIST_END_FLAG:
|
156 |
+
flag_end_count += 1
|
157 |
+
if line_text[0].isdigit():
|
158 |
+
num_start_count += 1
|
159 |
+
if line_text[-1].isdigit():
|
160 |
+
num_end_count += 1
|
161 |
+
|
162 |
+
if (
|
163 |
+
num_start_count / len(lines_text_list) >= 0.8
|
164 |
+
or num_end_count / len(lines_text_list) >= 0.8
|
165 |
+
):
|
166 |
+
line_num_flag = True
|
167 |
+
if flag_end_count / len(lines_text_list) >= 0.8:
|
168 |
+
line_end_flag = True
|
169 |
+
|
170 |
+
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
|
171 |
+
if (
|
172 |
+
left_close_num / len(block['lines']) >= 0.8
|
173 |
+
or right_close_num / len(block['lines']) >= 0.8
|
174 |
+
) and line_num_flag:
|
175 |
+
for line in block['lines']:
|
176 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
177 |
+
return BlockType.INDEX
|
178 |
+
|
179 |
+
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
|
180 |
+
# 补充条件block的长宽比有要求
|
181 |
+
elif (
|
182 |
+
external_sides_not_close_num >= 2
|
183 |
+
and center_close_num == len(block['lines'])
|
184 |
+
and external_sides_not_close_num / len(block['lines']) >= 0.5
|
185 |
+
and block_height / block_weight > 0.4
|
186 |
+
):
|
187 |
+
for line in block['lines']:
|
188 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
189 |
+
return BlockType.LIST
|
190 |
+
|
191 |
+
elif (
|
192 |
+
left_close_num >= 2
|
193 |
+
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
|
194 |
+
and not multiple_para_flag
|
195 |
+
# and block_weight_radio > 0.27
|
196 |
+
):
|
197 |
+
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
|
198 |
+
if left_close_num / len(block['lines']) > 0.8:
|
199 |
+
# 这种是每个item只有一行,且左边都贴边的短item list
|
200 |
+
if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
|
201 |
+
for line in block['lines']:
|
202 |
+
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
|
203 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
204 |
+
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
|
205 |
+
elif line_end_flag:
|
206 |
+
for i, line in enumerate(block['lines']):
|
207 |
+
if (
|
208 |
+
len(lines_text_list[i]) > 0
|
209 |
+
and lines_text_list[i][-1] in LIST_END_FLAG
|
210 |
+
):
|
211 |
+
line[ListLineTag.IS_LIST_END_LINE] = True
|
212 |
+
if i + 1 < len(block['lines']):
|
213 |
+
block['lines'][i + 1][
|
214 |
+
ListLineTag.IS_LIST_START_LINE
|
215 |
+
] = True
|
216 |
+
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
|
217 |
+
else:
|
218 |
+
line_start_flag = False
|
219 |
+
for i, line in enumerate(block['lines']):
|
220 |
+
if line_start_flag:
|
221 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
222 |
+
line_start_flag = False
|
223 |
+
|
224 |
+
if (
|
225 |
+
abs(block['bbox_fs'][2] - line['bbox'][2])
|
226 |
+
> 0.1 * block_weight
|
227 |
+
):
|
228 |
+
line[ListLineTag.IS_LIST_END_LINE] = True
|
229 |
+
line_start_flag = True
|
230 |
+
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
|
231 |
+
elif num_start_count >= 2 and num_start_count == flag_end_count:
|
232 |
+
for i, line in enumerate(block['lines']):
|
233 |
+
if len(lines_text_list[i]) > 0:
|
234 |
+
if lines_text_list[i][0].isdigit():
|
235 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
236 |
+
if lines_text_list[i][-1] in LIST_END_FLAG:
|
237 |
+
line[ListLineTag.IS_LIST_END_LINE] = True
|
238 |
+
else:
|
239 |
+
# 正常有缩进的list处理
|
240 |
+
for line in block['lines']:
|
241 |
+
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
|
242 |
+
line[ListLineTag.IS_LIST_START_LINE] = True
|
243 |
+
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
|
244 |
+
line[ListLineTag.IS_LIST_END_LINE] = True
|
245 |
+
|
246 |
+
return BlockType.LIST
|
247 |
+
else:
|
248 |
+
return BlockType.TEXT
|
249 |
+
else:
|
250 |
+
return BlockType.TEXT
|
251 |
+
|
252 |
+
|
253 |
+
def __merge_2_text_blocks(block1, block2):
|
254 |
+
if len(block1['lines']) > 0:
|
255 |
+
first_line = block1['lines'][0]
|
256 |
+
line_height = first_line['bbox'][3] - first_line['bbox'][1]
|
257 |
+
block1_weight = block1['bbox'][2] - block1['bbox'][0]
|
258 |
+
block2_weight = block2['bbox'][2] - block2['bbox'][0]
|
259 |
+
min_block_weight = min(block1_weight, block2_weight)
|
260 |
+
if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
|
261 |
+
last_line = block2['lines'][-1]
|
262 |
+
if len(last_line['spans']) > 0:
|
263 |
+
last_span = last_line['spans'][-1]
|
264 |
+
line_height = last_line['bbox'][3] - last_line['bbox'][1]
|
265 |
+
if len(first_line['spans']) > 0:
|
266 |
+
first_span = first_line['spans'][0]
|
267 |
+
if len(first_span['content']) > 0:
|
268 |
+
span_start_with_num = first_span['content'][0].isdigit()
|
269 |
+
span_start_with_big_char = first_span['content'][0].isupper()
|
270 |
+
if (
|
271 |
+
# 上一个block的最后一个line的右边界和block的右边界差距不超过line_height
|
272 |
+
abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
|
273 |
+
# 上一个block的最后一个span不是以特定符号结尾
|
274 |
+
and not last_span['content'].endswith(LINE_STOP_FLAG)
|
275 |
+
# 两个block宽度差距超过2倍也不合并
|
276 |
+
and abs(block1_weight - block2_weight) < min_block_weight
|
277 |
+
# 下一个block的第一个字符是数字
|
278 |
+
and not span_start_with_num
|
279 |
+
# 下一个block的第一个字符是大写字母
|
280 |
+
and not span_start_with_big_char
|
281 |
+
):
|
282 |
+
if block1['page_num'] != block2['page_num']:
|
283 |
+
for line in block1['lines']:
|
284 |
+
for span in line['spans']:
|
285 |
+
span[SplitFlag.CROSS_PAGE] = True
|
286 |
+
block2['lines'].extend(block1['lines'])
|
287 |
+
block1['lines'] = []
|
288 |
+
block1[SplitFlag.LINES_DELETED] = True
|
289 |
+
|
290 |
+
return block1, block2
|
291 |
+
|
292 |
+
|
293 |
+
def __merge_2_list_blocks(block1, block2):
|
294 |
+
if block1['page_num'] != block2['page_num']:
|
295 |
+
for line in block1['lines']:
|
296 |
+
for span in line['spans']:
|
297 |
+
span[SplitFlag.CROSS_PAGE] = True
|
298 |
+
block2['lines'].extend(block1['lines'])
|
299 |
+
block1['lines'] = []
|
300 |
+
block1[SplitFlag.LINES_DELETED] = True
|
301 |
+
|
302 |
+
return block1, block2
|
303 |
+
|
304 |
+
|
305 |
+
def __is_list_group(text_blocks_group):
|
306 |
+
# list group的特征是一个group内的所有block都满足以下条件
|
307 |
+
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
|
308 |
+
for block in text_blocks_group:
|
309 |
+
if len(block['lines']) > 3:
|
310 |
+
return False
|
311 |
+
return True
|
312 |
+
|
313 |
+
|
314 |
+
def __para_merge_page(blocks):
|
315 |
+
page_text_blocks_groups = __process_blocks(blocks)
|
316 |
+
for text_blocks_group in page_text_blocks_groups:
|
317 |
+
if len(text_blocks_group) > 0:
|
318 |
+
# 需要先在合并前对所有block判断是否为list or index block
|
319 |
+
for block in text_blocks_group:
|
320 |
+
block_type = __is_list_or_index_block(block)
|
321 |
+
block['type'] = block_type
|
322 |
+
# logger.info(f"{block['type']}:{block}")
|
323 |
+
|
324 |
+
if len(text_blocks_group) > 1:
|
325 |
+
# 在合并前判断这个group 是否是一个 list group
|
326 |
+
is_list_group = __is_list_group(text_blocks_group)
|
327 |
+
|
328 |
+
# 倒序遍历
|
329 |
+
for i in range(len(text_blocks_group) - 1, -1, -1):
|
330 |
+
current_block = text_blocks_group[i]
|
331 |
+
|
332 |
+
# 检查是否有前一个块
|
333 |
+
if i - 1 >= 0:
|
334 |
+
prev_block = text_blocks_group[i - 1]
|
335 |
+
|
336 |
+
if (
|
337 |
+
current_block['type'] == 'text'
|
338 |
+
and prev_block['type'] == 'text'
|
339 |
+
and not is_list_group
|
340 |
+
):
|
341 |
+
__merge_2_text_blocks(current_block, prev_block)
|
342 |
+
elif (
|
343 |
+
current_block['type'] == BlockType.LIST
|
344 |
+
and prev_block['type'] == BlockType.LIST
|
345 |
+
) or (
|
346 |
+
current_block['type'] == BlockType.INDEX
|
347 |
+
and prev_block['type'] == BlockType.INDEX
|
348 |
+
):
|
349 |
+
__merge_2_list_blocks(current_block, prev_block)
|
350 |
+
|
351 |
+
else:
|
352 |
+
continue
|
353 |
+
|
354 |
+
|
355 |
+
def para_split(page_info_list):
|
356 |
+
all_blocks = []
|
357 |
+
for page_info in page_info_list:
|
358 |
+
blocks = copy.deepcopy(page_info['preproc_blocks'])
|
359 |
+
for block in blocks:
|
360 |
+
block['page_num'] = page_info['page_idx']
|
361 |
+
block['page_size'] = page_info['page_size']
|
362 |
+
all_blocks.extend(blocks)
|
363 |
+
|
364 |
+
__para_merge_page(all_blocks)
|
365 |
+
for page_info in page_info_list:
|
366 |
+
page_info['para_blocks'] = []
|
367 |
+
for block in all_blocks:
|
368 |
+
if 'page_num' in block:
|
369 |
+
if block['page_num'] == page_info['page_idx']:
|
370 |
+
page_info['para_blocks'].append(block)
|
371 |
+
# 从block中删除不需要的page_num和page_size字段
|
372 |
+
del block['page_num']
|
373 |
+
del block['page_size']
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == '__main__':
|
377 |
+
input_blocks = []
|
378 |
+
# 调用函数
|
379 |
+
groups = __process_blocks(input_blocks)
|
380 |
+
for group_index, group in enumerate(groups):
|
381 |
+
print(f'Group {group_index}: {group}')
|
vendor/mineru/mineru/backend/pipeline/pipeline_analyze.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from typing import List, Tuple
|
4 |
+
import PIL.Image
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from .model_init import MineruPipelineModel
|
8 |
+
from mineru.utils.config_reader import get_device
|
9 |
+
from ...utils.pdf_classify import classify
|
10 |
+
from ...utils.pdf_image_tools import load_images_from_pdf
|
11 |
+
from ...utils.model_utils import get_vram, clean_memory
|
12 |
+
|
13 |
+
|
14 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
|
15 |
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
16 |
+
|
17 |
+
class ModelSingleton:
|
18 |
+
_instance = None
|
19 |
+
_models = {}
|
20 |
+
|
21 |
+
def __new__(cls, *args, **kwargs):
|
22 |
+
if cls._instance is None:
|
23 |
+
cls._instance = super().__new__(cls)
|
24 |
+
return cls._instance
|
25 |
+
|
26 |
+
def get_model(
|
27 |
+
self,
|
28 |
+
lang=None,
|
29 |
+
formula_enable=None,
|
30 |
+
table_enable=None,
|
31 |
+
):
|
32 |
+
key = (lang, formula_enable, table_enable)
|
33 |
+
if key not in self._models:
|
34 |
+
self._models[key] = custom_model_init(
|
35 |
+
lang=lang,
|
36 |
+
formula_enable=formula_enable,
|
37 |
+
table_enable=table_enable,
|
38 |
+
)
|
39 |
+
return self._models[key]
|
40 |
+
|
41 |
+
|
42 |
+
def custom_model_init(
|
43 |
+
lang=None,
|
44 |
+
formula_enable=True,
|
45 |
+
table_enable=True,
|
46 |
+
):
|
47 |
+
model_init_start = time.time()
|
48 |
+
# 从配置文件读取model-dir和device
|
49 |
+
device = get_device()
|
50 |
+
|
51 |
+
formula_config = {"enable": formula_enable}
|
52 |
+
table_config = {"enable": table_enable}
|
53 |
+
|
54 |
+
model_input = {
|
55 |
+
'device': device,
|
56 |
+
'table_config': table_config,
|
57 |
+
'formula_config': formula_config,
|
58 |
+
'lang': lang,
|
59 |
+
}
|
60 |
+
|
61 |
+
custom_model = MineruPipelineModel(**model_input)
|
62 |
+
|
63 |
+
model_init_cost = time.time() - model_init_start
|
64 |
+
logger.info(f'model init cost: {model_init_cost}')
|
65 |
+
|
66 |
+
return custom_model
|
67 |
+
|
68 |
+
|
69 |
+
def doc_analyze(
|
70 |
+
pdf_bytes_list,
|
71 |
+
lang_list,
|
72 |
+
parse_method: str = 'auto',
|
73 |
+
formula_enable=True,
|
74 |
+
table_enable=True,
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,可能会增加显存使用量,
|
78 |
+
可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为128。
|
79 |
+
"""
|
80 |
+
min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 128))
|
81 |
+
|
82 |
+
# 收集所有页面信息
|
83 |
+
all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
|
84 |
+
|
85 |
+
all_image_lists = []
|
86 |
+
all_pdf_docs = []
|
87 |
+
ocr_enabled_list = []
|
88 |
+
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
|
89 |
+
# 确定OCR设置
|
90 |
+
_ocr_enable = False
|
91 |
+
if parse_method == 'auto':
|
92 |
+
if classify(pdf_bytes) == 'ocr':
|
93 |
+
_ocr_enable = True
|
94 |
+
elif parse_method == 'ocr':
|
95 |
+
_ocr_enable = True
|
96 |
+
|
97 |
+
ocr_enabled_list.append(_ocr_enable)
|
98 |
+
_lang = lang_list[pdf_idx]
|
99 |
+
|
100 |
+
# 收集每个数据集中的页面
|
101 |
+
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
|
102 |
+
all_image_lists.append(images_list)
|
103 |
+
all_pdf_docs.append(pdf_doc)
|
104 |
+
for page_idx in range(len(images_list)):
|
105 |
+
img_dict = images_list[page_idx]
|
106 |
+
all_pages_info.append((
|
107 |
+
pdf_idx, page_idx,
|
108 |
+
img_dict['img_pil'], _ocr_enable, _lang,
|
109 |
+
))
|
110 |
+
|
111 |
+
# 准备批处理
|
112 |
+
images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
|
113 |
+
batch_size = min_batch_inference_size
|
114 |
+
batch_images = [
|
115 |
+
images_with_extra_info[i:i + batch_size]
|
116 |
+
for i in range(0, len(images_with_extra_info), batch_size)
|
117 |
+
]
|
118 |
+
|
119 |
+
# 执行批处理
|
120 |
+
results = []
|
121 |
+
processed_images_count = 0
|
122 |
+
for index, batch_image in enumerate(batch_images):
|
123 |
+
processed_images_count += len(batch_image)
|
124 |
+
logger.info(
|
125 |
+
f'Batch {index + 1}/{len(batch_images)}: '
|
126 |
+
f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
|
127 |
+
)
|
128 |
+
batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
|
129 |
+
results.extend(batch_results)
|
130 |
+
|
131 |
+
# 构建返回结果
|
132 |
+
infer_results = []
|
133 |
+
|
134 |
+
for _ in range(len(pdf_bytes_list)):
|
135 |
+
infer_results.append([])
|
136 |
+
|
137 |
+
for i, page_info in enumerate(all_pages_info):
|
138 |
+
pdf_idx, page_idx, pil_img, _, _ = page_info
|
139 |
+
result = results[i]
|
140 |
+
|
141 |
+
page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
|
142 |
+
page_dict = {'layout_dets': result, 'page_info': page_info_dict}
|
143 |
+
|
144 |
+
infer_results[pdf_idx].append(page_dict)
|
145 |
+
|
146 |
+
return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
|
147 |
+
|
148 |
+
|
149 |
+
def batch_image_analyze(
|
150 |
+
images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
|
151 |
+
formula_enable=True,
|
152 |
+
table_enable=True):
|
153 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
|
154 |
+
|
155 |
+
from .batch_analyze import BatchAnalyze
|
156 |
+
|
157 |
+
model_manager = ModelSingleton()
|
158 |
+
|
159 |
+
batch_ratio = 1
|
160 |
+
device = get_device()
|
161 |
+
|
162 |
+
if str(device).startswith('npu'):
|
163 |
+
try:
|
164 |
+
import torch_npu
|
165 |
+
if torch_npu.npu.is_available():
|
166 |
+
torch_npu.npu.set_compile_mode(jit_compile=False)
|
167 |
+
except Exception as e:
|
168 |
+
raise RuntimeError(
|
169 |
+
"NPU is selected as device, but torch_npu is not available. "
|
170 |
+
"Please ensure that the torch_npu package is installed correctly."
|
171 |
+
) from e
|
172 |
+
|
173 |
+
if str(device).startswith('npu') or str(device).startswith('cuda'):
|
174 |
+
vram = get_vram(device)
|
175 |
+
if vram is not None:
|
176 |
+
gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
|
177 |
+
if gpu_memory >= 16:
|
178 |
+
batch_ratio = 16
|
179 |
+
elif gpu_memory >= 12:
|
180 |
+
batch_ratio = 8
|
181 |
+
elif gpu_memory >= 8:
|
182 |
+
batch_ratio = 4
|
183 |
+
elif gpu_memory >= 6:
|
184 |
+
batch_ratio = 2
|
185 |
+
else:
|
186 |
+
batch_ratio = 1
|
187 |
+
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
|
188 |
+
else:
|
189 |
+
# Default batch_ratio when VRAM can't be determined
|
190 |
+
batch_ratio = 1
|
191 |
+
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
|
192 |
+
|
193 |
+
batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
|
194 |
+
results = batch_model(images_with_extra_info)
|
195 |
+
|
196 |
+
clean_memory(get_device())
|
197 |
+
|
198 |
+
return results
|
vendor/mineru/mineru/backend/pipeline/pipeline_magic_model.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
|
2 |
+
from mineru.utils.enum_class import CategoryId, ContentType
|
3 |
+
|
4 |
+
|
5 |
+
class MagicModel:
|
6 |
+
"""每个函数没有得到元素的时候返回空list."""
|
7 |
+
def __init__(self, page_model_info: dict, scale: float):
|
8 |
+
self.__page_model_info = page_model_info
|
9 |
+
self.__scale = scale
|
10 |
+
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
|
11 |
+
self.__fix_axis()
|
12 |
+
"""删除置信度特别低的模型数据(<0.05),提高质量"""
|
13 |
+
self.__fix_by_remove_low_confidence()
|
14 |
+
"""删除高iou(>0.9)数据中置信度较低的那个"""
|
15 |
+
self.__fix_by_remove_high_iou_and_low_confidence()
|
16 |
+
"""将部分tbale_footnote修正为image_footnote"""
|
17 |
+
self.__fix_footnote()
|
18 |
+
"""处理重叠的image_body和table_body"""
|
19 |
+
self.__fix_by_remove_overlap_image_table_body()
|
20 |
+
|
21 |
+
def __fix_by_remove_overlap_image_table_body(self):
|
22 |
+
need_remove_list = []
|
23 |
+
layout_dets = self.__page_model_info['layout_dets']
|
24 |
+
image_blocks = list(filter(
|
25 |
+
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
|
26 |
+
))
|
27 |
+
table_blocks = list(filter(
|
28 |
+
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
|
29 |
+
))
|
30 |
+
|
31 |
+
def add_need_remove_block(blocks):
|
32 |
+
for i in range(len(blocks)):
|
33 |
+
for j in range(i + 1, len(blocks)):
|
34 |
+
block1 = blocks[i]
|
35 |
+
block2 = blocks[j]
|
36 |
+
overlap_box = get_minbox_if_overlap_by_ratio(
|
37 |
+
block1['bbox'], block2['bbox'], 0.8
|
38 |
+
)
|
39 |
+
if overlap_box is not None:
|
40 |
+
# 判断哪个区块的面积更小,移除较小的区块
|
41 |
+
area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
|
42 |
+
area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
|
43 |
+
|
44 |
+
if area1 <= area2:
|
45 |
+
block_to_remove = block1
|
46 |
+
large_block = block2
|
47 |
+
else:
|
48 |
+
block_to_remove = block2
|
49 |
+
large_block = block1
|
50 |
+
|
51 |
+
if block_to_remove not in need_remove_list:
|
52 |
+
# 扩展大区块的边界框
|
53 |
+
x1, y1, x2, y2 = large_block['bbox']
|
54 |
+
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
|
55 |
+
x1 = min(x1, sx1)
|
56 |
+
y1 = min(y1, sy1)
|
57 |
+
x2 = max(x2, sx2)
|
58 |
+
y2 = max(y2, sy2)
|
59 |
+
large_block['bbox'] = [x1, y1, x2, y2]
|
60 |
+
need_remove_list.append(block_to_remove)
|
61 |
+
|
62 |
+
# 处理图像-图像重叠
|
63 |
+
add_need_remove_block(image_blocks)
|
64 |
+
# 处理表格-表格重叠
|
65 |
+
add_need_remove_block(table_blocks)
|
66 |
+
|
67 |
+
# 从布局中移除标记的区块
|
68 |
+
for need_remove in need_remove_list:
|
69 |
+
if need_remove in layout_dets:
|
70 |
+
layout_dets.remove(need_remove)
|
71 |
+
|
72 |
+
|
73 |
+
def __fix_axis(self):
|
74 |
+
need_remove_list = []
|
75 |
+
layout_dets = self.__page_model_info['layout_dets']
|
76 |
+
for layout_det in layout_dets:
|
77 |
+
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
|
78 |
+
bbox = [
|
79 |
+
int(x0 / self.__scale),
|
80 |
+
int(y0 / self.__scale),
|
81 |
+
int(x1 / self.__scale),
|
82 |
+
int(y1 / self.__scale),
|
83 |
+
]
|
84 |
+
layout_det['bbox'] = bbox
|
85 |
+
# 删除高度或者宽度小于等于0的spans
|
86 |
+
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
|
87 |
+
need_remove_list.append(layout_det)
|
88 |
+
for need_remove in need_remove_list:
|
89 |
+
layout_dets.remove(need_remove)
|
90 |
+
|
91 |
+
def __fix_by_remove_low_confidence(self):
|
92 |
+
need_remove_list = []
|
93 |
+
layout_dets = self.__page_model_info['layout_dets']
|
94 |
+
for layout_det in layout_dets:
|
95 |
+
if layout_det['score'] <= 0.05:
|
96 |
+
need_remove_list.append(layout_det)
|
97 |
+
else:
|
98 |
+
continue
|
99 |
+
for need_remove in need_remove_list:
|
100 |
+
layout_dets.remove(need_remove)
|
101 |
+
|
102 |
+
def __fix_by_remove_high_iou_and_low_confidence(self):
|
103 |
+
need_remove_list = []
|
104 |
+
layout_dets = list(filter(
|
105 |
+
lambda x: x['category_id'] in [
|
106 |
+
CategoryId.Title,
|
107 |
+
CategoryId.Text,
|
108 |
+
CategoryId.ImageBody,
|
109 |
+
CategoryId.ImageCaption,
|
110 |
+
CategoryId.TableBody,
|
111 |
+
CategoryId.TableCaption,
|
112 |
+
CategoryId.TableFootnote,
|
113 |
+
CategoryId.InterlineEquation_Layout,
|
114 |
+
CategoryId.InterlineEquationNumber_Layout,
|
115 |
+
], self.__page_model_info['layout_dets']
|
116 |
+
)
|
117 |
+
)
|
118 |
+
for i in range(len(layout_dets)):
|
119 |
+
for j in range(i + 1, len(layout_dets)):
|
120 |
+
layout_det1 = layout_dets[i]
|
121 |
+
layout_det2 = layout_dets[j]
|
122 |
+
|
123 |
+
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
|
124 |
+
|
125 |
+
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
|
126 |
+
|
127 |
+
if layout_det_need_remove not in need_remove_list:
|
128 |
+
need_remove_list.append(layout_det_need_remove)
|
129 |
+
|
130 |
+
for need_remove in need_remove_list:
|
131 |
+
self.__page_model_info['layout_dets'].remove(need_remove)
|
132 |
+
|
133 |
+
def __fix_footnote(self):
|
134 |
+
footnotes = []
|
135 |
+
figures = []
|
136 |
+
tables = []
|
137 |
+
|
138 |
+
for obj in self.__page_model_info['layout_dets']:
|
139 |
+
if obj['category_id'] == CategoryId.TableFootnote:
|
140 |
+
footnotes.append(obj)
|
141 |
+
elif obj['category_id'] == CategoryId.ImageBody:
|
142 |
+
figures.append(obj)
|
143 |
+
elif obj['category_id'] == CategoryId.TableBody:
|
144 |
+
tables.append(obj)
|
145 |
+
if len(footnotes) * len(figures) == 0:
|
146 |
+
continue
|
147 |
+
dis_figure_footnote = {}
|
148 |
+
dis_table_footnote = {}
|
149 |
+
|
150 |
+
for i in range(len(footnotes)):
|
151 |
+
for j in range(len(figures)):
|
152 |
+
pos_flag_count = sum(
|
153 |
+
list(
|
154 |
+
map(
|
155 |
+
lambda x: 1 if x else 0,
|
156 |
+
bbox_relative_pos(
|
157 |
+
footnotes[i]['bbox'], figures[j]['bbox']
|
158 |
+
),
|
159 |
+
)
|
160 |
+
)
|
161 |
+
)
|
162 |
+
if pos_flag_count > 1:
|
163 |
+
continue
|
164 |
+
dis_figure_footnote[i] = min(
|
165 |
+
self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
|
166 |
+
dis_figure_footnote.get(i, float('inf')),
|
167 |
+
)
|
168 |
+
for i in range(len(footnotes)):
|
169 |
+
for j in range(len(tables)):
|
170 |
+
pos_flag_count = sum(
|
171 |
+
list(
|
172 |
+
map(
|
173 |
+
lambda x: 1 if x else 0,
|
174 |
+
bbox_relative_pos(
|
175 |
+
footnotes[i]['bbox'], tables[j]['bbox']
|
176 |
+
),
|
177 |
+
)
|
178 |
+
)
|
179 |
+
)
|
180 |
+
if pos_flag_count > 1:
|
181 |
+
continue
|
182 |
+
|
183 |
+
dis_table_footnote[i] = min(
|
184 |
+
self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
|
185 |
+
dis_table_footnote.get(i, float('inf')),
|
186 |
+
)
|
187 |
+
for i in range(len(footnotes)):
|
188 |
+
if i not in dis_figure_footnote:
|
189 |
+
continue
|
190 |
+
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
|
191 |
+
footnotes[i]['category_id'] = CategoryId.ImageFootnote
|
192 |
+
|
193 |
+
def _bbox_distance(self, bbox1, bbox2):
|
194 |
+
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
|
195 |
+
flags = [left, right, bottom, top]
|
196 |
+
count = sum([1 if v else 0 for v in flags])
|
197 |
+
if count > 1:
|
198 |
+
return float('inf')
|
199 |
+
if left or right:
|
200 |
+
l1 = bbox1[3] - bbox1[1]
|
201 |
+
l2 = bbox2[3] - bbox2[1]
|
202 |
+
else:
|
203 |
+
l1 = bbox1[2] - bbox1[0]
|
204 |
+
l2 = bbox2[2] - bbox2[0]
|
205 |
+
|
206 |
+
if l2 > l1 and (l2 - l1) / l1 > 0.3:
|
207 |
+
return float('inf')
|
208 |
+
|
209 |
+
return bbox_distance(bbox1, bbox2)
|
210 |
+
|
211 |
+
def __reduct_overlap(self, bboxes):
|
212 |
+
N = len(bboxes)
|
213 |
+
keep = [True] * N
|
214 |
+
for i in range(N):
|
215 |
+
for j in range(N):
|
216 |
+
if i == j:
|
217 |
+
continue
|
218 |
+
if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
|
219 |
+
keep[i] = False
|
220 |
+
return [bboxes[i] for i in range(N) if keep[i]]
|
221 |
+
|
222 |
+
def __tie_up_category_by_distance_v3(
|
223 |
+
self,
|
224 |
+
subject_category_id: int,
|
225 |
+
object_category_id: int,
|
226 |
+
):
|
227 |
+
subjects = self.__reduct_overlap(
|
228 |
+
list(
|
229 |
+
map(
|
230 |
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
231 |
+
filter(
|
232 |
+
lambda x: x['category_id'] == subject_category_id,
|
233 |
+
self.__page_model_info['layout_dets'],
|
234 |
+
),
|
235 |
+
)
|
236 |
+
)
|
237 |
+
)
|
238 |
+
objects = self.__reduct_overlap(
|
239 |
+
list(
|
240 |
+
map(
|
241 |
+
lambda x: {'bbox': x['bbox'], 'score': x['score']},
|
242 |
+
filter(
|
243 |
+
lambda x: x['category_id'] == object_category_id,
|
244 |
+
self.__page_model_info['layout_dets'],
|
245 |
+
),
|
246 |
+
)
|
247 |
+
)
|
248 |
+
)
|
249 |
+
|
250 |
+
ret = []
|
251 |
+
N, M = len(subjects), len(objects)
|
252 |
+
subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
253 |
+
objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
|
254 |
+
|
255 |
+
OBJ_IDX_OFFSET = 10000
|
256 |
+
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
|
257 |
+
|
258 |
+
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
|
259 |
+
seen_idx = set()
|
260 |
+
seen_sub_idx = set()
|
261 |
+
|
262 |
+
while N > len(seen_sub_idx):
|
263 |
+
candidates = []
|
264 |
+
for idx, kind, x0, y0 in all_boxes_with_idx:
|
265 |
+
if idx in seen_idx:
|
266 |
+
continue
|
267 |
+
candidates.append((idx, kind, x0, y0))
|
268 |
+
|
269 |
+
if len(candidates) == 0:
|
270 |
+
break
|
271 |
+
left_x = min([v[2] for v in candidates])
|
272 |
+
top_y = min([v[3] for v in candidates])
|
273 |
+
|
274 |
+
candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
|
275 |
+
|
276 |
+
|
277 |
+
fst_idx, fst_kind, left_x, top_y = candidates[0]
|
278 |
+
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
|
279 |
+
nxt = None
|
280 |
+
|
281 |
+
for i in range(1, len(candidates)):
|
282 |
+
if candidates[i][1] ^ fst_kind == 1:
|
283 |
+
nxt = candidates[i]
|
284 |
+
break
|
285 |
+
if nxt is None:
|
286 |
+
break
|
287 |
+
|
288 |
+
if fst_kind == SUB_BIT_KIND:
|
289 |
+
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
|
290 |
+
|
291 |
+
else:
|
292 |
+
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
|
293 |
+
|
294 |
+
pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
|
295 |
+
nearest_dis = float('inf')
|
296 |
+
for i in range(N):
|
297 |
+
if i in seen_idx or i == sub_idx:continue
|
298 |
+
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
|
299 |
+
|
300 |
+
if pair_dis >= 3*nearest_dis:
|
301 |
+
seen_idx.add(sub_idx)
|
302 |
+
continue
|
303 |
+
|
304 |
+
seen_idx.add(sub_idx)
|
305 |
+
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
|
306 |
+
seen_sub_idx.add(sub_idx)
|
307 |
+
|
308 |
+
ret.append(
|
309 |
+
{
|
310 |
+
'sub_bbox': {
|
311 |
+
'bbox': subjects[sub_idx]['bbox'],
|
312 |
+
'score': subjects[sub_idx]['score'],
|
313 |
+
},
|
314 |
+
'obj_bboxes': [
|
315 |
+
{'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
|
316 |
+
],
|
317 |
+
'sub_idx': sub_idx,
|
318 |
+
}
|
319 |
+
)
|
320 |
+
|
321 |
+
for i in range(len(objects)):
|
322 |
+
j = i + OBJ_IDX_OFFSET
|
323 |
+
if j in seen_idx:
|
324 |
+
continue
|
325 |
+
seen_idx.add(j)
|
326 |
+
nearest_dis, nearest_sub_idx = float('inf'), -1
|
327 |
+
for k in range(len(subjects)):
|
328 |
+
dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
|
329 |
+
if dis < nearest_dis:
|
330 |
+
nearest_dis = dis
|
331 |
+
nearest_sub_idx = k
|
332 |
+
|
333 |
+
for k in range(len(subjects)):
|
334 |
+
if k != nearest_sub_idx: continue
|
335 |
+
if k in seen_sub_idx:
|
336 |
+
for kk in range(len(ret)):
|
337 |
+
if ret[kk]['sub_idx'] == k:
|
338 |
+
ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
|
339 |
+
break
|
340 |
+
else:
|
341 |
+
ret.append(
|
342 |
+
{
|
343 |
+
'sub_bbox': {
|
344 |
+
'bbox': subjects[k]['bbox'],
|
345 |
+
'score': subjects[k]['score'],
|
346 |
+
},
|
347 |
+
'obj_bboxes': [
|
348 |
+
{'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
|
349 |
+
],
|
350 |
+
'sub_idx': k,
|
351 |
+
}
|
352 |
+
)
|
353 |
+
seen_sub_idx.add(k)
|
354 |
+
seen_idx.add(k)
|
355 |
+
|
356 |
+
|
357 |
+
for i in range(len(subjects)):
|
358 |
+
if i in seen_sub_idx:
|
359 |
+
continue
|
360 |
+
ret.append(
|
361 |
+
{
|
362 |
+
'sub_bbox': {
|
363 |
+
'bbox': subjects[i]['bbox'],
|
364 |
+
'score': subjects[i]['score'],
|
365 |
+
},
|
366 |
+
'obj_bboxes': [],
|
367 |
+
'sub_idx': i,
|
368 |
+
}
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
return ret
|
373 |
+
|
374 |
+
def get_imgs(self):
|
375 |
+
with_captions = self.__tie_up_category_by_distance_v3(
|
376 |
+
CategoryId.ImageBody, CategoryId.ImageCaption
|
377 |
+
)
|
378 |
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
379 |
+
CategoryId.ImageBody, CategoryId.ImageFootnote
|
380 |
+
)
|
381 |
+
ret = []
|
382 |
+
for v in with_captions:
|
383 |
+
record = {
|
384 |
+
'image_body': v['sub_bbox'],
|
385 |
+
'image_caption_list': v['obj_bboxes'],
|
386 |
+
}
|
387 |
+
filter_idx = v['sub_idx']
|
388 |
+
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
|
389 |
+
record['image_footnote_list'] = d['obj_bboxes']
|
390 |
+
ret.append(record)
|
391 |
+
return ret
|
392 |
+
|
393 |
+
def get_tables(self) -> list:
|
394 |
+
with_captions = self.__tie_up_category_by_distance_v3(
|
395 |
+
CategoryId.TableBody, CategoryId.TableCaption
|
396 |
+
)
|
397 |
+
with_footnotes = self.__tie_up_category_by_distance_v3(
|
398 |
+
CategoryId.TableBody, CategoryId.TableFootnote
|
399 |
+
)
|
400 |
+
ret = []
|
401 |
+
for v in with_captions:
|
402 |
+
record = {
|
403 |
+
'table_body': v['sub_bbox'],
|
404 |
+
'table_caption_list': v['obj_bboxes'],
|
405 |
+
}
|
406 |
+
filter_idx = v['sub_idx']
|
407 |
+
d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
|
408 |
+
record['table_footnote_list'] = d['obj_bboxes']
|
409 |
+
ret.append(record)
|
410 |
+
return ret
|
411 |
+
|
412 |
+
def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
|
413 |
+
inline_equations = self.__get_blocks_by_type(
|
414 |
+
CategoryId.InlineEquation, ['latex']
|
415 |
+
)
|
416 |
+
interline_equations = self.__get_blocks_by_type(
|
417 |
+
CategoryId.InterlineEquation_YOLO, ['latex']
|
418 |
+
)
|
419 |
+
interline_equations_blocks = self.__get_blocks_by_type(
|
420 |
+
CategoryId.InterlineEquation_Layout
|
421 |
+
)
|
422 |
+
return inline_equations, interline_equations, interline_equations_blocks
|
423 |
+
|
424 |
+
def get_discarded(self) -> list: # 自研模型,只有坐标
|
425 |
+
blocks = self.__get_blocks_by_type(CategoryId.Abandon)
|
426 |
+
return blocks
|
427 |
+
|
428 |
+
def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
|
429 |
+
blocks = self.__get_blocks_by_type(CategoryId.Text)
|
430 |
+
return blocks
|
431 |
+
|
432 |
+
def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
|
433 |
+
blocks = self.__get_blocks_by_type(CategoryId.Title)
|
434 |
+
return blocks
|
435 |
+
|
436 |
+
def get_all_spans(self) -> list:
|
437 |
+
|
438 |
+
def remove_duplicate_spans(spans):
|
439 |
+
new_spans = []
|
440 |
+
for span in spans:
|
441 |
+
if not any(span == existing_span for existing_span in new_spans):
|
442 |
+
new_spans.append(span)
|
443 |
+
return new_spans
|
444 |
+
|
445 |
+
all_spans = []
|
446 |
+
layout_dets = self.__page_model_info['layout_dets']
|
447 |
+
allow_category_id_list = [
|
448 |
+
CategoryId.ImageBody,
|
449 |
+
CategoryId.TableBody,
|
450 |
+
CategoryId.InlineEquation,
|
451 |
+
CategoryId.InterlineEquation_YOLO,
|
452 |
+
CategoryId.OcrText,
|
453 |
+
]
|
454 |
+
"""当成span拼接的"""
|
455 |
+
for layout_det in layout_dets:
|
456 |
+
category_id = layout_det['category_id']
|
457 |
+
if category_id in allow_category_id_list:
|
458 |
+
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
|
459 |
+
if category_id == CategoryId.ImageBody:
|
460 |
+
span['type'] = ContentType.IMAGE
|
461 |
+
elif category_id == CategoryId.TableBody:
|
462 |
+
# 获取table模型结果
|
463 |
+
latex = layout_det.get('latex', None)
|
464 |
+
html = layout_det.get('html', None)
|
465 |
+
if latex:
|
466 |
+
span['latex'] = latex
|
467 |
+
elif html:
|
468 |
+
span['html'] = html
|
469 |
+
span['type'] = ContentType.TABLE
|
470 |
+
elif category_id == CategoryId.InlineEquation:
|
471 |
+
span['content'] = layout_det['latex']
|
472 |
+
span['type'] = ContentType.INLINE_EQUATION
|
473 |
+
elif category_id == CategoryId.InterlineEquation_YOLO:
|
474 |
+
span['content'] = layout_det['latex']
|
475 |
+
span['type'] = ContentType.INTERLINE_EQUATION
|
476 |
+
elif category_id == CategoryId.OcrText:
|
477 |
+
span['content'] = layout_det['text']
|
478 |
+
span['type'] = ContentType.TEXT
|
479 |
+
all_spans.append(span)
|
480 |
+
return remove_duplicate_spans(all_spans)
|
481 |
+
|
482 |
+
def __get_blocks_by_type(
|
483 |
+
self, category_type: int, extra_col=None
|
484 |
+
) -> list:
|
485 |
+
if extra_col is None:
|
486 |
+
extra_col = []
|
487 |
+
blocks = []
|
488 |
+
layout_dets = self.__page_model_info.get('layout_dets', [])
|
489 |
+
for item in layout_dets:
|
490 |
+
category_id = item.get('category_id', -1)
|
491 |
+
bbox = item.get('bbox', None)
|
492 |
+
|
493 |
+
if category_id == category_type:
|
494 |
+
block = {
|
495 |
+
'bbox': bbox,
|
496 |
+
'score': item.get('score'),
|
497 |
+
}
|
498 |
+
for col in extra_col:
|
499 |
+
block[col] = item.get(col, None)
|
500 |
+
blocks.append(block)
|
501 |
+
return blocks
|
vendor/mineru/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
from mineru.utils.config_reader import get_latex_delimiter_config
|
5 |
+
from mineru.backend.pipeline.para_split import ListLineTag
|
6 |
+
from mineru.utils.enum_class import BlockType, ContentType, MakeMode
|
7 |
+
from mineru.utils.language import detect_lang
|
8 |
+
|
9 |
+
|
10 |
+
def __is_hyphen_at_line_end(line):
|
11 |
+
"""Check if a line ends with one or more letters followed by a hyphen.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
line (str): The line of text to check.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
|
18 |
+
"""
|
19 |
+
# Use regex to check if the line ends with one or more letters followed by a hyphen
|
20 |
+
return bool(re.search(r'[A-Za-z]+-\s*$', line))
|
21 |
+
|
22 |
+
|
23 |
+
def make_blocks_to_markdown(paras_of_layout,
|
24 |
+
mode,
|
25 |
+
img_buket_path='',
|
26 |
+
):
|
27 |
+
page_markdown = []
|
28 |
+
for para_block in paras_of_layout:
|
29 |
+
para_text = ''
|
30 |
+
para_type = para_block['type']
|
31 |
+
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
|
32 |
+
para_text = merge_para_with_text(para_block)
|
33 |
+
elif para_type == BlockType.TITLE:
|
34 |
+
title_level = get_title_level(para_block)
|
35 |
+
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
|
36 |
+
elif para_type == BlockType.INTERLINE_EQUATION:
|
37 |
+
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
|
38 |
+
continue
|
39 |
+
if para_block['lines'][0]['spans'][0].get('content', ''):
|
40 |
+
para_text = merge_para_with_text(para_block)
|
41 |
+
else:
|
42 |
+
para_text += f""
|
43 |
+
elif para_type == BlockType.IMAGE:
|
44 |
+
if mode == MakeMode.NLP_MD:
|
45 |
+
continue
|
46 |
+
elif mode == MakeMode.MM_MD:
|
47 |
+
# 检测是否存在图片脚注
|
48 |
+
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
|
49 |
+
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
|
50 |
+
if has_image_footnote:
|
51 |
+
for block in para_block['blocks']: # 1st.拼image_caption
|
52 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
53 |
+
para_text += merge_para_with_text(block) + ' \n'
|
54 |
+
for block in para_block['blocks']: # 2nd.拼image_body
|
55 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
56 |
+
for line in block['lines']:
|
57 |
+
for span in line['spans']:
|
58 |
+
if span['type'] == ContentType.IMAGE:
|
59 |
+
if span.get('image_path', ''):
|
60 |
+
para_text += f""
|
61 |
+
for block in para_block['blocks']: # 3rd.拼image_footnote
|
62 |
+
if block['type'] == BlockType.IMAGE_FOOTNOTE:
|
63 |
+
para_text += ' \n' + merge_para_with_text(block)
|
64 |
+
else:
|
65 |
+
for block in para_block['blocks']: # 1st.拼image_body
|
66 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
67 |
+
for line in block['lines']:
|
68 |
+
for span in line['spans']:
|
69 |
+
if span['type'] == ContentType.IMAGE:
|
70 |
+
if span.get('image_path', ''):
|
71 |
+
para_text += f""
|
72 |
+
for block in para_block['blocks']: # 2nd.拼image_caption
|
73 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
74 |
+
para_text += ' \n' + merge_para_with_text(block)
|
75 |
+
elif para_type == BlockType.TABLE:
|
76 |
+
if mode == MakeMode.NLP_MD:
|
77 |
+
continue
|
78 |
+
elif mode == MakeMode.MM_MD:
|
79 |
+
for block in para_block['blocks']: # 1st.拼table_caption
|
80 |
+
if block['type'] == BlockType.TABLE_CAPTION:
|
81 |
+
para_text += merge_para_with_text(block) + ' \n'
|
82 |
+
for block in para_block['blocks']: # 2nd.拼table_body
|
83 |
+
if block['type'] == BlockType.TABLE_BODY:
|
84 |
+
for line in block['lines']:
|
85 |
+
for span in line['spans']:
|
86 |
+
if span['type'] == ContentType.TABLE:
|
87 |
+
# if processed by table model
|
88 |
+
if span.get('html', ''):
|
89 |
+
para_text += f"\n{span['html']}\n"
|
90 |
+
elif span.get('image_path', ''):
|
91 |
+
para_text += f""
|
92 |
+
for block in para_block['blocks']: # 3rd.拼table_footnote
|
93 |
+
if block['type'] == BlockType.TABLE_FOOTNOTE:
|
94 |
+
para_text += '\n' + merge_para_with_text(block) + ' '
|
95 |
+
|
96 |
+
if para_text.strip() == '':
|
97 |
+
continue
|
98 |
+
else:
|
99 |
+
# page_markdown.append(para_text.strip() + ' ')
|
100 |
+
page_markdown.append(para_text.strip())
|
101 |
+
|
102 |
+
return page_markdown
|
103 |
+
|
104 |
+
|
105 |
+
def full_to_half(text: str) -> str:
|
106 |
+
"""Convert full-width characters to half-width characters using code point manipulation.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
text: String containing full-width characters
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
String with full-width characters converted to half-width
|
113 |
+
"""
|
114 |
+
result = []
|
115 |
+
for char in text:
|
116 |
+
code = ord(char)
|
117 |
+
# Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
|
118 |
+
if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
|
119 |
+
result.append(chr(code - 0xFEE0)) # Shift to ASCII range
|
120 |
+
else:
|
121 |
+
result.append(char)
|
122 |
+
return ''.join(result)
|
123 |
+
|
124 |
+
latex_delimiters_config = get_latex_delimiter_config()
|
125 |
+
|
126 |
+
default_delimiters = {
|
127 |
+
'display': {'left': '$$', 'right': '$$'},
|
128 |
+
'inline': {'left': '$', 'right': '$'}
|
129 |
+
}
|
130 |
+
|
131 |
+
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
|
132 |
+
|
133 |
+
display_left_delimiter = delimiters['display']['left']
|
134 |
+
display_right_delimiter = delimiters['display']['right']
|
135 |
+
inline_left_delimiter = delimiters['inline']['left']
|
136 |
+
inline_right_delimiter = delimiters['inline']['right']
|
137 |
+
|
138 |
+
def merge_para_with_text(para_block):
|
139 |
+
block_text = ''
|
140 |
+
for line in para_block['lines']:
|
141 |
+
for span in line['spans']:
|
142 |
+
if span['type'] in [ContentType.TEXT]:
|
143 |
+
span['content'] = full_to_half(span['content'])
|
144 |
+
block_text += span['content']
|
145 |
+
block_lang = detect_lang(block_text)
|
146 |
+
|
147 |
+
para_text = ''
|
148 |
+
for i, line in enumerate(para_block['lines']):
|
149 |
+
|
150 |
+
if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
|
151 |
+
para_text += ' \n'
|
152 |
+
|
153 |
+
for j, span in enumerate(line['spans']):
|
154 |
+
|
155 |
+
span_type = span['type']
|
156 |
+
content = ''
|
157 |
+
if span_type == ContentType.TEXT:
|
158 |
+
content = escape_special_markdown_char(span['content'])
|
159 |
+
elif span_type == ContentType.INLINE_EQUATION:
|
160 |
+
if span.get('content', ''):
|
161 |
+
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
|
162 |
+
elif span_type == ContentType.INTERLINE_EQUATION:
|
163 |
+
if span.get('content', ''):
|
164 |
+
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
|
165 |
+
|
166 |
+
content = content.strip()
|
167 |
+
|
168 |
+
if content:
|
169 |
+
langs = ['zh', 'ja', 'ko']
|
170 |
+
# logger.info(f'block_lang: {block_lang}, content: {content}')
|
171 |
+
if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
|
172 |
+
if j == len(line['spans']) - 1 and span_type not in [ContentType.INLINE_EQUATION]:
|
173 |
+
para_text += content
|
174 |
+
else:
|
175 |
+
para_text += f'{content} '
|
176 |
+
else:
|
177 |
+
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
|
178 |
+
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
|
179 |
+
if j == len(line['spans'])-1 and span_type == ContentType.TEXT and __is_hyphen_at_line_end(content):
|
180 |
+
para_text += content[:-1]
|
181 |
+
else: # 西方文本语境下 content间需要空格分隔
|
182 |
+
para_text += f'{content} '
|
183 |
+
elif span_type == ContentType.INTERLINE_EQUATION:
|
184 |
+
para_text += content
|
185 |
+
else:
|
186 |
+
continue
|
187 |
+
|
188 |
+
return para_text
|
189 |
+
|
190 |
+
|
191 |
+
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
|
192 |
+
para_type = para_block['type']
|
193 |
+
para_content = {}
|
194 |
+
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
|
195 |
+
para_content = {
|
196 |
+
'type': ContentType.TEXT,
|
197 |
+
'text': merge_para_with_text(para_block),
|
198 |
+
}
|
199 |
+
elif para_type == BlockType.TITLE:
|
200 |
+
para_content = {
|
201 |
+
'type': ContentType.TEXT,
|
202 |
+
'text': merge_para_with_text(para_block),
|
203 |
+
}
|
204 |
+
title_level = get_title_level(para_block)
|
205 |
+
if title_level != 0:
|
206 |
+
para_content['text_level'] = title_level
|
207 |
+
elif para_type == BlockType.INTERLINE_EQUATION:
|
208 |
+
if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
|
209 |
+
return None
|
210 |
+
para_content = {
|
211 |
+
'type': ContentType.EQUATION,
|
212 |
+
'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
|
213 |
+
}
|
214 |
+
if para_block['lines'][0]['spans'][0].get('content', ''):
|
215 |
+
para_content['text'] = merge_para_with_text(para_block)
|
216 |
+
para_content['text_format'] = 'latex'
|
217 |
+
elif para_type == BlockType.IMAGE:
|
218 |
+
para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
|
219 |
+
for block in para_block['blocks']:
|
220 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
221 |
+
for line in block['lines']:
|
222 |
+
for span in line['spans']:
|
223 |
+
if span['type'] == ContentType.IMAGE:
|
224 |
+
if span.get('image_path', ''):
|
225 |
+
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
|
226 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
227 |
+
para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
|
228 |
+
if block['type'] == BlockType.IMAGE_FOOTNOTE:
|
229 |
+
para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
|
230 |
+
elif para_type == BlockType.TABLE:
|
231 |
+
para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
|
232 |
+
for block in para_block['blocks']:
|
233 |
+
if block['type'] == BlockType.TABLE_BODY:
|
234 |
+
for line in block['lines']:
|
235 |
+
for span in line['spans']:
|
236 |
+
if span['type'] == ContentType.TABLE:
|
237 |
+
if span.get('html', ''):
|
238 |
+
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
|
239 |
+
|
240 |
+
if span.get('image_path', ''):
|
241 |
+
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
|
242 |
+
|
243 |
+
if block['type'] == BlockType.TABLE_CAPTION:
|
244 |
+
para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
|
245 |
+
if block['type'] == BlockType.TABLE_FOOTNOTE:
|
246 |
+
para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
|
247 |
+
|
248 |
+
para_content['page_idx'] = page_idx
|
249 |
+
|
250 |
+
return para_content
|
251 |
+
|
252 |
+
|
253 |
+
def union_make(pdf_info_dict: list,
|
254 |
+
make_mode: str,
|
255 |
+
img_buket_path: str = '',
|
256 |
+
):
|
257 |
+
output_content = []
|
258 |
+
for page_info in pdf_info_dict:
|
259 |
+
paras_of_layout = page_info.get('para_blocks')
|
260 |
+
page_idx = page_info.get('page_idx')
|
261 |
+
if not paras_of_layout:
|
262 |
+
continue
|
263 |
+
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
|
264 |
+
page_markdown = make_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
|
265 |
+
output_content.extend(page_markdown)
|
266 |
+
elif make_mode == MakeMode.CONTENT_LIST:
|
267 |
+
for para_block in paras_of_layout:
|
268 |
+
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
|
269 |
+
if para_content:
|
270 |
+
output_content.append(para_content)
|
271 |
+
|
272 |
+
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
|
273 |
+
return '\n\n'.join(output_content)
|
274 |
+
elif make_mode == MakeMode.CONTENT_LIST:
|
275 |
+
return output_content
|
276 |
+
else:
|
277 |
+
logger.error(f"Unsupported make mode: {make_mode}")
|
278 |
+
return None
|
279 |
+
|
280 |
+
|
281 |
+
def get_title_level(block):
|
282 |
+
title_level = block.get('level', 1)
|
283 |
+
if title_level > 4:
|
284 |
+
title_level = 4
|
285 |
+
elif title_level < 1:
|
286 |
+
title_level = 0
|
287 |
+
return title_level
|
288 |
+
|
289 |
+
|
290 |
+
def escape_special_markdown_char(content):
|
291 |
+
"""
|
292 |
+
转义正文里对markdown语法有特殊意义的字符
|
293 |
+
"""
|
294 |
+
special_chars = ["*", "`", "~", "$"]
|
295 |
+
for char in special_chars:
|
296 |
+
content = content.replace(char, "\\" + char)
|
297 |
+
|
298 |
+
return content
|
vendor/mineru/mineru/backend/vlm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/backend/vlm/base_predictor.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import AsyncIterable, Iterable, List, Optional, Union
|
4 |
+
|
5 |
+
DEFAULT_SYSTEM_PROMPT = (
|
6 |
+
"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
|
7 |
+
)
|
8 |
+
DEFAULT_USER_PROMPT = "Document Parsing:"
|
9 |
+
DEFAULT_TEMPERATURE = 0.0
|
10 |
+
DEFAULT_TOP_P = 0.8
|
11 |
+
DEFAULT_TOP_K = 20
|
12 |
+
DEFAULT_REPETITION_PENALTY = 1.0
|
13 |
+
DEFAULT_PRESENCE_PENALTY = 0.0
|
14 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
|
15 |
+
DEFAULT_MAX_NEW_TOKENS = 16384
|
16 |
+
|
17 |
+
|
18 |
+
class BasePredictor(ABC):
|
19 |
+
system_prompt = DEFAULT_SYSTEM_PROMPT
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
24 |
+
top_p: float = DEFAULT_TOP_P,
|
25 |
+
top_k: int = DEFAULT_TOP_K,
|
26 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
27 |
+
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
28 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
29 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
30 |
+
) -> None:
|
31 |
+
self.temperature = temperature
|
32 |
+
self.top_p = top_p
|
33 |
+
self.top_k = top_k
|
34 |
+
self.repetition_penalty = repetition_penalty
|
35 |
+
self.presence_penalty = presence_penalty
|
36 |
+
self.no_repeat_ngram_size = no_repeat_ngram_size
|
37 |
+
self.max_new_tokens = max_new_tokens
|
38 |
+
|
39 |
+
@abstractmethod
|
40 |
+
def predict(
|
41 |
+
self,
|
42 |
+
image: str | bytes,
|
43 |
+
prompt: str = "",
|
44 |
+
temperature: Optional[float] = None,
|
45 |
+
top_p: Optional[float] = None,
|
46 |
+
top_k: Optional[int] = None,
|
47 |
+
repetition_penalty: Optional[float] = None,
|
48 |
+
presence_penalty: Optional[float] = None,
|
49 |
+
no_repeat_ngram_size: Optional[int] = None,
|
50 |
+
max_new_tokens: Optional[int] = None,
|
51 |
+
) -> str: ...
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def batch_predict(
|
55 |
+
self,
|
56 |
+
images: List[str] | List[bytes],
|
57 |
+
prompts: Union[List[str], str] = "",
|
58 |
+
temperature: Optional[float] = None,
|
59 |
+
top_p: Optional[float] = None,
|
60 |
+
top_k: Optional[int] = None,
|
61 |
+
repetition_penalty: Optional[float] = None,
|
62 |
+
presence_penalty: Optional[float] = None,
|
63 |
+
no_repeat_ngram_size: Optional[int] = None,
|
64 |
+
max_new_tokens: Optional[int] = None,
|
65 |
+
) -> List[str]: ...
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def stream_predict(
|
69 |
+
self,
|
70 |
+
image: str | bytes,
|
71 |
+
prompt: str = "",
|
72 |
+
temperature: Optional[float] = None,
|
73 |
+
top_p: Optional[float] = None,
|
74 |
+
top_k: Optional[int] = None,
|
75 |
+
repetition_penalty: Optional[float] = None,
|
76 |
+
presence_penalty: Optional[float] = None,
|
77 |
+
no_repeat_ngram_size: Optional[int] = None,
|
78 |
+
max_new_tokens: Optional[int] = None,
|
79 |
+
) -> Iterable[str]: ...
|
80 |
+
|
81 |
+
async def aio_predict(
|
82 |
+
self,
|
83 |
+
image: str | bytes,
|
84 |
+
prompt: str = "",
|
85 |
+
temperature: Optional[float] = None,
|
86 |
+
top_p: Optional[float] = None,
|
87 |
+
top_k: Optional[int] = None,
|
88 |
+
repetition_penalty: Optional[float] = None,
|
89 |
+
presence_penalty: Optional[float] = None,
|
90 |
+
no_repeat_ngram_size: Optional[int] = None,
|
91 |
+
max_new_tokens: Optional[int] = None,
|
92 |
+
) -> str:
|
93 |
+
return await asyncio.to_thread(
|
94 |
+
self.predict,
|
95 |
+
image,
|
96 |
+
prompt,
|
97 |
+
temperature,
|
98 |
+
top_p,
|
99 |
+
top_k,
|
100 |
+
repetition_penalty,
|
101 |
+
presence_penalty,
|
102 |
+
no_repeat_ngram_size,
|
103 |
+
max_new_tokens,
|
104 |
+
)
|
105 |
+
|
106 |
+
async def aio_batch_predict(
|
107 |
+
self,
|
108 |
+
images: List[str] | List[bytes],
|
109 |
+
prompts: Union[List[str], str] = "",
|
110 |
+
temperature: Optional[float] = None,
|
111 |
+
top_p: Optional[float] = None,
|
112 |
+
top_k: Optional[int] = None,
|
113 |
+
repetition_penalty: Optional[float] = None,
|
114 |
+
presence_penalty: Optional[float] = None,
|
115 |
+
no_repeat_ngram_size: Optional[int] = None,
|
116 |
+
max_new_tokens: Optional[int] = None,
|
117 |
+
) -> List[str]:
|
118 |
+
return await asyncio.to_thread(
|
119 |
+
self.batch_predict,
|
120 |
+
images,
|
121 |
+
prompts,
|
122 |
+
temperature,
|
123 |
+
top_p,
|
124 |
+
top_k,
|
125 |
+
repetition_penalty,
|
126 |
+
presence_penalty,
|
127 |
+
no_repeat_ngram_size,
|
128 |
+
max_new_tokens,
|
129 |
+
)
|
130 |
+
|
131 |
+
async def aio_stream_predict(
|
132 |
+
self,
|
133 |
+
image: str | bytes,
|
134 |
+
prompt: str = "",
|
135 |
+
temperature: Optional[float] = None,
|
136 |
+
top_p: Optional[float] = None,
|
137 |
+
top_k: Optional[int] = None,
|
138 |
+
repetition_penalty: Optional[float] = None,
|
139 |
+
presence_penalty: Optional[float] = None,
|
140 |
+
no_repeat_ngram_size: Optional[int] = None,
|
141 |
+
max_new_tokens: Optional[int] = None,
|
142 |
+
) -> AsyncIterable[str]:
|
143 |
+
queue = asyncio.Queue()
|
144 |
+
loop = asyncio.get_running_loop()
|
145 |
+
|
146 |
+
def synced_predict():
|
147 |
+
for chunk in self.stream_predict(
|
148 |
+
image=image,
|
149 |
+
prompt=prompt,
|
150 |
+
temperature=temperature,
|
151 |
+
top_p=top_p,
|
152 |
+
top_k=top_k,
|
153 |
+
repetition_penalty=repetition_penalty,
|
154 |
+
presence_penalty=presence_penalty,
|
155 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
156 |
+
max_new_tokens=max_new_tokens,
|
157 |
+
):
|
158 |
+
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
|
159 |
+
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
160 |
+
|
161 |
+
asyncio.create_task(
|
162 |
+
asyncio.to_thread(synced_predict),
|
163 |
+
)
|
164 |
+
|
165 |
+
while True:
|
166 |
+
chunk = await queue.get()
|
167 |
+
if chunk is None:
|
168 |
+
return
|
169 |
+
assert isinstance(chunk, str)
|
170 |
+
yield chunk
|
171 |
+
|
172 |
+
def build_prompt(self, prompt: str) -> str:
|
173 |
+
if prompt.startswith("<|im_start|>"):
|
174 |
+
return prompt
|
175 |
+
if not prompt:
|
176 |
+
prompt = DEFAULT_USER_PROMPT
|
177 |
+
|
178 |
+
return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
|
179 |
+
# Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
|
180 |
+
# if "Document OCR" in prompt:
|
181 |
+
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
|
182 |
+
# else:
|
183 |
+
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
|
184 |
+
|
185 |
+
def close(self):
|
186 |
+
pass
|
vendor/mineru/mineru/backend/vlm/hf_predictor.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
from typing import Iterable, List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig
|
8 |
+
|
9 |
+
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
|
10 |
+
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
|
11 |
+
from .base_predictor import (
|
12 |
+
DEFAULT_MAX_NEW_TOKENS,
|
13 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
14 |
+
DEFAULT_PRESENCE_PENALTY,
|
15 |
+
DEFAULT_REPETITION_PENALTY,
|
16 |
+
DEFAULT_TEMPERATURE,
|
17 |
+
DEFAULT_TOP_K,
|
18 |
+
DEFAULT_TOP_P,
|
19 |
+
BasePredictor,
|
20 |
+
)
|
21 |
+
from .utils import load_resource
|
22 |
+
|
23 |
+
|
24 |
+
class HuggingfacePredictor(BasePredictor):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model_path: str,
|
28 |
+
device_map="auto",
|
29 |
+
device="cuda",
|
30 |
+
torch_dtype="auto",
|
31 |
+
load_in_8bit=False,
|
32 |
+
load_in_4bit=False,
|
33 |
+
use_flash_attn=False,
|
34 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
35 |
+
top_p: float = DEFAULT_TOP_P,
|
36 |
+
top_k: int = DEFAULT_TOP_K,
|
37 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
38 |
+
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
39 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
40 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
super().__init__(
|
44 |
+
temperature=temperature,
|
45 |
+
top_p=top_p,
|
46 |
+
top_k=top_k,
|
47 |
+
repetition_penalty=repetition_penalty,
|
48 |
+
presence_penalty=presence_penalty,
|
49 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
50 |
+
max_new_tokens=max_new_tokens,
|
51 |
+
)
|
52 |
+
|
53 |
+
kwargs = {"device_map": device_map, **kwargs}
|
54 |
+
|
55 |
+
if device != "cuda":
|
56 |
+
kwargs["device_map"] = {"": device}
|
57 |
+
|
58 |
+
if load_in_8bit:
|
59 |
+
kwargs["load_in_8bit"] = True
|
60 |
+
elif load_in_4bit:
|
61 |
+
kwargs["load_in_4bit"] = True
|
62 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
63 |
+
load_in_4bit=True,
|
64 |
+
bnb_4bit_compute_dtype=torch.float16,
|
65 |
+
bnb_4bit_use_double_quant=True,
|
66 |
+
bnb_4bit_quant_type="nf4",
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
kwargs["torch_dtype"] = torch_dtype
|
70 |
+
|
71 |
+
if use_flash_attn:
|
72 |
+
kwargs["attn_implementation"] = "flash_attention_2"
|
73 |
+
|
74 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
75 |
+
self.model = Mineru2QwenForCausalLM.from_pretrained(
|
76 |
+
model_path,
|
77 |
+
low_cpu_mem_usage=True,
|
78 |
+
**kwargs,
|
79 |
+
)
|
80 |
+
setattr(self.model.config, "_name_or_path", model_path)
|
81 |
+
self.model.eval()
|
82 |
+
|
83 |
+
vision_tower = self.model.get_model().vision_tower
|
84 |
+
if device_map != "auto":
|
85 |
+
vision_tower.to(device=device_map, dtype=self.model.dtype)
|
86 |
+
|
87 |
+
self.image_processor = vision_tower.image_processor
|
88 |
+
self.eos_token_id = self.model.config.eos_token_id
|
89 |
+
|
90 |
+
def predict(
|
91 |
+
self,
|
92 |
+
image: str | bytes,
|
93 |
+
prompt: str = "",
|
94 |
+
temperature: Optional[float] = None,
|
95 |
+
top_p: Optional[float] = None,
|
96 |
+
top_k: Optional[int] = None,
|
97 |
+
repetition_penalty: Optional[float] = None,
|
98 |
+
presence_penalty: Optional[float] = None,
|
99 |
+
no_repeat_ngram_size: Optional[int] = None,
|
100 |
+
max_new_tokens: Optional[int] = None,
|
101 |
+
**kwargs,
|
102 |
+
) -> str:
|
103 |
+
prompt = self.build_prompt(prompt)
|
104 |
+
|
105 |
+
if temperature is None:
|
106 |
+
temperature = self.temperature
|
107 |
+
if top_p is None:
|
108 |
+
top_p = self.top_p
|
109 |
+
if top_k is None:
|
110 |
+
top_k = self.top_k
|
111 |
+
if repetition_penalty is None:
|
112 |
+
repetition_penalty = self.repetition_penalty
|
113 |
+
if no_repeat_ngram_size is None:
|
114 |
+
no_repeat_ngram_size = self.no_repeat_ngram_size
|
115 |
+
if max_new_tokens is None:
|
116 |
+
max_new_tokens = self.max_new_tokens
|
117 |
+
|
118 |
+
do_sample = (temperature > 0.0) and (top_k > 1)
|
119 |
+
|
120 |
+
generate_kwargs = {
|
121 |
+
"repetition_penalty": repetition_penalty,
|
122 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
123 |
+
"max_new_tokens": max_new_tokens,
|
124 |
+
"do_sample": do_sample,
|
125 |
+
}
|
126 |
+
if do_sample:
|
127 |
+
generate_kwargs["temperature"] = temperature
|
128 |
+
generate_kwargs["top_p"] = top_p
|
129 |
+
generate_kwargs["top_k"] = top_k
|
130 |
+
|
131 |
+
if isinstance(image, str):
|
132 |
+
image = load_resource(image)
|
133 |
+
|
134 |
+
image_obj = Image.open(BytesIO(image))
|
135 |
+
image_tensor = process_images([image_obj], self.image_processor, self.model.config)
|
136 |
+
image_tensor = image_tensor[0].unsqueeze(0)
|
137 |
+
image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
|
138 |
+
image_sizes = [[*image_obj.size]]
|
139 |
+
|
140 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
141 |
+
input_ids = input_ids.to(device=self.model.device)
|
142 |
+
|
143 |
+
with torch.inference_mode():
|
144 |
+
output_ids = self.model.generate(
|
145 |
+
input_ids,
|
146 |
+
images=image_tensor,
|
147 |
+
image_sizes=image_sizes,
|
148 |
+
use_cache=True,
|
149 |
+
**generate_kwargs,
|
150 |
+
**kwargs,
|
151 |
+
)
|
152 |
+
|
153 |
+
# Remove the last token if it is the eos_token_id
|
154 |
+
if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
|
155 |
+
output_ids = output_ids[:, :-1]
|
156 |
+
|
157 |
+
output = self.tokenizer.batch_decode(
|
158 |
+
output_ids,
|
159 |
+
skip_special_tokens=False,
|
160 |
+
)[0].strip()
|
161 |
+
|
162 |
+
return output
|
163 |
+
|
164 |
+
def batch_predict(
|
165 |
+
self,
|
166 |
+
images: List[str] | List[bytes],
|
167 |
+
prompts: Union[List[str], str] = "",
|
168 |
+
temperature: Optional[float] = None,
|
169 |
+
top_p: Optional[float] = None,
|
170 |
+
top_k: Optional[int] = None,
|
171 |
+
repetition_penalty: Optional[float] = None,
|
172 |
+
presence_penalty: Optional[float] = None, # not supported by hf
|
173 |
+
no_repeat_ngram_size: Optional[int] = None,
|
174 |
+
max_new_tokens: Optional[int] = None,
|
175 |
+
**kwargs,
|
176 |
+
) -> List[str]:
|
177 |
+
if not isinstance(prompts, list):
|
178 |
+
prompts = [prompts] * len(images)
|
179 |
+
|
180 |
+
assert len(prompts) == len(images), "Length of prompts and images must match."
|
181 |
+
|
182 |
+
outputs = []
|
183 |
+
for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
|
184 |
+
output = self.predict(
|
185 |
+
image,
|
186 |
+
prompt,
|
187 |
+
temperature=temperature,
|
188 |
+
top_p=top_p,
|
189 |
+
top_k=top_k,
|
190 |
+
repetition_penalty=repetition_penalty,
|
191 |
+
presence_penalty=presence_penalty,
|
192 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
193 |
+
max_new_tokens=max_new_tokens,
|
194 |
+
**kwargs,
|
195 |
+
)
|
196 |
+
outputs.append(output)
|
197 |
+
return outputs
|
198 |
+
|
199 |
+
def stream_predict(
|
200 |
+
self,
|
201 |
+
image: str | bytes,
|
202 |
+
prompt: str = "",
|
203 |
+
temperature: Optional[float] = None,
|
204 |
+
top_p: Optional[float] = None,
|
205 |
+
top_k: Optional[int] = None,
|
206 |
+
repetition_penalty: Optional[float] = None,
|
207 |
+
presence_penalty: Optional[float] = None,
|
208 |
+
no_repeat_ngram_size: Optional[int] = None,
|
209 |
+
max_new_tokens: Optional[int] = None,
|
210 |
+
) -> Iterable[str]:
|
211 |
+
raise NotImplementedError("Streaming is not supported yet.")
|
vendor/mineru/mineru/backend/vlm/predictor.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
|
3 |
+
import time
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from .base_predictor import (
|
8 |
+
DEFAULT_MAX_NEW_TOKENS,
|
9 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
10 |
+
DEFAULT_PRESENCE_PENALTY,
|
11 |
+
DEFAULT_REPETITION_PENALTY,
|
12 |
+
DEFAULT_TEMPERATURE,
|
13 |
+
DEFAULT_TOP_K,
|
14 |
+
DEFAULT_TOP_P,
|
15 |
+
BasePredictor,
|
16 |
+
)
|
17 |
+
from .sglang_client_predictor import SglangClientPredictor
|
18 |
+
|
19 |
+
hf_loaded = False
|
20 |
+
try:
|
21 |
+
from .hf_predictor import HuggingfacePredictor
|
22 |
+
|
23 |
+
hf_loaded = True
|
24 |
+
except ImportError as e:
|
25 |
+
logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
|
26 |
+
|
27 |
+
engine_loaded = False
|
28 |
+
try:
|
29 |
+
from sglang.srt.server_args import ServerArgs
|
30 |
+
|
31 |
+
from .sglang_engine_predictor import SglangEnginePredictor
|
32 |
+
|
33 |
+
engine_loaded = True
|
34 |
+
except Exception as e:
|
35 |
+
logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
|
36 |
+
|
37 |
+
|
38 |
+
def get_predictor(
|
39 |
+
backend: str = "sglang-client",
|
40 |
+
model_path: str | None = None,
|
41 |
+
server_url: str | None = None,
|
42 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
43 |
+
top_p: float = DEFAULT_TOP_P,
|
44 |
+
top_k: int = DEFAULT_TOP_K,
|
45 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
46 |
+
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
47 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
48 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
49 |
+
http_timeout: int = 600,
|
50 |
+
**kwargs,
|
51 |
+
) -> BasePredictor:
|
52 |
+
start_time = time.time()
|
53 |
+
|
54 |
+
if backend == "transformers":
|
55 |
+
if not model_path:
|
56 |
+
raise ValueError("model_path must be provided for transformers backend.")
|
57 |
+
if not hf_loaded:
|
58 |
+
raise ImportError(
|
59 |
+
"transformers is not installed, so huggingface backend cannot be used. "
|
60 |
+
"If you need to use huggingface backend, please install transformers first."
|
61 |
+
)
|
62 |
+
predictor = HuggingfacePredictor(
|
63 |
+
model_path=model_path,
|
64 |
+
temperature=temperature,
|
65 |
+
top_p=top_p,
|
66 |
+
top_k=top_k,
|
67 |
+
repetition_penalty=repetition_penalty,
|
68 |
+
presence_penalty=presence_penalty,
|
69 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
70 |
+
max_new_tokens=max_new_tokens,
|
71 |
+
**kwargs,
|
72 |
+
)
|
73 |
+
elif backend == "sglang-engine":
|
74 |
+
if not model_path:
|
75 |
+
raise ValueError("model_path must be provided for sglang-engine backend.")
|
76 |
+
if not engine_loaded:
|
77 |
+
raise ImportError(
|
78 |
+
"sglang is not installed, so sglang-engine backend cannot be used. "
|
79 |
+
"If you need to use sglang-engine backend for inference, "
|
80 |
+
"please install sglang[all]==0.4.8 or a newer version."
|
81 |
+
)
|
82 |
+
predictor = SglangEnginePredictor(
|
83 |
+
server_args=ServerArgs(model_path, **kwargs),
|
84 |
+
temperature=temperature,
|
85 |
+
top_p=top_p,
|
86 |
+
top_k=top_k,
|
87 |
+
repetition_penalty=repetition_penalty,
|
88 |
+
presence_penalty=presence_penalty,
|
89 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
90 |
+
max_new_tokens=max_new_tokens,
|
91 |
+
)
|
92 |
+
elif backend == "sglang-client":
|
93 |
+
if not server_url:
|
94 |
+
raise ValueError("server_url must be provided for sglang-client backend.")
|
95 |
+
predictor = SglangClientPredictor(
|
96 |
+
server_url=server_url,
|
97 |
+
temperature=temperature,
|
98 |
+
top_p=top_p,
|
99 |
+
top_k=top_k,
|
100 |
+
repetition_penalty=repetition_penalty,
|
101 |
+
presence_penalty=presence_penalty,
|
102 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
103 |
+
max_new_tokens=max_new_tokens,
|
104 |
+
http_timeout=http_timeout,
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
|
108 |
+
|
109 |
+
elapsed = round(time.time() - start_time, 2)
|
110 |
+
logger.info(f"get_predictor cost: {elapsed}s")
|
111 |
+
return predictor
|
vendor/mineru/mineru/backend/vlm/sglang_client_predictor.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from base64 import b64encode
|
5 |
+
from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
|
6 |
+
|
7 |
+
import httpx
|
8 |
+
|
9 |
+
from .base_predictor import (
|
10 |
+
DEFAULT_MAX_NEW_TOKENS,
|
11 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
12 |
+
DEFAULT_PRESENCE_PENALTY,
|
13 |
+
DEFAULT_REPETITION_PENALTY,
|
14 |
+
DEFAULT_TEMPERATURE,
|
15 |
+
DEFAULT_TOP_K,
|
16 |
+
DEFAULT_TOP_P,
|
17 |
+
BasePredictor,
|
18 |
+
)
|
19 |
+
from .utils import aio_load_resource, load_resource
|
20 |
+
|
21 |
+
|
22 |
+
class SglangClientPredictor(BasePredictor):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
server_url: str,
|
26 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
27 |
+
top_p: float = DEFAULT_TOP_P,
|
28 |
+
top_k: int = DEFAULT_TOP_K,
|
29 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
30 |
+
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
31 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
32 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
33 |
+
http_timeout: int = 600,
|
34 |
+
) -> None:
|
35 |
+
super().__init__(
|
36 |
+
temperature=temperature,
|
37 |
+
top_p=top_p,
|
38 |
+
top_k=top_k,
|
39 |
+
repetition_penalty=repetition_penalty,
|
40 |
+
presence_penalty=presence_penalty,
|
41 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
42 |
+
max_new_tokens=max_new_tokens,
|
43 |
+
)
|
44 |
+
self.http_timeout = http_timeout
|
45 |
+
|
46 |
+
base_url = self.get_base_url(server_url)
|
47 |
+
self.check_server_health(base_url)
|
48 |
+
self.model_path = self.get_model_path(base_url)
|
49 |
+
self.server_url = f"{base_url}/generate"
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def get_base_url(server_url: str) -> str:
|
53 |
+
matched = re.match(r"^(https?://[^/]+)", server_url)
|
54 |
+
if not matched:
|
55 |
+
raise ValueError(f"Invalid server URL: {server_url}")
|
56 |
+
return matched.group(1)
|
57 |
+
|
58 |
+
def check_server_health(self, base_url: str):
|
59 |
+
try:
|
60 |
+
response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
|
61 |
+
except httpx.ConnectError:
|
62 |
+
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
|
63 |
+
if response.status_code != 200:
|
64 |
+
raise RuntimeError(
|
65 |
+
f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
|
66 |
+
)
|
67 |
+
|
68 |
+
def get_model_path(self, base_url: str) -> str:
|
69 |
+
try:
|
70 |
+
response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
|
71 |
+
except httpx.ConnectError:
|
72 |
+
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
|
73 |
+
if response.status_code != 200:
|
74 |
+
raise RuntimeError(
|
75 |
+
f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
|
76 |
+
)
|
77 |
+
return response.json()["model_path"]
|
78 |
+
|
79 |
+
def build_sampling_params(
|
80 |
+
self,
|
81 |
+
temperature: Optional[float],
|
82 |
+
top_p: Optional[float],
|
83 |
+
top_k: Optional[int],
|
84 |
+
repetition_penalty: Optional[float],
|
85 |
+
presence_penalty: Optional[float],
|
86 |
+
no_repeat_ngram_size: Optional[int],
|
87 |
+
max_new_tokens: Optional[int],
|
88 |
+
) -> dict:
|
89 |
+
if temperature is None:
|
90 |
+
temperature = self.temperature
|
91 |
+
if top_p is None:
|
92 |
+
top_p = self.top_p
|
93 |
+
if top_k is None:
|
94 |
+
top_k = self.top_k
|
95 |
+
if repetition_penalty is None:
|
96 |
+
repetition_penalty = self.repetition_penalty
|
97 |
+
if presence_penalty is None:
|
98 |
+
presence_penalty = self.presence_penalty
|
99 |
+
if no_repeat_ngram_size is None:
|
100 |
+
no_repeat_ngram_size = self.no_repeat_ngram_size
|
101 |
+
if max_new_tokens is None:
|
102 |
+
max_new_tokens = self.max_new_tokens
|
103 |
+
|
104 |
+
# see SamplingParams for more details
|
105 |
+
return {
|
106 |
+
"temperature": temperature,
|
107 |
+
"top_p": top_p,
|
108 |
+
"top_k": top_k,
|
109 |
+
"repetition_penalty": repetition_penalty,
|
110 |
+
"presence_penalty": presence_penalty,
|
111 |
+
"custom_params": {
|
112 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
113 |
+
},
|
114 |
+
"max_new_tokens": max_new_tokens,
|
115 |
+
"skip_special_tokens": False,
|
116 |
+
}
|
117 |
+
|
118 |
+
def build_request_body(
|
119 |
+
self,
|
120 |
+
image: bytes,
|
121 |
+
prompt: str,
|
122 |
+
sampling_params: dict,
|
123 |
+
) -> dict:
|
124 |
+
image_base64 = b64encode(image).decode("utf-8")
|
125 |
+
return {
|
126 |
+
"text": prompt,
|
127 |
+
"image_data": image_base64,
|
128 |
+
"sampling_params": sampling_params,
|
129 |
+
"modalities": ["image"],
|
130 |
+
}
|
131 |
+
|
132 |
+
def predict(
|
133 |
+
self,
|
134 |
+
image: str | bytes,
|
135 |
+
prompt: str = "",
|
136 |
+
temperature: Optional[float] = None,
|
137 |
+
top_p: Optional[float] = None,
|
138 |
+
top_k: Optional[int] = None,
|
139 |
+
repetition_penalty: Optional[float] = None,
|
140 |
+
presence_penalty: Optional[float] = None,
|
141 |
+
no_repeat_ngram_size: Optional[int] = None,
|
142 |
+
max_new_tokens: Optional[int] = None,
|
143 |
+
) -> str:
|
144 |
+
prompt = self.build_prompt(prompt)
|
145 |
+
|
146 |
+
sampling_params = self.build_sampling_params(
|
147 |
+
temperature=temperature,
|
148 |
+
top_p=top_p,
|
149 |
+
top_k=top_k,
|
150 |
+
repetition_penalty=repetition_penalty,
|
151 |
+
presence_penalty=presence_penalty,
|
152 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
153 |
+
max_new_tokens=max_new_tokens,
|
154 |
+
)
|
155 |
+
|
156 |
+
if isinstance(image, str):
|
157 |
+
image = load_resource(image)
|
158 |
+
|
159 |
+
request_body = self.build_request_body(image, prompt, sampling_params)
|
160 |
+
response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
|
161 |
+
response_body = response.json()
|
162 |
+
return response_body["text"]
|
163 |
+
|
164 |
+
def batch_predict(
|
165 |
+
self,
|
166 |
+
images: List[str] | List[bytes],
|
167 |
+
prompts: Union[List[str], str] = "",
|
168 |
+
temperature: Optional[float] = None,
|
169 |
+
top_p: Optional[float] = None,
|
170 |
+
top_k: Optional[int] = None,
|
171 |
+
repetition_penalty: Optional[float] = None,
|
172 |
+
presence_penalty: Optional[float] = None,
|
173 |
+
no_repeat_ngram_size: Optional[int] = None,
|
174 |
+
max_new_tokens: Optional[int] = None,
|
175 |
+
max_concurrency: int = 100,
|
176 |
+
) -> List[str]:
|
177 |
+
try:
|
178 |
+
loop = asyncio.get_running_loop()
|
179 |
+
except RuntimeError:
|
180 |
+
loop = None
|
181 |
+
|
182 |
+
task = self.aio_batch_predict(
|
183 |
+
images=images,
|
184 |
+
prompts=prompts,
|
185 |
+
temperature=temperature,
|
186 |
+
top_p=top_p,
|
187 |
+
top_k=top_k,
|
188 |
+
repetition_penalty=repetition_penalty,
|
189 |
+
presence_penalty=presence_penalty,
|
190 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
191 |
+
max_new_tokens=max_new_tokens,
|
192 |
+
max_concurrency=max_concurrency,
|
193 |
+
)
|
194 |
+
|
195 |
+
if loop is not None:
|
196 |
+
return loop.run_until_complete(task)
|
197 |
+
else:
|
198 |
+
return asyncio.run(task)
|
199 |
+
|
200 |
+
def stream_predict(
|
201 |
+
self,
|
202 |
+
image: str | bytes,
|
203 |
+
prompt: str = "",
|
204 |
+
temperature: Optional[float] = None,
|
205 |
+
top_p: Optional[float] = None,
|
206 |
+
top_k: Optional[int] = None,
|
207 |
+
repetition_penalty: Optional[float] = None,
|
208 |
+
presence_penalty: Optional[float] = None,
|
209 |
+
no_repeat_ngram_size: Optional[int] = None,
|
210 |
+
max_new_tokens: Optional[int] = None,
|
211 |
+
) -> Iterable[str]:
|
212 |
+
prompt = self.build_prompt(prompt)
|
213 |
+
|
214 |
+
sampling_params = self.build_sampling_params(
|
215 |
+
temperature=temperature,
|
216 |
+
top_p=top_p,
|
217 |
+
top_k=top_k,
|
218 |
+
repetition_penalty=repetition_penalty,
|
219 |
+
presence_penalty=presence_penalty,
|
220 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
221 |
+
max_new_tokens=max_new_tokens,
|
222 |
+
)
|
223 |
+
|
224 |
+
if isinstance(image, str):
|
225 |
+
image = load_resource(image)
|
226 |
+
|
227 |
+
request_body = self.build_request_body(image, prompt, sampling_params)
|
228 |
+
request_body["stream"] = True
|
229 |
+
|
230 |
+
with httpx.stream(
|
231 |
+
"POST",
|
232 |
+
self.server_url,
|
233 |
+
json=request_body,
|
234 |
+
timeout=self.http_timeout,
|
235 |
+
) as response:
|
236 |
+
pos = 0
|
237 |
+
for chunk in response.iter_lines():
|
238 |
+
if not (chunk or "").startswith("data:"):
|
239 |
+
continue
|
240 |
+
if chunk == "data: [DONE]":
|
241 |
+
break
|
242 |
+
data = json.loads(chunk[5:].strip("\n"))
|
243 |
+
chunk_text = data["text"][pos:]
|
244 |
+
# meta_info = data["meta_info"]
|
245 |
+
pos += len(chunk_text)
|
246 |
+
yield chunk_text
|
247 |
+
|
248 |
+
async def aio_predict(
|
249 |
+
self,
|
250 |
+
image: str | bytes,
|
251 |
+
prompt: str = "",
|
252 |
+
temperature: Optional[float] = None,
|
253 |
+
top_p: Optional[float] = None,
|
254 |
+
top_k: Optional[int] = None,
|
255 |
+
repetition_penalty: Optional[float] = None,
|
256 |
+
presence_penalty: Optional[float] = None,
|
257 |
+
no_repeat_ngram_size: Optional[int] = None,
|
258 |
+
max_new_tokens: Optional[int] = None,
|
259 |
+
async_client: Optional[httpx.AsyncClient] = None,
|
260 |
+
) -> str:
|
261 |
+
prompt = self.build_prompt(prompt)
|
262 |
+
|
263 |
+
sampling_params = self.build_sampling_params(
|
264 |
+
temperature=temperature,
|
265 |
+
top_p=top_p,
|
266 |
+
top_k=top_k,
|
267 |
+
repetition_penalty=repetition_penalty,
|
268 |
+
presence_penalty=presence_penalty,
|
269 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
270 |
+
max_new_tokens=max_new_tokens,
|
271 |
+
)
|
272 |
+
|
273 |
+
if isinstance(image, str):
|
274 |
+
image = await aio_load_resource(image)
|
275 |
+
|
276 |
+
request_body = self.build_request_body(image, prompt, sampling_params)
|
277 |
+
|
278 |
+
if async_client is None:
|
279 |
+
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
280 |
+
response = await client.post(self.server_url, json=request_body)
|
281 |
+
response_body = response.json()
|
282 |
+
else:
|
283 |
+
response = await async_client.post(self.server_url, json=request_body)
|
284 |
+
response_body = response.json()
|
285 |
+
|
286 |
+
return response_body["text"]
|
287 |
+
|
288 |
+
async def aio_batch_predict(
|
289 |
+
self,
|
290 |
+
images: List[str] | List[bytes],
|
291 |
+
prompts: Union[List[str], str] = "",
|
292 |
+
temperature: Optional[float] = None,
|
293 |
+
top_p: Optional[float] = None,
|
294 |
+
top_k: Optional[int] = None,
|
295 |
+
repetition_penalty: Optional[float] = None,
|
296 |
+
presence_penalty: Optional[float] = None,
|
297 |
+
no_repeat_ngram_size: Optional[int] = None,
|
298 |
+
max_new_tokens: Optional[int] = None,
|
299 |
+
max_concurrency: int = 100,
|
300 |
+
) -> List[str]:
|
301 |
+
if not isinstance(prompts, list):
|
302 |
+
prompts = [prompts] * len(images)
|
303 |
+
|
304 |
+
assert len(prompts) == len(images), "Length of prompts and images must match."
|
305 |
+
|
306 |
+
semaphore = asyncio.Semaphore(max_concurrency)
|
307 |
+
outputs = [""] * len(images)
|
308 |
+
|
309 |
+
async def predict_with_semaphore(
|
310 |
+
idx: int,
|
311 |
+
image: str | bytes,
|
312 |
+
prompt: str,
|
313 |
+
async_client: httpx.AsyncClient,
|
314 |
+
):
|
315 |
+
async with semaphore:
|
316 |
+
output = await self.aio_predict(
|
317 |
+
image=image,
|
318 |
+
prompt=prompt,
|
319 |
+
temperature=temperature,
|
320 |
+
top_p=top_p,
|
321 |
+
top_k=top_k,
|
322 |
+
repetition_penalty=repetition_penalty,
|
323 |
+
presence_penalty=presence_penalty,
|
324 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
325 |
+
max_new_tokens=max_new_tokens,
|
326 |
+
async_client=async_client,
|
327 |
+
)
|
328 |
+
outputs[idx] = output
|
329 |
+
|
330 |
+
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
331 |
+
tasks = []
|
332 |
+
for idx, (prompt, image) in enumerate(zip(prompts, images)):
|
333 |
+
tasks.append(predict_with_semaphore(idx, image, prompt, client))
|
334 |
+
await asyncio.gather(*tasks)
|
335 |
+
|
336 |
+
return outputs
|
337 |
+
|
338 |
+
async def aio_batch_predict_as_iter(
|
339 |
+
self,
|
340 |
+
images: List[str] | List[bytes],
|
341 |
+
prompts: Union[List[str], str] = "",
|
342 |
+
temperature: Optional[float] = None,
|
343 |
+
top_p: Optional[float] = None,
|
344 |
+
top_k: Optional[int] = None,
|
345 |
+
repetition_penalty: Optional[float] = None,
|
346 |
+
presence_penalty: Optional[float] = None,
|
347 |
+
no_repeat_ngram_size: Optional[int] = None,
|
348 |
+
max_new_tokens: Optional[int] = None,
|
349 |
+
max_concurrency: int = 100,
|
350 |
+
) -> AsyncIterable[Tuple[int, str]]:
|
351 |
+
if not isinstance(prompts, list):
|
352 |
+
prompts = [prompts] * len(images)
|
353 |
+
|
354 |
+
assert len(prompts) == len(images), "Length of prompts and images must match."
|
355 |
+
|
356 |
+
semaphore = asyncio.Semaphore(max_concurrency)
|
357 |
+
|
358 |
+
async def predict_with_semaphore(
|
359 |
+
idx: int,
|
360 |
+
image: str | bytes,
|
361 |
+
prompt: str,
|
362 |
+
async_client: httpx.AsyncClient,
|
363 |
+
):
|
364 |
+
async with semaphore:
|
365 |
+
output = await self.aio_predict(
|
366 |
+
image=image,
|
367 |
+
prompt=prompt,
|
368 |
+
temperature=temperature,
|
369 |
+
top_p=top_p,
|
370 |
+
top_k=top_k,
|
371 |
+
repetition_penalty=repetition_penalty,
|
372 |
+
presence_penalty=presence_penalty,
|
373 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
374 |
+
max_new_tokens=max_new_tokens,
|
375 |
+
async_client=async_client,
|
376 |
+
)
|
377 |
+
return (idx, output)
|
378 |
+
|
379 |
+
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
380 |
+
pending: Set[asyncio.Task[Tuple[int, str]]] = set()
|
381 |
+
|
382 |
+
for idx, (prompt, image) in enumerate(zip(prompts, images)):
|
383 |
+
pending.add(
|
384 |
+
asyncio.create_task(
|
385 |
+
predict_with_semaphore(idx, image, prompt, client),
|
386 |
+
)
|
387 |
+
)
|
388 |
+
|
389 |
+
while len(pending) > 0:
|
390 |
+
done, pending = await asyncio.wait(
|
391 |
+
pending,
|
392 |
+
return_when=asyncio.FIRST_COMPLETED,
|
393 |
+
)
|
394 |
+
for task in done:
|
395 |
+
yield task.result()
|
396 |
+
|
397 |
+
async def aio_stream_predict(
|
398 |
+
self,
|
399 |
+
image: str | bytes,
|
400 |
+
prompt: str = "",
|
401 |
+
temperature: Optional[float] = None,
|
402 |
+
top_p: Optional[float] = None,
|
403 |
+
top_k: Optional[int] = None,
|
404 |
+
repetition_penalty: Optional[float] = None,
|
405 |
+
presence_penalty: Optional[float] = None,
|
406 |
+
no_repeat_ngram_size: Optional[int] = None,
|
407 |
+
max_new_tokens: Optional[int] = None,
|
408 |
+
) -> AsyncIterable[str]:
|
409 |
+
prompt = self.build_prompt(prompt)
|
410 |
+
|
411 |
+
sampling_params = self.build_sampling_params(
|
412 |
+
temperature=temperature,
|
413 |
+
top_p=top_p,
|
414 |
+
top_k=top_k,
|
415 |
+
repetition_penalty=repetition_penalty,
|
416 |
+
presence_penalty=presence_penalty,
|
417 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
418 |
+
max_new_tokens=max_new_tokens,
|
419 |
+
)
|
420 |
+
|
421 |
+
if isinstance(image, str):
|
422 |
+
image = await aio_load_resource(image)
|
423 |
+
|
424 |
+
request_body = self.build_request_body(image, prompt, sampling_params)
|
425 |
+
request_body["stream"] = True
|
426 |
+
|
427 |
+
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
428 |
+
async with client.stream(
|
429 |
+
"POST",
|
430 |
+
self.server_url,
|
431 |
+
json=request_body,
|
432 |
+
) as response:
|
433 |
+
pos = 0
|
434 |
+
async for chunk in response.aiter_lines():
|
435 |
+
if not (chunk or "").startswith("data:"):
|
436 |
+
continue
|
437 |
+
if chunk == "data: [DONE]":
|
438 |
+
break
|
439 |
+
data = json.loads(chunk[5:].strip("\n"))
|
440 |
+
chunk_text = data["text"][pos:]
|
441 |
+
# meta_info = data["meta_info"]
|
442 |
+
pos += len(chunk_text)
|
443 |
+
yield chunk_text
|
vendor/mineru/mineru/backend/vlm/sglang_engine_predictor.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from base64 import b64encode
|
2 |
+
from typing import AsyncIterable, Iterable, List, Optional, Union
|
3 |
+
|
4 |
+
from sglang.srt.server_args import ServerArgs
|
5 |
+
|
6 |
+
from ...model.vlm_sglang_model.engine import BatchEngine
|
7 |
+
from .base_predictor import (
|
8 |
+
DEFAULT_MAX_NEW_TOKENS,
|
9 |
+
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
10 |
+
DEFAULT_PRESENCE_PENALTY,
|
11 |
+
DEFAULT_REPETITION_PENALTY,
|
12 |
+
DEFAULT_TEMPERATURE,
|
13 |
+
DEFAULT_TOP_K,
|
14 |
+
DEFAULT_TOP_P,
|
15 |
+
BasePredictor,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class SglangEnginePredictor(BasePredictor):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
server_args: ServerArgs,
|
23 |
+
temperature: float = DEFAULT_TEMPERATURE,
|
24 |
+
top_p: float = DEFAULT_TOP_P,
|
25 |
+
top_k: int = DEFAULT_TOP_K,
|
26 |
+
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
27 |
+
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
28 |
+
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
29 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
30 |
+
) -> None:
|
31 |
+
super().__init__(
|
32 |
+
temperature=temperature,
|
33 |
+
top_p=top_p,
|
34 |
+
top_k=top_k,
|
35 |
+
repetition_penalty=repetition_penalty,
|
36 |
+
presence_penalty=presence_penalty,
|
37 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
38 |
+
max_new_tokens=max_new_tokens,
|
39 |
+
)
|
40 |
+
self.engine = BatchEngine(server_args=server_args)
|
41 |
+
|
42 |
+
def load_image_string(self, image: str | bytes) -> str:
|
43 |
+
if not isinstance(image, (str, bytes)):
|
44 |
+
raise ValueError("Image must be a string or bytes.")
|
45 |
+
if isinstance(image, bytes):
|
46 |
+
return b64encode(image).decode("utf-8")
|
47 |
+
if image.startswith("file://"):
|
48 |
+
return image[len("file://") :]
|
49 |
+
return image
|
50 |
+
|
51 |
+
def predict(
|
52 |
+
self,
|
53 |
+
image: str | bytes,
|
54 |
+
prompt: str = "",
|
55 |
+
temperature: Optional[float] = None,
|
56 |
+
top_p: Optional[float] = None,
|
57 |
+
top_k: Optional[int] = None,
|
58 |
+
repetition_penalty: Optional[float] = None,
|
59 |
+
presence_penalty: Optional[float] = None,
|
60 |
+
no_repeat_ngram_size: Optional[int] = None,
|
61 |
+
max_new_tokens: Optional[int] = None,
|
62 |
+
) -> str:
|
63 |
+
return self.batch_predict(
|
64 |
+
[image], # type: ignore
|
65 |
+
[prompt],
|
66 |
+
temperature=temperature,
|
67 |
+
top_p=top_p,
|
68 |
+
top_k=top_k,
|
69 |
+
repetition_penalty=repetition_penalty,
|
70 |
+
presence_penalty=presence_penalty,
|
71 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
72 |
+
max_new_tokens=max_new_tokens,
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
def batch_predict(
|
76 |
+
self,
|
77 |
+
images: List[str] | List[bytes],
|
78 |
+
prompts: Union[List[str], str] = "",
|
79 |
+
temperature: Optional[float] = None,
|
80 |
+
top_p: Optional[float] = None,
|
81 |
+
top_k: Optional[int] = None,
|
82 |
+
repetition_penalty: Optional[float] = None,
|
83 |
+
presence_penalty: Optional[float] = None,
|
84 |
+
no_repeat_ngram_size: Optional[int] = None,
|
85 |
+
max_new_tokens: Optional[int] = None,
|
86 |
+
) -> List[str]:
|
87 |
+
|
88 |
+
if not isinstance(prompts, list):
|
89 |
+
prompts = [prompts] * len(images)
|
90 |
+
|
91 |
+
assert len(prompts) == len(images), "Length of prompts and images must match."
|
92 |
+
prompts = [self.build_prompt(prompt) for prompt in prompts]
|
93 |
+
|
94 |
+
if temperature is None:
|
95 |
+
temperature = self.temperature
|
96 |
+
if top_p is None:
|
97 |
+
top_p = self.top_p
|
98 |
+
if top_k is None:
|
99 |
+
top_k = self.top_k
|
100 |
+
if repetition_penalty is None:
|
101 |
+
repetition_penalty = self.repetition_penalty
|
102 |
+
if presence_penalty is None:
|
103 |
+
presence_penalty = self.presence_penalty
|
104 |
+
if no_repeat_ngram_size is None:
|
105 |
+
no_repeat_ngram_size = self.no_repeat_ngram_size
|
106 |
+
if max_new_tokens is None:
|
107 |
+
max_new_tokens = self.max_new_tokens
|
108 |
+
|
109 |
+
# see SamplingParams for more details
|
110 |
+
sampling_params = {
|
111 |
+
"temperature": temperature,
|
112 |
+
"top_p": top_p,
|
113 |
+
"top_k": top_k,
|
114 |
+
"repetition_penalty": repetition_penalty,
|
115 |
+
"presence_penalty": presence_penalty,
|
116 |
+
"custom_params": {
|
117 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
118 |
+
},
|
119 |
+
"max_new_tokens": max_new_tokens,
|
120 |
+
"skip_special_tokens": False,
|
121 |
+
}
|
122 |
+
|
123 |
+
image_strings = [self.load_image_string(img) for img in images]
|
124 |
+
|
125 |
+
output = self.engine.generate(
|
126 |
+
prompt=prompts,
|
127 |
+
image_data=image_strings,
|
128 |
+
sampling_params=sampling_params,
|
129 |
+
)
|
130 |
+
return [item["text"] for item in output]
|
131 |
+
|
132 |
+
def stream_predict(
|
133 |
+
self,
|
134 |
+
image: str | bytes,
|
135 |
+
prompt: str = "",
|
136 |
+
temperature: Optional[float] = None,
|
137 |
+
top_p: Optional[float] = None,
|
138 |
+
top_k: Optional[int] = None,
|
139 |
+
repetition_penalty: Optional[float] = None,
|
140 |
+
presence_penalty: Optional[float] = None,
|
141 |
+
no_repeat_ngram_size: Optional[int] = None,
|
142 |
+
max_new_tokens: Optional[int] = None,
|
143 |
+
) -> Iterable[str]:
|
144 |
+
raise NotImplementedError("Streaming is not supported yet.")
|
145 |
+
|
146 |
+
async def aio_predict(
|
147 |
+
self,
|
148 |
+
image: str | bytes,
|
149 |
+
prompt: str = "",
|
150 |
+
temperature: Optional[float] = None,
|
151 |
+
top_p: Optional[float] = None,
|
152 |
+
top_k: Optional[int] = None,
|
153 |
+
repetition_penalty: Optional[float] = None,
|
154 |
+
presence_penalty: Optional[float] = None,
|
155 |
+
no_repeat_ngram_size: Optional[int] = None,
|
156 |
+
max_new_tokens: Optional[int] = None,
|
157 |
+
) -> str:
|
158 |
+
output = await self.aio_batch_predict(
|
159 |
+
[image], # type: ignore
|
160 |
+
[prompt],
|
161 |
+
temperature=temperature,
|
162 |
+
top_p=top_p,
|
163 |
+
top_k=top_k,
|
164 |
+
repetition_penalty=repetition_penalty,
|
165 |
+
presence_penalty=presence_penalty,
|
166 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
167 |
+
max_new_tokens=max_new_tokens,
|
168 |
+
)
|
169 |
+
return output[0]
|
170 |
+
|
171 |
+
async def aio_batch_predict(
|
172 |
+
self,
|
173 |
+
images: List[str] | List[bytes],
|
174 |
+
prompts: Union[List[str], str] = "",
|
175 |
+
temperature: Optional[float] = None,
|
176 |
+
top_p: Optional[float] = None,
|
177 |
+
top_k: Optional[int] = None,
|
178 |
+
repetition_penalty: Optional[float] = None,
|
179 |
+
presence_penalty: Optional[float] = None,
|
180 |
+
no_repeat_ngram_size: Optional[int] = None,
|
181 |
+
max_new_tokens: Optional[int] = None,
|
182 |
+
) -> List[str]:
|
183 |
+
|
184 |
+
if not isinstance(prompts, list):
|
185 |
+
prompts = [prompts] * len(images)
|
186 |
+
|
187 |
+
assert len(prompts) == len(images), "Length of prompts and images must match."
|
188 |
+
prompts = [self.build_prompt(prompt) for prompt in prompts]
|
189 |
+
|
190 |
+
if temperature is None:
|
191 |
+
temperature = self.temperature
|
192 |
+
if top_p is None:
|
193 |
+
top_p = self.top_p
|
194 |
+
if top_k is None:
|
195 |
+
top_k = self.top_k
|
196 |
+
if repetition_penalty is None:
|
197 |
+
repetition_penalty = self.repetition_penalty
|
198 |
+
if presence_penalty is None:
|
199 |
+
presence_penalty = self.presence_penalty
|
200 |
+
if no_repeat_ngram_size is None:
|
201 |
+
no_repeat_ngram_size = self.no_repeat_ngram_size
|
202 |
+
if max_new_tokens is None:
|
203 |
+
max_new_tokens = self.max_new_tokens
|
204 |
+
|
205 |
+
# see SamplingParams for more details
|
206 |
+
sampling_params = {
|
207 |
+
"temperature": temperature,
|
208 |
+
"top_p": top_p,
|
209 |
+
"top_k": top_k,
|
210 |
+
"repetition_penalty": repetition_penalty,
|
211 |
+
"presence_penalty": presence_penalty,
|
212 |
+
"custom_params": {
|
213 |
+
"no_repeat_ngram_size": no_repeat_ngram_size,
|
214 |
+
},
|
215 |
+
"max_new_tokens": max_new_tokens,
|
216 |
+
"skip_special_tokens": False,
|
217 |
+
}
|
218 |
+
|
219 |
+
image_strings = [self.load_image_string(img) for img in images]
|
220 |
+
|
221 |
+
output = await self.engine.async_generate(
|
222 |
+
prompt=prompts,
|
223 |
+
image_data=image_strings,
|
224 |
+
sampling_params=sampling_params,
|
225 |
+
)
|
226 |
+
ret = []
|
227 |
+
for item in output: # type: ignore
|
228 |
+
ret.append(item["text"])
|
229 |
+
return ret
|
230 |
+
|
231 |
+
async def aio_stream_predict(
|
232 |
+
self,
|
233 |
+
image: str | bytes,
|
234 |
+
prompt: str = "",
|
235 |
+
temperature: Optional[float] = None,
|
236 |
+
top_p: Optional[float] = None,
|
237 |
+
top_k: Optional[int] = None,
|
238 |
+
repetition_penalty: Optional[float] = None,
|
239 |
+
presence_penalty: Optional[float] = None,
|
240 |
+
no_repeat_ngram_size: Optional[int] = None,
|
241 |
+
max_new_tokens: Optional[int] = None,
|
242 |
+
) -> AsyncIterable[str]:
|
243 |
+
raise NotImplementedError("Streaming is not supported yet.")
|
244 |
+
|
245 |
+
def close(self):
|
246 |
+
self.engine.shutdown()
|
vendor/mineru/mineru/backend/vlm/token_to_middle_json.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from mineru.backend.pipeline.model_init import AtomModelSingleton
|
7 |
+
from mineru.utils.config_reader import get_llm_aided_config
|
8 |
+
from mineru.utils.cut_image import cut_image_and_table
|
9 |
+
from mineru.utils.enum_class import ContentType
|
10 |
+
from mineru.utils.hash_utils import str_md5
|
11 |
+
from mineru.backend.vlm.vlm_magic_model import MagicModel
|
12 |
+
from mineru.utils.llm_aided import llm_aided_title
|
13 |
+
from mineru.utils.pdf_image_tools import get_crop_img
|
14 |
+
from mineru.version import __version__
|
15 |
+
|
16 |
+
|
17 |
+
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
|
18 |
+
"""将token转换为页面信息"""
|
19 |
+
# 解析token,提取坐标和类型
|
20 |
+
# 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
|
21 |
+
# 这里需要根据实际的token格式进行解析
|
22 |
+
# 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
|
23 |
+
|
24 |
+
scale = image_dict["scale"]
|
25 |
+
page_pil_img = image_dict["img_pil"]
|
26 |
+
page_img_md5 = str_md5(image_dict["img_base64"])
|
27 |
+
width, height = map(int, page.get_size())
|
28 |
+
|
29 |
+
magic_model = MagicModel(token, width, height)
|
30 |
+
image_blocks = magic_model.get_image_blocks()
|
31 |
+
table_blocks = magic_model.get_table_blocks()
|
32 |
+
title_blocks = magic_model.get_title_blocks()
|
33 |
+
|
34 |
+
# 如果有标题优化需求,则对title_blocks截图det
|
35 |
+
llm_aided_config = get_llm_aided_config()
|
36 |
+
if llm_aided_config is not None:
|
37 |
+
title_aided_config = llm_aided_config.get('title_aided', None)
|
38 |
+
if title_aided_config is not None:
|
39 |
+
if title_aided_config.get('enable', False):
|
40 |
+
atom_model_manager = AtomModelSingleton()
|
41 |
+
ocr_model = atom_model_manager.get_atom_model(
|
42 |
+
atom_model_name='ocr',
|
43 |
+
ocr_show_log=False,
|
44 |
+
det_db_box_thresh=0.3,
|
45 |
+
lang='ch_lite'
|
46 |
+
)
|
47 |
+
for title_block in title_blocks:
|
48 |
+
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
|
49 |
+
title_np_img = np.array(title_pil_img)
|
50 |
+
# 给title_pil_img添加上下左右各50像素白边padding
|
51 |
+
title_np_img = cv2.copyMakeBorder(
|
52 |
+
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
|
53 |
+
)
|
54 |
+
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
|
55 |
+
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
|
56 |
+
if len(ocr_det_res) > 0:
|
57 |
+
# 计算所有res的平均高度
|
58 |
+
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
|
59 |
+
title_block['line_avg_height'] = round(avg_height/scale)
|
60 |
+
|
61 |
+
text_blocks = magic_model.get_text_blocks()
|
62 |
+
interline_equation_blocks = magic_model.get_interline_equation_blocks()
|
63 |
+
|
64 |
+
all_spans = magic_model.get_all_spans()
|
65 |
+
# 对image/table/interline_equation的span截图
|
66 |
+
for span in all_spans:
|
67 |
+
if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
|
68 |
+
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
|
69 |
+
|
70 |
+
page_blocks = []
|
71 |
+
page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
|
72 |
+
# 对page_blocks根据index的值进行排序
|
73 |
+
page_blocks.sort(key=lambda x: x["index"])
|
74 |
+
|
75 |
+
page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
|
76 |
+
return page_info
|
77 |
+
|
78 |
+
|
79 |
+
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
|
80 |
+
middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
|
81 |
+
for index, token in enumerate(token_list):
|
82 |
+
page = pdf_doc[index]
|
83 |
+
image_dict = images_list[index]
|
84 |
+
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
|
85 |
+
middle_json["pdf_info"].append(page_info)
|
86 |
+
|
87 |
+
"""llm优化"""
|
88 |
+
llm_aided_config = get_llm_aided_config()
|
89 |
+
|
90 |
+
if llm_aided_config is not None:
|
91 |
+
"""标题优化"""
|
92 |
+
title_aided_config = llm_aided_config.get('title_aided', None)
|
93 |
+
if title_aided_config is not None:
|
94 |
+
if title_aided_config.get('enable', False):
|
95 |
+
llm_aided_title_start_time = time.time()
|
96 |
+
llm_aided_title(middle_json["pdf_info"], title_aided_config)
|
97 |
+
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
|
98 |
+
|
99 |
+
# 关闭pdf文档
|
100 |
+
pdf_doc.close()
|
101 |
+
return middle_json
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
|
106 |
+
output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|><|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
|
107 |
+
|
108 |
+
p_info = token_to_page_info(output)
|
109 |
+
# 将blocks 转换为json文本
|
110 |
+
import json
|
111 |
+
|
112 |
+
json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
|
113 |
+
print(json_str)
|
vendor/mineru/mineru/backend/vlm/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from base64 import b64decode
|
4 |
+
|
5 |
+
import httpx
|
6 |
+
|
7 |
+
_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
8 |
+
_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
|
9 |
+
_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
|
10 |
+
|
11 |
+
|
12 |
+
def load_resource(uri: str) -> bytes:
|
13 |
+
if uri.startswith("http://") or uri.startswith("https://"):
|
14 |
+
response = httpx.get(uri, timeout=_timeout)
|
15 |
+
return response.content
|
16 |
+
if uri.startswith("file://"):
|
17 |
+
with open(uri[len("file://") :], "rb") as file:
|
18 |
+
return file.read()
|
19 |
+
if uri.lower().endswith(_file_exts):
|
20 |
+
with open(uri, "rb") as file:
|
21 |
+
return file.read()
|
22 |
+
if re.match(_data_uri_regex, uri):
|
23 |
+
return b64decode(uri.split(",")[1])
|
24 |
+
return b64decode(uri)
|
25 |
+
|
26 |
+
|
27 |
+
async def aio_load_resource(uri: str) -> bytes:
|
28 |
+
if uri.startswith("http://") or uri.startswith("https://"):
|
29 |
+
async with httpx.AsyncClient(timeout=_timeout) as client:
|
30 |
+
response = await client.get(uri)
|
31 |
+
return response.content
|
32 |
+
if uri.startswith("file://"):
|
33 |
+
with open(uri[len("file://") :], "rb") as file:
|
34 |
+
return file.read()
|
35 |
+
if uri.lower().endswith(_file_exts):
|
36 |
+
with open(uri, "rb") as file:
|
37 |
+
return file.read()
|
38 |
+
if re.match(_data_uri_regex, uri):
|
39 |
+
return b64decode(uri.split(",")[1])
|
40 |
+
return b64decode(uri)
|
vendor/mineru/mineru/backend/vlm/vlm_analyze.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
import time
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from ...data.data_reader_writer import DataWriter
|
7 |
+
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
8 |
+
from .base_predictor import BasePredictor
|
9 |
+
from .predictor import get_predictor
|
10 |
+
from .token_to_middle_json import result_to_middle_json
|
11 |
+
from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
12 |
+
|
13 |
+
|
14 |
+
class ModelSingleton:
|
15 |
+
_instance = None
|
16 |
+
_models = {}
|
17 |
+
|
18 |
+
def __new__(cls, *args, **kwargs):
|
19 |
+
if cls._instance is None:
|
20 |
+
cls._instance = super().__new__(cls)
|
21 |
+
return cls._instance
|
22 |
+
|
23 |
+
def get_model(
|
24 |
+
self,
|
25 |
+
backend: str,
|
26 |
+
model_path: str | None,
|
27 |
+
server_url: str | None,
|
28 |
+
**kwargs,
|
29 |
+
) -> BasePredictor:
|
30 |
+
key = (backend, model_path, server_url)
|
31 |
+
if key not in self._models:
|
32 |
+
if backend in ['transformers', 'sglang-engine'] and not model_path:
|
33 |
+
model_path = auto_download_and_get_model_root_path("/","vlm")
|
34 |
+
self._models[key] = get_predictor(
|
35 |
+
backend=backend,
|
36 |
+
model_path=model_path,
|
37 |
+
server_url=server_url,
|
38 |
+
**kwargs,
|
39 |
+
)
|
40 |
+
return self._models[key]
|
41 |
+
|
42 |
+
|
43 |
+
def doc_analyze(
|
44 |
+
pdf_bytes,
|
45 |
+
image_writer: DataWriter | None,
|
46 |
+
predictor: BasePredictor | None = None,
|
47 |
+
backend="transformers",
|
48 |
+
model_path: str | None = None,
|
49 |
+
server_url: str | None = None,
|
50 |
+
**kwargs,
|
51 |
+
):
|
52 |
+
if predictor is None:
|
53 |
+
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
54 |
+
|
55 |
+
# load_images_start = time.time()
|
56 |
+
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
|
57 |
+
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
|
58 |
+
# load_images_time = round(time.time() - load_images_start, 2)
|
59 |
+
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
|
60 |
+
|
61 |
+
# infer_start = time.time()
|
62 |
+
results = predictor.batch_predict(images=images_base64_list)
|
63 |
+
# infer_time = round(time.time() - infer_start, 2)
|
64 |
+
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
65 |
+
|
66 |
+
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
|
67 |
+
return middle_json, results
|
68 |
+
|
69 |
+
|
70 |
+
async def aio_doc_analyze(
|
71 |
+
pdf_bytes,
|
72 |
+
image_writer: DataWriter | None,
|
73 |
+
predictor: BasePredictor | None = None,
|
74 |
+
backend="transformers",
|
75 |
+
model_path: str | None = None,
|
76 |
+
server_url: str | None = None,
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
if predictor is None:
|
80 |
+
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
81 |
+
|
82 |
+
# load_images_start = time.time()
|
83 |
+
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
|
84 |
+
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
|
85 |
+
# load_images_time = round(time.time() - load_images_start, 2)
|
86 |
+
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
|
87 |
+
|
88 |
+
# infer_start = time.time()
|
89 |
+
results = await predictor.aio_batch_predict(images=images_base64_list)
|
90 |
+
# infer_time = round(time.time() - infer_start, 2)
|
91 |
+
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
92 |
+
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
|
93 |
+
return middle_json, results
|
vendor/mineru/mineru/backend/vlm/vlm_magic_model.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Literal
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
from mineru.utils.boxbase import bbox_distance, is_in
|
7 |
+
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
|
8 |
+
from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
|
9 |
+
from mineru.utils.format_utils import convert_otsl_to_html
|
10 |
+
|
11 |
+
|
12 |
+
class MagicModel:
|
13 |
+
def __init__(self, token: str, width, height):
|
14 |
+
self.token = token
|
15 |
+
|
16 |
+
# 使用正则表达式查找所有块
|
17 |
+
pattern = (
|
18 |
+
r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
|
19 |
+
)
|
20 |
+
block_infos = re.findall(pattern, token, re.DOTALL)
|
21 |
+
|
22 |
+
blocks = []
|
23 |
+
self.all_spans = []
|
24 |
+
# 解析每个块
|
25 |
+
for index, block_info in enumerate(block_infos):
|
26 |
+
block_bbox = block_info[0].strip()
|
27 |
+
try:
|
28 |
+
x1, y1, x2, y2 = map(int, block_bbox.split())
|
29 |
+
x_1, y_1, x_2, y_2 = (
|
30 |
+
int(x1 * width / 1000),
|
31 |
+
int(y1 * height / 1000),
|
32 |
+
int(x2 * width / 1000),
|
33 |
+
int(y2 * height / 1000),
|
34 |
+
)
|
35 |
+
if x_2 < x_1:
|
36 |
+
x_1, x_2 = x_2, x_1
|
37 |
+
if y_2 < y_1:
|
38 |
+
y_1, y_2 = y_2, y_1
|
39 |
+
block_bbox = (x_1, y_1, x_2, y_2)
|
40 |
+
block_type = block_info[1].strip()
|
41 |
+
block_content = block_info[2].strip()
|
42 |
+
|
43 |
+
# print(f"坐标: {block_bbox}")
|
44 |
+
# print(f"类型: {block_type}")
|
45 |
+
# print(f"内容: {block_content}")
|
46 |
+
# print("-" * 50)
|
47 |
+
except Exception as e:
|
48 |
+
# 如果解析失败,可能是因为格式不正确,跳过这个块
|
49 |
+
logger.warning(f"Invalid block format: {block_info}, error: {e}")
|
50 |
+
continue
|
51 |
+
|
52 |
+
span_type = "unknown"
|
53 |
+
if block_type in [
|
54 |
+
"text",
|
55 |
+
"title",
|
56 |
+
"image_caption",
|
57 |
+
"image_footnote",
|
58 |
+
"table_caption",
|
59 |
+
"table_footnote",
|
60 |
+
"list",
|
61 |
+
"index",
|
62 |
+
]:
|
63 |
+
span_type = ContentType.TEXT
|
64 |
+
elif block_type in ["image"]:
|
65 |
+
block_type = BlockType.IMAGE_BODY
|
66 |
+
span_type = ContentType.IMAGE
|
67 |
+
elif block_type in ["table"]:
|
68 |
+
block_type = BlockType.TABLE_BODY
|
69 |
+
span_type = ContentType.TABLE
|
70 |
+
elif block_type in ["equation"]:
|
71 |
+
block_type = BlockType.INTERLINE_EQUATION
|
72 |
+
span_type = ContentType.INTERLINE_EQUATION
|
73 |
+
|
74 |
+
if span_type in ["image", "table"]:
|
75 |
+
span = {
|
76 |
+
"bbox": block_bbox,
|
77 |
+
"type": span_type,
|
78 |
+
}
|
79 |
+
if span_type == ContentType.TABLE:
|
80 |
+
if "<fcel>" in block_content or "<ecel>" in block_content:
|
81 |
+
lines = block_content.split("\n\n")
|
82 |
+
new_lines = []
|
83 |
+
for line in lines:
|
84 |
+
if "<fcel>" in line or "<ecel>" in line:
|
85 |
+
line = convert_otsl_to_html(line)
|
86 |
+
new_lines.append(line)
|
87 |
+
span["html"] = "\n\n".join(new_lines)
|
88 |
+
else:
|
89 |
+
span["html"] = block_content
|
90 |
+
elif span_type in [ContentType.INTERLINE_EQUATION]:
|
91 |
+
span = {
|
92 |
+
"bbox": block_bbox,
|
93 |
+
"type": span_type,
|
94 |
+
"content": isolated_formula_clean(block_content),
|
95 |
+
}
|
96 |
+
else:
|
97 |
+
if block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
|
98 |
+
# 生成包含文本和公式的span列表
|
99 |
+
spans = []
|
100 |
+
last_end = 0
|
101 |
+
|
102 |
+
# 查找所有公式
|
103 |
+
for match in re.finditer(r'\\\((.+?)\\\)', block_content):
|
104 |
+
start, end = match.span()
|
105 |
+
|
106 |
+
# 添加公式前的文本
|
107 |
+
if start > last_end:
|
108 |
+
text_before = block_content[last_end:start]
|
109 |
+
if text_before.strip():
|
110 |
+
spans.append({
|
111 |
+
"bbox": block_bbox,
|
112 |
+
"type": ContentType.TEXT,
|
113 |
+
"content": text_before
|
114 |
+
})
|
115 |
+
|
116 |
+
# 添加公式(去除\(和\))
|
117 |
+
formula = match.group(1)
|
118 |
+
spans.append({
|
119 |
+
"bbox": block_bbox,
|
120 |
+
"type": ContentType.INLINE_EQUATION,
|
121 |
+
"content": formula.strip()
|
122 |
+
})
|
123 |
+
|
124 |
+
last_end = end
|
125 |
+
|
126 |
+
# 添加最后一个公式后的文本
|
127 |
+
if last_end < len(block_content):
|
128 |
+
text_after = block_content[last_end:]
|
129 |
+
if text_after.strip():
|
130 |
+
spans.append({
|
131 |
+
"bbox": block_bbox,
|
132 |
+
"type": ContentType.TEXT,
|
133 |
+
"content": text_after
|
134 |
+
})
|
135 |
+
|
136 |
+
span = spans
|
137 |
+
else:
|
138 |
+
span = {
|
139 |
+
"bbox": block_bbox,
|
140 |
+
"type": span_type,
|
141 |
+
"content": block_content,
|
142 |
+
}
|
143 |
+
|
144 |
+
if isinstance(span, dict) and "bbox" in span:
|
145 |
+
self.all_spans.append(span)
|
146 |
+
line = {
|
147 |
+
"bbox": block_bbox,
|
148 |
+
"spans": [span],
|
149 |
+
}
|
150 |
+
elif isinstance(span, list):
|
151 |
+
self.all_spans.extend(span)
|
152 |
+
line = {
|
153 |
+
"bbox": block_bbox,
|
154 |
+
"spans": span,
|
155 |
+
}
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
|
158 |
+
|
159 |
+
blocks.append(
|
160 |
+
{
|
161 |
+
"bbox": block_bbox,
|
162 |
+
"type": block_type,
|
163 |
+
"lines": [line],
|
164 |
+
"index": index,
|
165 |
+
}
|
166 |
+
)
|
167 |
+
|
168 |
+
self.image_blocks = []
|
169 |
+
self.table_blocks = []
|
170 |
+
self.interline_equation_blocks = []
|
171 |
+
self.text_blocks = []
|
172 |
+
self.title_blocks = []
|
173 |
+
for block in blocks:
|
174 |
+
if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
|
175 |
+
self.image_blocks.append(block)
|
176 |
+
elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
|
177 |
+
self.table_blocks.append(block)
|
178 |
+
elif block["type"] == BlockType.INTERLINE_EQUATION:
|
179 |
+
self.interline_equation_blocks.append(block)
|
180 |
+
elif block["type"] == BlockType.TEXT:
|
181 |
+
self.text_blocks.append(block)
|
182 |
+
elif block["type"] == BlockType.TITLE:
|
183 |
+
self.title_blocks.append(block)
|
184 |
+
else:
|
185 |
+
continue
|
186 |
+
|
187 |
+
def get_image_blocks(self):
|
188 |
+
return fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
|
189 |
+
|
190 |
+
def get_table_blocks(self):
|
191 |
+
return fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
|
192 |
+
|
193 |
+
def get_title_blocks(self):
|
194 |
+
return fix_title_blocks(self.title_blocks)
|
195 |
+
|
196 |
+
def get_text_blocks(self):
|
197 |
+
return fix_text_blocks(self.text_blocks)
|
198 |
+
|
199 |
+
def get_interline_equation_blocks(self):
|
200 |
+
return self.interline_equation_blocks
|
201 |
+
|
202 |
+
def get_all_spans(self):
|
203 |
+
return self.all_spans
|
204 |
+
|
205 |
+
|
206 |
+
def isolated_formula_clean(txt):
|
207 |
+
latex = txt[:]
|
208 |
+
if latex.startswith("\\["): latex = latex[2:]
|
209 |
+
if latex.endswith("\\]"): latex = latex[:-2]
|
210 |
+
latex = latex_fix(latex.strip())
|
211 |
+
return latex
|
212 |
+
|
213 |
+
|
214 |
+
def latex_fix(latex):
|
215 |
+
# valid pairs:
|
216 |
+
# \left\{ ... \right\}
|
217 |
+
# \left( ... \right)
|
218 |
+
# \left| ... \right|
|
219 |
+
# \left\| ... \right\|
|
220 |
+
# \left[ ... \right]
|
221 |
+
|
222 |
+
LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
|
223 |
+
RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
|
224 |
+
left_count = len(LEFT_COUNT_PATTERN.findall(latex)) # 不匹配\lefteqn等
|
225 |
+
right_count = len(RIGHT_COUNT_PATTERN.findall(latex)) # 不匹配\rightarrow
|
226 |
+
|
227 |
+
if left_count != right_count:
|
228 |
+
for _ in range(2):
|
229 |
+
# replace valid pairs
|
230 |
+
latex = re.sub(r'\\left\\\{', "{", latex) # \left\{
|
231 |
+
latex = re.sub(r"\\left\|", "|", latex) # \left|
|
232 |
+
latex = re.sub(r"\\left\\\|", "|", latex) # \left\|
|
233 |
+
latex = re.sub(r"\\left\(", "(", latex) # \left(
|
234 |
+
latex = re.sub(r"\\left\[", "[", latex) # \left[
|
235 |
+
|
236 |
+
latex = re.sub(r"\\right\\\}", "}", latex) # \right\}
|
237 |
+
latex = re.sub(r"\\right\|", "|", latex) # \right|
|
238 |
+
latex = re.sub(r"\\right\\\|", "|", latex) # \right\|
|
239 |
+
latex = re.sub(r"\\right\)", ")", latex) # \right)
|
240 |
+
latex = re.sub(r"\\right\]", "]", latex) # \right]
|
241 |
+
latex = re.sub(r"\\right\.", "", latex) # \right.
|
242 |
+
|
243 |
+
# replace invalid pairs first
|
244 |
+
latex = re.sub(r'\\left\{', "{", latex)
|
245 |
+
latex = re.sub(r'\\right\}', "}", latex) # \left{ ... \right}
|
246 |
+
latex = re.sub(r'\\left\\\(', "(", latex)
|
247 |
+
latex = re.sub(r'\\right\\\)', ")", latex) # \left\( ... \right\)
|
248 |
+
latex = re.sub(r'\\left\\\[', "[", latex)
|
249 |
+
latex = re.sub(r'\\right\\\]', "]", latex) # \left\[ ... \right\]
|
250 |
+
|
251 |
+
return latex
|
252 |
+
|
253 |
+
|
254 |
+
def __reduct_overlap(bboxes):
|
255 |
+
N = len(bboxes)
|
256 |
+
keep = [True] * N
|
257 |
+
for i in range(N):
|
258 |
+
for j in range(N):
|
259 |
+
if i == j:
|
260 |
+
continue
|
261 |
+
if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
|
262 |
+
keep[i] = False
|
263 |
+
return [bboxes[i] for i in range(N) if keep[i]]
|
264 |
+
|
265 |
+
|
266 |
+
def __tie_up_category_by_distance_v3(
|
267 |
+
blocks: list,
|
268 |
+
subject_block_type: str,
|
269 |
+
object_block_type: str,
|
270 |
+
):
|
271 |
+
subjects = __reduct_overlap(
|
272 |
+
list(
|
273 |
+
map(
|
274 |
+
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
|
275 |
+
filter(
|
276 |
+
lambda x: x["type"] == subject_block_type,
|
277 |
+
blocks,
|
278 |
+
),
|
279 |
+
)
|
280 |
+
)
|
281 |
+
)
|
282 |
+
objects = __reduct_overlap(
|
283 |
+
list(
|
284 |
+
map(
|
285 |
+
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
|
286 |
+
filter(
|
287 |
+
lambda x: x["type"] == object_block_type,
|
288 |
+
blocks,
|
289 |
+
),
|
290 |
+
)
|
291 |
+
)
|
292 |
+
)
|
293 |
+
|
294 |
+
ret = []
|
295 |
+
N, M = len(subjects), len(objects)
|
296 |
+
subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
|
297 |
+
objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
|
298 |
+
|
299 |
+
OBJ_IDX_OFFSET = 10000
|
300 |
+
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
|
301 |
+
|
302 |
+
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
|
303 |
+
(i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
|
304 |
+
]
|
305 |
+
seen_idx = set()
|
306 |
+
seen_sub_idx = set()
|
307 |
+
|
308 |
+
while N > len(seen_sub_idx):
|
309 |
+
candidates = []
|
310 |
+
for idx, kind, x0, y0 in all_boxes_with_idx:
|
311 |
+
if idx in seen_idx:
|
312 |
+
continue
|
313 |
+
candidates.append((idx, kind, x0, y0))
|
314 |
+
|
315 |
+
if len(candidates) == 0:
|
316 |
+
break
|
317 |
+
left_x = min([v[2] for v in candidates])
|
318 |
+
top_y = min([v[3] for v in candidates])
|
319 |
+
|
320 |
+
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
|
321 |
+
|
322 |
+
fst_idx, fst_kind, left_x, top_y = candidates[0]
|
323 |
+
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
|
324 |
+
nxt = None
|
325 |
+
|
326 |
+
for i in range(1, len(candidates)):
|
327 |
+
if candidates[i][1] ^ fst_kind == 1:
|
328 |
+
nxt = candidates[i]
|
329 |
+
break
|
330 |
+
if nxt is None:
|
331 |
+
break
|
332 |
+
|
333 |
+
if fst_kind == SUB_BIT_KIND:
|
334 |
+
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
|
335 |
+
|
336 |
+
else:
|
337 |
+
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
|
338 |
+
|
339 |
+
pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
|
340 |
+
nearest_dis = float("inf")
|
341 |
+
for i in range(N):
|
342 |
+
if i in seen_idx or i == sub_idx:
|
343 |
+
continue
|
344 |
+
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
|
345 |
+
|
346 |
+
if pair_dis >= 3 * nearest_dis:
|
347 |
+
seen_idx.add(sub_idx)
|
348 |
+
continue
|
349 |
+
|
350 |
+
seen_idx.add(sub_idx)
|
351 |
+
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
|
352 |
+
seen_sub_idx.add(sub_idx)
|
353 |
+
|
354 |
+
ret.append(
|
355 |
+
{
|
356 |
+
"sub_bbox": {
|
357 |
+
"bbox": subjects[sub_idx]["bbox"],
|
358 |
+
"lines": subjects[sub_idx]["lines"],
|
359 |
+
"index": subjects[sub_idx]["index"],
|
360 |
+
},
|
361 |
+
"obj_bboxes": [
|
362 |
+
{"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
|
363 |
+
],
|
364 |
+
"sub_idx": sub_idx,
|
365 |
+
}
|
366 |
+
)
|
367 |
+
|
368 |
+
for i in range(len(objects)):
|
369 |
+
j = i + OBJ_IDX_OFFSET
|
370 |
+
if j in seen_idx:
|
371 |
+
continue
|
372 |
+
seen_idx.add(j)
|
373 |
+
nearest_dis, nearest_sub_idx = float("inf"), -1
|
374 |
+
for k in range(len(subjects)):
|
375 |
+
dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
|
376 |
+
if dis < nearest_dis:
|
377 |
+
nearest_dis = dis
|
378 |
+
nearest_sub_idx = k
|
379 |
+
|
380 |
+
for k in range(len(subjects)):
|
381 |
+
if k != nearest_sub_idx:
|
382 |
+
continue
|
383 |
+
if k in seen_sub_idx:
|
384 |
+
for kk in range(len(ret)):
|
385 |
+
if ret[kk]["sub_idx"] == k:
|
386 |
+
ret[kk]["obj_bboxes"].append(
|
387 |
+
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
|
388 |
+
)
|
389 |
+
break
|
390 |
+
else:
|
391 |
+
ret.append(
|
392 |
+
{
|
393 |
+
"sub_bbox": {
|
394 |
+
"bbox": subjects[k]["bbox"],
|
395 |
+
"lines": subjects[k]["lines"],
|
396 |
+
"index": subjects[k]["index"],
|
397 |
+
},
|
398 |
+
"obj_bboxes": [
|
399 |
+
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
|
400 |
+
],
|
401 |
+
"sub_idx": k,
|
402 |
+
}
|
403 |
+
)
|
404 |
+
seen_sub_idx.add(k)
|
405 |
+
seen_idx.add(k)
|
406 |
+
|
407 |
+
for i in range(len(subjects)):
|
408 |
+
if i in seen_sub_idx:
|
409 |
+
continue
|
410 |
+
ret.append(
|
411 |
+
{
|
412 |
+
"sub_bbox": {
|
413 |
+
"bbox": subjects[i]["bbox"],
|
414 |
+
"lines": subjects[i]["lines"],
|
415 |
+
"index": subjects[i]["index"],
|
416 |
+
},
|
417 |
+
"obj_bboxes": [],
|
418 |
+
"sub_idx": i,
|
419 |
+
}
|
420 |
+
)
|
421 |
+
|
422 |
+
return ret
|
423 |
+
|
424 |
+
|
425 |
+
def get_type_blocks(blocks, block_type: Literal["image", "table"]):
|
426 |
+
with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
|
427 |
+
with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
|
428 |
+
ret = []
|
429 |
+
for v in with_captions:
|
430 |
+
record = {
|
431 |
+
f"{block_type}_body": v["sub_bbox"],
|
432 |
+
f"{block_type}_caption_list": v["obj_bboxes"],
|
433 |
+
}
|
434 |
+
filter_idx = v["sub_idx"]
|
435 |
+
d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
|
436 |
+
record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
|
437 |
+
ret.append(record)
|
438 |
+
return ret
|
439 |
+
|
440 |
+
|
441 |
+
def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
|
442 |
+
need_fix_blocks = get_type_blocks(blocks, fix_type)
|
443 |
+
fixed_blocks = []
|
444 |
+
for block in need_fix_blocks:
|
445 |
+
body = block[f"{fix_type}_body"]
|
446 |
+
caption_list = block[f"{fix_type}_caption_list"]
|
447 |
+
footnote_list = block[f"{fix_type}_footnote_list"]
|
448 |
+
|
449 |
+
body["type"] = f"{fix_type}_body"
|
450 |
+
for caption in caption_list:
|
451 |
+
caption["type"] = f"{fix_type}_caption"
|
452 |
+
for footnote in footnote_list:
|
453 |
+
footnote["type"] = f"{fix_type}_footnote"
|
454 |
+
|
455 |
+
two_layer_block = {
|
456 |
+
"type": fix_type,
|
457 |
+
"bbox": body["bbox"],
|
458 |
+
"blocks": [
|
459 |
+
body,
|
460 |
+
],
|
461 |
+
"index": body["index"],
|
462 |
+
}
|
463 |
+
two_layer_block["blocks"].extend([*caption_list, *footnote_list])
|
464 |
+
|
465 |
+
fixed_blocks.append(two_layer_block)
|
466 |
+
|
467 |
+
return fixed_blocks
|
468 |
+
|
469 |
+
|
470 |
+
def fix_title_blocks(blocks):
|
471 |
+
for block in blocks:
|
472 |
+
if block["type"] == BlockType.TITLE:
|
473 |
+
title_content = merge_para_with_text(block)
|
474 |
+
title_level = count_leading_hashes(title_content)
|
475 |
+
block['level'] = title_level
|
476 |
+
for line in block['lines']:
|
477 |
+
for span in line['spans']:
|
478 |
+
span['content'] = strip_leading_hashes(span['content'])
|
479 |
+
break
|
480 |
+
break
|
481 |
+
return blocks
|
482 |
+
|
483 |
+
|
484 |
+
def count_leading_hashes(text):
|
485 |
+
match = re.match(r'^(#+)', text)
|
486 |
+
return len(match.group(1)) if match else 0
|
487 |
+
|
488 |
+
|
489 |
+
def strip_leading_hashes(text):
|
490 |
+
# 去除开头的#和紧随其后的空格
|
491 |
+
return re.sub(r'^#+\s*', '', text)
|
492 |
+
|
493 |
+
|
494 |
+
def fix_text_blocks(blocks):
|
495 |
+
i = 0
|
496 |
+
while i < len(blocks):
|
497 |
+
block = blocks[i]
|
498 |
+
last_line = block["lines"][-1]if block["lines"] else None
|
499 |
+
if last_line:
|
500 |
+
last_span = last_line["spans"][-1] if last_line["spans"] else None
|
501 |
+
if last_span and last_span['content'].endswith('<|txt_contd|>'):
|
502 |
+
last_span['content'] = last_span['content'][:-len('<|txt_contd|>')]
|
503 |
+
|
504 |
+
# 查找下一个未被清空的块
|
505 |
+
next_idx = i + 1
|
506 |
+
while next_idx < len(blocks) and blocks[next_idx].get(SplitFlag.LINES_DELETED, False):
|
507 |
+
next_idx += 1
|
508 |
+
|
509 |
+
# 如果找到下一个有效块,则合并
|
510 |
+
if next_idx < len(blocks):
|
511 |
+
next_block = blocks[next_idx]
|
512 |
+
# 将下一个块的lines扩展到当前块的lines中
|
513 |
+
block["lines"].extend(next_block["lines"])
|
514 |
+
# 清空下一个块的lines
|
515 |
+
next_block["lines"] = []
|
516 |
+
# 在下一个块中添加标志
|
517 |
+
next_block[SplitFlag.LINES_DELETED] = True
|
518 |
+
# 不增加i,继续检查当前块(现在已包含下一个块的内容)
|
519 |
+
continue
|
520 |
+
i += 1
|
521 |
+
return blocks
|
vendor/mineru/mineru/backend/vlm/vlm_middle_json_mkcontent.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
|
4 |
+
from mineru.utils.enum_class import MakeMode, BlockType, ContentType
|
5 |
+
|
6 |
+
|
7 |
+
latex_delimiters_config = get_latex_delimiter_config()
|
8 |
+
|
9 |
+
default_delimiters = {
|
10 |
+
'display': {'left': '$$', 'right': '$$'},
|
11 |
+
'inline': {'left': '$', 'right': '$'}
|
12 |
+
}
|
13 |
+
|
14 |
+
delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
|
15 |
+
|
16 |
+
display_left_delimiter = delimiters['display']['left']
|
17 |
+
display_right_delimiter = delimiters['display']['right']
|
18 |
+
inline_left_delimiter = delimiters['inline']['left']
|
19 |
+
inline_right_delimiter = delimiters['inline']['right']
|
20 |
+
|
21 |
+
def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
|
22 |
+
para_text = ''
|
23 |
+
for line in para_block['lines']:
|
24 |
+
for j, span in enumerate(line['spans']):
|
25 |
+
span_type = span['type']
|
26 |
+
content = ''
|
27 |
+
if span_type == ContentType.TEXT:
|
28 |
+
content = span['content']
|
29 |
+
elif span_type == ContentType.INLINE_EQUATION:
|
30 |
+
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
|
31 |
+
elif span_type == ContentType.INTERLINE_EQUATION:
|
32 |
+
if formula_enable:
|
33 |
+
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
|
34 |
+
else:
|
35 |
+
if span.get('image_path', ''):
|
36 |
+
content = f""
|
37 |
+
# content = content.strip()
|
38 |
+
if content:
|
39 |
+
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
|
40 |
+
if j == len(line['spans']) - 1:
|
41 |
+
para_text += content
|
42 |
+
else:
|
43 |
+
para_text += f'{content} '
|
44 |
+
elif span_type == ContentType.INTERLINE_EQUATION:
|
45 |
+
para_text += content
|
46 |
+
return para_text
|
47 |
+
|
48 |
+
def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
|
49 |
+
page_markdown = []
|
50 |
+
for para_block in para_blocks:
|
51 |
+
para_text = ''
|
52 |
+
para_type = para_block['type']
|
53 |
+
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
|
54 |
+
para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
|
55 |
+
elif para_type == BlockType.TITLE:
|
56 |
+
title_level = get_title_level(para_block)
|
57 |
+
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
|
58 |
+
elif para_type == BlockType.IMAGE:
|
59 |
+
if make_mode == MakeMode.NLP_MD:
|
60 |
+
continue
|
61 |
+
elif make_mode == MakeMode.MM_MD:
|
62 |
+
# 检测是否存在图片脚注
|
63 |
+
has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
|
64 |
+
# 如果存在图片脚注,则将图片脚注拼接到图片正文后面
|
65 |
+
if has_image_footnote:
|
66 |
+
for block in para_block['blocks']: # 1st.拼image_caption
|
67 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
68 |
+
para_text += merge_para_with_text(block) + ' \n'
|
69 |
+
for block in para_block['blocks']: # 2nd.拼image_body
|
70 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
71 |
+
for line in block['lines']:
|
72 |
+
for span in line['spans']:
|
73 |
+
if span['type'] == ContentType.IMAGE:
|
74 |
+
if span.get('image_path', ''):
|
75 |
+
para_text += f""
|
76 |
+
for block in para_block['blocks']: # 3rd.拼image_footnote
|
77 |
+
if block['type'] == BlockType.IMAGE_FOOTNOTE:
|
78 |
+
para_text += ' \n' + merge_para_with_text(block)
|
79 |
+
else:
|
80 |
+
for block in para_block['blocks']: # 1st.拼image_body
|
81 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
82 |
+
for line in block['lines']:
|
83 |
+
for span in line['spans']:
|
84 |
+
if span['type'] == ContentType.IMAGE:
|
85 |
+
if span.get('image_path', ''):
|
86 |
+
para_text += f""
|
87 |
+
for block in para_block['blocks']: # 2nd.拼image_caption
|
88 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
89 |
+
para_text += ' \n' + merge_para_with_text(block)
|
90 |
+
|
91 |
+
elif para_type == BlockType.TABLE:
|
92 |
+
if make_mode == MakeMode.NLP_MD:
|
93 |
+
continue
|
94 |
+
elif make_mode == MakeMode.MM_MD:
|
95 |
+
for block in para_block['blocks']: # 1st.拼table_caption
|
96 |
+
if block['type'] == BlockType.TABLE_CAPTION:
|
97 |
+
para_text += merge_para_with_text(block) + ' \n'
|
98 |
+
for block in para_block['blocks']: # 2nd.拼table_body
|
99 |
+
if block['type'] == BlockType.TABLE_BODY:
|
100 |
+
for line in block['lines']:
|
101 |
+
for span in line['spans']:
|
102 |
+
if span['type'] == ContentType.TABLE:
|
103 |
+
# if processed by table model
|
104 |
+
if table_enable:
|
105 |
+
if span.get('html', ''):
|
106 |
+
para_text += f"\n{span['html']}\n"
|
107 |
+
elif span.get('image_path', ''):
|
108 |
+
para_text += f""
|
109 |
+
else:
|
110 |
+
if span.get('image_path', ''):
|
111 |
+
para_text += f""
|
112 |
+
for block in para_block['blocks']: # 3rd.拼table_footnote
|
113 |
+
if block['type'] == BlockType.TABLE_FOOTNOTE:
|
114 |
+
para_text += '\n' + merge_para_with_text(block) + ' '
|
115 |
+
|
116 |
+
if para_text.strip() == '':
|
117 |
+
continue
|
118 |
+
else:
|
119 |
+
# page_markdown.append(para_text.strip() + ' ')
|
120 |
+
page_markdown.append(para_text.strip())
|
121 |
+
|
122 |
+
return page_markdown
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
|
129 |
+
para_type = para_block['type']
|
130 |
+
para_content = {}
|
131 |
+
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
|
132 |
+
para_content = {
|
133 |
+
'type': ContentType.TEXT,
|
134 |
+
'text': merge_para_with_text(para_block),
|
135 |
+
}
|
136 |
+
elif para_type == BlockType.TITLE:
|
137 |
+
title_level = get_title_level(para_block)
|
138 |
+
para_content = {
|
139 |
+
'type': ContentType.TEXT,
|
140 |
+
'text': merge_para_with_text(para_block),
|
141 |
+
}
|
142 |
+
if title_level != 0:
|
143 |
+
para_content['text_level'] = title_level
|
144 |
+
elif para_type == BlockType.INTERLINE_EQUATION:
|
145 |
+
para_content = {
|
146 |
+
'type': ContentType.EQUATION,
|
147 |
+
'text': merge_para_with_text(para_block),
|
148 |
+
'text_format': 'latex',
|
149 |
+
}
|
150 |
+
elif para_type == BlockType.IMAGE:
|
151 |
+
para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
|
152 |
+
for block in para_block['blocks']:
|
153 |
+
if block['type'] == BlockType.IMAGE_BODY:
|
154 |
+
for line in block['lines']:
|
155 |
+
for span in line['spans']:
|
156 |
+
if span['type'] == ContentType.IMAGE:
|
157 |
+
if span.get('image_path', ''):
|
158 |
+
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
|
159 |
+
if block['type'] == BlockType.IMAGE_CAPTION:
|
160 |
+
para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
|
161 |
+
if block['type'] == BlockType.IMAGE_FOOTNOTE:
|
162 |
+
para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
|
163 |
+
elif para_type == BlockType.TABLE:
|
164 |
+
para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
|
165 |
+
for block in para_block['blocks']:
|
166 |
+
if block['type'] == BlockType.TABLE_BODY:
|
167 |
+
for line in block['lines']:
|
168 |
+
for span in line['spans']:
|
169 |
+
if span['type'] == ContentType.TABLE:
|
170 |
+
|
171 |
+
if span.get('html', ''):
|
172 |
+
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
|
173 |
+
|
174 |
+
if span.get('image_path', ''):
|
175 |
+
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
|
176 |
+
|
177 |
+
if block['type'] == BlockType.TABLE_CAPTION:
|
178 |
+
para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
|
179 |
+
if block['type'] == BlockType.TABLE_FOOTNOTE:
|
180 |
+
para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
|
181 |
+
|
182 |
+
para_content['page_idx'] = page_idx
|
183 |
+
|
184 |
+
return para_content
|
185 |
+
|
186 |
+
def union_make(pdf_info_dict: list,
|
187 |
+
make_mode: str,
|
188 |
+
img_buket_path: str = '',
|
189 |
+
):
|
190 |
+
|
191 |
+
formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
|
192 |
+
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
|
193 |
+
|
194 |
+
output_content = []
|
195 |
+
for page_info in pdf_info_dict:
|
196 |
+
paras_of_layout = page_info.get('para_blocks')
|
197 |
+
page_idx = page_info.get('page_idx')
|
198 |
+
if not paras_of_layout:
|
199 |
+
continue
|
200 |
+
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
|
201 |
+
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
|
202 |
+
output_content.extend(page_markdown)
|
203 |
+
elif make_mode == MakeMode.CONTENT_LIST:
|
204 |
+
for para_block in paras_of_layout:
|
205 |
+
para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
|
206 |
+
output_content.append(para_content)
|
207 |
+
|
208 |
+
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
|
209 |
+
return '\n\n'.join(output_content)
|
210 |
+
elif make_mode == MakeMode.CONTENT_LIST:
|
211 |
+
return output_content
|
212 |
+
return None
|
213 |
+
|
214 |
+
|
215 |
+
def get_title_level(block):
|
216 |
+
title_level = block.get('level', 1)
|
217 |
+
if title_level > 4:
|
218 |
+
title_level = 4
|
219 |
+
elif title_level < 1:
|
220 |
+
title_level = 0
|
221 |
+
return title_level
|
vendor/mineru/mineru/cli/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/cli/client.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import click
|
4 |
+
from pathlib import Path
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from mineru.utils.cli_parser import arg_parse
|
8 |
+
from mineru.utils.config_reader import get_device
|
9 |
+
from mineru.utils.model_utils import get_vram
|
10 |
+
from ..version import __version__
|
11 |
+
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
|
12 |
+
|
13 |
+
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
|
14 |
+
@click.pass_context
|
15 |
+
@click.version_option(__version__,
|
16 |
+
'--version',
|
17 |
+
'-v',
|
18 |
+
help='display the version and exit')
|
19 |
+
@click.option(
|
20 |
+
'-p',
|
21 |
+
'--path',
|
22 |
+
'input_path',
|
23 |
+
type=click.Path(exists=True),
|
24 |
+
required=True,
|
25 |
+
help='local filepath or directory. support pdf, png, jpg, jpeg files',
|
26 |
+
)
|
27 |
+
@click.option(
|
28 |
+
'-o',
|
29 |
+
'--output',
|
30 |
+
'output_dir',
|
31 |
+
type=click.Path(),
|
32 |
+
required=True,
|
33 |
+
help='output local directory',
|
34 |
+
)
|
35 |
+
@click.option(
|
36 |
+
'-m',
|
37 |
+
'--method',
|
38 |
+
'method',
|
39 |
+
type=click.Choice(['auto', 'txt', 'ocr']),
|
40 |
+
help="""the method for parsing pdf:
|
41 |
+
auto: Automatically determine the method based on the file type.
|
42 |
+
txt: Use text extraction method.
|
43 |
+
ocr: Use OCR method for image-based PDFs.
|
44 |
+
Without method specified, 'auto' will be used by default.
|
45 |
+
Adapted only for the case where the backend is set to "pipeline".""",
|
46 |
+
default='auto',
|
47 |
+
)
|
48 |
+
@click.option(
|
49 |
+
'-b',
|
50 |
+
'--backend',
|
51 |
+
'backend',
|
52 |
+
type=click.Choice(['pipeline', 'vlm-transformers', 'vlm-sglang-engine', 'vlm-sglang-client']),
|
53 |
+
help="""the backend for parsing pdf:
|
54 |
+
pipeline: More general.
|
55 |
+
vlm-transformers: More general.
|
56 |
+
vlm-sglang-engine: Faster(engine).
|
57 |
+
vlm-sglang-client: Faster(client).
|
58 |
+
without method specified, pipeline will be used by default.""",
|
59 |
+
default='pipeline',
|
60 |
+
)
|
61 |
+
@click.option(
|
62 |
+
'-l',
|
63 |
+
'--lang',
|
64 |
+
'lang',
|
65 |
+
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka',
|
66 |
+
'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']),
|
67 |
+
help="""
|
68 |
+
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
|
69 |
+
Without languages specified, 'ch' will be used by default.
|
70 |
+
Adapted only for the case where the backend is set to "pipeline".
|
71 |
+
""",
|
72 |
+
default='ch',
|
73 |
+
)
|
74 |
+
@click.option(
|
75 |
+
'-u',
|
76 |
+
'--url',
|
77 |
+
'server_url',
|
78 |
+
type=str,
|
79 |
+
help="""
|
80 |
+
When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
|
81 |
+
""",
|
82 |
+
default=None,
|
83 |
+
)
|
84 |
+
@click.option(
|
85 |
+
'-s',
|
86 |
+
'--start',
|
87 |
+
'start_page_id',
|
88 |
+
type=int,
|
89 |
+
help='The starting page for PDF parsing, beginning from 0.',
|
90 |
+
default=0,
|
91 |
+
)
|
92 |
+
@click.option(
|
93 |
+
'-e',
|
94 |
+
'--end',
|
95 |
+
'end_page_id',
|
96 |
+
type=int,
|
97 |
+
help='The ending page for PDF parsing, beginning from 0.',
|
98 |
+
default=None,
|
99 |
+
)
|
100 |
+
@click.option(
|
101 |
+
'-f',
|
102 |
+
'--formula',
|
103 |
+
'formula_enable',
|
104 |
+
type=bool,
|
105 |
+
help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
|
106 |
+
default=True,
|
107 |
+
)
|
108 |
+
@click.option(
|
109 |
+
'-t',
|
110 |
+
'--table',
|
111 |
+
'table_enable',
|
112 |
+
type=bool,
|
113 |
+
help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
|
114 |
+
default=True,
|
115 |
+
)
|
116 |
+
@click.option(
|
117 |
+
'-d',
|
118 |
+
'--device',
|
119 |
+
'device_mode',
|
120 |
+
type=str,
|
121 |
+
help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
|
122 |
+
default=None,
|
123 |
+
)
|
124 |
+
@click.option(
|
125 |
+
'--vram',
|
126 |
+
'virtual_vram',
|
127 |
+
type=int,
|
128 |
+
help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
|
129 |
+
default=None,
|
130 |
+
)
|
131 |
+
@click.option(
|
132 |
+
'--source',
|
133 |
+
'model_source',
|
134 |
+
type=click.Choice(['huggingface', 'modelscope', 'local']),
|
135 |
+
help="""
|
136 |
+
The source of the model repository. Default is 'huggingface'.
|
137 |
+
""",
|
138 |
+
default='huggingface',
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
def main(
|
143 |
+
ctx,
|
144 |
+
input_path, output_dir, method, backend, lang, server_url,
|
145 |
+
start_page_id, end_page_id, formula_enable, table_enable,
|
146 |
+
device_mode, virtual_vram, model_source, **kwargs
|
147 |
+
):
|
148 |
+
|
149 |
+
kwargs.update(arg_parse(ctx))
|
150 |
+
|
151 |
+
if not backend.endswith('-client'):
|
152 |
+
def get_device_mode() -> str:
|
153 |
+
if device_mode is not None:
|
154 |
+
return device_mode
|
155 |
+
else:
|
156 |
+
return get_device()
|
157 |
+
if os.getenv('MINERU_DEVICE_MODE', None) is None:
|
158 |
+
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
|
159 |
+
|
160 |
+
def get_virtual_vram_size() -> int:
|
161 |
+
if virtual_vram is not None:
|
162 |
+
return virtual_vram
|
163 |
+
if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
|
164 |
+
return round(get_vram(get_device_mode()))
|
165 |
+
return 1
|
166 |
+
if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
|
167 |
+
os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
|
168 |
+
|
169 |
+
if os.getenv('MINERU_MODEL_SOURCE', None) is None:
|
170 |
+
os.environ['MINERU_MODEL_SOURCE'] = model_source
|
171 |
+
|
172 |
+
os.makedirs(output_dir, exist_ok=True)
|
173 |
+
|
174 |
+
def parse_doc(path_list: list[Path]):
|
175 |
+
try:
|
176 |
+
file_name_list = []
|
177 |
+
pdf_bytes_list = []
|
178 |
+
lang_list = []
|
179 |
+
for path in path_list:
|
180 |
+
file_name = str(Path(path).stem)
|
181 |
+
pdf_bytes = read_fn(path)
|
182 |
+
file_name_list.append(file_name)
|
183 |
+
pdf_bytes_list.append(pdf_bytes)
|
184 |
+
lang_list.append(lang)
|
185 |
+
do_parse(
|
186 |
+
output_dir=output_dir,
|
187 |
+
pdf_file_names=file_name_list,
|
188 |
+
pdf_bytes_list=pdf_bytes_list,
|
189 |
+
p_lang_list=lang_list,
|
190 |
+
backend=backend,
|
191 |
+
parse_method=method,
|
192 |
+
formula_enable=formula_enable,
|
193 |
+
table_enable=table_enable,
|
194 |
+
server_url=server_url,
|
195 |
+
start_page_id=start_page_id,
|
196 |
+
end_page_id=end_page_id,
|
197 |
+
**kwargs,
|
198 |
+
)
|
199 |
+
except Exception as e:
|
200 |
+
logger.exception(e)
|
201 |
+
|
202 |
+
if os.path.isdir(input_path):
|
203 |
+
doc_path_list = []
|
204 |
+
for doc_path in Path(input_path).glob('*'):
|
205 |
+
if doc_path.suffix in pdf_suffixes + image_suffixes:
|
206 |
+
doc_path_list.append(doc_path)
|
207 |
+
parse_doc(doc_path_list)
|
208 |
+
else:
|
209 |
+
parse_doc([Path(input_path)])
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
main()
|
vendor/mineru/mineru/cli/common.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import copy
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import pypdfium2 as pdfium
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
from mineru.data.data_reader_writer import FileBasedDataWriter
|
12 |
+
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
|
13 |
+
from mineru.utils.enum_class import MakeMode
|
14 |
+
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
|
15 |
+
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
|
16 |
+
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
|
17 |
+
from mineru.backend.vlm.vlm_analyze import aio_doc_analyze as aio_vlm_doc_analyze
|
18 |
+
|
19 |
+
pdf_suffixes = [".pdf"]
|
20 |
+
image_suffixes = [".png", ".jpeg", ".jpg", ".webp", ".gif"]
|
21 |
+
|
22 |
+
|
23 |
+
def read_fn(path):
|
24 |
+
if not isinstance(path, Path):
|
25 |
+
path = Path(path)
|
26 |
+
with open(str(path), "rb") as input_file:
|
27 |
+
file_bytes = input_file.read()
|
28 |
+
if path.suffix in image_suffixes:
|
29 |
+
return images_bytes_to_pdf_bytes(file_bytes)
|
30 |
+
elif path.suffix in pdf_suffixes:
|
31 |
+
return file_bytes
|
32 |
+
else:
|
33 |
+
raise Exception(f"Unknown file suffix: {path.suffix}")
|
34 |
+
|
35 |
+
|
36 |
+
def prepare_env(output_dir, pdf_file_name, parse_method):
|
37 |
+
local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method))
|
38 |
+
local_image_dir = os.path.join(str(local_md_dir), "images")
|
39 |
+
os.makedirs(local_image_dir, exist_ok=True)
|
40 |
+
os.makedirs(local_md_dir, exist_ok=True)
|
41 |
+
return local_image_dir, local_md_dir
|
42 |
+
|
43 |
+
|
44 |
+
def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
|
45 |
+
|
46 |
+
# 从字节数据加载PDF
|
47 |
+
pdf = pdfium.PdfDocument(pdf_bytes)
|
48 |
+
|
49 |
+
# 确定结束页
|
50 |
+
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
|
51 |
+
if end_page_id > len(pdf) - 1:
|
52 |
+
logger.warning("end_page_id is out of range, use pdf_docs length")
|
53 |
+
end_page_id = len(pdf) - 1
|
54 |
+
|
55 |
+
# 创建一个新的PDF文档
|
56 |
+
output_pdf = pdfium.PdfDocument.new()
|
57 |
+
|
58 |
+
# 选择要导入的页面索引
|
59 |
+
page_indices = list(range(start_page_id, end_page_id + 1))
|
60 |
+
|
61 |
+
# 从原PDF导入页面到新PDF
|
62 |
+
output_pdf.import_pages(pdf, page_indices)
|
63 |
+
|
64 |
+
# 将新PDF保存到内存缓冲区
|
65 |
+
output_buffer = io.BytesIO()
|
66 |
+
output_pdf.save(output_buffer)
|
67 |
+
|
68 |
+
# 获取字节数据
|
69 |
+
output_bytes = output_buffer.getvalue()
|
70 |
+
|
71 |
+
pdf.close() # 关闭原PDF文档以释放资源
|
72 |
+
output_pdf.close() # 关闭新PDF文档以释放资源
|
73 |
+
|
74 |
+
return output_bytes
|
75 |
+
|
76 |
+
|
77 |
+
def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
|
78 |
+
"""准备处理PDF字节数据"""
|
79 |
+
result = []
|
80 |
+
for pdf_bytes in pdf_bytes_list:
|
81 |
+
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
|
82 |
+
result.append(new_pdf_bytes)
|
83 |
+
return result
|
84 |
+
|
85 |
+
|
86 |
+
def _process_output(
|
87 |
+
pdf_info,
|
88 |
+
pdf_bytes,
|
89 |
+
pdf_file_name,
|
90 |
+
local_md_dir,
|
91 |
+
local_image_dir,
|
92 |
+
md_writer,
|
93 |
+
f_draw_layout_bbox,
|
94 |
+
f_draw_span_bbox,
|
95 |
+
f_dump_orig_pdf,
|
96 |
+
f_dump_md,
|
97 |
+
f_dump_content_list,
|
98 |
+
f_dump_middle_json,
|
99 |
+
f_dump_model_output,
|
100 |
+
f_make_md_mode,
|
101 |
+
middle_json,
|
102 |
+
model_output=None,
|
103 |
+
is_pipeline=True
|
104 |
+
):
|
105 |
+
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
|
106 |
+
"""处理输出文件"""
|
107 |
+
if f_draw_layout_bbox:
|
108 |
+
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
|
109 |
+
|
110 |
+
if f_draw_span_bbox:
|
111 |
+
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
|
112 |
+
|
113 |
+
if f_dump_orig_pdf:
|
114 |
+
md_writer.write(
|
115 |
+
f"{pdf_file_name}_origin.pdf",
|
116 |
+
pdf_bytes,
|
117 |
+
)
|
118 |
+
|
119 |
+
image_dir = str(os.path.basename(local_image_dir))
|
120 |
+
|
121 |
+
if f_dump_md:
|
122 |
+
make_func = pipeline_union_make if is_pipeline else vlm_union_make
|
123 |
+
md_content_str = make_func(pdf_info, f_make_md_mode, image_dir)
|
124 |
+
md_writer.write_string(
|
125 |
+
f"{pdf_file_name}.md",
|
126 |
+
md_content_str,
|
127 |
+
)
|
128 |
+
|
129 |
+
if f_dump_content_list:
|
130 |
+
make_func = pipeline_union_make if is_pipeline else vlm_union_make
|
131 |
+
content_list = make_func(pdf_info, MakeMode.CONTENT_LIST, image_dir)
|
132 |
+
md_writer.write_string(
|
133 |
+
f"{pdf_file_name}_content_list.json",
|
134 |
+
json.dumps(content_list, ensure_ascii=False, indent=4),
|
135 |
+
)
|
136 |
+
|
137 |
+
if f_dump_middle_json:
|
138 |
+
md_writer.write_string(
|
139 |
+
f"{pdf_file_name}_middle.json",
|
140 |
+
json.dumps(middle_json, ensure_ascii=False, indent=4),
|
141 |
+
)
|
142 |
+
|
143 |
+
if f_dump_model_output:
|
144 |
+
if is_pipeline:
|
145 |
+
md_writer.write_string(
|
146 |
+
f"{pdf_file_name}_model.json",
|
147 |
+
json.dumps(model_output, ensure_ascii=False, indent=4),
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
output_text = ("\n" + "-" * 50 + "\n").join(model_output)
|
151 |
+
md_writer.write_string(
|
152 |
+
f"{pdf_file_name}_model_output.txt",
|
153 |
+
output_text,
|
154 |
+
)
|
155 |
+
|
156 |
+
logger.info(f"local output dir is {local_md_dir}")
|
157 |
+
|
158 |
+
|
159 |
+
def _process_pipeline(
|
160 |
+
output_dir,
|
161 |
+
pdf_file_names,
|
162 |
+
pdf_bytes_list,
|
163 |
+
p_lang_list,
|
164 |
+
parse_method,
|
165 |
+
p_formula_enable,
|
166 |
+
p_table_enable,
|
167 |
+
f_draw_layout_bbox,
|
168 |
+
f_draw_span_bbox,
|
169 |
+
f_dump_md,
|
170 |
+
f_dump_middle_json,
|
171 |
+
f_dump_model_output,
|
172 |
+
f_dump_orig_pdf,
|
173 |
+
f_dump_content_list,
|
174 |
+
f_make_md_mode,
|
175 |
+
):
|
176 |
+
"""处理pipeline后端逻辑"""
|
177 |
+
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
|
178 |
+
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
|
179 |
+
|
180 |
+
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
|
181 |
+
pipeline_doc_analyze(
|
182 |
+
pdf_bytes_list, p_lang_list, parse_method=parse_method,
|
183 |
+
formula_enable=p_formula_enable, table_enable=p_table_enable
|
184 |
+
)
|
185 |
+
)
|
186 |
+
|
187 |
+
for idx, model_list in enumerate(infer_results):
|
188 |
+
model_json = copy.deepcopy(model_list)
|
189 |
+
pdf_file_name = pdf_file_names[idx]
|
190 |
+
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
|
191 |
+
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
|
192 |
+
|
193 |
+
images_list = all_image_lists[idx]
|
194 |
+
pdf_doc = all_pdf_docs[idx]
|
195 |
+
_lang = lang_list[idx]
|
196 |
+
_ocr_enable = ocr_enabled_list[idx]
|
197 |
+
|
198 |
+
middle_json = pipeline_result_to_middle_json(
|
199 |
+
model_list, images_list, pdf_doc, image_writer,
|
200 |
+
_lang, _ocr_enable, p_formula_enable
|
201 |
+
)
|
202 |
+
|
203 |
+
pdf_info = middle_json["pdf_info"]
|
204 |
+
pdf_bytes = pdf_bytes_list[idx]
|
205 |
+
|
206 |
+
_process_output(
|
207 |
+
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
|
208 |
+
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
|
209 |
+
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
|
210 |
+
f_make_md_mode, middle_json, model_json, is_pipeline=True
|
211 |
+
)
|
212 |
+
|
213 |
+
|
214 |
+
async def _async_process_vlm(
|
215 |
+
output_dir,
|
216 |
+
pdf_file_names,
|
217 |
+
pdf_bytes_list,
|
218 |
+
backend,
|
219 |
+
f_draw_layout_bbox,
|
220 |
+
f_draw_span_bbox,
|
221 |
+
f_dump_md,
|
222 |
+
f_dump_middle_json,
|
223 |
+
f_dump_model_output,
|
224 |
+
f_dump_orig_pdf,
|
225 |
+
f_dump_content_list,
|
226 |
+
f_make_md_mode,
|
227 |
+
server_url=None,
|
228 |
+
**kwargs,
|
229 |
+
):
|
230 |
+
"""异步处理VLM后端逻辑"""
|
231 |
+
parse_method = "vlm"
|
232 |
+
f_draw_span_bbox = False
|
233 |
+
if not backend.endswith("client"):
|
234 |
+
server_url = None
|
235 |
+
|
236 |
+
for idx, pdf_bytes in enumerate(pdf_bytes_list):
|
237 |
+
pdf_file_name = pdf_file_names[idx]
|
238 |
+
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
|
239 |
+
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
|
240 |
+
|
241 |
+
middle_json, infer_result = await aio_vlm_doc_analyze(
|
242 |
+
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
|
243 |
+
)
|
244 |
+
|
245 |
+
pdf_info = middle_json["pdf_info"]
|
246 |
+
|
247 |
+
_process_output(
|
248 |
+
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
|
249 |
+
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
|
250 |
+
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
|
251 |
+
f_make_md_mode, middle_json, infer_result, is_pipeline=False
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def _process_vlm(
|
256 |
+
output_dir,
|
257 |
+
pdf_file_names,
|
258 |
+
pdf_bytes_list,
|
259 |
+
backend,
|
260 |
+
f_draw_layout_bbox,
|
261 |
+
f_draw_span_bbox,
|
262 |
+
f_dump_md,
|
263 |
+
f_dump_middle_json,
|
264 |
+
f_dump_model_output,
|
265 |
+
f_dump_orig_pdf,
|
266 |
+
f_dump_content_list,
|
267 |
+
f_make_md_mode,
|
268 |
+
server_url=None,
|
269 |
+
**kwargs,
|
270 |
+
):
|
271 |
+
"""同步处理VLM后端逻辑"""
|
272 |
+
parse_method = "vlm"
|
273 |
+
f_draw_span_bbox = False
|
274 |
+
if not backend.endswith("client"):
|
275 |
+
server_url = None
|
276 |
+
|
277 |
+
for idx, pdf_bytes in enumerate(pdf_bytes_list):
|
278 |
+
pdf_file_name = pdf_file_names[idx]
|
279 |
+
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
|
280 |
+
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
|
281 |
+
|
282 |
+
middle_json, infer_result = vlm_doc_analyze(
|
283 |
+
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
|
284 |
+
)
|
285 |
+
|
286 |
+
pdf_info = middle_json["pdf_info"]
|
287 |
+
|
288 |
+
_process_output(
|
289 |
+
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
|
290 |
+
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
|
291 |
+
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
|
292 |
+
f_make_md_mode, middle_json, infer_result, is_pipeline=False
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
def do_parse(
|
297 |
+
output_dir,
|
298 |
+
pdf_file_names: list[str],
|
299 |
+
pdf_bytes_list: list[bytes],
|
300 |
+
p_lang_list: list[str],
|
301 |
+
backend="pipeline",
|
302 |
+
parse_method="auto",
|
303 |
+
formula_enable=True,
|
304 |
+
table_enable=True,
|
305 |
+
server_url=None,
|
306 |
+
f_draw_layout_bbox=True,
|
307 |
+
f_draw_span_bbox=True,
|
308 |
+
f_dump_md=True,
|
309 |
+
f_dump_middle_json=True,
|
310 |
+
f_dump_model_output=True,
|
311 |
+
f_dump_orig_pdf=True,
|
312 |
+
f_dump_content_list=True,
|
313 |
+
f_make_md_mode=MakeMode.MM_MD,
|
314 |
+
start_page_id=0,
|
315 |
+
end_page_id=None,
|
316 |
+
**kwargs,
|
317 |
+
):
|
318 |
+
# 预处理PDF字节数据
|
319 |
+
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
|
320 |
+
|
321 |
+
if backend == "pipeline":
|
322 |
+
_process_pipeline(
|
323 |
+
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
|
324 |
+
parse_method, formula_enable, table_enable,
|
325 |
+
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
|
326 |
+
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
if backend.startswith("vlm-"):
|
330 |
+
backend = backend[4:]
|
331 |
+
|
332 |
+
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
|
333 |
+
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
|
334 |
+
|
335 |
+
_process_vlm(
|
336 |
+
output_dir, pdf_file_names, pdf_bytes_list, backend,
|
337 |
+
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
|
338 |
+
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
|
339 |
+
server_url, **kwargs,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
async def aio_do_parse(
|
344 |
+
output_dir,
|
345 |
+
pdf_file_names: list[str],
|
346 |
+
pdf_bytes_list: list[bytes],
|
347 |
+
p_lang_list: list[str],
|
348 |
+
backend="pipeline",
|
349 |
+
parse_method="auto",
|
350 |
+
formula_enable=True,
|
351 |
+
table_enable=True,
|
352 |
+
server_url=None,
|
353 |
+
f_draw_layout_bbox=True,
|
354 |
+
f_draw_span_bbox=True,
|
355 |
+
f_dump_md=True,
|
356 |
+
f_dump_middle_json=True,
|
357 |
+
f_dump_model_output=True,
|
358 |
+
f_dump_orig_pdf=True,
|
359 |
+
f_dump_content_list=True,
|
360 |
+
f_make_md_mode=MakeMode.MM_MD,
|
361 |
+
start_page_id=0,
|
362 |
+
end_page_id=None,
|
363 |
+
**kwargs,
|
364 |
+
):
|
365 |
+
# 预处理PDF字节数据
|
366 |
+
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
|
367 |
+
|
368 |
+
if backend == "pipeline":
|
369 |
+
# pipeline模式暂不支持异步,使用同步处理方式
|
370 |
+
_process_pipeline(
|
371 |
+
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
|
372 |
+
parse_method, formula_enable, table_enable,
|
373 |
+
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
|
374 |
+
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
if backend.startswith("vlm-"):
|
378 |
+
backend = backend[4:]
|
379 |
+
|
380 |
+
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
|
381 |
+
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
|
382 |
+
|
383 |
+
await _async_process_vlm(
|
384 |
+
output_dir, pdf_file_names, pdf_bytes_list, backend,
|
385 |
+
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
|
386 |
+
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
|
387 |
+
server_url, **kwargs,
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
# pdf_path = "../../demo/pdfs/demo3.pdf"
|
394 |
+
pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
|
395 |
+
|
396 |
+
try:
|
397 |
+
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
|
398 |
+
end_page_id=10,
|
399 |
+
backend='vlm-huggingface'
|
400 |
+
# backend = 'pipeline'
|
401 |
+
)
|
402 |
+
except Exception as e:
|
403 |
+
logger.exception(e)
|
vendor/mineru/mineru/cli/fast_api.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import os
|
3 |
+
import uvicorn
|
4 |
+
import click
|
5 |
+
from pathlib import Path
|
6 |
+
from glob import glob
|
7 |
+
from fastapi import FastAPI, UploadFile, File, Form
|
8 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
9 |
+
from fastapi.responses import JSONResponse
|
10 |
+
from typing import List, Optional
|
11 |
+
from loguru import logger
|
12 |
+
from base64 import b64encode
|
13 |
+
|
14 |
+
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
|
15 |
+
from mineru.utils.cli_parser import arg_parse
|
16 |
+
from mineru.version import __version__
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
20 |
+
|
21 |
+
def encode_image(image_path: str) -> str:
|
22 |
+
"""Encode image using base64"""
|
23 |
+
with open(image_path, "rb") as f:
|
24 |
+
return b64encode(f.read()).decode()
|
25 |
+
|
26 |
+
|
27 |
+
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
|
28 |
+
"""从结果文件中读取推理结果"""
|
29 |
+
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
|
30 |
+
if os.path.exists(result_file_path):
|
31 |
+
with open(result_file_path, "r", encoding="utf-8") as fp:
|
32 |
+
return fp.read()
|
33 |
+
return None
|
34 |
+
|
35 |
+
|
36 |
+
@app.post(path="/file_parse",)
|
37 |
+
async def parse_pdf(
|
38 |
+
files: List[UploadFile] = File(...),
|
39 |
+
output_dir: str = Form("./output"),
|
40 |
+
lang_list: List[str] = Form(["ch"]),
|
41 |
+
backend: str = Form("pipeline"),
|
42 |
+
parse_method: str = Form("auto"),
|
43 |
+
formula_enable: bool = Form(True),
|
44 |
+
table_enable: bool = Form(True),
|
45 |
+
server_url: Optional[str] = Form(None),
|
46 |
+
return_md: bool = Form(True),
|
47 |
+
return_middle_json: bool = Form(False),
|
48 |
+
return_model_output: bool = Form(False),
|
49 |
+
return_content_list: bool = Form(False),
|
50 |
+
return_images: bool = Form(False),
|
51 |
+
start_page_id: int = Form(0),
|
52 |
+
end_page_id: int = Form(99999),
|
53 |
+
):
|
54 |
+
|
55 |
+
# 获取命令行配置参数
|
56 |
+
config = getattr(app.state, "config", {})
|
57 |
+
|
58 |
+
try:
|
59 |
+
# 创建唯一的输出目录
|
60 |
+
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
|
61 |
+
os.makedirs(unique_dir, exist_ok=True)
|
62 |
+
|
63 |
+
# 处理上传的PDF文件
|
64 |
+
pdf_file_names = []
|
65 |
+
pdf_bytes_list = []
|
66 |
+
|
67 |
+
for file in files:
|
68 |
+
content = await file.read()
|
69 |
+
file_path = Path(file.filename)
|
70 |
+
|
71 |
+
# 如果是图像文件或PDF,使用read_fn处理
|
72 |
+
if file_path.suffix.lower() in pdf_suffixes + image_suffixes:
|
73 |
+
# 创建临时文件以便使用read_fn
|
74 |
+
temp_path = Path(unique_dir) / file_path.name
|
75 |
+
with open(temp_path, "wb") as f:
|
76 |
+
f.write(content)
|
77 |
+
|
78 |
+
try:
|
79 |
+
pdf_bytes = read_fn(temp_path)
|
80 |
+
pdf_bytes_list.append(pdf_bytes)
|
81 |
+
pdf_file_names.append(file_path.stem)
|
82 |
+
os.remove(temp_path) # 删除临时文件
|
83 |
+
except Exception as e:
|
84 |
+
return JSONResponse(
|
85 |
+
status_code=400,
|
86 |
+
content={"error": f"Failed to load file: {str(e)}"}
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
return JSONResponse(
|
90 |
+
status_code=400,
|
91 |
+
content={"error": f"Unsupported file type: {file_path.suffix}"}
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
# 设置语言列表,确保与文件数量一致
|
96 |
+
actual_lang_list = lang_list
|
97 |
+
if len(actual_lang_list) != len(pdf_file_names):
|
98 |
+
# 如果语言列表长度不匹配,使用第一个语言或默认"ch"
|
99 |
+
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
|
100 |
+
|
101 |
+
# 调用异步处理函数
|
102 |
+
await aio_do_parse(
|
103 |
+
output_dir=unique_dir,
|
104 |
+
pdf_file_names=pdf_file_names,
|
105 |
+
pdf_bytes_list=pdf_bytes_list,
|
106 |
+
p_lang_list=actual_lang_list,
|
107 |
+
backend=backend,
|
108 |
+
parse_method=parse_method,
|
109 |
+
formula_enable=formula_enable,
|
110 |
+
table_enable=table_enable,
|
111 |
+
server_url=server_url,
|
112 |
+
f_draw_layout_bbox=False,
|
113 |
+
f_draw_span_bbox=False,
|
114 |
+
f_dump_md=return_md,
|
115 |
+
f_dump_middle_json=return_middle_json,
|
116 |
+
f_dump_model_output=return_model_output,
|
117 |
+
f_dump_orig_pdf=False,
|
118 |
+
f_dump_content_list=return_content_list,
|
119 |
+
start_page_id=start_page_id,
|
120 |
+
end_page_id=end_page_id,
|
121 |
+
**config
|
122 |
+
)
|
123 |
+
|
124 |
+
# 构建结果路径
|
125 |
+
result_dict = {}
|
126 |
+
for pdf_name in pdf_file_names:
|
127 |
+
result_dict[pdf_name] = {}
|
128 |
+
data = result_dict[pdf_name]
|
129 |
+
|
130 |
+
if backend.startswith("pipeline"):
|
131 |
+
parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
|
132 |
+
else:
|
133 |
+
parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
|
134 |
+
|
135 |
+
if os.path.exists(parse_dir):
|
136 |
+
if return_md:
|
137 |
+
data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
|
138 |
+
if return_middle_json:
|
139 |
+
data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
|
140 |
+
if return_model_output:
|
141 |
+
if backend.startswith("pipeline"):
|
142 |
+
data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
|
143 |
+
else:
|
144 |
+
data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
|
145 |
+
if return_content_list:
|
146 |
+
data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
|
147 |
+
if return_images:
|
148 |
+
image_paths = glob(f"{parse_dir}/images/*.jpg")
|
149 |
+
data["images"] = {
|
150 |
+
os.path.basename(
|
151 |
+
image_path
|
152 |
+
): f"data:image/jpeg;base64,{encode_image(image_path)}"
|
153 |
+
for image_path in image_paths
|
154 |
+
}
|
155 |
+
return JSONResponse(
|
156 |
+
status_code=200,
|
157 |
+
content={
|
158 |
+
"backend": backend,
|
159 |
+
"version": __version__,
|
160 |
+
"results": result_dict
|
161 |
+
}
|
162 |
+
)
|
163 |
+
except Exception as e:
|
164 |
+
logger.exception(e)
|
165 |
+
return JSONResponse(
|
166 |
+
status_code=500,
|
167 |
+
content={"error": f"Failed to process file: {str(e)}"}
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
|
172 |
+
@click.pass_context
|
173 |
+
@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
|
174 |
+
@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
|
175 |
+
@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
|
176 |
+
def main(ctx, host, port, reload, **kwargs):
|
177 |
+
|
178 |
+
kwargs.update(arg_parse(ctx))
|
179 |
+
|
180 |
+
# 将配置参数存储到应用状态中
|
181 |
+
app.state.config = kwargs
|
182 |
+
|
183 |
+
"""启动MinerU FastAPI服务器的命令行入口"""
|
184 |
+
print(f"Start MinerU FastAPI Service: http://{host}:{port}")
|
185 |
+
print("The API documentation can be accessed at the following address:")
|
186 |
+
print(f"- Swagger UI: http://{host}:{port}/docs")
|
187 |
+
print(f"- ReDoc: http://{host}:{port}/redoc")
|
188 |
+
|
189 |
+
uvicorn.run(
|
190 |
+
"mineru.cli.fast_api:app",
|
191 |
+
host=host,
|
192 |
+
port=port,
|
193 |
+
reload=reload
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
main()
|
vendor/mineru/mineru/cli/gradio_app.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
|
3 |
+
import base64
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
import zipfile
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import click
|
11 |
+
import gradio as gr
|
12 |
+
from gradio_pdf import PDF
|
13 |
+
from loguru import logger
|
14 |
+
|
15 |
+
from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
|
16 |
+
from mineru.utils.cli_parser import arg_parse
|
17 |
+
from mineru.utils.hash_utils import str_sha256
|
18 |
+
|
19 |
+
|
20 |
+
async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
|
21 |
+
os.makedirs(output_dir, exist_ok=True)
|
22 |
+
|
23 |
+
try:
|
24 |
+
file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
|
25 |
+
pdf_data = read_fn(doc_path)
|
26 |
+
if is_ocr:
|
27 |
+
parse_method = 'ocr'
|
28 |
+
else:
|
29 |
+
parse_method = 'auto'
|
30 |
+
|
31 |
+
if backend.startswith("vlm"):
|
32 |
+
parse_method = "vlm"
|
33 |
+
|
34 |
+
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
|
35 |
+
await aio_do_parse(
|
36 |
+
output_dir=output_dir,
|
37 |
+
pdf_file_names=[file_name],
|
38 |
+
pdf_bytes_list=[pdf_data],
|
39 |
+
p_lang_list=[language],
|
40 |
+
parse_method=parse_method,
|
41 |
+
end_page_id=end_page_id,
|
42 |
+
formula_enable=formula_enable,
|
43 |
+
table_enable=table_enable,
|
44 |
+
backend=backend,
|
45 |
+
server_url=url,
|
46 |
+
)
|
47 |
+
return local_md_dir, file_name
|
48 |
+
except Exception as e:
|
49 |
+
logger.exception(e)
|
50 |
+
return None
|
51 |
+
|
52 |
+
|
53 |
+
def compress_directory_to_zip(directory_path, output_zip_path):
|
54 |
+
"""压缩指定目录到一个 ZIP 文件。
|
55 |
+
|
56 |
+
:param directory_path: 要压缩的目录路径
|
57 |
+
:param output_zip_path: 输出的 ZIP 文件路径
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
61 |
+
|
62 |
+
# 遍历目录中的所有文件和子目录
|
63 |
+
for root, dirs, files in os.walk(directory_path):
|
64 |
+
for file in files:
|
65 |
+
# 构建完整的文件路径
|
66 |
+
file_path = os.path.join(root, file)
|
67 |
+
# 计算相对路径
|
68 |
+
arcname = os.path.relpath(file_path, directory_path)
|
69 |
+
# 添加文件到 ZIP 文件
|
70 |
+
zipf.write(file_path, arcname)
|
71 |
+
return 0
|
72 |
+
except Exception as e:
|
73 |
+
logger.exception(e)
|
74 |
+
return -1
|
75 |
+
|
76 |
+
|
77 |
+
def image_to_base64(image_path):
|
78 |
+
with open(image_path, 'rb') as image_file:
|
79 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
80 |
+
|
81 |
+
|
82 |
+
def replace_image_with_base64(markdown_text, image_dir_path):
|
83 |
+
# 匹配Markdown中的图片标签
|
84 |
+
pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
|
85 |
+
|
86 |
+
# 替换图片链接
|
87 |
+
def replace(match):
|
88 |
+
relative_path = match.group(1)
|
89 |
+
full_path = os.path.join(image_dir_path, relative_path)
|
90 |
+
base64_image = image_to_base64(full_path)
|
91 |
+
return f''
|
92 |
+
|
93 |
+
# 应用替换
|
94 |
+
return re.sub(pattern, replace, markdown_text)
|
95 |
+
|
96 |
+
|
97 |
+
async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
|
98 |
+
file_path = to_pdf(file_path)
|
99 |
+
# 获取识别的md文件以及压缩包文件路径
|
100 |
+
local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
|
101 |
+
archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
|
102 |
+
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
|
103 |
+
if zip_archive_success == 0:
|
104 |
+
logger.info('Compression successful')
|
105 |
+
else:
|
106 |
+
logger.error('Compression failed')
|
107 |
+
md_path = os.path.join(local_md_dir, file_name + '.md')
|
108 |
+
with open(md_path, 'r', encoding='utf-8') as f:
|
109 |
+
txt_content = f.read()
|
110 |
+
md_content = replace_image_with_base64(txt_content, local_md_dir)
|
111 |
+
# 返回转换后的PDF路径
|
112 |
+
new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
|
113 |
+
|
114 |
+
return md_content, txt_content, archive_zip_path, new_pdf_path
|
115 |
+
|
116 |
+
|
117 |
+
latex_delimiters = [
|
118 |
+
{'left': '$$', 'right': '$$', 'display': True},
|
119 |
+
{'left': '$', 'right': '$', 'display': False},
|
120 |
+
{'left': '\\(', 'right': '\\)', 'display': False},
|
121 |
+
{'left': '\\[', 'right': '\\]', 'display': True},
|
122 |
+
]
|
123 |
+
|
124 |
+
header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
|
125 |
+
with open(header_path, 'r') as header_file:
|
126 |
+
header = header_file.read()
|
127 |
+
|
128 |
+
|
129 |
+
latin_lang = [
|
130 |
+
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
|
131 |
+
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
|
132 |
+
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
|
133 |
+
'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
|
134 |
+
]
|
135 |
+
arabic_lang = ['ar', 'fa', 'ug', 'ur']
|
136 |
+
cyrillic_lang = [
|
137 |
+
'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
|
138 |
+
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
|
139 |
+
]
|
140 |
+
east_slavic_lang = ["ru", "be", "uk"]
|
141 |
+
devanagari_lang = [
|
142 |
+
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
|
143 |
+
'sa', 'bgc'
|
144 |
+
]
|
145 |
+
other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
|
146 |
+
add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
|
147 |
+
|
148 |
+
# all_lang = ['', 'auto']
|
149 |
+
all_lang = []
|
150 |
+
# all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
|
151 |
+
all_lang.extend([*other_lang, *add_lang])
|
152 |
+
|
153 |
+
|
154 |
+
def safe_stem(file_path):
|
155 |
+
stem = Path(file_path).stem
|
156 |
+
# 只保留字母、数字、下划线和点,其他字符替换为下划线
|
157 |
+
return re.sub(r'[^\w.]', '_', stem)
|
158 |
+
|
159 |
+
|
160 |
+
def to_pdf(file_path):
|
161 |
+
|
162 |
+
if file_path is None:
|
163 |
+
return None
|
164 |
+
|
165 |
+
pdf_bytes = read_fn(file_path)
|
166 |
+
|
167 |
+
# unique_filename = f'{uuid.uuid4()}.pdf'
|
168 |
+
unique_filename = f'{safe_stem(file_path)}.pdf'
|
169 |
+
|
170 |
+
# 构建完整的文件路径
|
171 |
+
tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
|
172 |
+
|
173 |
+
# 将字节数据写入文件
|
174 |
+
with open(tmp_file_path, 'wb') as tmp_pdf_file:
|
175 |
+
tmp_pdf_file.write(pdf_bytes)
|
176 |
+
|
177 |
+
return tmp_file_path
|
178 |
+
|
179 |
+
|
180 |
+
# 更新界面函数
|
181 |
+
def update_interface(backend_choice):
|
182 |
+
if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
|
183 |
+
return gr.update(visible=False), gr.update(visible=False)
|
184 |
+
elif backend_choice in ["vlm-sglang-client"]:
|
185 |
+
return gr.update(visible=True), gr.update(visible=False)
|
186 |
+
elif backend_choice in ["pipeline"]:
|
187 |
+
return gr.update(visible=False), gr.update(visible=True)
|
188 |
+
else:
|
189 |
+
pass
|
190 |
+
|
191 |
+
|
192 |
+
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
|
193 |
+
@click.pass_context
|
194 |
+
@click.option(
|
195 |
+
'--enable-example',
|
196 |
+
'example_enable',
|
197 |
+
type=bool,
|
198 |
+
help="Enable example files for input."
|
199 |
+
"The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
|
200 |
+
default=True,
|
201 |
+
)
|
202 |
+
@click.option(
|
203 |
+
'--enable-sglang-engine',
|
204 |
+
'sglang_engine_enable',
|
205 |
+
type=bool,
|
206 |
+
help="Enable SgLang engine backend for faster processing.",
|
207 |
+
default=False,
|
208 |
+
)
|
209 |
+
@click.option(
|
210 |
+
'--enable-api',
|
211 |
+
'api_enable',
|
212 |
+
type=bool,
|
213 |
+
help="Enable gradio API for serving the application.",
|
214 |
+
default=True,
|
215 |
+
)
|
216 |
+
@click.option(
|
217 |
+
'--max-convert-pages',
|
218 |
+
'max_convert_pages',
|
219 |
+
type=int,
|
220 |
+
help="Set the maximum number of pages to convert from PDF to Markdown.",
|
221 |
+
default=1000,
|
222 |
+
)
|
223 |
+
@click.option(
|
224 |
+
'--server-name',
|
225 |
+
'server_name',
|
226 |
+
type=str,
|
227 |
+
help="Set the server name for the Gradio app.",
|
228 |
+
default=None,
|
229 |
+
)
|
230 |
+
@click.option(
|
231 |
+
'--server-port',
|
232 |
+
'server_port',
|
233 |
+
type=int,
|
234 |
+
help="Set the server port for the Gradio app.",
|
235 |
+
default=None,
|
236 |
+
)
|
237 |
+
def main(ctx,
|
238 |
+
example_enable, sglang_engine_enable, api_enable, max_convert_pages,
|
239 |
+
server_name, server_port, **kwargs
|
240 |
+
):
|
241 |
+
|
242 |
+
kwargs.update(arg_parse(ctx))
|
243 |
+
|
244 |
+
if sglang_engine_enable:
|
245 |
+
try:
|
246 |
+
print("Start init SgLang engine...")
|
247 |
+
from mineru.backend.vlm.vlm_analyze import ModelSingleton
|
248 |
+
model_singleton = ModelSingleton()
|
249 |
+
predictor = model_singleton.get_model(
|
250 |
+
"sglang-engine",
|
251 |
+
None,
|
252 |
+
None,
|
253 |
+
**kwargs
|
254 |
+
)
|
255 |
+
print("SgLang engine init successfully.")
|
256 |
+
except Exception as e:
|
257 |
+
logger.exception(e)
|
258 |
+
|
259 |
+
suffixes = pdf_suffixes + image_suffixes
|
260 |
+
with gr.Blocks() as demo:
|
261 |
+
gr.HTML(header)
|
262 |
+
with gr.Row():
|
263 |
+
with gr.Column(variant='panel', scale=5):
|
264 |
+
with gr.Row():
|
265 |
+
input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
|
266 |
+
with gr.Row():
|
267 |
+
max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
|
268 |
+
with gr.Row():
|
269 |
+
if sglang_engine_enable:
|
270 |
+
drop_list = ["pipeline", "vlm-sglang-engine"]
|
271 |
+
preferred_option = "vlm-sglang-engine"
|
272 |
+
else:
|
273 |
+
drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
|
274 |
+
preferred_option = "pipeline"
|
275 |
+
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
|
276 |
+
with gr.Row(visible=False) as client_options:
|
277 |
+
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
|
278 |
+
with gr.Row(equal_height=True):
|
279 |
+
with gr.Column():
|
280 |
+
gr.Markdown("**Recognition Options:**")
|
281 |
+
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
|
282 |
+
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
|
283 |
+
with gr.Column(visible=False) as ocr_options:
|
284 |
+
language = gr.Dropdown(all_lang, label='Language', value='ch')
|
285 |
+
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
|
286 |
+
with gr.Row():
|
287 |
+
change_bu = gr.Button('Convert')
|
288 |
+
clear_bu = gr.ClearButton(value='Clear')
|
289 |
+
pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
|
290 |
+
if example_enable:
|
291 |
+
example_root = os.path.join(os.getcwd(), 'examples')
|
292 |
+
if os.path.exists(example_root):
|
293 |
+
with gr.Accordion('Examples:'):
|
294 |
+
gr.Examples(
|
295 |
+
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
|
296 |
+
_.endswith(tuple(suffixes))],
|
297 |
+
inputs=input_file
|
298 |
+
)
|
299 |
+
|
300 |
+
with gr.Column(variant='panel', scale=5):
|
301 |
+
output_file = gr.File(label='convert result', interactive=False)
|
302 |
+
with gr.Tabs():
|
303 |
+
with gr.Tab('Markdown rendering'):
|
304 |
+
md = gr.Markdown(label='Markdown rendering', height=1100, show_copy_button=True,
|
305 |
+
latex_delimiters=latex_delimiters,
|
306 |
+
line_breaks=True)
|
307 |
+
with gr.Tab('Markdown text'):
|
308 |
+
md_text = gr.TextArea(lines=45, show_copy_button=True)
|
309 |
+
|
310 |
+
# 添加事件处理
|
311 |
+
backend.change(
|
312 |
+
fn=update_interface,
|
313 |
+
inputs=[backend],
|
314 |
+
outputs=[client_options, ocr_options],
|
315 |
+
api_name=False
|
316 |
+
)
|
317 |
+
# 添加demo.load事件,在页面加载时触发一次界面更新
|
318 |
+
demo.load(
|
319 |
+
fn=update_interface,
|
320 |
+
inputs=[backend],
|
321 |
+
outputs=[client_options, ocr_options],
|
322 |
+
api_name=False
|
323 |
+
)
|
324 |
+
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
|
325 |
+
|
326 |
+
if api_enable:
|
327 |
+
api_name = None
|
328 |
+
else:
|
329 |
+
api_name = False
|
330 |
+
|
331 |
+
input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
|
332 |
+
change_bu.click(
|
333 |
+
fn=to_markdown,
|
334 |
+
inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
|
335 |
+
outputs=[md, md_text, output_file, pdf_show],
|
336 |
+
api_name=api_name
|
337 |
+
)
|
338 |
+
|
339 |
+
demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable)
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == '__main__':
|
343 |
+
main()
|
vendor/mineru/mineru/cli/models_download.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import click
|
5 |
+
import requests
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
from mineru.utils.enum_class import ModelPath
|
9 |
+
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
10 |
+
|
11 |
+
|
12 |
+
def download_json(url):
|
13 |
+
"""下载JSON文件"""
|
14 |
+
response = requests.get(url)
|
15 |
+
response.raise_for_status()
|
16 |
+
return response.json()
|
17 |
+
|
18 |
+
|
19 |
+
def download_and_modify_json(url, local_filename, modifications):
|
20 |
+
"""下载JSON并修改内容"""
|
21 |
+
if os.path.exists(local_filename):
|
22 |
+
data = json.load(open(local_filename))
|
23 |
+
config_version = data.get('config_version', '0.0.0')
|
24 |
+
if config_version < '1.3.0':
|
25 |
+
data = download_json(url)
|
26 |
+
else:
|
27 |
+
data = download_json(url)
|
28 |
+
|
29 |
+
# 修改内容
|
30 |
+
for key, value in modifications.items():
|
31 |
+
if key in data:
|
32 |
+
if isinstance(data[key], dict):
|
33 |
+
# 如果是字典,合并新值
|
34 |
+
data[key].update(value)
|
35 |
+
else:
|
36 |
+
# 否则直接替换
|
37 |
+
data[key] = value
|
38 |
+
|
39 |
+
# 保存修改后的内容
|
40 |
+
with open(local_filename, 'w', encoding='utf-8') as f:
|
41 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
42 |
+
|
43 |
+
|
44 |
+
def configure_model(model_dir, model_type):
|
45 |
+
"""配置模型"""
|
46 |
+
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
|
47 |
+
config_file_name = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
|
48 |
+
home_dir = os.path.expanduser('~')
|
49 |
+
config_file = os.path.join(home_dir, config_file_name)
|
50 |
+
|
51 |
+
json_mods = {
|
52 |
+
'models-dir': {
|
53 |
+
f'{model_type}': model_dir
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
download_and_modify_json(json_url, config_file, json_mods)
|
58 |
+
logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
|
59 |
+
|
60 |
+
|
61 |
+
def download_pipeline_models():
|
62 |
+
"""下载Pipeline模型"""
|
63 |
+
model_paths = [
|
64 |
+
ModelPath.doclayout_yolo,
|
65 |
+
ModelPath.yolo_v8_mfd,
|
66 |
+
ModelPath.unimernet_small,
|
67 |
+
ModelPath.pytorch_paddle,
|
68 |
+
ModelPath.layout_reader,
|
69 |
+
ModelPath.slanet_plus
|
70 |
+
]
|
71 |
+
download_finish_path = ""
|
72 |
+
for model_path in model_paths:
|
73 |
+
logger.info(f"Downloading model: {model_path}")
|
74 |
+
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
|
75 |
+
logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
|
76 |
+
configure_model(download_finish_path, "pipeline")
|
77 |
+
|
78 |
+
|
79 |
+
def download_vlm_models():
|
80 |
+
"""下载VLM模型"""
|
81 |
+
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
|
82 |
+
logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
|
83 |
+
configure_model(download_finish_path, "vlm")
|
84 |
+
|
85 |
+
|
86 |
+
@click.command()
|
87 |
+
@click.option(
|
88 |
+
'-s',
|
89 |
+
'--source',
|
90 |
+
'model_source',
|
91 |
+
type=click.Choice(['huggingface', 'modelscope']),
|
92 |
+
help="""
|
93 |
+
The source of the model repository.
|
94 |
+
""",
|
95 |
+
default=None,
|
96 |
+
)
|
97 |
+
@click.option(
|
98 |
+
'-m',
|
99 |
+
'--model_type',
|
100 |
+
'model_type',
|
101 |
+
type=click.Choice(['pipeline', 'vlm', 'all']),
|
102 |
+
help="""
|
103 |
+
The type of the model to download.
|
104 |
+
""",
|
105 |
+
default=None,
|
106 |
+
)
|
107 |
+
def download_models(model_source, model_type):
|
108 |
+
"""Download MinerU model files.
|
109 |
+
|
110 |
+
Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
|
111 |
+
"""
|
112 |
+
# 如果未显式指定则交互式输入下载来源
|
113 |
+
if model_source is None:
|
114 |
+
model_source = click.prompt(
|
115 |
+
"Please select the model download source: ",
|
116 |
+
type=click.Choice(['huggingface', 'modelscope']),
|
117 |
+
default='huggingface'
|
118 |
+
)
|
119 |
+
|
120 |
+
if os.getenv('MINERU_MODEL_SOURCE', None) is None:
|
121 |
+
os.environ['MINERU_MODEL_SOURCE'] = model_source
|
122 |
+
|
123 |
+
# 如果未显式指定则交互式输入模型类型
|
124 |
+
if model_type is None:
|
125 |
+
model_type = click.prompt(
|
126 |
+
"Please select the model type to download: ",
|
127 |
+
type=click.Choice(['pipeline', 'vlm', 'all']),
|
128 |
+
default='all'
|
129 |
+
)
|
130 |
+
|
131 |
+
logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
|
132 |
+
|
133 |
+
try:
|
134 |
+
if model_type == 'pipeline':
|
135 |
+
download_pipeline_models()
|
136 |
+
elif model_type == 'vlm':
|
137 |
+
download_vlm_models()
|
138 |
+
elif model_type == 'all':
|
139 |
+
download_pipeline_models()
|
140 |
+
download_vlm_models()
|
141 |
+
else:
|
142 |
+
click.echo(f"Unsupported model type: {model_type}", err=True)
|
143 |
+
sys.exit(1)
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
logger.exception(f"An error occurred while downloading models: {str(e)}")
|
147 |
+
sys.exit(1)
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
download_models()
|
vendor/mineru/mineru/cli/vlm_sglang_server.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..model.vlm_sglang_model.server import main
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
main()
|
vendor/mineru/mineru/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/data/data_reader_writer/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import DataReader, DataWriter
|
2 |
+
from .dummy import DummyDataWriter
|
3 |
+
from .filebase import FileBasedDataReader, FileBasedDataWriter
|
4 |
+
from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
|
5 |
+
from .s3 import S3DataReader, S3DataWriter
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"DataReader",
|
9 |
+
"DataWriter",
|
10 |
+
"FileBasedDataReader",
|
11 |
+
"FileBasedDataWriter",
|
12 |
+
"S3DataReader",
|
13 |
+
"S3DataWriter",
|
14 |
+
"MultiBucketS3DataReader",
|
15 |
+
"MultiBucketS3DataWriter",
|
16 |
+
"DummyDataWriter",
|
17 |
+
]
|
vendor/mineru/mineru/data/data_reader_writer/base.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
|
5 |
+
class DataReader(ABC):
|
6 |
+
|
7 |
+
def read(self, path: str) -> bytes:
|
8 |
+
"""Read the file.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
path (str): file path to read
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
bytes: the content of the file
|
15 |
+
"""
|
16 |
+
return self.read_at(path)
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
20 |
+
"""Read the file at offset and limit.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
path (str): the file path
|
24 |
+
offset (int, optional): the number of bytes skipped. Defaults to 0.
|
25 |
+
limit (int, optional): the length of bytes want to read. Defaults to -1.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
bytes: the content of the file
|
29 |
+
"""
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
class DataWriter(ABC):
|
34 |
+
@abstractmethod
|
35 |
+
def write(self, path: str, data: bytes) -> None:
|
36 |
+
"""Write the data to the file.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
path (str): the target file where to write
|
40 |
+
data (bytes): the data want to write
|
41 |
+
"""
|
42 |
+
pass
|
43 |
+
|
44 |
+
def write_string(self, path: str, data: str) -> None:
|
45 |
+
"""Write the data to file, the data will be encoded to bytes.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
path (str): the target file where to write
|
49 |
+
data (str): the data want to write
|
50 |
+
"""
|
51 |
+
|
52 |
+
def safe_encode(data: str, method: str):
|
53 |
+
try:
|
54 |
+
bit_data = data.encode(encoding=method, errors='replace')
|
55 |
+
return bit_data, True
|
56 |
+
except: # noqa
|
57 |
+
return None, False
|
58 |
+
|
59 |
+
for method in ['utf-8', 'ascii']:
|
60 |
+
bit_data, flag = safe_encode(data, method)
|
61 |
+
if flag:
|
62 |
+
self.write(path, bit_data)
|
63 |
+
break
|
vendor/mineru/mineru/data/data_reader_writer/dummy.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import DataWriter
|
2 |
+
|
3 |
+
|
4 |
+
class DummyDataWriter(DataWriter):
|
5 |
+
def write(self, path: str, data: bytes) -> None:
|
6 |
+
"""Dummy write method that does nothing."""
|
7 |
+
pass
|
8 |
+
|
9 |
+
def write_string(self, path: str, data: str) -> None:
|
10 |
+
"""Dummy write_string method that does nothing."""
|
11 |
+
pass
|
vendor/mineru/mineru/data/data_reader_writer/filebase.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .base import DataReader, DataWriter
|
4 |
+
|
5 |
+
|
6 |
+
class FileBasedDataReader(DataReader):
|
7 |
+
def __init__(self, parent_dir: str = ''):
|
8 |
+
"""Initialized with parent_dir.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
|
12 |
+
"""
|
13 |
+
self._parent_dir = parent_dir
|
14 |
+
|
15 |
+
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
16 |
+
"""Read at offset and limit.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
20 |
+
offset (int, optional): the number of bytes skipped. Defaults to 0.
|
21 |
+
limit (int, optional): the length of bytes want to read. Defaults to -1.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
bytes: the content of file
|
25 |
+
"""
|
26 |
+
fn_path = path
|
27 |
+
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
|
28 |
+
fn_path = os.path.join(self._parent_dir, path)
|
29 |
+
|
30 |
+
with open(fn_path, 'rb') as f:
|
31 |
+
f.seek(offset)
|
32 |
+
if limit == -1:
|
33 |
+
return f.read()
|
34 |
+
else:
|
35 |
+
return f.read(limit)
|
36 |
+
|
37 |
+
|
38 |
+
class FileBasedDataWriter(DataWriter):
|
39 |
+
def __init__(self, parent_dir: str = '') -> None:
|
40 |
+
"""Initialized with parent_dir.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
|
44 |
+
"""
|
45 |
+
self._parent_dir = parent_dir
|
46 |
+
|
47 |
+
def write(self, path: str, data: bytes) -> None:
|
48 |
+
"""Write file with data.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
52 |
+
data (bytes): the data want to write
|
53 |
+
"""
|
54 |
+
fn_path = path
|
55 |
+
if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
|
56 |
+
fn_path = os.path.join(self._parent_dir, path)
|
57 |
+
|
58 |
+
if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
|
59 |
+
os.makedirs(os.path.dirname(fn_path), exist_ok=True)
|
60 |
+
|
61 |
+
with open(fn_path, 'wb') as f:
|
62 |
+
f.write(data)
|
vendor/mineru/mineru/data/data_reader_writer/multi_bucket_s3.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from ..utils.exceptions import InvalidConfig, InvalidParams
|
3 |
+
from .base import DataReader, DataWriter
|
4 |
+
from ..io.s3 import S3Reader, S3Writer
|
5 |
+
from ..utils.schemas import S3Config
|
6 |
+
from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
|
7 |
+
|
8 |
+
|
9 |
+
class MultiS3Mixin:
|
10 |
+
def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
|
11 |
+
"""Initialized with multiple s3 configs.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
|
15 |
+
s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
|
16 |
+
|
17 |
+
Raises:
|
18 |
+
InvalidConfig: default bucket config not in s3_configs.
|
19 |
+
InvalidConfig: bucket name not unique in s3_configs.
|
20 |
+
InvalidConfig: default bucket must be provided.
|
21 |
+
"""
|
22 |
+
if len(default_prefix) == 0:
|
23 |
+
raise InvalidConfig('default_prefix must be provided')
|
24 |
+
|
25 |
+
arr = default_prefix.strip('/').split('/')
|
26 |
+
self.default_bucket = arr[0]
|
27 |
+
self.default_prefix = '/'.join(arr[1:])
|
28 |
+
|
29 |
+
found_default_bucket_config = False
|
30 |
+
for conf in s3_configs:
|
31 |
+
if conf.bucket_name == self.default_bucket:
|
32 |
+
found_default_bucket_config = True
|
33 |
+
break
|
34 |
+
|
35 |
+
if not found_default_bucket_config:
|
36 |
+
raise InvalidConfig(
|
37 |
+
f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
|
38 |
+
)
|
39 |
+
|
40 |
+
uniq_bucket = set([conf.bucket_name for conf in s3_configs])
|
41 |
+
if len(uniq_bucket) != len(s3_configs):
|
42 |
+
raise InvalidConfig(
|
43 |
+
f'the bucket_name in s3_configs: {s3_configs} must be unique'
|
44 |
+
)
|
45 |
+
|
46 |
+
self.s3_configs = s3_configs
|
47 |
+
self._s3_clients_h: dict = {}
|
48 |
+
|
49 |
+
|
50 |
+
class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
|
51 |
+
def read(self, path: str) -> bytes:
|
52 |
+
"""Read the path from s3, select diffect bucket client for each request
|
53 |
+
based on the bucket, also support range read.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
|
57 |
+
for example: s3://bucket_name/path?0,100.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
bytes: the content of s3 file.
|
61 |
+
"""
|
62 |
+
may_range_params = parse_s3_range_params(path)
|
63 |
+
if may_range_params is None or 2 != len(may_range_params):
|
64 |
+
byte_start, byte_len = 0, -1
|
65 |
+
else:
|
66 |
+
byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
|
67 |
+
path = remove_non_official_s3_args(path)
|
68 |
+
return self.read_at(path, byte_start, byte_len)
|
69 |
+
|
70 |
+
def __get_s3_client(self, bucket_name: str):
|
71 |
+
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
|
72 |
+
raise InvalidParams(
|
73 |
+
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
|
74 |
+
)
|
75 |
+
if bucket_name not in self._s3_clients_h:
|
76 |
+
conf = next(
|
77 |
+
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
|
78 |
+
)
|
79 |
+
self._s3_clients_h[bucket_name] = S3Reader(
|
80 |
+
bucket_name,
|
81 |
+
conf.access_key,
|
82 |
+
conf.secret_key,
|
83 |
+
conf.endpoint_url,
|
84 |
+
conf.addressing_style,
|
85 |
+
)
|
86 |
+
return self._s3_clients_h[bucket_name]
|
87 |
+
|
88 |
+
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
89 |
+
"""Read the file with offset and limit, select diffect bucket client
|
90 |
+
for each request based on the bucket.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
path (str): the file path.
|
94 |
+
offset (int, optional): the number of bytes skipped. Defaults to 0.
|
95 |
+
limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
bytes: the file content.
|
99 |
+
"""
|
100 |
+
if path.startswith('s3://'):
|
101 |
+
bucket_name, path = parse_s3path(path)
|
102 |
+
s3_reader = self.__get_s3_client(bucket_name)
|
103 |
+
else:
|
104 |
+
s3_reader = self.__get_s3_client(self.default_bucket)
|
105 |
+
if self.default_prefix:
|
106 |
+
path = self.default_prefix + '/' + path
|
107 |
+
return s3_reader.read_at(path, offset, limit)
|
108 |
+
|
109 |
+
|
110 |
+
class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
|
111 |
+
def __get_s3_client(self, bucket_name: str):
|
112 |
+
if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
|
113 |
+
raise InvalidParams(
|
114 |
+
f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
|
115 |
+
)
|
116 |
+
if bucket_name not in self._s3_clients_h:
|
117 |
+
conf = next(
|
118 |
+
filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
|
119 |
+
)
|
120 |
+
self._s3_clients_h[bucket_name] = S3Writer(
|
121 |
+
bucket_name,
|
122 |
+
conf.access_key,
|
123 |
+
conf.secret_key,
|
124 |
+
conf.endpoint_url,
|
125 |
+
conf.addressing_style,
|
126 |
+
)
|
127 |
+
return self._s3_clients_h[bucket_name]
|
128 |
+
|
129 |
+
def write(self, path: str, data: bytes) -> None:
|
130 |
+
"""Write file with data, also select diffect bucket client for each
|
131 |
+
request based on the bucket.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
135 |
+
data (bytes): the data want to write.
|
136 |
+
"""
|
137 |
+
if path.startswith('s3://'):
|
138 |
+
bucket_name, path = parse_s3path(path)
|
139 |
+
s3_writer = self.__get_s3_client(bucket_name)
|
140 |
+
else:
|
141 |
+
s3_writer = self.__get_s3_client(self.default_bucket)
|
142 |
+
if self.default_prefix:
|
143 |
+
path = self.default_prefix + '/' + path
|
144 |
+
return s3_writer.write(path, data)
|
vendor/mineru/mineru/data/data_reader_writer/s3.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
|
2 |
+
from ..utils.schemas import S3Config
|
3 |
+
|
4 |
+
|
5 |
+
class S3DataReader(MultiBucketS3DataReader):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
default_prefix_without_bucket: str,
|
9 |
+
bucket: str,
|
10 |
+
ak: str,
|
11 |
+
sk: str,
|
12 |
+
endpoint_url: str,
|
13 |
+
addressing_style: str = 'auto',
|
14 |
+
):
|
15 |
+
"""s3 reader client.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
default_prefix_without_bucket: prefix that not contains bucket
|
19 |
+
bucket (str): bucket name
|
20 |
+
ak (str): access key
|
21 |
+
sk (str): secret key
|
22 |
+
endpoint_url (str): endpoint url of s3
|
23 |
+
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
|
24 |
+
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
|
25 |
+
"""
|
26 |
+
super().__init__(
|
27 |
+
f'{bucket}/{default_prefix_without_bucket}',
|
28 |
+
[
|
29 |
+
S3Config(
|
30 |
+
bucket_name=bucket,
|
31 |
+
access_key=ak,
|
32 |
+
secret_key=sk,
|
33 |
+
endpoint_url=endpoint_url,
|
34 |
+
addressing_style=addressing_style,
|
35 |
+
)
|
36 |
+
],
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class S3DataWriter(MultiBucketS3DataWriter):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
default_prefix_without_bucket: str,
|
44 |
+
bucket: str,
|
45 |
+
ak: str,
|
46 |
+
sk: str,
|
47 |
+
endpoint_url: str,
|
48 |
+
addressing_style: str = 'auto',
|
49 |
+
):
|
50 |
+
"""s3 writer client.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
default_prefix_without_bucket: prefix that not contains bucket
|
54 |
+
bucket (str): bucket name
|
55 |
+
ak (str): access key
|
56 |
+
sk (str): secret key
|
57 |
+
endpoint_url (str): endpoint url of s3
|
58 |
+
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
|
59 |
+
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
|
60 |
+
"""
|
61 |
+
super().__init__(
|
62 |
+
f'{bucket}/{default_prefix_without_bucket}',
|
63 |
+
[
|
64 |
+
S3Config(
|
65 |
+
bucket_name=bucket,
|
66 |
+
access_key=ak,
|
67 |
+
secret_key=sk,
|
68 |
+
endpoint_url=endpoint_url,
|
69 |
+
addressing_style=addressing_style,
|
70 |
+
)
|
71 |
+
],
|
72 |
+
)
|
vendor/mineru/mineru/data/io/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base import IOReader, IOWriter
|
3 |
+
from .http import HttpReader, HttpWriter
|
4 |
+
from .s3 import S3Reader, S3Writer
|
5 |
+
|
6 |
+
__all__ = ['IOReader', 'IOWriter', 'HttpReader', 'HttpWriter', 'S3Reader', 'S3Writer']
|
vendor/mineru/mineru/data/io/base.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class IOReader(ABC):
|
5 |
+
@abstractmethod
|
6 |
+
def read(self, path: str) -> bytes:
|
7 |
+
"""Read the file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
path (str): file path to read
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
bytes: the content of the file
|
14 |
+
"""
|
15 |
+
pass
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
19 |
+
"""Read at offset and limit.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
23 |
+
offset (int, optional): the number of bytes skipped. Defaults to 0.
|
24 |
+
limit (int, optional): the length of bytes want to read. Defaults to -1.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
bytes: the content of file
|
28 |
+
"""
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
class IOWriter(ABC):
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def write(self, path: str, data: bytes) -> None:
|
36 |
+
"""Write file with data.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
40 |
+
data (bytes): the data want to write
|
41 |
+
"""
|
42 |
+
pass
|
vendor/mineru/mineru/data/io/http.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import io
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from .base import IOReader, IOWriter
|
7 |
+
|
8 |
+
|
9 |
+
class HttpReader(IOReader):
|
10 |
+
|
11 |
+
def read(self, url: str) -> bytes:
|
12 |
+
"""Read the file.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (str): file path to read
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
bytes: the content of the file
|
19 |
+
"""
|
20 |
+
return requests.get(url).content
|
21 |
+
|
22 |
+
def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
|
23 |
+
"""Not Implemented."""
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
|
27 |
+
class HttpWriter(IOWriter):
|
28 |
+
def write(self, url: str, data: bytes) -> None:
|
29 |
+
"""Write file with data.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
33 |
+
data (bytes): the data want to write
|
34 |
+
"""
|
35 |
+
files = {'file': io.BytesIO(data)}
|
36 |
+
response = requests.post(url, files=files)
|
37 |
+
assert 300 > response.status_code and response.status_code > 199
|
vendor/mineru/mineru/data/io/s3.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
from botocore.config import Config
|
3 |
+
|
4 |
+
from ..io.base import IOReader, IOWriter
|
5 |
+
|
6 |
+
|
7 |
+
class S3Reader(IOReader):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
bucket: str,
|
11 |
+
ak: str,
|
12 |
+
sk: str,
|
13 |
+
endpoint_url: str,
|
14 |
+
addressing_style: str = 'auto',
|
15 |
+
):
|
16 |
+
"""s3 reader client.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
bucket (str): bucket name
|
20 |
+
ak (str): access key
|
21 |
+
sk (str): secret key
|
22 |
+
endpoint_url (str): endpoint url of s3
|
23 |
+
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
|
24 |
+
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
|
25 |
+
"""
|
26 |
+
self._bucket = bucket
|
27 |
+
self._ak = ak
|
28 |
+
self._sk = sk
|
29 |
+
self._s3_client = boto3.client(
|
30 |
+
service_name='s3',
|
31 |
+
aws_access_key_id=ak,
|
32 |
+
aws_secret_access_key=sk,
|
33 |
+
endpoint_url=endpoint_url,
|
34 |
+
config=Config(
|
35 |
+
s3={'addressing_style': addressing_style},
|
36 |
+
retries={'max_attempts': 5, 'mode': 'standard'},
|
37 |
+
),
|
38 |
+
)
|
39 |
+
|
40 |
+
def read(self, key: str) -> bytes:
|
41 |
+
"""Read the file.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
path (str): file path to read
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
bytes: the content of the file
|
48 |
+
"""
|
49 |
+
return self.read_at(key)
|
50 |
+
|
51 |
+
def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
|
52 |
+
"""Read at offset and limit.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
56 |
+
offset (int, optional): the number of bytes skipped. Defaults to 0.
|
57 |
+
limit (int, optional): the length of bytes want to read. Defaults to -1.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
bytes: the content of file
|
61 |
+
"""
|
62 |
+
if limit > -1:
|
63 |
+
range_header = f'bytes={offset}-{offset+limit-1}'
|
64 |
+
res = self._s3_client.get_object(
|
65 |
+
Bucket=self._bucket, Key=key, Range=range_header
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
res = self._s3_client.get_object(
|
69 |
+
Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
|
70 |
+
)
|
71 |
+
return res['Body'].read()
|
72 |
+
|
73 |
+
|
74 |
+
class S3Writer(IOWriter):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
bucket: str,
|
78 |
+
ak: str,
|
79 |
+
sk: str,
|
80 |
+
endpoint_url: str,
|
81 |
+
addressing_style: str = 'auto',
|
82 |
+
):
|
83 |
+
"""s3 reader client.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
bucket (str): bucket name
|
87 |
+
ak (str): access key
|
88 |
+
sk (str): secret key
|
89 |
+
endpoint_url (str): endpoint url of s3
|
90 |
+
addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
|
91 |
+
refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
|
92 |
+
"""
|
93 |
+
self._bucket = bucket
|
94 |
+
self._ak = ak
|
95 |
+
self._sk = sk
|
96 |
+
self._s3_client = boto3.client(
|
97 |
+
service_name='s3',
|
98 |
+
aws_access_key_id=ak,
|
99 |
+
aws_secret_access_key=sk,
|
100 |
+
endpoint_url=endpoint_url,
|
101 |
+
config=Config(
|
102 |
+
s3={'addressing_style': addressing_style},
|
103 |
+
retries={'max_attempts': 5, 'mode': 'standard'},
|
104 |
+
),
|
105 |
+
)
|
106 |
+
|
107 |
+
def write(self, key: str, data: bytes):
|
108 |
+
"""Write file with data.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
|
112 |
+
data (bytes): the data want to write
|
113 |
+
"""
|
114 |
+
self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
|
vendor/mineru/mineru/data/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/data/utils/exceptions.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
|
3 |
+
class FileNotExisted(Exception):
|
4 |
+
|
5 |
+
def __init__(self, path):
|
6 |
+
self.path = path
|
7 |
+
|
8 |
+
def __str__(self):
|
9 |
+
return f'File {self.path} does not exist.'
|
10 |
+
|
11 |
+
|
12 |
+
class InvalidConfig(Exception):
|
13 |
+
def __init__(self, msg):
|
14 |
+
self.msg = msg
|
15 |
+
|
16 |
+
def __str__(self):
|
17 |
+
return f'Invalid config: {self.msg}'
|
18 |
+
|
19 |
+
|
20 |
+
class InvalidParams(Exception):
|
21 |
+
def __init__(self, msg):
|
22 |
+
self.msg = msg
|
23 |
+
|
24 |
+
def __str__(self):
|
25 |
+
return f'Invalid params: {self.msg}'
|
26 |
+
|
27 |
+
|
28 |
+
class EmptyData(Exception):
|
29 |
+
def __init__(self, msg):
|
30 |
+
self.msg = msg
|
31 |
+
|
32 |
+
def __str__(self):
|
33 |
+
return f'Empty data: {self.msg}'
|
34 |
+
|
35 |
+
class CUDA_NOT_AVAILABLE(Exception):
|
36 |
+
def __init__(self, msg):
|
37 |
+
self.msg = msg
|
38 |
+
|
39 |
+
def __str__(self):
|
40 |
+
return f'CUDA not available: {self.msg}'
|
vendor/mineru/mineru/data/utils/path_utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
def remove_non_official_s3_args(s3path):
|
5 |
+
"""
|
6 |
+
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
|
7 |
+
"""
|
8 |
+
arr = s3path.split("?")
|
9 |
+
return arr[0]
|
10 |
+
|
11 |
+
def parse_s3path(s3path: str):
|
12 |
+
# from s3pathlib import S3Path
|
13 |
+
# p = S3Path(remove_non_official_s3_args(s3path))
|
14 |
+
# return p.bucket, p.key
|
15 |
+
s3path = remove_non_official_s3_args(s3path).strip()
|
16 |
+
if s3path.startswith(('s3://', 's3a://')):
|
17 |
+
prefix, path = s3path.split('://', 1)
|
18 |
+
bucket_name, key = path.split('/', 1)
|
19 |
+
return bucket_name, key
|
20 |
+
elif s3path.startswith('/'):
|
21 |
+
raise ValueError("The provided path starts with '/'. This does not conform to a valid S3 path format.")
|
22 |
+
else:
|
23 |
+
raise ValueError("Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'.")
|
24 |
+
|
25 |
+
|
26 |
+
def parse_s3_range_params(s3path: str):
|
27 |
+
"""
|
28 |
+
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
|
29 |
+
"""
|
30 |
+
arr = s3path.split("?bytes=")
|
31 |
+
if len(arr) == 1:
|
32 |
+
return None
|
33 |
+
return arr[1].split(",")
|
vendor/mineru/mineru/data/utils/schemas.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
|
5 |
+
|
6 |
+
class S3Config(BaseModel):
|
7 |
+
"""S3 config
|
8 |
+
"""
|
9 |
+
bucket_name: str = Field(description='s3 bucket name', min_length=1)
|
10 |
+
access_key: str = Field(description='s3 access key', min_length=1)
|
11 |
+
secret_key: str = Field(description='s3 secret key', min_length=1)
|
12 |
+
endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
|
13 |
+
addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
|
14 |
+
|
15 |
+
|
16 |
+
class PageInfo(BaseModel):
|
17 |
+
"""The width and height of page
|
18 |
+
"""
|
19 |
+
w: float = Field(description='the width of page')
|
20 |
+
h: float = Field(description='the height of page')
|
vendor/mineru/mineru/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|
vendor/mineru/mineru/model/layout/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Opendatalab. All rights reserved.
|