marcosremar2 commited on
Commit
550ec39
·
1 Parent(s): 57d51b7

Add PDF conversion API endpoints

Browse files

- Add PDF upload endpoint at /api/convert
- Add status checking endpoint at /api/status/{task_id}
- Add download endpoint at /api/download/{task_id}
- Add basic PDF processing dependency (PyMuPDF)
- Prepare structure for MinerU integration
- Update API info with new endpoints

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +118 -6
  2. config/magic-pdf.json +9 -0
  3. pdf_converter_mineru.py +272 -0
  4. requirements.txt +4 -1
  5. vendor/mineru/mineru/__init__.py +1 -0
  6. vendor/mineru/mineru/backend/__init__.py +1 -0
  7. vendor/mineru/mineru/backend/pipeline/__init__.py +1 -0
  8. vendor/mineru/mineru/backend/pipeline/batch_analyze.py +331 -0
  9. vendor/mineru/mineru/backend/pipeline/model_init.py +182 -0
  10. vendor/mineru/mineru/backend/pipeline/model_json_to_middle_json.py +249 -0
  11. vendor/mineru/mineru/backend/pipeline/model_list.py +6 -0
  12. vendor/mineru/mineru/backend/pipeline/para_split.py +381 -0
  13. vendor/mineru/mineru/backend/pipeline/pipeline_analyze.py +198 -0
  14. vendor/mineru/mineru/backend/pipeline/pipeline_magic_model.py +501 -0
  15. vendor/mineru/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +298 -0
  16. vendor/mineru/mineru/backend/vlm/__init__.py +1 -0
  17. vendor/mineru/mineru/backend/vlm/base_predictor.py +186 -0
  18. vendor/mineru/mineru/backend/vlm/hf_predictor.py +211 -0
  19. vendor/mineru/mineru/backend/vlm/predictor.py +111 -0
  20. vendor/mineru/mineru/backend/vlm/sglang_client_predictor.py +443 -0
  21. vendor/mineru/mineru/backend/vlm/sglang_engine_predictor.py +246 -0
  22. vendor/mineru/mineru/backend/vlm/token_to_middle_json.py +113 -0
  23. vendor/mineru/mineru/backend/vlm/utils.py +40 -0
  24. vendor/mineru/mineru/backend/vlm/vlm_analyze.py +93 -0
  25. vendor/mineru/mineru/backend/vlm/vlm_magic_model.py +521 -0
  26. vendor/mineru/mineru/backend/vlm/vlm_middle_json_mkcontent.py +221 -0
  27. vendor/mineru/mineru/cli/__init__.py +1 -0
  28. vendor/mineru/mineru/cli/client.py +212 -0
  29. vendor/mineru/mineru/cli/common.py +403 -0
  30. vendor/mineru/mineru/cli/fast_api.py +198 -0
  31. vendor/mineru/mineru/cli/gradio_app.py +343 -0
  32. vendor/mineru/mineru/cli/models_download.py +150 -0
  33. vendor/mineru/mineru/cli/vlm_sglang_server.py +4 -0
  34. vendor/mineru/mineru/data/__init__.py +1 -0
  35. vendor/mineru/mineru/data/data_reader_writer/__init__.py +17 -0
  36. vendor/mineru/mineru/data/data_reader_writer/base.py +63 -0
  37. vendor/mineru/mineru/data/data_reader_writer/dummy.py +11 -0
  38. vendor/mineru/mineru/data/data_reader_writer/filebase.py +62 -0
  39. vendor/mineru/mineru/data/data_reader_writer/multi_bucket_s3.py +144 -0
  40. vendor/mineru/mineru/data/data_reader_writer/s3.py +72 -0
  41. vendor/mineru/mineru/data/io/__init__.py +6 -0
  42. vendor/mineru/mineru/data/io/base.py +42 -0
  43. vendor/mineru/mineru/data/io/http.py +37 -0
  44. vendor/mineru/mineru/data/io/s3.py +114 -0
  45. vendor/mineru/mineru/data/utils/__init__.py +1 -0
  46. vendor/mineru/mineru/data/utils/exceptions.py +40 -0
  47. vendor/mineru/mineru/data/utils/path_utils.py +33 -0
  48. vendor/mineru/mineru/data/utils/schemas.py +20 -0
  49. vendor/mineru/mineru/model/__init__.py +1 -0
  50. vendor/mineru/mineru/model/layout/__init__.py +1 -0
app.py CHANGED
@@ -1,8 +1,14 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
  import os
 
 
 
 
 
 
4
 
5
- app = FastAPI()
6
 
7
  @app.get("/")
8
  async def root():
@@ -59,12 +65,118 @@ async def api_info():
59
  """API information endpoint"""
60
  return {
61
  "name": "PDF to Markdown Converter API",
62
- "version": "0.1.0",
63
  "endpoints": {
64
  "/": "Main endpoint",
65
  "/health": "Health check",
66
  "/test": "Test HTML page",
67
  "/docs": "FastAPI automatic documentation",
68
- "/api/info": "This endpoint"
 
 
 
69
  }
70
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
2
+ from fastapi.responses import HTMLResponse, FileResponse
3
  import os
4
+ import tempfile
5
+ import shutil
6
+ from pathlib import Path
7
+ import asyncio
8
+ from typing import Dict, Optional
9
+ import uuid
10
 
11
+ app = FastAPI(title="MinerU PDF Converter", version="0.2.0")
12
 
13
  @app.get("/")
14
  async def root():
 
65
  """API information endpoint"""
66
  return {
67
  "name": "PDF to Markdown Converter API",
68
+ "version": "0.2.0",
69
  "endpoints": {
70
  "/": "Main endpoint",
71
  "/health": "Health check",
72
  "/test": "Test HTML page",
73
  "/docs": "FastAPI automatic documentation",
74
+ "/api/info": "This endpoint",
75
+ "/api/convert": "Convert PDF to Markdown (POST)",
76
+ "/api/status/{task_id}": "Check conversion status",
77
+ "/api/download/{task_id}": "Download converted markdown"
78
  }
79
+ }
80
+
81
+ # Store for conversion tasks
82
+ conversion_tasks: Dict[str, dict] = {}
83
+
84
+ @app.post("/api/convert")
85
+ async def convert_pdf(
86
+ background_tasks: BackgroundTasks,
87
+ file: UploadFile = File(...)
88
+ ):
89
+ """Convert PDF to Markdown"""
90
+ if not file.filename.endswith('.pdf'):
91
+ raise HTTPException(status_code=400, detail="Only PDF files are supported")
92
+
93
+ # Generate unique task ID
94
+ task_id = str(uuid.uuid4())
95
+
96
+ # Save uploaded file
97
+ temp_dir = Path(tempfile.mkdtemp())
98
+ pdf_path = temp_dir / file.filename
99
+
100
+ try:
101
+ with open(pdf_path, "wb") as buffer:
102
+ shutil.copyfileobj(file.file, buffer)
103
+ except Exception as e:
104
+ shutil.rmtree(temp_dir)
105
+ raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
106
+
107
+ # Initialize task status
108
+ conversion_tasks[task_id] = {
109
+ "status": "processing",
110
+ "filename": file.filename,
111
+ "result": None,
112
+ "error": None,
113
+ "temp_dir": str(temp_dir)
114
+ }
115
+
116
+ # Start conversion in background
117
+ background_tasks.add_task(process_pdf_conversion, task_id, str(pdf_path))
118
+
119
+ return {
120
+ "task_id": task_id,
121
+ "status": "processing",
122
+ "message": "PDF conversion started",
123
+ "check_status_url": f"/api/status/{task_id}"
124
+ }
125
+
126
+ async def process_pdf_conversion(task_id: str, pdf_path: str):
127
+ """Process PDF conversion in background"""
128
+ try:
129
+ # For now, just simulate conversion
130
+ await asyncio.sleep(2) # Simulate processing
131
+
132
+ # Create a dummy markdown file
133
+ output_path = Path(pdf_path).with_suffix('.md')
134
+ with open(output_path, 'w') as f:
135
+ f.write(f"# Converted from {Path(pdf_path).name}\n\n")
136
+ f.write("This is a placeholder conversion. Full MinerU integration coming soon.\n")
137
+
138
+ conversion_tasks[task_id]["status"] = "completed"
139
+ conversion_tasks[task_id]["result"] = str(output_path)
140
+
141
+ except Exception as e:
142
+ conversion_tasks[task_id]["status"] = "failed"
143
+ conversion_tasks[task_id]["error"] = str(e)
144
+
145
+ @app.get("/api/status/{task_id}")
146
+ async def get_conversion_status(task_id: str):
147
+ """Check conversion status"""
148
+ if task_id not in conversion_tasks:
149
+ raise HTTPException(status_code=404, detail="Task not found")
150
+
151
+ task = conversion_tasks[task_id]
152
+ response = {
153
+ "task_id": task_id,
154
+ "status": task["status"],
155
+ "filename": task["filename"]
156
+ }
157
+
158
+ if task["status"] == "completed":
159
+ response["download_url"] = f"/api/download/{task_id}"
160
+ elif task["status"] == "failed":
161
+ response["error"] = task["error"]
162
+
163
+ return response
164
+
165
+ @app.get("/api/download/{task_id}")
166
+ async def download_converted_file(task_id: str):
167
+ """Download converted markdown file"""
168
+ if task_id not in conversion_tasks:
169
+ raise HTTPException(status_code=404, detail="Task not found")
170
+
171
+ task = conversion_tasks[task_id]
172
+ if task["status"] != "completed":
173
+ raise HTTPException(status_code=400, detail="Conversion not completed")
174
+
175
+ if not task["result"] or not Path(task["result"]).exists():
176
+ raise HTTPException(status_code=404, detail="Converted file not found")
177
+
178
+ return FileResponse(
179
+ task["result"],
180
+ media_type="text/markdown",
181
+ filename=Path(task["result"]).name
182
+ )
config/magic-pdf.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bucket_info":{
3
+ "bucket-name-1":["ak", "sk", "endpoint"],
4
+ "bucket-name-2":["ak", "sk", "endpoint"]
5
+ },
6
+ "temp-output-dir":"/tmp",
7
+ "models-dir":"/tmp/models",
8
+ "device-mode":"cpu"
9
+ }
pdf_converter_mineru.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PDF to Markdown Converter using MinerU (vendor/mineru)
4
+ This is the main conversion script that uses the local MinerU installation
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import logging
10
+ import argparse
11
+ from pathlib import Path
12
+ import subprocess
13
+
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(levelname)s - %(message)s',
18
+ handlers=[
19
+ logging.StreamHandler(),
20
+ logging.FileHandler('pdf_converter.log')
21
+ ]
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class PdfConverterResult:
27
+ """Class representing the result of a PDF conversion"""
28
+
29
+ def __init__(self, pdf_path: str, success: bool, md_path: str = None,
30
+ time_taken: float = 0, error: str = None):
31
+ self.pdf_path = pdf_path
32
+ self.success = success
33
+ self.md_path = md_path
34
+ self.time_taken = time_taken
35
+ self.error = error
36
+
37
+ def __str__(self):
38
+ if self.success:
39
+ return f"✅ Successfully converted {self.pdf_path} in {self.time_taken:.2f}s"
40
+ else:
41
+ return f"❌ Failed to convert {self.pdf_path}: {self.error}"
42
+
43
+
44
+ class MineruPdfConverter:
45
+ """
46
+ PDF to Markdown converter using MinerU
47
+ """
48
+
49
+ def __init__(self, output_dir: str = "output"):
50
+ self.output_dir = output_dir
51
+ os.makedirs(output_dir, exist_ok=True)
52
+
53
+ def convert_file(self, pdf_path: str, delete_after: bool = False) -> PdfConverterResult:
54
+ """Convert a single PDF file to Markdown using MinerU"""
55
+ import time
56
+ start_time = time.time()
57
+
58
+ try:
59
+ pdf_path = Path(pdf_path)
60
+ if not pdf_path.exists():
61
+ return PdfConverterResult(
62
+ str(pdf_path), False, error=f"File not found: {pdf_path}"
63
+ )
64
+
65
+ logger.info(f"Processing: {pdf_path}")
66
+
67
+ # Prepare output directory
68
+ pdf_output_dir = os.path.join(self.output_dir, pdf_path.stem)
69
+
70
+ # Run MinerU command
71
+ cmd = [
72
+ "mineru",
73
+ "-p", str(pdf_path),
74
+ "-o", pdf_output_dir,
75
+ "-m", "txt", # Use text mode
76
+ "-f", "false", # Disable formula parsing for speed
77
+ "-t", "false", # Disable table parsing for speed
78
+ ]
79
+
80
+ logger.info(f"Running command: {' '.join(cmd)}")
81
+
82
+ # Execute MinerU
83
+ result = subprocess.run(cmd, capture_output=True, text=True)
84
+
85
+ if result.returncode != 0:
86
+ error_msg = result.stderr if result.stderr else "Unknown error"
87
+ return PdfConverterResult(
88
+ str(pdf_path), False, error=error_msg
89
+ )
90
+
91
+ # Find the generated markdown file
92
+ md_path = None
93
+ expected_md = Path(pdf_output_dir) / pdf_path.stem / "txt" / f"{pdf_path.stem}.md"
94
+
95
+ if expected_md.exists():
96
+ md_path = str(expected_md)
97
+ logger.info(f"✅ Markdown file created: {md_path}")
98
+ else:
99
+ # Search for any .md file in the output directory
100
+ for md_file in Path(pdf_output_dir).rglob("*.md"):
101
+ md_path = str(md_file)
102
+ logger.info(f"✅ Found markdown file: {md_path}")
103
+ break
104
+
105
+ if not md_path:
106
+ return PdfConverterResult(
107
+ str(pdf_path), False, error="No markdown file generated"
108
+ )
109
+
110
+ # Delete original PDF if requested
111
+ if delete_after and pdf_path.exists():
112
+ pdf_path.unlink()
113
+ logger.info(f"🗑️ Deleted original PDF: {pdf_path}")
114
+
115
+ elapsed_time = time.time() - start_time
116
+
117
+ return PdfConverterResult(
118
+ str(pdf_path), True, md_path=md_path, time_taken=elapsed_time
119
+ )
120
+
121
+ except Exception as e:
122
+ logger.error(f"Error processing {pdf_path}: {e}")
123
+ import traceback
124
+ traceback.print_exc()
125
+
126
+ return PdfConverterResult(
127
+ str(pdf_path), False, error=str(e)
128
+ )
129
+
130
+
131
+ class BatchProcessor:
132
+ """Process multiple PDF files in batch"""
133
+
134
+ def __init__(self, batch_dir: str = "batch-files", output_dir: str = "output",
135
+ workers: int = 1, delete_after: bool = False):
136
+ self.batch_dir = batch_dir
137
+ self.output_dir = output_dir
138
+ self.workers = workers
139
+ self.delete_after = delete_after
140
+ self.converter = MineruPdfConverter(output_dir)
141
+
142
+ def find_pdf_files(self) -> list[Path]:
143
+ """Find all PDF files in the batch directory"""
144
+ pdf_files = []
145
+ batch_path = Path(self.batch_dir)
146
+
147
+ if not batch_path.exists():
148
+ logger.warning(f"Batch directory not found: {self.batch_dir}")
149
+ return pdf_files
150
+
151
+ # Find all PDFs recursively
152
+ pdf_files = list(batch_path.rglob("*.pdf"))
153
+ logger.info(f"Found {len(pdf_files)} PDF files in {self.batch_dir}")
154
+
155
+ return pdf_files
156
+
157
+ def process_batch(self) -> tuple[int, int]:
158
+ """Process all PDFs in the batch directory"""
159
+ pdf_files = self.find_pdf_files()
160
+
161
+ if not pdf_files:
162
+ logger.info("No PDF files found to process")
163
+ return 0, 0
164
+
165
+ successful = 0
166
+ failed = 0
167
+
168
+ logger.info(f"Starting batch processing of {len(pdf_files)} files...")
169
+
170
+ # Process files sequentially (MinerU already handles parallelism internally)
171
+ for pdf_file in pdf_files:
172
+ result = self.converter.convert_file(str(pdf_file), self.delete_after)
173
+
174
+ if result.success:
175
+ successful += 1
176
+ logger.info(f"✅ {result}")
177
+ else:
178
+ failed += 1
179
+ logger.error(f"❌ {result}")
180
+
181
+ return successful, failed
182
+
183
+
184
+ def main():
185
+ """Main entry point"""
186
+ parser = argparse.ArgumentParser(
187
+ description="Convert PDF files to Markdown using MinerU",
188
+ formatter_class=argparse.RawDescriptionHelpFormatter,
189
+ epilog="""
190
+ Examples:
191
+ # Convert a single PDF
192
+ %(prog)s convert path/to/file.pdf
193
+
194
+ # Batch convert all PDFs in batch-files directory
195
+ %(prog)s batch
196
+
197
+ # Batch convert with custom settings
198
+ %(prog)s batch --batch-dir /path/to/pdfs --output-dir /path/to/output --workers 4
199
+
200
+ # Delete PDFs after successful conversion
201
+ %(prog)s batch --delete-after
202
+ """
203
+ )
204
+
205
+ subparsers = parser.add_subparsers(dest='command', help='Command to run')
206
+
207
+ # Convert command
208
+ convert_parser = subparsers.add_parser('convert', help='Convert a single PDF file')
209
+ convert_parser.add_argument('pdf_file', help='Path to PDF file')
210
+ convert_parser.add_argument('--output-dir', default='output', help='Output directory')
211
+ convert_parser.add_argument('--delete-after', action='store_true',
212
+ help='Delete PDF after successful conversion')
213
+
214
+ # Batch command
215
+ batch_parser = subparsers.add_parser('batch', help='Batch convert PDF files')
216
+ batch_parser.add_argument('--batch-dir', default='batch-files',
217
+ help='Directory containing PDF files')
218
+ batch_parser.add_argument('--output-dir', default='output',
219
+ help='Output directory')
220
+ batch_parser.add_argument('--workers', type=int, default=1,
221
+ help='Number of parallel workers')
222
+ batch_parser.add_argument('--delete-after', action='store_true',
223
+ help='Delete PDFs after successful conversion')
224
+
225
+ args = parser.parse_args()
226
+
227
+ # Auto-detect command if none specified
228
+ if not args.command:
229
+ # If first argument looks like a file, assume convert command
230
+ if len(sys.argv) > 1 and (sys.argv[1].endswith('.pdf') or Path(sys.argv[1]).exists()):
231
+ args.command = 'convert'
232
+ args.pdf_file = sys.argv[1]
233
+ args.output_dir = 'output'
234
+ args.delete_after = False
235
+ else:
236
+ # Default to batch mode
237
+ args.command = 'batch'
238
+ args.batch_dir = 'batch-files'
239
+ args.output_dir = 'output'
240
+ args.workers = 1
241
+ args.delete_after = False
242
+
243
+ # Execute command
244
+ if args.command == 'convert':
245
+ converter = MineruPdfConverter(args.output_dir)
246
+ result = converter.convert_file(args.pdf_file, args.delete_after)
247
+ print(result)
248
+ sys.exit(0 if result.success else 1)
249
+
250
+ elif args.command == 'batch':
251
+ processor = BatchProcessor(
252
+ args.batch_dir,
253
+ args.output_dir,
254
+ args.workers,
255
+ args.delete_after
256
+ )
257
+ successful, failed = processor.process_batch()
258
+
259
+ print(f"\n📊 Batch processing complete:")
260
+ print(f" ✅ Successful: {successful}")
261
+ print(f" ❌ Failed: {failed}")
262
+ print(f" 📁 Output directory: {args.output_dir}")
263
+
264
+ sys.exit(0 if failed == 0 else 1)
265
+
266
+ else:
267
+ parser.print_help()
268
+ sys.exit(1)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  fastapi==0.104.1
2
  uvicorn==0.24.0
3
  python-multipart==0.0.6
4
- aiofiles==23.2.1
 
 
 
 
1
  fastapi==0.104.1
2
  uvicorn==0.24.0
3
  python-multipart==0.0.6
4
+ aiofiles==23.2.1
5
+
6
+ # Basic PDF processing (will add MinerU later)
7
+ PyMuPDF>=1.18.16
vendor/mineru/mineru/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/backend/pipeline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/backend/pipeline/batch_analyze.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from loguru import logger
3
+ from tqdm import tqdm
4
+ from collections import defaultdict
5
+ import numpy as np
6
+
7
+ from .model_init import AtomModelSingleton
8
+ from ...utils.config_reader import get_formula_enable, get_table_enable
9
+ from ...utils.model_utils import crop_img, get_res_list_from_layout_res
10
+ from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
11
+
12
+ YOLO_LAYOUT_BASE_BATCH_SIZE = 8
13
+ MFD_BASE_BATCH_SIZE = 1
14
+ MFR_BASE_BATCH_SIZE = 16
15
+
16
+
17
+ class BatchAnalyze:
18
+ def __init__(self, model_manager, batch_ratio: int, formula_enable, table_enable, enable_ocr_det_batch: bool = True):
19
+ self.batch_ratio = batch_ratio
20
+ self.formula_enable = get_formula_enable(formula_enable)
21
+ self.table_enable = get_table_enable(table_enable)
22
+ self.model_manager = model_manager
23
+ self.enable_ocr_det_batch = enable_ocr_det_batch
24
+
25
+ def __call__(self, images_with_extra_info: list) -> list:
26
+ if len(images_with_extra_info) == 0:
27
+ return []
28
+
29
+ images_layout_res = []
30
+
31
+ self.model = self.model_manager.get_model(
32
+ lang=None,
33
+ formula_enable=self.formula_enable,
34
+ table_enable=self.table_enable,
35
+ )
36
+ atom_model_manager = AtomModelSingleton()
37
+
38
+ images = [image for image, _, _ in images_with_extra_info]
39
+
40
+ # doclayout_yolo
41
+ layout_images = []
42
+ for image_index, image in enumerate(images):
43
+ layout_images.append(image)
44
+
45
+
46
+ images_layout_res += self.model.layout_model.batch_predict(
47
+ layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
48
+ )
49
+
50
+ if self.formula_enable:
51
+ # 公式检测
52
+ images_mfd_res = self.model.mfd_model.batch_predict(
53
+ images, MFD_BASE_BATCH_SIZE
54
+ )
55
+
56
+ # 公式识别
57
+ images_formula_list = self.model.mfr_model.batch_predict(
58
+ images_mfd_res,
59
+ images,
60
+ batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
61
+ )
62
+ mfr_count = 0
63
+ for image_index in range(len(images)):
64
+ images_layout_res[image_index] += images_formula_list[image_index]
65
+ mfr_count += len(images_formula_list[image_index])
66
+
67
+ # 清理显存
68
+ # clean_vram(self.model.device, vram_threshold=8)
69
+
70
+ ocr_res_list_all_page = []
71
+ table_res_list_all_page = []
72
+ for index in range(len(images)):
73
+ _, ocr_enable, _lang = images_with_extra_info[index]
74
+ layout_res = images_layout_res[index]
75
+ pil_img = images[index]
76
+
77
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = (
78
+ get_res_list_from_layout_res(layout_res)
79
+ )
80
+
81
+ ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
82
+ 'lang':_lang,
83
+ 'ocr_enable':ocr_enable,
84
+ 'pil_img':pil_img,
85
+ 'single_page_mfdetrec_res':single_page_mfdetrec_res,
86
+ 'layout_res':layout_res,
87
+ })
88
+
89
+ for table_res in table_res_list:
90
+ table_img, _ = crop_img(table_res, pil_img)
91
+ table_res_list_all_page.append({'table_res':table_res,
92
+ 'lang':_lang,
93
+ 'table_img':table_img,
94
+ })
95
+
96
+ # OCR检测处理
97
+ if self.enable_ocr_det_batch:
98
+ # 批处理模式 - 按语言和分辨率分组
99
+ # 收集所有需要OCR检测的裁剪图像
100
+ all_cropped_images_info = []
101
+
102
+ for ocr_res_list_dict in ocr_res_list_all_page:
103
+ _lang = ocr_res_list_dict['lang']
104
+
105
+ for res in ocr_res_list_dict['ocr_res_list']:
106
+ new_image, useful_list = crop_img(
107
+ res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
108
+ )
109
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
110
+ ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
111
+ )
112
+
113
+ # BGR转换
114
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
115
+
116
+ all_cropped_images_info.append((
117
+ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
118
+ ))
119
+
120
+ # 按语言分组
121
+ lang_groups = defaultdict(list)
122
+ for crop_info in all_cropped_images_info:
123
+ lang = crop_info[5]
124
+ lang_groups[lang].append(crop_info)
125
+
126
+ # 对每种语言按分辨率分组并批处理
127
+ for lang, lang_crop_list in lang_groups.items():
128
+ if not lang_crop_list:
129
+ continue
130
+
131
+ # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
132
+
133
+ # 获取OCR模型
134
+ ocr_model = atom_model_manager.get_atom_model(
135
+ atom_model_name='ocr',
136
+ det_db_box_thresh=0.3,
137
+ lang=lang
138
+ )
139
+
140
+ # 按分辨率分组并同时完成padding
141
+ resolution_groups = defaultdict(list)
142
+ for crop_info in lang_crop_list:
143
+ cropped_img = crop_info[0]
144
+ h, w = cropped_img.shape[:2]
145
+ # 使用更大的分组容差,减少分组数量
146
+ # 将尺寸标准化到32的倍数
147
+ normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
148
+ normalized_w = ((w + 32) // 32) * 32
149
+ group_key = (normalized_h, normalized_w)
150
+ resolution_groups[group_key].append(crop_info)
151
+
152
+ # 对每个分辨率组进行批处理
153
+ for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
154
+
155
+ # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
156
+ max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
157
+ max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
158
+ target_h = ((max_h + 32 - 1) // 32) * 32
159
+ target_w = ((max_w + 32 - 1) // 32) * 32
160
+
161
+ # 对所有图像进行padding到统一尺寸
162
+ batch_images = []
163
+ for crop_info in group_crops:
164
+ img = crop_info[0]
165
+ h, w = img.shape[:2]
166
+ # 创建目标尺寸的白色背景
167
+ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
168
+ # 将原图像粘贴到左上角
169
+ padded_img[:h, :w] = img
170
+ batch_images.append(padded_img)
171
+
172
+ # 批处理检测
173
+ batch_size = min(len(batch_images), self.batch_ratio * 16) # 增加批处理大小
174
+ # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
175
+ batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
176
+
177
+ # 处理批处理结果
178
+ for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
179
+ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
180
+
181
+ if dt_boxes is not None and len(dt_boxes) > 0:
182
+ # 直接应用原始OCR流程中的关键处理步骤
183
+ from mineru.utils.ocr_utils import (
184
+ merge_det_boxes, update_det_boxes, sorted_boxes
185
+ )
186
+
187
+ # 1. 排序检测框
188
+ if len(dt_boxes) > 0:
189
+ dt_boxes_sorted = sorted_boxes(dt_boxes)
190
+ else:
191
+ dt_boxes_sorted = []
192
+
193
+ # 2. 合并相邻检测框
194
+ if dt_boxes_sorted:
195
+ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
196
+ else:
197
+ dt_boxes_merged = []
198
+
199
+ # 3. 根据公式位置更新检测框(关键步骤!)
200
+ if dt_boxes_merged and adjusted_mfdetrec_res:
201
+ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
202
+ else:
203
+ dt_boxes_final = dt_boxes_merged
204
+
205
+ # 构造OCR结果格式
206
+ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
207
+
208
+ if ocr_res:
209
+ ocr_result_list = get_ocr_result_list(
210
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
211
+ )
212
+
213
+ ocr_res_list_dict['layout_res'].extend(ocr_result_list)
214
+ else:
215
+ # 原始单张处理模式
216
+ for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
217
+ # Process each area that requires OCR processing
218
+ _lang = ocr_res_list_dict['lang']
219
+ # Get OCR results for this language's images
220
+ ocr_model = atom_model_manager.get_atom_model(
221
+ atom_model_name='ocr',
222
+ ocr_show_log=False,
223
+ det_db_box_thresh=0.3,
224
+ lang=_lang
225
+ )
226
+ for res in ocr_res_list_dict['ocr_res_list']:
227
+ new_image, useful_list = crop_img(
228
+ res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
229
+ )
230
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
231
+ ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
232
+ )
233
+ # OCR-det
234
+ new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
235
+ ocr_res = ocr_model.ocr(
236
+ new_image, mfd_res=adjusted_mfdetrec_res, rec=False
237
+ )[0]
238
+
239
+ # Integration results
240
+ if ocr_res:
241
+ ocr_result_list = get_ocr_result_list(
242
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],new_image, _lang
243
+ )
244
+
245
+ ocr_res_list_dict['layout_res'].extend(ocr_result_list)
246
+
247
+ # 表格识别 table recognition
248
+ if self.table_enable:
249
+ for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
250
+ _lang = table_res_dict['lang']
251
+ table_model = atom_model_manager.get_atom_model(
252
+ atom_model_name='table',
253
+ lang=_lang,
254
+ )
255
+ html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
256
+ # 判断是否返回正常
257
+ if html_code:
258
+ expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
259
+ if expected_ending:
260
+ table_res_dict['table_res']['html'] = html_code
261
+ else:
262
+ logger.warning(
263
+ 'table recognition processing fails, not found expected HTML table end'
264
+ )
265
+ else:
266
+ logger.warning(
267
+ 'table recognition processing fails, not get html return'
268
+ )
269
+
270
+ # Create dictionaries to store items by language
271
+ need_ocr_lists_by_lang = {} # Dict of lists for each language
272
+ img_crop_lists_by_lang = {} # Dict of lists for each language
273
+
274
+ for layout_res in images_layout_res:
275
+ for layout_res_item in layout_res:
276
+ if layout_res_item['category_id'] in [15]:
277
+ if 'np_img' in layout_res_item and 'lang' in layout_res_item:
278
+ lang = layout_res_item['lang']
279
+
280
+ # Initialize lists for this language if not exist
281
+ if lang not in need_ocr_lists_by_lang:
282
+ need_ocr_lists_by_lang[lang] = []
283
+ img_crop_lists_by_lang[lang] = []
284
+
285
+ # Add to the appropriate language-specific lists
286
+ need_ocr_lists_by_lang[lang].append(layout_res_item)
287
+ img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
288
+
289
+ # Remove the fields after adding to lists
290
+ layout_res_item.pop('np_img')
291
+ layout_res_item.pop('lang')
292
+
293
+ if len(img_crop_lists_by_lang) > 0:
294
+
295
+ # Process OCR by language
296
+ total_processed = 0
297
+
298
+ # Process each language separately
299
+ for lang, img_crop_list in img_crop_lists_by_lang.items():
300
+ if len(img_crop_list) > 0:
301
+ # Get OCR results for this language's images
302
+
303
+ ocr_model = atom_model_manager.get_atom_model(
304
+ atom_model_name='ocr',
305
+ det_db_box_thresh=0.3,
306
+ lang=lang
307
+ )
308
+ ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
309
+
310
+ # Verify we have matching counts
311
+ assert len(ocr_res_list) == len(
312
+ need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
313
+
314
+ # Process OCR results for this language
315
+ for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
316
+ ocr_text, ocr_score = ocr_res_list[index]
317
+ layout_res_item['text'] = ocr_text
318
+ layout_res_item['score'] = float(f"{ocr_score:.3f}")
319
+ if ocr_score < OcrConfidence.min_confidence:
320
+ layout_res_item['category_id'] = 16
321
+ else:
322
+ layout_res_bbox = [layout_res_item['poly'][0], layout_res_item['poly'][1],
323
+ layout_res_item['poly'][4], layout_res_item['poly'][5]]
324
+ layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
325
+ layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
326
+ if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
327
+ layout_res_item['category_id'] = 16
328
+
329
+ total_processed += len(img_crop_list)
330
+
331
+ return images_layout_res
vendor/mineru/mineru/backend/pipeline/model_init.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from .model_list import AtomicModel
7
+ from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
8
+ from ...model.mfd.yolo_v8 import YOLOv8MFDModel
9
+ from ...model.mfr.unimernet.Unimernet import UnimernetModel
10
+ from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
11
+ from ...model.table.rapid_table import RapidTableModel
12
+ from ...utils.enum_class import ModelPath
13
+ from ...utils.models_download_utils import auto_download_and_get_model_root_path
14
+
15
+
16
+ def table_model_init(lang=None):
17
+ atom_model_manager = AtomModelSingleton()
18
+ ocr_engine = atom_model_manager.get_atom_model(
19
+ atom_model_name='ocr',
20
+ det_db_box_thresh=0.5,
21
+ det_db_unclip_ratio=1.6,
22
+ lang=lang
23
+ )
24
+ table_model = RapidTableModel(ocr_engine)
25
+ return table_model
26
+
27
+
28
+ def mfd_model_init(weight, device='cpu'):
29
+ if str(device).startswith('npu'):
30
+ device = torch.device(device)
31
+ mfd_model = YOLOv8MFDModel(weight, device)
32
+ return mfd_model
33
+
34
+
35
+ def mfr_model_init(weight_dir, device='cpu'):
36
+ mfr_model = UnimernetModel(weight_dir, device)
37
+ return mfr_model
38
+
39
+
40
+ def doclayout_yolo_model_init(weight, device='cpu'):
41
+ if str(device).startswith('npu'):
42
+ device = torch.device(device)
43
+ model = DocLayoutYOLOModel(weight, device)
44
+ return model
45
+
46
+ def ocr_model_init(det_db_box_thresh=0.3,
47
+ lang=None,
48
+ use_dilation=True,
49
+ det_db_unclip_ratio=1.8,
50
+ ):
51
+ if lang is not None and lang != '':
52
+ model = PytorchPaddleOCR(
53
+ det_db_box_thresh=det_db_box_thresh,
54
+ lang=lang,
55
+ use_dilation=use_dilation,
56
+ det_db_unclip_ratio=det_db_unclip_ratio,
57
+ )
58
+ else:
59
+ model = PytorchPaddleOCR(
60
+ det_db_box_thresh=det_db_box_thresh,
61
+ use_dilation=use_dilation,
62
+ det_db_unclip_ratio=det_db_unclip_ratio,
63
+ )
64
+ return model
65
+
66
+
67
+ class AtomModelSingleton:
68
+ _instance = None
69
+ _models = {}
70
+
71
+ def __new__(cls, *args, **kwargs):
72
+ if cls._instance is None:
73
+ cls._instance = super().__new__(cls)
74
+ return cls._instance
75
+
76
+ def get_atom_model(self, atom_model_name: str, **kwargs):
77
+
78
+ lang = kwargs.get('lang', None)
79
+ table_model_name = kwargs.get('table_model_name', None)
80
+
81
+ if atom_model_name in [AtomicModel.OCR]:
82
+ key = (atom_model_name, lang)
83
+ elif atom_model_name in [AtomicModel.Table]:
84
+ key = (atom_model_name, table_model_name, lang)
85
+ else:
86
+ key = atom_model_name
87
+
88
+ if key not in self._models:
89
+ self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
90
+ return self._models[key]
91
+
92
+ def atom_model_init(model_name: str, **kwargs):
93
+ atom_model = None
94
+ if model_name == AtomicModel.Layout:
95
+ atom_model = doclayout_yolo_model_init(
96
+ kwargs.get('doclayout_yolo_weights'),
97
+ kwargs.get('device')
98
+ )
99
+ elif model_name == AtomicModel.MFD:
100
+ atom_model = mfd_model_init(
101
+ kwargs.get('mfd_weights'),
102
+ kwargs.get('device')
103
+ )
104
+ elif model_name == AtomicModel.MFR:
105
+ atom_model = mfr_model_init(
106
+ kwargs.get('mfr_weight_dir'),
107
+ kwargs.get('device')
108
+ )
109
+ elif model_name == AtomicModel.OCR:
110
+ atom_model = ocr_model_init(
111
+ kwargs.get('det_db_box_thresh'),
112
+ kwargs.get('lang'),
113
+ )
114
+ elif model_name == AtomicModel.Table:
115
+ atom_model = table_model_init(
116
+ kwargs.get('lang'),
117
+ )
118
+ else:
119
+ logger.error('model name not allow')
120
+ exit(1)
121
+
122
+ if atom_model is None:
123
+ logger.error('model init failed')
124
+ exit(1)
125
+ else:
126
+ return atom_model
127
+
128
+
129
+ class MineruPipelineModel:
130
+ def __init__(self, **kwargs):
131
+ self.formula_config = kwargs.get('formula_config')
132
+ self.apply_formula = self.formula_config.get('enable', True)
133
+ self.table_config = kwargs.get('table_config')
134
+ self.apply_table = self.table_config.get('enable', True)
135
+ self.lang = kwargs.get('lang', None)
136
+ self.device = kwargs.get('device', 'cpu')
137
+ logger.info(
138
+ 'DocAnalysis init, this may take some times......'
139
+ )
140
+ atom_model_manager = AtomModelSingleton()
141
+
142
+ if self.apply_formula:
143
+ # 初始化公式检测模型
144
+ self.mfd_model = atom_model_manager.get_atom_model(
145
+ atom_model_name=AtomicModel.MFD,
146
+ mfd_weights=str(
147
+ os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
148
+ ),
149
+ device=self.device,
150
+ )
151
+
152
+ # 初始化公式解析模型
153
+ mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
154
+
155
+ self.mfr_model = atom_model_manager.get_atom_model(
156
+ atom_model_name=AtomicModel.MFR,
157
+ mfr_weight_dir=mfr_weight_dir,
158
+ device=self.device,
159
+ )
160
+
161
+ # 初始化layout模型
162
+ self.layout_model = atom_model_manager.get_atom_model(
163
+ atom_model_name=AtomicModel.Layout,
164
+ doclayout_yolo_weights=str(
165
+ os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
166
+ ),
167
+ device=self.device,
168
+ )
169
+ # 初始化ocr
170
+ self.ocr_model = atom_model_manager.get_atom_model(
171
+ atom_model_name=AtomicModel.OCR,
172
+ det_db_box_thresh=0.3,
173
+ lang=self.lang
174
+ )
175
+ # init table model
176
+ if self.apply_table:
177
+ self.table_model = atom_model_manager.get_atom_model(
178
+ atom_model_name=AtomicModel.Table,
179
+ lang=self.lang,
180
+ )
181
+
182
+ logger.info('DocAnalysis init done!')
vendor/mineru/mineru/backend/pipeline/model_json_to_middle_json.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import os
3
+ import time
4
+
5
+ from loguru import logger
6
+ from tqdm import tqdm
7
+
8
+ from mineru.utils.config_reader import get_device, get_llm_aided_config, get_formula_enable
9
+ from mineru.backend.pipeline.model_init import AtomModelSingleton
10
+ from mineru.backend.pipeline.para_split import para_split
11
+ from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
12
+ from mineru.utils.block_sort import sort_blocks_by_bbox
13
+ from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
14
+ from mineru.utils.cut_image import cut_image_and_table
15
+ from mineru.utils.enum_class import ContentType
16
+ from mineru.utils.llm_aided import llm_aided_title
17
+ from mineru.utils.model_utils import clean_memory
18
+ from mineru.backend.pipeline.pipeline_magic_model import MagicModel
19
+ from mineru.utils.ocr_utils import OcrConfidence
20
+ from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
21
+ from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
22
+ remove_overlaps_min_spans, txt_spans_extract
23
+ from mineru.version import __version__
24
+ from mineru.utils.hash_utils import str_md5
25
+
26
+
27
+ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
28
+ scale = image_dict["scale"]
29
+ page_pil_img = image_dict["img_pil"]
30
+ page_img_md5 = str_md5(image_dict["img_base64"])
31
+ page_w, page_h = map(int, page.get_size())
32
+ magic_model = MagicModel(page_model_info, scale)
33
+
34
+ """从magic_model对象中获取后面会用到的区块信息"""
35
+ discarded_blocks = magic_model.get_discarded()
36
+ text_blocks = magic_model.get_text_blocks()
37
+ title_blocks = magic_model.get_title_blocks()
38
+ inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
39
+
40
+ img_groups = magic_model.get_imgs()
41
+ table_groups = magic_model.get_tables()
42
+
43
+ """对image和table的区块分组"""
44
+ img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
45
+ img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
46
+ )
47
+
48
+ table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
49
+ table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
50
+ )
51
+
52
+ """获取所有的spans信息"""
53
+ spans = magic_model.get_all_spans()
54
+
55
+ """某些图可能是文本块,通过简单的规则判断一下"""
56
+ if len(maybe_text_image_blocks) > 0:
57
+ for block in maybe_text_image_blocks:
58
+ span_in_block_list = []
59
+ for span in spans:
60
+ if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
61
+ span_in_block_list.append(span)
62
+ if len(span_in_block_list) > 0:
63
+ # span_in_block_list中所有bbox的面积之和
64
+ spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
65
+ # 求ocr_res_area和res的面积的比值
66
+ block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
67
+ if block_area > 0:
68
+ ratio = spans_area / block_area
69
+ if ratio > 0.25 and ocr_enable:
70
+ # 移除block的group_id
71
+ block.pop('group_id', None)
72
+ # 符合文本图的条件就把块加入到文本块列表中
73
+ text_blocks.append(block)
74
+ else:
75
+ # 如果不符合文本图的条件,就把块加回到图片块列表中
76
+ img_body_blocks.append(block)
77
+ else:
78
+ img_body_blocks.append(block)
79
+
80
+
81
+ """将所有区块的bbox整理到一起"""
82
+ if formula_enabled:
83
+ interline_equation_blocks = []
84
+
85
+ if len(interline_equation_blocks) > 0:
86
+
87
+ for block in interline_equation_blocks:
88
+ spans.append({
89
+ "type": ContentType.INTERLINE_EQUATION,
90
+ 'score': block['score'],
91
+ "bbox": block['bbox'],
92
+ })
93
+
94
+ all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
95
+ img_body_blocks, img_caption_blocks, img_footnote_blocks,
96
+ table_body_blocks, table_caption_blocks, table_footnote_blocks,
97
+ discarded_blocks,
98
+ text_blocks,
99
+ title_blocks,
100
+ interline_equation_blocks,
101
+ page_w,
102
+ page_h,
103
+ )
104
+ else:
105
+ all_bboxes, all_discarded_blocks, footnote_blocks = prepare_block_bboxes(
106
+ img_body_blocks, img_caption_blocks, img_footnote_blocks,
107
+ table_body_blocks, table_caption_blocks, table_footnote_blocks,
108
+ discarded_blocks,
109
+ text_blocks,
110
+ title_blocks,
111
+ interline_equations,
112
+ page_w,
113
+ page_h,
114
+ )
115
+
116
+ """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
117
+ """顺便删除大水印并保留abandon的span"""
118
+ spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
119
+
120
+ """删除重叠spans中置信度较低的那些"""
121
+ spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
122
+ """删除重叠spans中较小的那些"""
123
+ spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
124
+
125
+ """根据parse_mode,构造spans,主要是文本类的字符填充"""
126
+ if ocr_enable:
127
+ pass
128
+ else:
129
+ """使用新版本的混合ocr方案."""
130
+ spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
131
+
132
+ """先处理不需要排版的discarded_blocks"""
133
+ discarded_block_with_spans, spans = fill_spans_in_blocks(
134
+ all_discarded_blocks, spans, 0.4
135
+ )
136
+ fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
137
+
138
+ """如果当前页面没有有效的bbox则跳过"""
139
+ if len(all_bboxes) == 0:
140
+ return None
141
+
142
+ """对image/table/interline_equation截图"""
143
+ for span in spans:
144
+ if span['type'] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
145
+ span = cut_image_and_table(
146
+ span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale
147
+ )
148
+
149
+ """span填充进block"""
150
+ block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
151
+
152
+ """对block进行fix操作"""
153
+ fix_blocks = fix_block_spans(block_with_spans)
154
+
155
+ """对block进行排序"""
156
+ sorted_blocks = sort_blocks_by_bbox(fix_blocks, page_w, page_h, footnote_blocks)
157
+
158
+ """构造page_info"""
159
+ page_info = make_page_info_dict(sorted_blocks, page_index, page_w, page_h, fix_discarded_blocks)
160
+
161
+ return page_info
162
+
163
+
164
+ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
165
+ middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
166
+ formula_enabled = get_formula_enable(formula_enabled)
167
+ for page_index, page_model_info in tqdm(enumerate(model_list), total=len(model_list), desc="Processing pages"):
168
+ page = pdf_doc[page_index]
169
+ image_dict = images_list[page_index]
170
+ page_info = page_model_info_to_page_info(
171
+ page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
172
+ )
173
+ if page_info is None:
174
+ page_w, page_h = map(int, page.get_size())
175
+ page_info = make_page_info_dict([], page_index, page_w, page_h, [])
176
+ middle_json["pdf_info"].append(page_info)
177
+
178
+ """后置ocr处理"""
179
+ need_ocr_list = []
180
+ img_crop_list = []
181
+ text_block_list = []
182
+ for page_info in middle_json["pdf_info"]:
183
+ for block in page_info['preproc_blocks']:
184
+ if block['type'] in ['table', 'image']:
185
+ for sub_block in block['blocks']:
186
+ if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
187
+ text_block_list.append(sub_block)
188
+ elif block['type'] in ['text', 'title']:
189
+ text_block_list.append(block)
190
+ for block in page_info['discarded_blocks']:
191
+ text_block_list.append(block)
192
+ for block in text_block_list:
193
+ for line in block['lines']:
194
+ for span in line['spans']:
195
+ if 'np_img' in span:
196
+ need_ocr_list.append(span)
197
+ img_crop_list.append(span['np_img'])
198
+ span.pop('np_img')
199
+ if len(img_crop_list) > 0:
200
+ atom_model_manager = AtomModelSingleton()
201
+ ocr_model = atom_model_manager.get_atom_model(
202
+ atom_model_name='ocr',
203
+ ocr_show_log=False,
204
+ det_db_box_thresh=0.3,
205
+ lang=lang
206
+ )
207
+ ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
208
+ assert len(ocr_res_list) == len(
209
+ need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
210
+ for index, span in enumerate(need_ocr_list):
211
+ ocr_text, ocr_score = ocr_res_list[index]
212
+ if ocr_score > OcrConfidence.min_confidence:
213
+ span['content'] = ocr_text
214
+ span['score'] = float(f"{ocr_score:.3f}")
215
+ else:
216
+ span['content'] = ''
217
+ span['score'] = 0.0
218
+
219
+ """分段"""
220
+ para_split(middle_json["pdf_info"])
221
+
222
+ """llm优化"""
223
+ llm_aided_config = get_llm_aided_config()
224
+
225
+ if llm_aided_config is not None:
226
+ """标题优化"""
227
+ title_aided_config = llm_aided_config.get('title_aided', None)
228
+ if title_aided_config is not None:
229
+ if title_aided_config.get('enable', False):
230
+ llm_aided_title_start_time = time.time()
231
+ llm_aided_title(middle_json["pdf_info"], title_aided_config)
232
+ logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
233
+
234
+ """清理内存"""
235
+ pdf_doc.close()
236
+ if os.getenv('MINERU_DONOT_CLEAN_MEM') is None and len(model_list) >= 10:
237
+ clean_memory(get_device())
238
+
239
+ return middle_json
240
+
241
+
242
+ def make_page_info_dict(blocks, page_id, page_w, page_h, discarded_blocks):
243
+ return_dict = {
244
+ 'preproc_blocks': blocks,
245
+ 'page_idx': page_id,
246
+ 'page_size': [page_w, page_h],
247
+ 'discarded_blocks': discarded_blocks,
248
+ }
249
+ return return_dict
vendor/mineru/mineru/backend/pipeline/model_list.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class AtomicModel:
2
+ Layout = "layout"
3
+ MFD = "mfd"
4
+ MFR = "mfr"
5
+ OCR = "ocr"
6
+ Table = "table"
vendor/mineru/mineru/backend/pipeline/para_split.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from loguru import logger
3
+ from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
4
+ from mineru.utils.language import detect_lang
5
+
6
+
7
+ LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
8
+ LIST_END_FLAG = ('.', '。', ';', ';')
9
+
10
+
11
+ class ListLineTag:
12
+ IS_LIST_START_LINE = 'is_list_start_line'
13
+ IS_LIST_END_LINE = 'is_list_end_line'
14
+
15
+
16
+ def __process_blocks(blocks):
17
+ # 对所有block预处理
18
+ # 1.通过title和interline_equation将block分组
19
+ # 2.bbox边界根据line信息重置
20
+
21
+ result = []
22
+ current_group = []
23
+
24
+ for i in range(len(blocks)):
25
+ current_block = blocks[i]
26
+
27
+ # 如果当前块是 text 类型
28
+ if current_block['type'] == 'text':
29
+ current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
30
+ if 'lines' in current_block and len(current_block['lines']) > 0:
31
+ current_block['bbox_fs'] = [
32
+ min([line['bbox'][0] for line in current_block['lines']]),
33
+ min([line['bbox'][1] for line in current_block['lines']]),
34
+ max([line['bbox'][2] for line in current_block['lines']]),
35
+ max([line['bbox'][3] for line in current_block['lines']]),
36
+ ]
37
+ current_group.append(current_block)
38
+
39
+ # 检查下一个块是否存在
40
+ if i + 1 < len(blocks):
41
+ next_block = blocks[i + 1]
42
+ # 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
43
+ if next_block['type'] in ['title', 'interline_equation']:
44
+ result.append(current_group)
45
+ current_group = []
46
+
47
+ # 处理最后一个 group
48
+ if current_group:
49
+ result.append(current_group)
50
+
51
+ return result
52
+
53
+
54
+ def __is_list_or_index_block(block):
55
+ # 一个block如果是list block 应该同时满足以下特征
56
+ # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
57
+ # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
58
+ # 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
59
+
60
+ # index block 是一种特殊的list block
61
+ # 一个block如果是index block 应该同时满足以下特征
62
+ # 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
63
+ if len(block['lines']) >= 2:
64
+ first_line = block['lines'][0]
65
+ line_height = first_line['bbox'][3] - first_line['bbox'][1]
66
+ block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
67
+ block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
68
+ page_weight, page_height = block['page_size']
69
+
70
+ left_close_num = 0
71
+ left_not_close_num = 0
72
+ right_not_close_num = 0
73
+ right_close_num = 0
74
+ lines_text_list = []
75
+ center_close_num = 0
76
+ external_sides_not_close_num = 0
77
+ multiple_para_flag = False
78
+ last_line = block['lines'][-1]
79
+
80
+ if page_weight == 0:
81
+ block_weight_radio = 0
82
+ else:
83
+ block_weight_radio = block_weight / page_weight
84
+ # logger.info(f"block_weight_radio: {block_weight_radio}")
85
+
86
+ # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
87
+ if (
88
+ first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
89
+ and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
90
+ and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
91
+ ):
92
+ multiple_para_flag = True
93
+
94
+ block_text = ''
95
+
96
+ for line in block['lines']:
97
+ line_text = ''
98
+
99
+ for span in line['spans']:
100
+ span_type = span['type']
101
+ if span_type == ContentType.TEXT:
102
+ line_text += span['content'].strip()
103
+ # 添加所有文本,包括空行,保持与block['lines']长度一致
104
+ lines_text_list.append(line_text)
105
+ block_text = ''.join(lines_text_list)
106
+
107
+ block_lang = detect_lang(block_text)
108
+ # logger.info(f"block_lang: {block_lang}")
109
+
110
+ for line in block['lines']:
111
+ line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
112
+ block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
113
+ if (
114
+ line['bbox'][0] - block['bbox_fs'][0] > 0.7 * line_height
115
+ and block['bbox_fs'][2] - line['bbox'][2] > 0.7 * line_height
116
+ ):
117
+ external_sides_not_close_num += 1
118
+ if abs(line_mid_x - block_mid_x) < line_height / 2:
119
+ center_close_num += 1
120
+
121
+ # 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
122
+ if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
123
+ left_close_num += 1
124
+ elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
125
+ left_not_close_num += 1
126
+
127
+ # 计算右侧是否顶格
128
+ if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height:
129
+ right_close_num += 1
130
+ else:
131
+ # 类中文没有超长单词的情况,可以用统一的阈值
132
+ if block_lang in ['zh', 'ja', 'ko']:
133
+ closed_area = 0.26 * block_weight
134
+ else:
135
+ # 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
136
+ # block宽的阈值可以小些,block窄的阈值要大
137
+ if block_weight_radio >= 0.5:
138
+ closed_area = 0.26 * block_weight
139
+ else:
140
+ closed_area = 0.36 * block_weight
141
+ if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
142
+ right_not_close_num += 1
143
+
144
+ # 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
145
+ line_end_flag = False
146
+ # 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
147
+ line_num_flag = False
148
+ num_start_count = 0
149
+ num_end_count = 0
150
+ flag_end_count = 0
151
+
152
+ if len(lines_text_list) > 0:
153
+ for line_text in lines_text_list:
154
+ if len(line_text) > 0:
155
+ if line_text[-1] in LIST_END_FLAG:
156
+ flag_end_count += 1
157
+ if line_text[0].isdigit():
158
+ num_start_count += 1
159
+ if line_text[-1].isdigit():
160
+ num_end_count += 1
161
+
162
+ if (
163
+ num_start_count / len(lines_text_list) >= 0.8
164
+ or num_end_count / len(lines_text_list) >= 0.8
165
+ ):
166
+ line_num_flag = True
167
+ if flag_end_count / len(lines_text_list) >= 0.8:
168
+ line_end_flag = True
169
+
170
+ # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
171
+ if (
172
+ left_close_num / len(block['lines']) >= 0.8
173
+ or right_close_num / len(block['lines']) >= 0.8
174
+ ) and line_num_flag:
175
+ for line in block['lines']:
176
+ line[ListLineTag.IS_LIST_START_LINE] = True
177
+ return BlockType.INDEX
178
+
179
+ # 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
180
+ # 补充条件block的长宽比有要求
181
+ elif (
182
+ external_sides_not_close_num >= 2
183
+ and center_close_num == len(block['lines'])
184
+ and external_sides_not_close_num / len(block['lines']) >= 0.5
185
+ and block_height / block_weight > 0.4
186
+ ):
187
+ for line in block['lines']:
188
+ line[ListLineTag.IS_LIST_START_LINE] = True
189
+ return BlockType.LIST
190
+
191
+ elif (
192
+ left_close_num >= 2
193
+ and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
194
+ and not multiple_para_flag
195
+ # and block_weight_radio > 0.27
196
+ ):
197
+ # 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
198
+ if left_close_num / len(block['lines']) > 0.8:
199
+ # 这种是每个item只有一行,且左边都贴边的短item list
200
+ if flag_end_count == 0 and right_close_num / len(block['lines']) < 0.5:
201
+ for line in block['lines']:
202
+ if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
203
+ line[ListLineTag.IS_LIST_START_LINE] = True
204
+ # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
205
+ elif line_end_flag:
206
+ for i, line in enumerate(block['lines']):
207
+ if (
208
+ len(lines_text_list[i]) > 0
209
+ and lines_text_list[i][-1] in LIST_END_FLAG
210
+ ):
211
+ line[ListLineTag.IS_LIST_END_LINE] = True
212
+ if i + 1 < len(block['lines']):
213
+ block['lines'][i + 1][
214
+ ListLineTag.IS_LIST_START_LINE
215
+ ] = True
216
+ # line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
217
+ else:
218
+ line_start_flag = False
219
+ for i, line in enumerate(block['lines']):
220
+ if line_start_flag:
221
+ line[ListLineTag.IS_LIST_START_LINE] = True
222
+ line_start_flag = False
223
+
224
+ if (
225
+ abs(block['bbox_fs'][2] - line['bbox'][2])
226
+ > 0.1 * block_weight
227
+ ):
228
+ line[ListLineTag.IS_LIST_END_LINE] = True
229
+ line_start_flag = True
230
+ # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
231
+ elif num_start_count >= 2 and num_start_count == flag_end_count:
232
+ for i, line in enumerate(block['lines']):
233
+ if len(lines_text_list[i]) > 0:
234
+ if lines_text_list[i][0].isdigit():
235
+ line[ListLineTag.IS_LIST_START_LINE] = True
236
+ if lines_text_list[i][-1] in LIST_END_FLAG:
237
+ line[ListLineTag.IS_LIST_END_LINE] = True
238
+ else:
239
+ # 正常有缩进的list处理
240
+ for line in block['lines']:
241
+ if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
242
+ line[ListLineTag.IS_LIST_START_LINE] = True
243
+ if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
244
+ line[ListLineTag.IS_LIST_END_LINE] = True
245
+
246
+ return BlockType.LIST
247
+ else:
248
+ return BlockType.TEXT
249
+ else:
250
+ return BlockType.TEXT
251
+
252
+
253
+ def __merge_2_text_blocks(block1, block2):
254
+ if len(block1['lines']) > 0:
255
+ first_line = block1['lines'][0]
256
+ line_height = first_line['bbox'][3] - first_line['bbox'][1]
257
+ block1_weight = block1['bbox'][2] - block1['bbox'][0]
258
+ block2_weight = block2['bbox'][2] - block2['bbox'][0]
259
+ min_block_weight = min(block1_weight, block2_weight)
260
+ if abs(block1['bbox_fs'][0] - first_line['bbox'][0]) < line_height / 2:
261
+ last_line = block2['lines'][-1]
262
+ if len(last_line['spans']) > 0:
263
+ last_span = last_line['spans'][-1]
264
+ line_height = last_line['bbox'][3] - last_line['bbox'][1]
265
+ if len(first_line['spans']) > 0:
266
+ first_span = first_line['spans'][0]
267
+ if len(first_span['content']) > 0:
268
+ span_start_with_num = first_span['content'][0].isdigit()
269
+ span_start_with_big_char = first_span['content'][0].isupper()
270
+ if (
271
+ # 上一个block的最后一个line的右边界和block的右边界差距不超过line_height
272
+ abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height
273
+ # 上一个block的最后一个span不是以特定符号结尾
274
+ and not last_span['content'].endswith(LINE_STOP_FLAG)
275
+ # 两个block宽度差距超过2倍也不合并
276
+ and abs(block1_weight - block2_weight) < min_block_weight
277
+ # 下一个block的第一个字符是数字
278
+ and not span_start_with_num
279
+ # 下一个block的第一个字符是大写字母
280
+ and not span_start_with_big_char
281
+ ):
282
+ if block1['page_num'] != block2['page_num']:
283
+ for line in block1['lines']:
284
+ for span in line['spans']:
285
+ span[SplitFlag.CROSS_PAGE] = True
286
+ block2['lines'].extend(block1['lines'])
287
+ block1['lines'] = []
288
+ block1[SplitFlag.LINES_DELETED] = True
289
+
290
+ return block1, block2
291
+
292
+
293
+ def __merge_2_list_blocks(block1, block2):
294
+ if block1['page_num'] != block2['page_num']:
295
+ for line in block1['lines']:
296
+ for span in line['spans']:
297
+ span[SplitFlag.CROSS_PAGE] = True
298
+ block2['lines'].extend(block1['lines'])
299
+ block1['lines'] = []
300
+ block1[SplitFlag.LINES_DELETED] = True
301
+
302
+ return block1, block2
303
+
304
+
305
+ def __is_list_group(text_blocks_group):
306
+ # list group的特征是一个group内的所有block都满足以下条件
307
+ # 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
308
+ for block in text_blocks_group:
309
+ if len(block['lines']) > 3:
310
+ return False
311
+ return True
312
+
313
+
314
+ def __para_merge_page(blocks):
315
+ page_text_blocks_groups = __process_blocks(blocks)
316
+ for text_blocks_group in page_text_blocks_groups:
317
+ if len(text_blocks_group) > 0:
318
+ # 需要先在合并前对所有block判断是否为list or index block
319
+ for block in text_blocks_group:
320
+ block_type = __is_list_or_index_block(block)
321
+ block['type'] = block_type
322
+ # logger.info(f"{block['type']}:{block}")
323
+
324
+ if len(text_blocks_group) > 1:
325
+ # 在合并前判断这个group 是否是一个 list group
326
+ is_list_group = __is_list_group(text_blocks_group)
327
+
328
+ # 倒序遍历
329
+ for i in range(len(text_blocks_group) - 1, -1, -1):
330
+ current_block = text_blocks_group[i]
331
+
332
+ # 检查是否有前一个块
333
+ if i - 1 >= 0:
334
+ prev_block = text_blocks_group[i - 1]
335
+
336
+ if (
337
+ current_block['type'] == 'text'
338
+ and prev_block['type'] == 'text'
339
+ and not is_list_group
340
+ ):
341
+ __merge_2_text_blocks(current_block, prev_block)
342
+ elif (
343
+ current_block['type'] == BlockType.LIST
344
+ and prev_block['type'] == BlockType.LIST
345
+ ) or (
346
+ current_block['type'] == BlockType.INDEX
347
+ and prev_block['type'] == BlockType.INDEX
348
+ ):
349
+ __merge_2_list_blocks(current_block, prev_block)
350
+
351
+ else:
352
+ continue
353
+
354
+
355
+ def para_split(page_info_list):
356
+ all_blocks = []
357
+ for page_info in page_info_list:
358
+ blocks = copy.deepcopy(page_info['preproc_blocks'])
359
+ for block in blocks:
360
+ block['page_num'] = page_info['page_idx']
361
+ block['page_size'] = page_info['page_size']
362
+ all_blocks.extend(blocks)
363
+
364
+ __para_merge_page(all_blocks)
365
+ for page_info in page_info_list:
366
+ page_info['para_blocks'] = []
367
+ for block in all_blocks:
368
+ if 'page_num' in block:
369
+ if block['page_num'] == page_info['page_idx']:
370
+ page_info['para_blocks'].append(block)
371
+ # 从block中删除不需要的page_num和page_size字段
372
+ del block['page_num']
373
+ del block['page_size']
374
+
375
+
376
+ if __name__ == '__main__':
377
+ input_blocks = []
378
+ # 调用函数
379
+ groups = __process_blocks(input_blocks)
380
+ for group_index, group in enumerate(groups):
381
+ print(f'Group {group_index}: {group}')
vendor/mineru/mineru/backend/pipeline/pipeline_analyze.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import List, Tuple
4
+ import PIL.Image
5
+ from loguru import logger
6
+
7
+ from .model_init import MineruPipelineModel
8
+ from mineru.utils.config_reader import get_device
9
+ from ...utils.pdf_classify import classify
10
+ from ...utils.pdf_image_tools import load_images_from_pdf
11
+ from ...utils.model_utils import get_vram, clean_memory
12
+
13
+
14
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
15
+ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
16
+
17
+ class ModelSingleton:
18
+ _instance = None
19
+ _models = {}
20
+
21
+ def __new__(cls, *args, **kwargs):
22
+ if cls._instance is None:
23
+ cls._instance = super().__new__(cls)
24
+ return cls._instance
25
+
26
+ def get_model(
27
+ self,
28
+ lang=None,
29
+ formula_enable=None,
30
+ table_enable=None,
31
+ ):
32
+ key = (lang, formula_enable, table_enable)
33
+ if key not in self._models:
34
+ self._models[key] = custom_model_init(
35
+ lang=lang,
36
+ formula_enable=formula_enable,
37
+ table_enable=table_enable,
38
+ )
39
+ return self._models[key]
40
+
41
+
42
+ def custom_model_init(
43
+ lang=None,
44
+ formula_enable=True,
45
+ table_enable=True,
46
+ ):
47
+ model_init_start = time.time()
48
+ # 从配置文件读取model-dir和device
49
+ device = get_device()
50
+
51
+ formula_config = {"enable": formula_enable}
52
+ table_config = {"enable": table_enable}
53
+
54
+ model_input = {
55
+ 'device': device,
56
+ 'table_config': table_config,
57
+ 'formula_config': formula_config,
58
+ 'lang': lang,
59
+ }
60
+
61
+ custom_model = MineruPipelineModel(**model_input)
62
+
63
+ model_init_cost = time.time() - model_init_start
64
+ logger.info(f'model init cost: {model_init_cost}')
65
+
66
+ return custom_model
67
+
68
+
69
+ def doc_analyze(
70
+ pdf_bytes_list,
71
+ lang_list,
72
+ parse_method: str = 'auto',
73
+ formula_enable=True,
74
+ table_enable=True,
75
+ ):
76
+ """
77
+ 适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,可能会增加显存使用量,
78
+ 可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为128。
79
+ """
80
+ min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 128))
81
+
82
+ # 收集所有页面信息
83
+ all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
84
+
85
+ all_image_lists = []
86
+ all_pdf_docs = []
87
+ ocr_enabled_list = []
88
+ for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
89
+ # 确定OCR设置
90
+ _ocr_enable = False
91
+ if parse_method == 'auto':
92
+ if classify(pdf_bytes) == 'ocr':
93
+ _ocr_enable = True
94
+ elif parse_method == 'ocr':
95
+ _ocr_enable = True
96
+
97
+ ocr_enabled_list.append(_ocr_enable)
98
+ _lang = lang_list[pdf_idx]
99
+
100
+ # 收集每个数据集中的页面
101
+ images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
102
+ all_image_lists.append(images_list)
103
+ all_pdf_docs.append(pdf_doc)
104
+ for page_idx in range(len(images_list)):
105
+ img_dict = images_list[page_idx]
106
+ all_pages_info.append((
107
+ pdf_idx, page_idx,
108
+ img_dict['img_pil'], _ocr_enable, _lang,
109
+ ))
110
+
111
+ # 准备批处理
112
+ images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
113
+ batch_size = min_batch_inference_size
114
+ batch_images = [
115
+ images_with_extra_info[i:i + batch_size]
116
+ for i in range(0, len(images_with_extra_info), batch_size)
117
+ ]
118
+
119
+ # 执行批处理
120
+ results = []
121
+ processed_images_count = 0
122
+ for index, batch_image in enumerate(batch_images):
123
+ processed_images_count += len(batch_image)
124
+ logger.info(
125
+ f'Batch {index + 1}/{len(batch_images)}: '
126
+ f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
127
+ )
128
+ batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
129
+ results.extend(batch_results)
130
+
131
+ # 构建返回结果
132
+ infer_results = []
133
+
134
+ for _ in range(len(pdf_bytes_list)):
135
+ infer_results.append([])
136
+
137
+ for i, page_info in enumerate(all_pages_info):
138
+ pdf_idx, page_idx, pil_img, _, _ = page_info
139
+ result = results[i]
140
+
141
+ page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
142
+ page_dict = {'layout_dets': result, 'page_info': page_info_dict}
143
+
144
+ infer_results[pdf_idx].append(page_dict)
145
+
146
+ return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
147
+
148
+
149
+ def batch_image_analyze(
150
+ images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
151
+ formula_enable=True,
152
+ table_enable=True):
153
+ # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
154
+
155
+ from .batch_analyze import BatchAnalyze
156
+
157
+ model_manager = ModelSingleton()
158
+
159
+ batch_ratio = 1
160
+ device = get_device()
161
+
162
+ if str(device).startswith('npu'):
163
+ try:
164
+ import torch_npu
165
+ if torch_npu.npu.is_available():
166
+ torch_npu.npu.set_compile_mode(jit_compile=False)
167
+ except Exception as e:
168
+ raise RuntimeError(
169
+ "NPU is selected as device, but torch_npu is not available. "
170
+ "Please ensure that the torch_npu package is installed correctly."
171
+ ) from e
172
+
173
+ if str(device).startswith('npu') or str(device).startswith('cuda'):
174
+ vram = get_vram(device)
175
+ if vram is not None:
176
+ gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
177
+ if gpu_memory >= 16:
178
+ batch_ratio = 16
179
+ elif gpu_memory >= 12:
180
+ batch_ratio = 8
181
+ elif gpu_memory >= 8:
182
+ batch_ratio = 4
183
+ elif gpu_memory >= 6:
184
+ batch_ratio = 2
185
+ else:
186
+ batch_ratio = 1
187
+ logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
188
+ else:
189
+ # Default batch_ratio when VRAM can't be determined
190
+ batch_ratio = 1
191
+ logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
192
+
193
+ batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
194
+ results = batch_model(images_with_extra_info)
195
+
196
+ clean_memory(get_device())
197
+
198
+ return results
vendor/mineru/mineru/backend/pipeline/pipeline_magic_model.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
2
+ from mineru.utils.enum_class import CategoryId, ContentType
3
+
4
+
5
+ class MagicModel:
6
+ """每个函数没有得到元素的时候返回空list."""
7
+ def __init__(self, page_model_info: dict, scale: float):
8
+ self.__page_model_info = page_model_info
9
+ self.__scale = scale
10
+ """为所有模型数据添加bbox信息(缩放,poly->bbox)"""
11
+ self.__fix_axis()
12
+ """删除置信度特别低的模型数据(<0.05),提高质量"""
13
+ self.__fix_by_remove_low_confidence()
14
+ """删除高iou(>0.9)数据中置信度较低的那个"""
15
+ self.__fix_by_remove_high_iou_and_low_confidence()
16
+ """将部分tbale_footnote修正为image_footnote"""
17
+ self.__fix_footnote()
18
+ """处理重叠的image_body和table_body"""
19
+ self.__fix_by_remove_overlap_image_table_body()
20
+
21
+ def __fix_by_remove_overlap_image_table_body(self):
22
+ need_remove_list = []
23
+ layout_dets = self.__page_model_info['layout_dets']
24
+ image_blocks = list(filter(
25
+ lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
26
+ ))
27
+ table_blocks = list(filter(
28
+ lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
29
+ ))
30
+
31
+ def add_need_remove_block(blocks):
32
+ for i in range(len(blocks)):
33
+ for j in range(i + 1, len(blocks)):
34
+ block1 = blocks[i]
35
+ block2 = blocks[j]
36
+ overlap_box = get_minbox_if_overlap_by_ratio(
37
+ block1['bbox'], block2['bbox'], 0.8
38
+ )
39
+ if overlap_box is not None:
40
+ # 判断哪个区块的面积更小,移除较小的区块
41
+ area1 = (block1['bbox'][2] - block1['bbox'][0]) * (block1['bbox'][3] - block1['bbox'][1])
42
+ area2 = (block2['bbox'][2] - block2['bbox'][0]) * (block2['bbox'][3] - block2['bbox'][1])
43
+
44
+ if area1 <= area2:
45
+ block_to_remove = block1
46
+ large_block = block2
47
+ else:
48
+ block_to_remove = block2
49
+ large_block = block1
50
+
51
+ if block_to_remove not in need_remove_list:
52
+ # 扩展大区块的边界框
53
+ x1, y1, x2, y2 = large_block['bbox']
54
+ sx1, sy1, sx2, sy2 = block_to_remove['bbox']
55
+ x1 = min(x1, sx1)
56
+ y1 = min(y1, sy1)
57
+ x2 = max(x2, sx2)
58
+ y2 = max(y2, sy2)
59
+ large_block['bbox'] = [x1, y1, x2, y2]
60
+ need_remove_list.append(block_to_remove)
61
+
62
+ # 处理图像-图像重叠
63
+ add_need_remove_block(image_blocks)
64
+ # 处理表格-表格重叠
65
+ add_need_remove_block(table_blocks)
66
+
67
+ # 从布局中移除标记的区块
68
+ for need_remove in need_remove_list:
69
+ if need_remove in layout_dets:
70
+ layout_dets.remove(need_remove)
71
+
72
+
73
+ def __fix_axis(self):
74
+ need_remove_list = []
75
+ layout_dets = self.__page_model_info['layout_dets']
76
+ for layout_det in layout_dets:
77
+ x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
78
+ bbox = [
79
+ int(x0 / self.__scale),
80
+ int(y0 / self.__scale),
81
+ int(x1 / self.__scale),
82
+ int(y1 / self.__scale),
83
+ ]
84
+ layout_det['bbox'] = bbox
85
+ # 删除高度或者宽度小于等于0的spans
86
+ if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
87
+ need_remove_list.append(layout_det)
88
+ for need_remove in need_remove_list:
89
+ layout_dets.remove(need_remove)
90
+
91
+ def __fix_by_remove_low_confidence(self):
92
+ need_remove_list = []
93
+ layout_dets = self.__page_model_info['layout_dets']
94
+ for layout_det in layout_dets:
95
+ if layout_det['score'] <= 0.05:
96
+ need_remove_list.append(layout_det)
97
+ else:
98
+ continue
99
+ for need_remove in need_remove_list:
100
+ layout_dets.remove(need_remove)
101
+
102
+ def __fix_by_remove_high_iou_and_low_confidence(self):
103
+ need_remove_list = []
104
+ layout_dets = list(filter(
105
+ lambda x: x['category_id'] in [
106
+ CategoryId.Title,
107
+ CategoryId.Text,
108
+ CategoryId.ImageBody,
109
+ CategoryId.ImageCaption,
110
+ CategoryId.TableBody,
111
+ CategoryId.TableCaption,
112
+ CategoryId.TableFootnote,
113
+ CategoryId.InterlineEquation_Layout,
114
+ CategoryId.InterlineEquationNumber_Layout,
115
+ ], self.__page_model_info['layout_dets']
116
+ )
117
+ )
118
+ for i in range(len(layout_dets)):
119
+ for j in range(i + 1, len(layout_dets)):
120
+ layout_det1 = layout_dets[i]
121
+ layout_det2 = layout_dets[j]
122
+
123
+ if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
124
+
125
+ layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
126
+
127
+ if layout_det_need_remove not in need_remove_list:
128
+ need_remove_list.append(layout_det_need_remove)
129
+
130
+ for need_remove in need_remove_list:
131
+ self.__page_model_info['layout_dets'].remove(need_remove)
132
+
133
+ def __fix_footnote(self):
134
+ footnotes = []
135
+ figures = []
136
+ tables = []
137
+
138
+ for obj in self.__page_model_info['layout_dets']:
139
+ if obj['category_id'] == CategoryId.TableFootnote:
140
+ footnotes.append(obj)
141
+ elif obj['category_id'] == CategoryId.ImageBody:
142
+ figures.append(obj)
143
+ elif obj['category_id'] == CategoryId.TableBody:
144
+ tables.append(obj)
145
+ if len(footnotes) * len(figures) == 0:
146
+ continue
147
+ dis_figure_footnote = {}
148
+ dis_table_footnote = {}
149
+
150
+ for i in range(len(footnotes)):
151
+ for j in range(len(figures)):
152
+ pos_flag_count = sum(
153
+ list(
154
+ map(
155
+ lambda x: 1 if x else 0,
156
+ bbox_relative_pos(
157
+ footnotes[i]['bbox'], figures[j]['bbox']
158
+ ),
159
+ )
160
+ )
161
+ )
162
+ if pos_flag_count > 1:
163
+ continue
164
+ dis_figure_footnote[i] = min(
165
+ self._bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
166
+ dis_figure_footnote.get(i, float('inf')),
167
+ )
168
+ for i in range(len(footnotes)):
169
+ for j in range(len(tables)):
170
+ pos_flag_count = sum(
171
+ list(
172
+ map(
173
+ lambda x: 1 if x else 0,
174
+ bbox_relative_pos(
175
+ footnotes[i]['bbox'], tables[j]['bbox']
176
+ ),
177
+ )
178
+ )
179
+ )
180
+ if pos_flag_count > 1:
181
+ continue
182
+
183
+ dis_table_footnote[i] = min(
184
+ self._bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
185
+ dis_table_footnote.get(i, float('inf')),
186
+ )
187
+ for i in range(len(footnotes)):
188
+ if i not in dis_figure_footnote:
189
+ continue
190
+ if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
191
+ footnotes[i]['category_id'] = CategoryId.ImageFootnote
192
+
193
+ def _bbox_distance(self, bbox1, bbox2):
194
+ left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
195
+ flags = [left, right, bottom, top]
196
+ count = sum([1 if v else 0 for v in flags])
197
+ if count > 1:
198
+ return float('inf')
199
+ if left or right:
200
+ l1 = bbox1[3] - bbox1[1]
201
+ l2 = bbox2[3] - bbox2[1]
202
+ else:
203
+ l1 = bbox1[2] - bbox1[0]
204
+ l2 = bbox2[2] - bbox2[0]
205
+
206
+ if l2 > l1 and (l2 - l1) / l1 > 0.3:
207
+ return float('inf')
208
+
209
+ return bbox_distance(bbox1, bbox2)
210
+
211
+ def __reduct_overlap(self, bboxes):
212
+ N = len(bboxes)
213
+ keep = [True] * N
214
+ for i in range(N):
215
+ for j in range(N):
216
+ if i == j:
217
+ continue
218
+ if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
219
+ keep[i] = False
220
+ return [bboxes[i] for i in range(N) if keep[i]]
221
+
222
+ def __tie_up_category_by_distance_v3(
223
+ self,
224
+ subject_category_id: int,
225
+ object_category_id: int,
226
+ ):
227
+ subjects = self.__reduct_overlap(
228
+ list(
229
+ map(
230
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
231
+ filter(
232
+ lambda x: x['category_id'] == subject_category_id,
233
+ self.__page_model_info['layout_dets'],
234
+ ),
235
+ )
236
+ )
237
+ )
238
+ objects = self.__reduct_overlap(
239
+ list(
240
+ map(
241
+ lambda x: {'bbox': x['bbox'], 'score': x['score']},
242
+ filter(
243
+ lambda x: x['category_id'] == object_category_id,
244
+ self.__page_model_info['layout_dets'],
245
+ ),
246
+ )
247
+ )
248
+ )
249
+
250
+ ret = []
251
+ N, M = len(subjects), len(objects)
252
+ subjects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
253
+ objects.sort(key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2)
254
+
255
+ OBJ_IDX_OFFSET = 10000
256
+ SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
257
+
258
+ all_boxes_with_idx = [(i, SUB_BIT_KIND, sub['bbox'][0], sub['bbox'][1]) for i, sub in enumerate(subjects)] + [(i + OBJ_IDX_OFFSET , OBJ_BIT_KIND, obj['bbox'][0], obj['bbox'][1]) for i, obj in enumerate(objects)]
259
+ seen_idx = set()
260
+ seen_sub_idx = set()
261
+
262
+ while N > len(seen_sub_idx):
263
+ candidates = []
264
+ for idx, kind, x0, y0 in all_boxes_with_idx:
265
+ if idx in seen_idx:
266
+ continue
267
+ candidates.append((idx, kind, x0, y0))
268
+
269
+ if len(candidates) == 0:
270
+ break
271
+ left_x = min([v[2] for v in candidates])
272
+ top_y = min([v[3] for v in candidates])
273
+
274
+ candidates.sort(key=lambda x: (x[2]-left_x) ** 2 + (x[3] - top_y) ** 2)
275
+
276
+
277
+ fst_idx, fst_kind, left_x, top_y = candidates[0]
278
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y)**2)
279
+ nxt = None
280
+
281
+ for i in range(1, len(candidates)):
282
+ if candidates[i][1] ^ fst_kind == 1:
283
+ nxt = candidates[i]
284
+ break
285
+ if nxt is None:
286
+ break
287
+
288
+ if fst_kind == SUB_BIT_KIND:
289
+ sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
290
+
291
+ else:
292
+ sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
293
+
294
+ pair_dis = bbox_distance(subjects[sub_idx]['bbox'], objects[obj_idx]['bbox'])
295
+ nearest_dis = float('inf')
296
+ for i in range(N):
297
+ if i in seen_idx or i == sub_idx:continue
298
+ nearest_dis = min(nearest_dis, bbox_distance(subjects[i]['bbox'], objects[obj_idx]['bbox']))
299
+
300
+ if pair_dis >= 3*nearest_dis:
301
+ seen_idx.add(sub_idx)
302
+ continue
303
+
304
+ seen_idx.add(sub_idx)
305
+ seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
306
+ seen_sub_idx.add(sub_idx)
307
+
308
+ ret.append(
309
+ {
310
+ 'sub_bbox': {
311
+ 'bbox': subjects[sub_idx]['bbox'],
312
+ 'score': subjects[sub_idx]['score'],
313
+ },
314
+ 'obj_bboxes': [
315
+ {'score': objects[obj_idx]['score'], 'bbox': objects[obj_idx]['bbox']}
316
+ ],
317
+ 'sub_idx': sub_idx,
318
+ }
319
+ )
320
+
321
+ for i in range(len(objects)):
322
+ j = i + OBJ_IDX_OFFSET
323
+ if j in seen_idx:
324
+ continue
325
+ seen_idx.add(j)
326
+ nearest_dis, nearest_sub_idx = float('inf'), -1
327
+ for k in range(len(subjects)):
328
+ dis = bbox_distance(objects[i]['bbox'], subjects[k]['bbox'])
329
+ if dis < nearest_dis:
330
+ nearest_dis = dis
331
+ nearest_sub_idx = k
332
+
333
+ for k in range(len(subjects)):
334
+ if k != nearest_sub_idx: continue
335
+ if k in seen_sub_idx:
336
+ for kk in range(len(ret)):
337
+ if ret[kk]['sub_idx'] == k:
338
+ ret[kk]['obj_bboxes'].append({'score': objects[i]['score'], 'bbox': objects[i]['bbox']})
339
+ break
340
+ else:
341
+ ret.append(
342
+ {
343
+ 'sub_bbox': {
344
+ 'bbox': subjects[k]['bbox'],
345
+ 'score': subjects[k]['score'],
346
+ },
347
+ 'obj_bboxes': [
348
+ {'score': objects[i]['score'], 'bbox': objects[i]['bbox']}
349
+ ],
350
+ 'sub_idx': k,
351
+ }
352
+ )
353
+ seen_sub_idx.add(k)
354
+ seen_idx.add(k)
355
+
356
+
357
+ for i in range(len(subjects)):
358
+ if i in seen_sub_idx:
359
+ continue
360
+ ret.append(
361
+ {
362
+ 'sub_bbox': {
363
+ 'bbox': subjects[i]['bbox'],
364
+ 'score': subjects[i]['score'],
365
+ },
366
+ 'obj_bboxes': [],
367
+ 'sub_idx': i,
368
+ }
369
+ )
370
+
371
+
372
+ return ret
373
+
374
+ def get_imgs(self):
375
+ with_captions = self.__tie_up_category_by_distance_v3(
376
+ CategoryId.ImageBody, CategoryId.ImageCaption
377
+ )
378
+ with_footnotes = self.__tie_up_category_by_distance_v3(
379
+ CategoryId.ImageBody, CategoryId.ImageFootnote
380
+ )
381
+ ret = []
382
+ for v in with_captions:
383
+ record = {
384
+ 'image_body': v['sub_bbox'],
385
+ 'image_caption_list': v['obj_bboxes'],
386
+ }
387
+ filter_idx = v['sub_idx']
388
+ d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
389
+ record['image_footnote_list'] = d['obj_bboxes']
390
+ ret.append(record)
391
+ return ret
392
+
393
+ def get_tables(self) -> list:
394
+ with_captions = self.__tie_up_category_by_distance_v3(
395
+ CategoryId.TableBody, CategoryId.TableCaption
396
+ )
397
+ with_footnotes = self.__tie_up_category_by_distance_v3(
398
+ CategoryId.TableBody, CategoryId.TableFootnote
399
+ )
400
+ ret = []
401
+ for v in with_captions:
402
+ record = {
403
+ 'table_body': v['sub_bbox'],
404
+ 'table_caption_list': v['obj_bboxes'],
405
+ }
406
+ filter_idx = v['sub_idx']
407
+ d = next(filter(lambda x: x['sub_idx'] == filter_idx, with_footnotes))
408
+ record['table_footnote_list'] = d['obj_bboxes']
409
+ ret.append(record)
410
+ return ret
411
+
412
+ def get_equations(self) -> tuple[list, list, list]: # 有坐标,也有字
413
+ inline_equations = self.__get_blocks_by_type(
414
+ CategoryId.InlineEquation, ['latex']
415
+ )
416
+ interline_equations = self.__get_blocks_by_type(
417
+ CategoryId.InterlineEquation_YOLO, ['latex']
418
+ )
419
+ interline_equations_blocks = self.__get_blocks_by_type(
420
+ CategoryId.InterlineEquation_Layout
421
+ )
422
+ return inline_equations, interline_equations, interline_equations_blocks
423
+
424
+ def get_discarded(self) -> list: # 自研模型,只有坐标
425
+ blocks = self.__get_blocks_by_type(CategoryId.Abandon)
426
+ return blocks
427
+
428
+ def get_text_blocks(self) -> list: # 自研模型搞的,只有坐标,没有字
429
+ blocks = self.__get_blocks_by_type(CategoryId.Text)
430
+ return blocks
431
+
432
+ def get_title_blocks(self) -> list: # 自研模型,只有坐标,没字
433
+ blocks = self.__get_blocks_by_type(CategoryId.Title)
434
+ return blocks
435
+
436
+ def get_all_spans(self) -> list:
437
+
438
+ def remove_duplicate_spans(spans):
439
+ new_spans = []
440
+ for span in spans:
441
+ if not any(span == existing_span for existing_span in new_spans):
442
+ new_spans.append(span)
443
+ return new_spans
444
+
445
+ all_spans = []
446
+ layout_dets = self.__page_model_info['layout_dets']
447
+ allow_category_id_list = [
448
+ CategoryId.ImageBody,
449
+ CategoryId.TableBody,
450
+ CategoryId.InlineEquation,
451
+ CategoryId.InterlineEquation_YOLO,
452
+ CategoryId.OcrText,
453
+ ]
454
+ """当成span拼接的"""
455
+ for layout_det in layout_dets:
456
+ category_id = layout_det['category_id']
457
+ if category_id in allow_category_id_list:
458
+ span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
459
+ if category_id == CategoryId.ImageBody:
460
+ span['type'] = ContentType.IMAGE
461
+ elif category_id == CategoryId.TableBody:
462
+ # 获取table模型结果
463
+ latex = layout_det.get('latex', None)
464
+ html = layout_det.get('html', None)
465
+ if latex:
466
+ span['latex'] = latex
467
+ elif html:
468
+ span['html'] = html
469
+ span['type'] = ContentType.TABLE
470
+ elif category_id == CategoryId.InlineEquation:
471
+ span['content'] = layout_det['latex']
472
+ span['type'] = ContentType.INLINE_EQUATION
473
+ elif category_id == CategoryId.InterlineEquation_YOLO:
474
+ span['content'] = layout_det['latex']
475
+ span['type'] = ContentType.INTERLINE_EQUATION
476
+ elif category_id == CategoryId.OcrText:
477
+ span['content'] = layout_det['text']
478
+ span['type'] = ContentType.TEXT
479
+ all_spans.append(span)
480
+ return remove_duplicate_spans(all_spans)
481
+
482
+ def __get_blocks_by_type(
483
+ self, category_type: int, extra_col=None
484
+ ) -> list:
485
+ if extra_col is None:
486
+ extra_col = []
487
+ blocks = []
488
+ layout_dets = self.__page_model_info.get('layout_dets', [])
489
+ for item in layout_dets:
490
+ category_id = item.get('category_id', -1)
491
+ bbox = item.get('bbox', None)
492
+
493
+ if category_id == category_type:
494
+ block = {
495
+ 'bbox': bbox,
496
+ 'score': item.get('score'),
497
+ }
498
+ for col in extra_col:
499
+ block[col] = item.get(col, None)
500
+ blocks.append(block)
501
+ return blocks
vendor/mineru/mineru/backend/pipeline/pipeline_middle_json_mkcontent.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from loguru import logger
3
+
4
+ from mineru.utils.config_reader import get_latex_delimiter_config
5
+ from mineru.backend.pipeline.para_split import ListLineTag
6
+ from mineru.utils.enum_class import BlockType, ContentType, MakeMode
7
+ from mineru.utils.language import detect_lang
8
+
9
+
10
+ def __is_hyphen_at_line_end(line):
11
+ """Check if a line ends with one or more letters followed by a hyphen.
12
+
13
+ Args:
14
+ line (str): The line of text to check.
15
+
16
+ Returns:
17
+ bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
18
+ """
19
+ # Use regex to check if the line ends with one or more letters followed by a hyphen
20
+ return bool(re.search(r'[A-Za-z]+-\s*$', line))
21
+
22
+
23
+ def make_blocks_to_markdown(paras_of_layout,
24
+ mode,
25
+ img_buket_path='',
26
+ ):
27
+ page_markdown = []
28
+ for para_block in paras_of_layout:
29
+ para_text = ''
30
+ para_type = para_block['type']
31
+ if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
32
+ para_text = merge_para_with_text(para_block)
33
+ elif para_type == BlockType.TITLE:
34
+ title_level = get_title_level(para_block)
35
+ para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
36
+ elif para_type == BlockType.INTERLINE_EQUATION:
37
+ if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
38
+ continue
39
+ if para_block['lines'][0]['spans'][0].get('content', ''):
40
+ para_text = merge_para_with_text(para_block)
41
+ else:
42
+ para_text += f"![]({img_buket_path}/{para_block['lines'][0]['spans'][0]['image_path']})"
43
+ elif para_type == BlockType.IMAGE:
44
+ if mode == MakeMode.NLP_MD:
45
+ continue
46
+ elif mode == MakeMode.MM_MD:
47
+ # 检测是否存在图片脚注
48
+ has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
49
+ # 如果存在图片脚注,则将图片脚注拼接到图片正文后面
50
+ if has_image_footnote:
51
+ for block in para_block['blocks']: # 1st.拼image_caption
52
+ if block['type'] == BlockType.IMAGE_CAPTION:
53
+ para_text += merge_para_with_text(block) + ' \n'
54
+ for block in para_block['blocks']: # 2nd.拼image_body
55
+ if block['type'] == BlockType.IMAGE_BODY:
56
+ for line in block['lines']:
57
+ for span in line['spans']:
58
+ if span['type'] == ContentType.IMAGE:
59
+ if span.get('image_path', ''):
60
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
61
+ for block in para_block['blocks']: # 3rd.拼image_footnote
62
+ if block['type'] == BlockType.IMAGE_FOOTNOTE:
63
+ para_text += ' \n' + merge_para_with_text(block)
64
+ else:
65
+ for block in para_block['blocks']: # 1st.拼image_body
66
+ if block['type'] == BlockType.IMAGE_BODY:
67
+ for line in block['lines']:
68
+ for span in line['spans']:
69
+ if span['type'] == ContentType.IMAGE:
70
+ if span.get('image_path', ''):
71
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
72
+ for block in para_block['blocks']: # 2nd.拼image_caption
73
+ if block['type'] == BlockType.IMAGE_CAPTION:
74
+ para_text += ' \n' + merge_para_with_text(block)
75
+ elif para_type == BlockType.TABLE:
76
+ if mode == MakeMode.NLP_MD:
77
+ continue
78
+ elif mode == MakeMode.MM_MD:
79
+ for block in para_block['blocks']: # 1st.拼table_caption
80
+ if block['type'] == BlockType.TABLE_CAPTION:
81
+ para_text += merge_para_with_text(block) + ' \n'
82
+ for block in para_block['blocks']: # 2nd.拼table_body
83
+ if block['type'] == BlockType.TABLE_BODY:
84
+ for line in block['lines']:
85
+ for span in line['spans']:
86
+ if span['type'] == ContentType.TABLE:
87
+ # if processed by table model
88
+ if span.get('html', ''):
89
+ para_text += f"\n{span['html']}\n"
90
+ elif span.get('image_path', ''):
91
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
92
+ for block in para_block['blocks']: # 3rd.拼table_footnote
93
+ if block['type'] == BlockType.TABLE_FOOTNOTE:
94
+ para_text += '\n' + merge_para_with_text(block) + ' '
95
+
96
+ if para_text.strip() == '':
97
+ continue
98
+ else:
99
+ # page_markdown.append(para_text.strip() + ' ')
100
+ page_markdown.append(para_text.strip())
101
+
102
+ return page_markdown
103
+
104
+
105
+ def full_to_half(text: str) -> str:
106
+ """Convert full-width characters to half-width characters using code point manipulation.
107
+
108
+ Args:
109
+ text: String containing full-width characters
110
+
111
+ Returns:
112
+ String with full-width characters converted to half-width
113
+ """
114
+ result = []
115
+ for char in text:
116
+ code = ord(char)
117
+ # Full-width letters and numbers (FF21-FF3A for A-Z, FF41-FF5A for a-z, FF10-FF19 for 0-9)
118
+ if (0xFF21 <= code <= 0xFF3A) or (0xFF41 <= code <= 0xFF5A) or (0xFF10 <= code <= 0xFF19):
119
+ result.append(chr(code - 0xFEE0)) # Shift to ASCII range
120
+ else:
121
+ result.append(char)
122
+ return ''.join(result)
123
+
124
+ latex_delimiters_config = get_latex_delimiter_config()
125
+
126
+ default_delimiters = {
127
+ 'display': {'left': '$$', 'right': '$$'},
128
+ 'inline': {'left': '$', 'right': '$'}
129
+ }
130
+
131
+ delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
132
+
133
+ display_left_delimiter = delimiters['display']['left']
134
+ display_right_delimiter = delimiters['display']['right']
135
+ inline_left_delimiter = delimiters['inline']['left']
136
+ inline_right_delimiter = delimiters['inline']['right']
137
+
138
+ def merge_para_with_text(para_block):
139
+ block_text = ''
140
+ for line in para_block['lines']:
141
+ for span in line['spans']:
142
+ if span['type'] in [ContentType.TEXT]:
143
+ span['content'] = full_to_half(span['content'])
144
+ block_text += span['content']
145
+ block_lang = detect_lang(block_text)
146
+
147
+ para_text = ''
148
+ for i, line in enumerate(para_block['lines']):
149
+
150
+ if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
151
+ para_text += ' \n'
152
+
153
+ for j, span in enumerate(line['spans']):
154
+
155
+ span_type = span['type']
156
+ content = ''
157
+ if span_type == ContentType.TEXT:
158
+ content = escape_special_markdown_char(span['content'])
159
+ elif span_type == ContentType.INLINE_EQUATION:
160
+ if span.get('content', ''):
161
+ content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
162
+ elif span_type == ContentType.INTERLINE_EQUATION:
163
+ if span.get('content', ''):
164
+ content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
165
+
166
+ content = content.strip()
167
+
168
+ if content:
169
+ langs = ['zh', 'ja', 'ko']
170
+ # logger.info(f'block_lang: {block_lang}, content: {content}')
171
+ if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
172
+ if j == len(line['spans']) - 1 and span_type not in [ContentType.INLINE_EQUATION]:
173
+ para_text += content
174
+ else:
175
+ para_text += f'{content} '
176
+ else:
177
+ if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
178
+ # 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
179
+ if j == len(line['spans'])-1 and span_type == ContentType.TEXT and __is_hyphen_at_line_end(content):
180
+ para_text += content[:-1]
181
+ else: # 西方文本语境下 content间需要空格分隔
182
+ para_text += f'{content} '
183
+ elif span_type == ContentType.INTERLINE_EQUATION:
184
+ para_text += content
185
+ else:
186
+ continue
187
+
188
+ return para_text
189
+
190
+
191
+ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
192
+ para_type = para_block['type']
193
+ para_content = {}
194
+ if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
195
+ para_content = {
196
+ 'type': ContentType.TEXT,
197
+ 'text': merge_para_with_text(para_block),
198
+ }
199
+ elif para_type == BlockType.TITLE:
200
+ para_content = {
201
+ 'type': ContentType.TEXT,
202
+ 'text': merge_para_with_text(para_block),
203
+ }
204
+ title_level = get_title_level(para_block)
205
+ if title_level != 0:
206
+ para_content['text_level'] = title_level
207
+ elif para_type == BlockType.INTERLINE_EQUATION:
208
+ if len(para_block['lines']) == 0 or len(para_block['lines'][0]['spans']) == 0:
209
+ return None
210
+ para_content = {
211
+ 'type': ContentType.EQUATION,
212
+ 'img_path': f"{img_buket_path}/{para_block['lines'][0]['spans'][0].get('image_path', '')}",
213
+ }
214
+ if para_block['lines'][0]['spans'][0].get('content', ''):
215
+ para_content['text'] = merge_para_with_text(para_block)
216
+ para_content['text_format'] = 'latex'
217
+ elif para_type == BlockType.IMAGE:
218
+ para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
219
+ for block in para_block['blocks']:
220
+ if block['type'] == BlockType.IMAGE_BODY:
221
+ for line in block['lines']:
222
+ for span in line['spans']:
223
+ if span['type'] == ContentType.IMAGE:
224
+ if span.get('image_path', ''):
225
+ para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
226
+ if block['type'] == BlockType.IMAGE_CAPTION:
227
+ para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
228
+ if block['type'] == BlockType.IMAGE_FOOTNOTE:
229
+ para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
230
+ elif para_type == BlockType.TABLE:
231
+ para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
232
+ for block in para_block['blocks']:
233
+ if block['type'] == BlockType.TABLE_BODY:
234
+ for line in block['lines']:
235
+ for span in line['spans']:
236
+ if span['type'] == ContentType.TABLE:
237
+ if span.get('html', ''):
238
+ para_content[BlockType.TABLE_BODY] = f"{span['html']}"
239
+
240
+ if span.get('image_path', ''):
241
+ para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
242
+
243
+ if block['type'] == BlockType.TABLE_CAPTION:
244
+ para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
245
+ if block['type'] == BlockType.TABLE_FOOTNOTE:
246
+ para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
247
+
248
+ para_content['page_idx'] = page_idx
249
+
250
+ return para_content
251
+
252
+
253
+ def union_make(pdf_info_dict: list,
254
+ make_mode: str,
255
+ img_buket_path: str = '',
256
+ ):
257
+ output_content = []
258
+ for page_info in pdf_info_dict:
259
+ paras_of_layout = page_info.get('para_blocks')
260
+ page_idx = page_info.get('page_idx')
261
+ if not paras_of_layout:
262
+ continue
263
+ if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
264
+ page_markdown = make_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
265
+ output_content.extend(page_markdown)
266
+ elif make_mode == MakeMode.CONTENT_LIST:
267
+ for para_block in paras_of_layout:
268
+ para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
269
+ if para_content:
270
+ output_content.append(para_content)
271
+
272
+ if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
273
+ return '\n\n'.join(output_content)
274
+ elif make_mode == MakeMode.CONTENT_LIST:
275
+ return output_content
276
+ else:
277
+ logger.error(f"Unsupported make mode: {make_mode}")
278
+ return None
279
+
280
+
281
+ def get_title_level(block):
282
+ title_level = block.get('level', 1)
283
+ if title_level > 4:
284
+ title_level = 4
285
+ elif title_level < 1:
286
+ title_level = 0
287
+ return title_level
288
+
289
+
290
+ def escape_special_markdown_char(content):
291
+ """
292
+ 转义正文里对markdown语法有特殊意义的字符
293
+ """
294
+ special_chars = ["*", "`", "~", "$"]
295
+ for char in special_chars:
296
+ content = content.replace(char, "\\" + char)
297
+
298
+ return content
vendor/mineru/mineru/backend/vlm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/backend/vlm/base_predictor.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from typing import AsyncIterable, Iterable, List, Optional, Union
4
+
5
+ DEFAULT_SYSTEM_PROMPT = (
6
+ "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
7
+ )
8
+ DEFAULT_USER_PROMPT = "Document Parsing:"
9
+ DEFAULT_TEMPERATURE = 0.0
10
+ DEFAULT_TOP_P = 0.8
11
+ DEFAULT_TOP_K = 20
12
+ DEFAULT_REPETITION_PENALTY = 1.0
13
+ DEFAULT_PRESENCE_PENALTY = 0.0
14
+ DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
15
+ DEFAULT_MAX_NEW_TOKENS = 16384
16
+
17
+
18
+ class BasePredictor(ABC):
19
+ system_prompt = DEFAULT_SYSTEM_PROMPT
20
+
21
+ def __init__(
22
+ self,
23
+ temperature: float = DEFAULT_TEMPERATURE,
24
+ top_p: float = DEFAULT_TOP_P,
25
+ top_k: int = DEFAULT_TOP_K,
26
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
27
+ presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
28
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
29
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
30
+ ) -> None:
31
+ self.temperature = temperature
32
+ self.top_p = top_p
33
+ self.top_k = top_k
34
+ self.repetition_penalty = repetition_penalty
35
+ self.presence_penalty = presence_penalty
36
+ self.no_repeat_ngram_size = no_repeat_ngram_size
37
+ self.max_new_tokens = max_new_tokens
38
+
39
+ @abstractmethod
40
+ def predict(
41
+ self,
42
+ image: str | bytes,
43
+ prompt: str = "",
44
+ temperature: Optional[float] = None,
45
+ top_p: Optional[float] = None,
46
+ top_k: Optional[int] = None,
47
+ repetition_penalty: Optional[float] = None,
48
+ presence_penalty: Optional[float] = None,
49
+ no_repeat_ngram_size: Optional[int] = None,
50
+ max_new_tokens: Optional[int] = None,
51
+ ) -> str: ...
52
+
53
+ @abstractmethod
54
+ def batch_predict(
55
+ self,
56
+ images: List[str] | List[bytes],
57
+ prompts: Union[List[str], str] = "",
58
+ temperature: Optional[float] = None,
59
+ top_p: Optional[float] = None,
60
+ top_k: Optional[int] = None,
61
+ repetition_penalty: Optional[float] = None,
62
+ presence_penalty: Optional[float] = None,
63
+ no_repeat_ngram_size: Optional[int] = None,
64
+ max_new_tokens: Optional[int] = None,
65
+ ) -> List[str]: ...
66
+
67
+ @abstractmethod
68
+ def stream_predict(
69
+ self,
70
+ image: str | bytes,
71
+ prompt: str = "",
72
+ temperature: Optional[float] = None,
73
+ top_p: Optional[float] = None,
74
+ top_k: Optional[int] = None,
75
+ repetition_penalty: Optional[float] = None,
76
+ presence_penalty: Optional[float] = None,
77
+ no_repeat_ngram_size: Optional[int] = None,
78
+ max_new_tokens: Optional[int] = None,
79
+ ) -> Iterable[str]: ...
80
+
81
+ async def aio_predict(
82
+ self,
83
+ image: str | bytes,
84
+ prompt: str = "",
85
+ temperature: Optional[float] = None,
86
+ top_p: Optional[float] = None,
87
+ top_k: Optional[int] = None,
88
+ repetition_penalty: Optional[float] = None,
89
+ presence_penalty: Optional[float] = None,
90
+ no_repeat_ngram_size: Optional[int] = None,
91
+ max_new_tokens: Optional[int] = None,
92
+ ) -> str:
93
+ return await asyncio.to_thread(
94
+ self.predict,
95
+ image,
96
+ prompt,
97
+ temperature,
98
+ top_p,
99
+ top_k,
100
+ repetition_penalty,
101
+ presence_penalty,
102
+ no_repeat_ngram_size,
103
+ max_new_tokens,
104
+ )
105
+
106
+ async def aio_batch_predict(
107
+ self,
108
+ images: List[str] | List[bytes],
109
+ prompts: Union[List[str], str] = "",
110
+ temperature: Optional[float] = None,
111
+ top_p: Optional[float] = None,
112
+ top_k: Optional[int] = None,
113
+ repetition_penalty: Optional[float] = None,
114
+ presence_penalty: Optional[float] = None,
115
+ no_repeat_ngram_size: Optional[int] = None,
116
+ max_new_tokens: Optional[int] = None,
117
+ ) -> List[str]:
118
+ return await asyncio.to_thread(
119
+ self.batch_predict,
120
+ images,
121
+ prompts,
122
+ temperature,
123
+ top_p,
124
+ top_k,
125
+ repetition_penalty,
126
+ presence_penalty,
127
+ no_repeat_ngram_size,
128
+ max_new_tokens,
129
+ )
130
+
131
+ async def aio_stream_predict(
132
+ self,
133
+ image: str | bytes,
134
+ prompt: str = "",
135
+ temperature: Optional[float] = None,
136
+ top_p: Optional[float] = None,
137
+ top_k: Optional[int] = None,
138
+ repetition_penalty: Optional[float] = None,
139
+ presence_penalty: Optional[float] = None,
140
+ no_repeat_ngram_size: Optional[int] = None,
141
+ max_new_tokens: Optional[int] = None,
142
+ ) -> AsyncIterable[str]:
143
+ queue = asyncio.Queue()
144
+ loop = asyncio.get_running_loop()
145
+
146
+ def synced_predict():
147
+ for chunk in self.stream_predict(
148
+ image=image,
149
+ prompt=prompt,
150
+ temperature=temperature,
151
+ top_p=top_p,
152
+ top_k=top_k,
153
+ repetition_penalty=repetition_penalty,
154
+ presence_penalty=presence_penalty,
155
+ no_repeat_ngram_size=no_repeat_ngram_size,
156
+ max_new_tokens=max_new_tokens,
157
+ ):
158
+ asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
159
+ asyncio.run_coroutine_threadsafe(queue.put(None), loop)
160
+
161
+ asyncio.create_task(
162
+ asyncio.to_thread(synced_predict),
163
+ )
164
+
165
+ while True:
166
+ chunk = await queue.get()
167
+ if chunk is None:
168
+ return
169
+ assert isinstance(chunk, str)
170
+ yield chunk
171
+
172
+ def build_prompt(self, prompt: str) -> str:
173
+ if prompt.startswith("<|im_start|>"):
174
+ return prompt
175
+ if not prompt:
176
+ prompt = DEFAULT_USER_PROMPT
177
+
178
+ return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
179
+ # Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
180
+ # if "Document OCR" in prompt:
181
+ # return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
182
+ # else:
183
+ # return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
184
+
185
+ def close(self):
186
+ pass
vendor/mineru/mineru/backend/vlm/hf_predictor.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Iterable, List, Optional, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from transformers import AutoTokenizer, BitsAndBytesConfig
8
+
9
+ from ...model.vlm_hf_model import Mineru2QwenForCausalLM
10
+ from ...model.vlm_hf_model.image_processing_mineru2 import process_images
11
+ from .base_predictor import (
12
+ DEFAULT_MAX_NEW_TOKENS,
13
+ DEFAULT_NO_REPEAT_NGRAM_SIZE,
14
+ DEFAULT_PRESENCE_PENALTY,
15
+ DEFAULT_REPETITION_PENALTY,
16
+ DEFAULT_TEMPERATURE,
17
+ DEFAULT_TOP_K,
18
+ DEFAULT_TOP_P,
19
+ BasePredictor,
20
+ )
21
+ from .utils import load_resource
22
+
23
+
24
+ class HuggingfacePredictor(BasePredictor):
25
+ def __init__(
26
+ self,
27
+ model_path: str,
28
+ device_map="auto",
29
+ device="cuda",
30
+ torch_dtype="auto",
31
+ load_in_8bit=False,
32
+ load_in_4bit=False,
33
+ use_flash_attn=False,
34
+ temperature: float = DEFAULT_TEMPERATURE,
35
+ top_p: float = DEFAULT_TOP_P,
36
+ top_k: int = DEFAULT_TOP_K,
37
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
38
+ presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
39
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
40
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ repetition_penalty=repetition_penalty,
48
+ presence_penalty=presence_penalty,
49
+ no_repeat_ngram_size=no_repeat_ngram_size,
50
+ max_new_tokens=max_new_tokens,
51
+ )
52
+
53
+ kwargs = {"device_map": device_map, **kwargs}
54
+
55
+ if device != "cuda":
56
+ kwargs["device_map"] = {"": device}
57
+
58
+ if load_in_8bit:
59
+ kwargs["load_in_8bit"] = True
60
+ elif load_in_4bit:
61
+ kwargs["load_in_4bit"] = True
62
+ kwargs["quantization_config"] = BitsAndBytesConfig(
63
+ load_in_4bit=True,
64
+ bnb_4bit_compute_dtype=torch.float16,
65
+ bnb_4bit_use_double_quant=True,
66
+ bnb_4bit_quant_type="nf4",
67
+ )
68
+ else:
69
+ kwargs["torch_dtype"] = torch_dtype
70
+
71
+ if use_flash_attn:
72
+ kwargs["attn_implementation"] = "flash_attention_2"
73
+
74
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
75
+ self.model = Mineru2QwenForCausalLM.from_pretrained(
76
+ model_path,
77
+ low_cpu_mem_usage=True,
78
+ **kwargs,
79
+ )
80
+ setattr(self.model.config, "_name_or_path", model_path)
81
+ self.model.eval()
82
+
83
+ vision_tower = self.model.get_model().vision_tower
84
+ if device_map != "auto":
85
+ vision_tower.to(device=device_map, dtype=self.model.dtype)
86
+
87
+ self.image_processor = vision_tower.image_processor
88
+ self.eos_token_id = self.model.config.eos_token_id
89
+
90
+ def predict(
91
+ self,
92
+ image: str | bytes,
93
+ prompt: str = "",
94
+ temperature: Optional[float] = None,
95
+ top_p: Optional[float] = None,
96
+ top_k: Optional[int] = None,
97
+ repetition_penalty: Optional[float] = None,
98
+ presence_penalty: Optional[float] = None,
99
+ no_repeat_ngram_size: Optional[int] = None,
100
+ max_new_tokens: Optional[int] = None,
101
+ **kwargs,
102
+ ) -> str:
103
+ prompt = self.build_prompt(prompt)
104
+
105
+ if temperature is None:
106
+ temperature = self.temperature
107
+ if top_p is None:
108
+ top_p = self.top_p
109
+ if top_k is None:
110
+ top_k = self.top_k
111
+ if repetition_penalty is None:
112
+ repetition_penalty = self.repetition_penalty
113
+ if no_repeat_ngram_size is None:
114
+ no_repeat_ngram_size = self.no_repeat_ngram_size
115
+ if max_new_tokens is None:
116
+ max_new_tokens = self.max_new_tokens
117
+
118
+ do_sample = (temperature > 0.0) and (top_k > 1)
119
+
120
+ generate_kwargs = {
121
+ "repetition_penalty": repetition_penalty,
122
+ "no_repeat_ngram_size": no_repeat_ngram_size,
123
+ "max_new_tokens": max_new_tokens,
124
+ "do_sample": do_sample,
125
+ }
126
+ if do_sample:
127
+ generate_kwargs["temperature"] = temperature
128
+ generate_kwargs["top_p"] = top_p
129
+ generate_kwargs["top_k"] = top_k
130
+
131
+ if isinstance(image, str):
132
+ image = load_resource(image)
133
+
134
+ image_obj = Image.open(BytesIO(image))
135
+ image_tensor = process_images([image_obj], self.image_processor, self.model.config)
136
+ image_tensor = image_tensor[0].unsqueeze(0)
137
+ image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
138
+ image_sizes = [[*image_obj.size]]
139
+
140
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
141
+ input_ids = input_ids.to(device=self.model.device)
142
+
143
+ with torch.inference_mode():
144
+ output_ids = self.model.generate(
145
+ input_ids,
146
+ images=image_tensor,
147
+ image_sizes=image_sizes,
148
+ use_cache=True,
149
+ **generate_kwargs,
150
+ **kwargs,
151
+ )
152
+
153
+ # Remove the last token if it is the eos_token_id
154
+ if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
155
+ output_ids = output_ids[:, :-1]
156
+
157
+ output = self.tokenizer.batch_decode(
158
+ output_ids,
159
+ skip_special_tokens=False,
160
+ )[0].strip()
161
+
162
+ return output
163
+
164
+ def batch_predict(
165
+ self,
166
+ images: List[str] | List[bytes],
167
+ prompts: Union[List[str], str] = "",
168
+ temperature: Optional[float] = None,
169
+ top_p: Optional[float] = None,
170
+ top_k: Optional[int] = None,
171
+ repetition_penalty: Optional[float] = None,
172
+ presence_penalty: Optional[float] = None, # not supported by hf
173
+ no_repeat_ngram_size: Optional[int] = None,
174
+ max_new_tokens: Optional[int] = None,
175
+ **kwargs,
176
+ ) -> List[str]:
177
+ if not isinstance(prompts, list):
178
+ prompts = [prompts] * len(images)
179
+
180
+ assert len(prompts) == len(images), "Length of prompts and images must match."
181
+
182
+ outputs = []
183
+ for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
184
+ output = self.predict(
185
+ image,
186
+ prompt,
187
+ temperature=temperature,
188
+ top_p=top_p,
189
+ top_k=top_k,
190
+ repetition_penalty=repetition_penalty,
191
+ presence_penalty=presence_penalty,
192
+ no_repeat_ngram_size=no_repeat_ngram_size,
193
+ max_new_tokens=max_new_tokens,
194
+ **kwargs,
195
+ )
196
+ outputs.append(output)
197
+ return outputs
198
+
199
+ def stream_predict(
200
+ self,
201
+ image: str | bytes,
202
+ prompt: str = "",
203
+ temperature: Optional[float] = None,
204
+ top_p: Optional[float] = None,
205
+ top_k: Optional[int] = None,
206
+ repetition_penalty: Optional[float] = None,
207
+ presence_penalty: Optional[float] = None,
208
+ no_repeat_ngram_size: Optional[int] = None,
209
+ max_new_tokens: Optional[int] = None,
210
+ ) -> Iterable[str]:
211
+ raise NotImplementedError("Streaming is not supported yet.")
vendor/mineru/mineru/backend/vlm/predictor.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+
3
+ import time
4
+
5
+ from loguru import logger
6
+
7
+ from .base_predictor import (
8
+ DEFAULT_MAX_NEW_TOKENS,
9
+ DEFAULT_NO_REPEAT_NGRAM_SIZE,
10
+ DEFAULT_PRESENCE_PENALTY,
11
+ DEFAULT_REPETITION_PENALTY,
12
+ DEFAULT_TEMPERATURE,
13
+ DEFAULT_TOP_K,
14
+ DEFAULT_TOP_P,
15
+ BasePredictor,
16
+ )
17
+ from .sglang_client_predictor import SglangClientPredictor
18
+
19
+ hf_loaded = False
20
+ try:
21
+ from .hf_predictor import HuggingfacePredictor
22
+
23
+ hf_loaded = True
24
+ except ImportError as e:
25
+ logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
26
+
27
+ engine_loaded = False
28
+ try:
29
+ from sglang.srt.server_args import ServerArgs
30
+
31
+ from .sglang_engine_predictor import SglangEnginePredictor
32
+
33
+ engine_loaded = True
34
+ except Exception as e:
35
+ logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
36
+
37
+
38
+ def get_predictor(
39
+ backend: str = "sglang-client",
40
+ model_path: str | None = None,
41
+ server_url: str | None = None,
42
+ temperature: float = DEFAULT_TEMPERATURE,
43
+ top_p: float = DEFAULT_TOP_P,
44
+ top_k: int = DEFAULT_TOP_K,
45
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
46
+ presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
47
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
48
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
49
+ http_timeout: int = 600,
50
+ **kwargs,
51
+ ) -> BasePredictor:
52
+ start_time = time.time()
53
+
54
+ if backend == "transformers":
55
+ if not model_path:
56
+ raise ValueError("model_path must be provided for transformers backend.")
57
+ if not hf_loaded:
58
+ raise ImportError(
59
+ "transformers is not installed, so huggingface backend cannot be used. "
60
+ "If you need to use huggingface backend, please install transformers first."
61
+ )
62
+ predictor = HuggingfacePredictor(
63
+ model_path=model_path,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ repetition_penalty=repetition_penalty,
68
+ presence_penalty=presence_penalty,
69
+ no_repeat_ngram_size=no_repeat_ngram_size,
70
+ max_new_tokens=max_new_tokens,
71
+ **kwargs,
72
+ )
73
+ elif backend == "sglang-engine":
74
+ if not model_path:
75
+ raise ValueError("model_path must be provided for sglang-engine backend.")
76
+ if not engine_loaded:
77
+ raise ImportError(
78
+ "sglang is not installed, so sglang-engine backend cannot be used. "
79
+ "If you need to use sglang-engine backend for inference, "
80
+ "please install sglang[all]==0.4.8 or a newer version."
81
+ )
82
+ predictor = SglangEnginePredictor(
83
+ server_args=ServerArgs(model_path, **kwargs),
84
+ temperature=temperature,
85
+ top_p=top_p,
86
+ top_k=top_k,
87
+ repetition_penalty=repetition_penalty,
88
+ presence_penalty=presence_penalty,
89
+ no_repeat_ngram_size=no_repeat_ngram_size,
90
+ max_new_tokens=max_new_tokens,
91
+ )
92
+ elif backend == "sglang-client":
93
+ if not server_url:
94
+ raise ValueError("server_url must be provided for sglang-client backend.")
95
+ predictor = SglangClientPredictor(
96
+ server_url=server_url,
97
+ temperature=temperature,
98
+ top_p=top_p,
99
+ top_k=top_k,
100
+ repetition_penalty=repetition_penalty,
101
+ presence_penalty=presence_penalty,
102
+ no_repeat_ngram_size=no_repeat_ngram_size,
103
+ max_new_tokens=max_new_tokens,
104
+ http_timeout=http_timeout,
105
+ )
106
+ else:
107
+ raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
108
+
109
+ elapsed = round(time.time() - start_time, 2)
110
+ logger.info(f"get_predictor cost: {elapsed}s")
111
+ return predictor
vendor/mineru/mineru/backend/vlm/sglang_client_predictor.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import re
4
+ from base64 import b64encode
5
+ from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
6
+
7
+ import httpx
8
+
9
+ from .base_predictor import (
10
+ DEFAULT_MAX_NEW_TOKENS,
11
+ DEFAULT_NO_REPEAT_NGRAM_SIZE,
12
+ DEFAULT_PRESENCE_PENALTY,
13
+ DEFAULT_REPETITION_PENALTY,
14
+ DEFAULT_TEMPERATURE,
15
+ DEFAULT_TOP_K,
16
+ DEFAULT_TOP_P,
17
+ BasePredictor,
18
+ )
19
+ from .utils import aio_load_resource, load_resource
20
+
21
+
22
+ class SglangClientPredictor(BasePredictor):
23
+ def __init__(
24
+ self,
25
+ server_url: str,
26
+ temperature: float = DEFAULT_TEMPERATURE,
27
+ top_p: float = DEFAULT_TOP_P,
28
+ top_k: int = DEFAULT_TOP_K,
29
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
30
+ presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
31
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
32
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
33
+ http_timeout: int = 600,
34
+ ) -> None:
35
+ super().__init__(
36
+ temperature=temperature,
37
+ top_p=top_p,
38
+ top_k=top_k,
39
+ repetition_penalty=repetition_penalty,
40
+ presence_penalty=presence_penalty,
41
+ no_repeat_ngram_size=no_repeat_ngram_size,
42
+ max_new_tokens=max_new_tokens,
43
+ )
44
+ self.http_timeout = http_timeout
45
+
46
+ base_url = self.get_base_url(server_url)
47
+ self.check_server_health(base_url)
48
+ self.model_path = self.get_model_path(base_url)
49
+ self.server_url = f"{base_url}/generate"
50
+
51
+ @staticmethod
52
+ def get_base_url(server_url: str) -> str:
53
+ matched = re.match(r"^(https?://[^/]+)", server_url)
54
+ if not matched:
55
+ raise ValueError(f"Invalid server URL: {server_url}")
56
+ return matched.group(1)
57
+
58
+ def check_server_health(self, base_url: str):
59
+ try:
60
+ response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
61
+ except httpx.ConnectError:
62
+ raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
63
+ if response.status_code != 200:
64
+ raise RuntimeError(
65
+ f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
66
+ )
67
+
68
+ def get_model_path(self, base_url: str) -> str:
69
+ try:
70
+ response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
71
+ except httpx.ConnectError:
72
+ raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
73
+ if response.status_code != 200:
74
+ raise RuntimeError(
75
+ f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
76
+ )
77
+ return response.json()["model_path"]
78
+
79
+ def build_sampling_params(
80
+ self,
81
+ temperature: Optional[float],
82
+ top_p: Optional[float],
83
+ top_k: Optional[int],
84
+ repetition_penalty: Optional[float],
85
+ presence_penalty: Optional[float],
86
+ no_repeat_ngram_size: Optional[int],
87
+ max_new_tokens: Optional[int],
88
+ ) -> dict:
89
+ if temperature is None:
90
+ temperature = self.temperature
91
+ if top_p is None:
92
+ top_p = self.top_p
93
+ if top_k is None:
94
+ top_k = self.top_k
95
+ if repetition_penalty is None:
96
+ repetition_penalty = self.repetition_penalty
97
+ if presence_penalty is None:
98
+ presence_penalty = self.presence_penalty
99
+ if no_repeat_ngram_size is None:
100
+ no_repeat_ngram_size = self.no_repeat_ngram_size
101
+ if max_new_tokens is None:
102
+ max_new_tokens = self.max_new_tokens
103
+
104
+ # see SamplingParams for more details
105
+ return {
106
+ "temperature": temperature,
107
+ "top_p": top_p,
108
+ "top_k": top_k,
109
+ "repetition_penalty": repetition_penalty,
110
+ "presence_penalty": presence_penalty,
111
+ "custom_params": {
112
+ "no_repeat_ngram_size": no_repeat_ngram_size,
113
+ },
114
+ "max_new_tokens": max_new_tokens,
115
+ "skip_special_tokens": False,
116
+ }
117
+
118
+ def build_request_body(
119
+ self,
120
+ image: bytes,
121
+ prompt: str,
122
+ sampling_params: dict,
123
+ ) -> dict:
124
+ image_base64 = b64encode(image).decode("utf-8")
125
+ return {
126
+ "text": prompt,
127
+ "image_data": image_base64,
128
+ "sampling_params": sampling_params,
129
+ "modalities": ["image"],
130
+ }
131
+
132
+ def predict(
133
+ self,
134
+ image: str | bytes,
135
+ prompt: str = "",
136
+ temperature: Optional[float] = None,
137
+ top_p: Optional[float] = None,
138
+ top_k: Optional[int] = None,
139
+ repetition_penalty: Optional[float] = None,
140
+ presence_penalty: Optional[float] = None,
141
+ no_repeat_ngram_size: Optional[int] = None,
142
+ max_new_tokens: Optional[int] = None,
143
+ ) -> str:
144
+ prompt = self.build_prompt(prompt)
145
+
146
+ sampling_params = self.build_sampling_params(
147
+ temperature=temperature,
148
+ top_p=top_p,
149
+ top_k=top_k,
150
+ repetition_penalty=repetition_penalty,
151
+ presence_penalty=presence_penalty,
152
+ no_repeat_ngram_size=no_repeat_ngram_size,
153
+ max_new_tokens=max_new_tokens,
154
+ )
155
+
156
+ if isinstance(image, str):
157
+ image = load_resource(image)
158
+
159
+ request_body = self.build_request_body(image, prompt, sampling_params)
160
+ response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
161
+ response_body = response.json()
162
+ return response_body["text"]
163
+
164
+ def batch_predict(
165
+ self,
166
+ images: List[str] | List[bytes],
167
+ prompts: Union[List[str], str] = "",
168
+ temperature: Optional[float] = None,
169
+ top_p: Optional[float] = None,
170
+ top_k: Optional[int] = None,
171
+ repetition_penalty: Optional[float] = None,
172
+ presence_penalty: Optional[float] = None,
173
+ no_repeat_ngram_size: Optional[int] = None,
174
+ max_new_tokens: Optional[int] = None,
175
+ max_concurrency: int = 100,
176
+ ) -> List[str]:
177
+ try:
178
+ loop = asyncio.get_running_loop()
179
+ except RuntimeError:
180
+ loop = None
181
+
182
+ task = self.aio_batch_predict(
183
+ images=images,
184
+ prompts=prompts,
185
+ temperature=temperature,
186
+ top_p=top_p,
187
+ top_k=top_k,
188
+ repetition_penalty=repetition_penalty,
189
+ presence_penalty=presence_penalty,
190
+ no_repeat_ngram_size=no_repeat_ngram_size,
191
+ max_new_tokens=max_new_tokens,
192
+ max_concurrency=max_concurrency,
193
+ )
194
+
195
+ if loop is not None:
196
+ return loop.run_until_complete(task)
197
+ else:
198
+ return asyncio.run(task)
199
+
200
+ def stream_predict(
201
+ self,
202
+ image: str | bytes,
203
+ prompt: str = "",
204
+ temperature: Optional[float] = None,
205
+ top_p: Optional[float] = None,
206
+ top_k: Optional[int] = None,
207
+ repetition_penalty: Optional[float] = None,
208
+ presence_penalty: Optional[float] = None,
209
+ no_repeat_ngram_size: Optional[int] = None,
210
+ max_new_tokens: Optional[int] = None,
211
+ ) -> Iterable[str]:
212
+ prompt = self.build_prompt(prompt)
213
+
214
+ sampling_params = self.build_sampling_params(
215
+ temperature=temperature,
216
+ top_p=top_p,
217
+ top_k=top_k,
218
+ repetition_penalty=repetition_penalty,
219
+ presence_penalty=presence_penalty,
220
+ no_repeat_ngram_size=no_repeat_ngram_size,
221
+ max_new_tokens=max_new_tokens,
222
+ )
223
+
224
+ if isinstance(image, str):
225
+ image = load_resource(image)
226
+
227
+ request_body = self.build_request_body(image, prompt, sampling_params)
228
+ request_body["stream"] = True
229
+
230
+ with httpx.stream(
231
+ "POST",
232
+ self.server_url,
233
+ json=request_body,
234
+ timeout=self.http_timeout,
235
+ ) as response:
236
+ pos = 0
237
+ for chunk in response.iter_lines():
238
+ if not (chunk or "").startswith("data:"):
239
+ continue
240
+ if chunk == "data: [DONE]":
241
+ break
242
+ data = json.loads(chunk[5:].strip("\n"))
243
+ chunk_text = data["text"][pos:]
244
+ # meta_info = data["meta_info"]
245
+ pos += len(chunk_text)
246
+ yield chunk_text
247
+
248
+ async def aio_predict(
249
+ self,
250
+ image: str | bytes,
251
+ prompt: str = "",
252
+ temperature: Optional[float] = None,
253
+ top_p: Optional[float] = None,
254
+ top_k: Optional[int] = None,
255
+ repetition_penalty: Optional[float] = None,
256
+ presence_penalty: Optional[float] = None,
257
+ no_repeat_ngram_size: Optional[int] = None,
258
+ max_new_tokens: Optional[int] = None,
259
+ async_client: Optional[httpx.AsyncClient] = None,
260
+ ) -> str:
261
+ prompt = self.build_prompt(prompt)
262
+
263
+ sampling_params = self.build_sampling_params(
264
+ temperature=temperature,
265
+ top_p=top_p,
266
+ top_k=top_k,
267
+ repetition_penalty=repetition_penalty,
268
+ presence_penalty=presence_penalty,
269
+ no_repeat_ngram_size=no_repeat_ngram_size,
270
+ max_new_tokens=max_new_tokens,
271
+ )
272
+
273
+ if isinstance(image, str):
274
+ image = await aio_load_resource(image)
275
+
276
+ request_body = self.build_request_body(image, prompt, sampling_params)
277
+
278
+ if async_client is None:
279
+ async with httpx.AsyncClient(timeout=self.http_timeout) as client:
280
+ response = await client.post(self.server_url, json=request_body)
281
+ response_body = response.json()
282
+ else:
283
+ response = await async_client.post(self.server_url, json=request_body)
284
+ response_body = response.json()
285
+
286
+ return response_body["text"]
287
+
288
+ async def aio_batch_predict(
289
+ self,
290
+ images: List[str] | List[bytes],
291
+ prompts: Union[List[str], str] = "",
292
+ temperature: Optional[float] = None,
293
+ top_p: Optional[float] = None,
294
+ top_k: Optional[int] = None,
295
+ repetition_penalty: Optional[float] = None,
296
+ presence_penalty: Optional[float] = None,
297
+ no_repeat_ngram_size: Optional[int] = None,
298
+ max_new_tokens: Optional[int] = None,
299
+ max_concurrency: int = 100,
300
+ ) -> List[str]:
301
+ if not isinstance(prompts, list):
302
+ prompts = [prompts] * len(images)
303
+
304
+ assert len(prompts) == len(images), "Length of prompts and images must match."
305
+
306
+ semaphore = asyncio.Semaphore(max_concurrency)
307
+ outputs = [""] * len(images)
308
+
309
+ async def predict_with_semaphore(
310
+ idx: int,
311
+ image: str | bytes,
312
+ prompt: str,
313
+ async_client: httpx.AsyncClient,
314
+ ):
315
+ async with semaphore:
316
+ output = await self.aio_predict(
317
+ image=image,
318
+ prompt=prompt,
319
+ temperature=temperature,
320
+ top_p=top_p,
321
+ top_k=top_k,
322
+ repetition_penalty=repetition_penalty,
323
+ presence_penalty=presence_penalty,
324
+ no_repeat_ngram_size=no_repeat_ngram_size,
325
+ max_new_tokens=max_new_tokens,
326
+ async_client=async_client,
327
+ )
328
+ outputs[idx] = output
329
+
330
+ async with httpx.AsyncClient(timeout=self.http_timeout) as client:
331
+ tasks = []
332
+ for idx, (prompt, image) in enumerate(zip(prompts, images)):
333
+ tasks.append(predict_with_semaphore(idx, image, prompt, client))
334
+ await asyncio.gather(*tasks)
335
+
336
+ return outputs
337
+
338
+ async def aio_batch_predict_as_iter(
339
+ self,
340
+ images: List[str] | List[bytes],
341
+ prompts: Union[List[str], str] = "",
342
+ temperature: Optional[float] = None,
343
+ top_p: Optional[float] = None,
344
+ top_k: Optional[int] = None,
345
+ repetition_penalty: Optional[float] = None,
346
+ presence_penalty: Optional[float] = None,
347
+ no_repeat_ngram_size: Optional[int] = None,
348
+ max_new_tokens: Optional[int] = None,
349
+ max_concurrency: int = 100,
350
+ ) -> AsyncIterable[Tuple[int, str]]:
351
+ if not isinstance(prompts, list):
352
+ prompts = [prompts] * len(images)
353
+
354
+ assert len(prompts) == len(images), "Length of prompts and images must match."
355
+
356
+ semaphore = asyncio.Semaphore(max_concurrency)
357
+
358
+ async def predict_with_semaphore(
359
+ idx: int,
360
+ image: str | bytes,
361
+ prompt: str,
362
+ async_client: httpx.AsyncClient,
363
+ ):
364
+ async with semaphore:
365
+ output = await self.aio_predict(
366
+ image=image,
367
+ prompt=prompt,
368
+ temperature=temperature,
369
+ top_p=top_p,
370
+ top_k=top_k,
371
+ repetition_penalty=repetition_penalty,
372
+ presence_penalty=presence_penalty,
373
+ no_repeat_ngram_size=no_repeat_ngram_size,
374
+ max_new_tokens=max_new_tokens,
375
+ async_client=async_client,
376
+ )
377
+ return (idx, output)
378
+
379
+ async with httpx.AsyncClient(timeout=self.http_timeout) as client:
380
+ pending: Set[asyncio.Task[Tuple[int, str]]] = set()
381
+
382
+ for idx, (prompt, image) in enumerate(zip(prompts, images)):
383
+ pending.add(
384
+ asyncio.create_task(
385
+ predict_with_semaphore(idx, image, prompt, client),
386
+ )
387
+ )
388
+
389
+ while len(pending) > 0:
390
+ done, pending = await asyncio.wait(
391
+ pending,
392
+ return_when=asyncio.FIRST_COMPLETED,
393
+ )
394
+ for task in done:
395
+ yield task.result()
396
+
397
+ async def aio_stream_predict(
398
+ self,
399
+ image: str | bytes,
400
+ prompt: str = "",
401
+ temperature: Optional[float] = None,
402
+ top_p: Optional[float] = None,
403
+ top_k: Optional[int] = None,
404
+ repetition_penalty: Optional[float] = None,
405
+ presence_penalty: Optional[float] = None,
406
+ no_repeat_ngram_size: Optional[int] = None,
407
+ max_new_tokens: Optional[int] = None,
408
+ ) -> AsyncIterable[str]:
409
+ prompt = self.build_prompt(prompt)
410
+
411
+ sampling_params = self.build_sampling_params(
412
+ temperature=temperature,
413
+ top_p=top_p,
414
+ top_k=top_k,
415
+ repetition_penalty=repetition_penalty,
416
+ presence_penalty=presence_penalty,
417
+ no_repeat_ngram_size=no_repeat_ngram_size,
418
+ max_new_tokens=max_new_tokens,
419
+ )
420
+
421
+ if isinstance(image, str):
422
+ image = await aio_load_resource(image)
423
+
424
+ request_body = self.build_request_body(image, prompt, sampling_params)
425
+ request_body["stream"] = True
426
+
427
+ async with httpx.AsyncClient(timeout=self.http_timeout) as client:
428
+ async with client.stream(
429
+ "POST",
430
+ self.server_url,
431
+ json=request_body,
432
+ ) as response:
433
+ pos = 0
434
+ async for chunk in response.aiter_lines():
435
+ if not (chunk or "").startswith("data:"):
436
+ continue
437
+ if chunk == "data: [DONE]":
438
+ break
439
+ data = json.loads(chunk[5:].strip("\n"))
440
+ chunk_text = data["text"][pos:]
441
+ # meta_info = data["meta_info"]
442
+ pos += len(chunk_text)
443
+ yield chunk_text
vendor/mineru/mineru/backend/vlm/sglang_engine_predictor.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+ from typing import AsyncIterable, Iterable, List, Optional, Union
3
+
4
+ from sglang.srt.server_args import ServerArgs
5
+
6
+ from ...model.vlm_sglang_model.engine import BatchEngine
7
+ from .base_predictor import (
8
+ DEFAULT_MAX_NEW_TOKENS,
9
+ DEFAULT_NO_REPEAT_NGRAM_SIZE,
10
+ DEFAULT_PRESENCE_PENALTY,
11
+ DEFAULT_REPETITION_PENALTY,
12
+ DEFAULT_TEMPERATURE,
13
+ DEFAULT_TOP_K,
14
+ DEFAULT_TOP_P,
15
+ BasePredictor,
16
+ )
17
+
18
+
19
+ class SglangEnginePredictor(BasePredictor):
20
+ def __init__(
21
+ self,
22
+ server_args: ServerArgs,
23
+ temperature: float = DEFAULT_TEMPERATURE,
24
+ top_p: float = DEFAULT_TOP_P,
25
+ top_k: int = DEFAULT_TOP_K,
26
+ repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
27
+ presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
28
+ no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
29
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
30
+ ) -> None:
31
+ super().__init__(
32
+ temperature=temperature,
33
+ top_p=top_p,
34
+ top_k=top_k,
35
+ repetition_penalty=repetition_penalty,
36
+ presence_penalty=presence_penalty,
37
+ no_repeat_ngram_size=no_repeat_ngram_size,
38
+ max_new_tokens=max_new_tokens,
39
+ )
40
+ self.engine = BatchEngine(server_args=server_args)
41
+
42
+ def load_image_string(self, image: str | bytes) -> str:
43
+ if not isinstance(image, (str, bytes)):
44
+ raise ValueError("Image must be a string or bytes.")
45
+ if isinstance(image, bytes):
46
+ return b64encode(image).decode("utf-8")
47
+ if image.startswith("file://"):
48
+ return image[len("file://") :]
49
+ return image
50
+
51
+ def predict(
52
+ self,
53
+ image: str | bytes,
54
+ prompt: str = "",
55
+ temperature: Optional[float] = None,
56
+ top_p: Optional[float] = None,
57
+ top_k: Optional[int] = None,
58
+ repetition_penalty: Optional[float] = None,
59
+ presence_penalty: Optional[float] = None,
60
+ no_repeat_ngram_size: Optional[int] = None,
61
+ max_new_tokens: Optional[int] = None,
62
+ ) -> str:
63
+ return self.batch_predict(
64
+ [image], # type: ignore
65
+ [prompt],
66
+ temperature=temperature,
67
+ top_p=top_p,
68
+ top_k=top_k,
69
+ repetition_penalty=repetition_penalty,
70
+ presence_penalty=presence_penalty,
71
+ no_repeat_ngram_size=no_repeat_ngram_size,
72
+ max_new_tokens=max_new_tokens,
73
+ )[0]
74
+
75
+ def batch_predict(
76
+ self,
77
+ images: List[str] | List[bytes],
78
+ prompts: Union[List[str], str] = "",
79
+ temperature: Optional[float] = None,
80
+ top_p: Optional[float] = None,
81
+ top_k: Optional[int] = None,
82
+ repetition_penalty: Optional[float] = None,
83
+ presence_penalty: Optional[float] = None,
84
+ no_repeat_ngram_size: Optional[int] = None,
85
+ max_new_tokens: Optional[int] = None,
86
+ ) -> List[str]:
87
+
88
+ if not isinstance(prompts, list):
89
+ prompts = [prompts] * len(images)
90
+
91
+ assert len(prompts) == len(images), "Length of prompts and images must match."
92
+ prompts = [self.build_prompt(prompt) for prompt in prompts]
93
+
94
+ if temperature is None:
95
+ temperature = self.temperature
96
+ if top_p is None:
97
+ top_p = self.top_p
98
+ if top_k is None:
99
+ top_k = self.top_k
100
+ if repetition_penalty is None:
101
+ repetition_penalty = self.repetition_penalty
102
+ if presence_penalty is None:
103
+ presence_penalty = self.presence_penalty
104
+ if no_repeat_ngram_size is None:
105
+ no_repeat_ngram_size = self.no_repeat_ngram_size
106
+ if max_new_tokens is None:
107
+ max_new_tokens = self.max_new_tokens
108
+
109
+ # see SamplingParams for more details
110
+ sampling_params = {
111
+ "temperature": temperature,
112
+ "top_p": top_p,
113
+ "top_k": top_k,
114
+ "repetition_penalty": repetition_penalty,
115
+ "presence_penalty": presence_penalty,
116
+ "custom_params": {
117
+ "no_repeat_ngram_size": no_repeat_ngram_size,
118
+ },
119
+ "max_new_tokens": max_new_tokens,
120
+ "skip_special_tokens": False,
121
+ }
122
+
123
+ image_strings = [self.load_image_string(img) for img in images]
124
+
125
+ output = self.engine.generate(
126
+ prompt=prompts,
127
+ image_data=image_strings,
128
+ sampling_params=sampling_params,
129
+ )
130
+ return [item["text"] for item in output]
131
+
132
+ def stream_predict(
133
+ self,
134
+ image: str | bytes,
135
+ prompt: str = "",
136
+ temperature: Optional[float] = None,
137
+ top_p: Optional[float] = None,
138
+ top_k: Optional[int] = None,
139
+ repetition_penalty: Optional[float] = None,
140
+ presence_penalty: Optional[float] = None,
141
+ no_repeat_ngram_size: Optional[int] = None,
142
+ max_new_tokens: Optional[int] = None,
143
+ ) -> Iterable[str]:
144
+ raise NotImplementedError("Streaming is not supported yet.")
145
+
146
+ async def aio_predict(
147
+ self,
148
+ image: str | bytes,
149
+ prompt: str = "",
150
+ temperature: Optional[float] = None,
151
+ top_p: Optional[float] = None,
152
+ top_k: Optional[int] = None,
153
+ repetition_penalty: Optional[float] = None,
154
+ presence_penalty: Optional[float] = None,
155
+ no_repeat_ngram_size: Optional[int] = None,
156
+ max_new_tokens: Optional[int] = None,
157
+ ) -> str:
158
+ output = await self.aio_batch_predict(
159
+ [image], # type: ignore
160
+ [prompt],
161
+ temperature=temperature,
162
+ top_p=top_p,
163
+ top_k=top_k,
164
+ repetition_penalty=repetition_penalty,
165
+ presence_penalty=presence_penalty,
166
+ no_repeat_ngram_size=no_repeat_ngram_size,
167
+ max_new_tokens=max_new_tokens,
168
+ )
169
+ return output[0]
170
+
171
+ async def aio_batch_predict(
172
+ self,
173
+ images: List[str] | List[bytes],
174
+ prompts: Union[List[str], str] = "",
175
+ temperature: Optional[float] = None,
176
+ top_p: Optional[float] = None,
177
+ top_k: Optional[int] = None,
178
+ repetition_penalty: Optional[float] = None,
179
+ presence_penalty: Optional[float] = None,
180
+ no_repeat_ngram_size: Optional[int] = None,
181
+ max_new_tokens: Optional[int] = None,
182
+ ) -> List[str]:
183
+
184
+ if not isinstance(prompts, list):
185
+ prompts = [prompts] * len(images)
186
+
187
+ assert len(prompts) == len(images), "Length of prompts and images must match."
188
+ prompts = [self.build_prompt(prompt) for prompt in prompts]
189
+
190
+ if temperature is None:
191
+ temperature = self.temperature
192
+ if top_p is None:
193
+ top_p = self.top_p
194
+ if top_k is None:
195
+ top_k = self.top_k
196
+ if repetition_penalty is None:
197
+ repetition_penalty = self.repetition_penalty
198
+ if presence_penalty is None:
199
+ presence_penalty = self.presence_penalty
200
+ if no_repeat_ngram_size is None:
201
+ no_repeat_ngram_size = self.no_repeat_ngram_size
202
+ if max_new_tokens is None:
203
+ max_new_tokens = self.max_new_tokens
204
+
205
+ # see SamplingParams for more details
206
+ sampling_params = {
207
+ "temperature": temperature,
208
+ "top_p": top_p,
209
+ "top_k": top_k,
210
+ "repetition_penalty": repetition_penalty,
211
+ "presence_penalty": presence_penalty,
212
+ "custom_params": {
213
+ "no_repeat_ngram_size": no_repeat_ngram_size,
214
+ },
215
+ "max_new_tokens": max_new_tokens,
216
+ "skip_special_tokens": False,
217
+ }
218
+
219
+ image_strings = [self.load_image_string(img) for img in images]
220
+
221
+ output = await self.engine.async_generate(
222
+ prompt=prompts,
223
+ image_data=image_strings,
224
+ sampling_params=sampling_params,
225
+ )
226
+ ret = []
227
+ for item in output: # type: ignore
228
+ ret.append(item["text"])
229
+ return ret
230
+
231
+ async def aio_stream_predict(
232
+ self,
233
+ image: str | bytes,
234
+ prompt: str = "",
235
+ temperature: Optional[float] = None,
236
+ top_p: Optional[float] = None,
237
+ top_k: Optional[int] = None,
238
+ repetition_penalty: Optional[float] = None,
239
+ presence_penalty: Optional[float] = None,
240
+ no_repeat_ngram_size: Optional[int] = None,
241
+ max_new_tokens: Optional[int] = None,
242
+ ) -> AsyncIterable[str]:
243
+ raise NotImplementedError("Streaming is not supported yet.")
244
+
245
+ def close(self):
246
+ self.engine.shutdown()
vendor/mineru/mineru/backend/vlm/token_to_middle_json.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import cv2
3
+ import numpy as np
4
+ from loguru import logger
5
+
6
+ from mineru.backend.pipeline.model_init import AtomModelSingleton
7
+ from mineru.utils.config_reader import get_llm_aided_config
8
+ from mineru.utils.cut_image import cut_image_and_table
9
+ from mineru.utils.enum_class import ContentType
10
+ from mineru.utils.hash_utils import str_md5
11
+ from mineru.backend.vlm.vlm_magic_model import MagicModel
12
+ from mineru.utils.llm_aided import llm_aided_title
13
+ from mineru.utils.pdf_image_tools import get_crop_img
14
+ from mineru.version import __version__
15
+
16
+
17
+ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
18
+ """将token转换为页面信息"""
19
+ # 解析token,提取坐标和类型
20
+ # 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
21
+ # 这里需要根据实际的token格式进行解析
22
+ # 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
23
+
24
+ scale = image_dict["scale"]
25
+ page_pil_img = image_dict["img_pil"]
26
+ page_img_md5 = str_md5(image_dict["img_base64"])
27
+ width, height = map(int, page.get_size())
28
+
29
+ magic_model = MagicModel(token, width, height)
30
+ image_blocks = magic_model.get_image_blocks()
31
+ table_blocks = magic_model.get_table_blocks()
32
+ title_blocks = magic_model.get_title_blocks()
33
+
34
+ # 如果有标题优化需求,则对title_blocks截图det
35
+ llm_aided_config = get_llm_aided_config()
36
+ if llm_aided_config is not None:
37
+ title_aided_config = llm_aided_config.get('title_aided', None)
38
+ if title_aided_config is not None:
39
+ if title_aided_config.get('enable', False):
40
+ atom_model_manager = AtomModelSingleton()
41
+ ocr_model = atom_model_manager.get_atom_model(
42
+ atom_model_name='ocr',
43
+ ocr_show_log=False,
44
+ det_db_box_thresh=0.3,
45
+ lang='ch_lite'
46
+ )
47
+ for title_block in title_blocks:
48
+ title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
49
+ title_np_img = np.array(title_pil_img)
50
+ # 给title_pil_img添加上下左右各50像素白边padding
51
+ title_np_img = cv2.copyMakeBorder(
52
+ title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
53
+ )
54
+ title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
55
+ ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
56
+ if len(ocr_det_res) > 0:
57
+ # 计算所有res的平均高度
58
+ avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
59
+ title_block['line_avg_height'] = round(avg_height/scale)
60
+
61
+ text_blocks = magic_model.get_text_blocks()
62
+ interline_equation_blocks = magic_model.get_interline_equation_blocks()
63
+
64
+ all_spans = magic_model.get_all_spans()
65
+ # 对image/table/interline_equation的span截图
66
+ for span in all_spans:
67
+ if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
68
+ span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
69
+
70
+ page_blocks = []
71
+ page_blocks.extend([*image_blocks, *table_blocks, *title_blocks, *text_blocks, *interline_equation_blocks])
72
+ # 对page_blocks根据index的值进行排序
73
+ page_blocks.sort(key=lambda x: x["index"])
74
+
75
+ page_info = {"para_blocks": page_blocks, "discarded_blocks": [], "page_size": [width, height], "page_idx": page_index}
76
+ return page_info
77
+
78
+
79
+ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
80
+ middle_json = {"pdf_info": [], "_backend":"vlm", "_version_name": __version__}
81
+ for index, token in enumerate(token_list):
82
+ page = pdf_doc[index]
83
+ image_dict = images_list[index]
84
+ page_info = token_to_page_info(token, image_dict, page, image_writer, index)
85
+ middle_json["pdf_info"].append(page_info)
86
+
87
+ """llm优化"""
88
+ llm_aided_config = get_llm_aided_config()
89
+
90
+ if llm_aided_config is not None:
91
+ """标题优化"""
92
+ title_aided_config = llm_aided_config.get('title_aided', None)
93
+ if title_aided_config is not None:
94
+ if title_aided_config.get('enable', False):
95
+ llm_aided_title_start_time = time.time()
96
+ llm_aided_title(middle_json["pdf_info"], title_aided_config)
97
+ logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
98
+
99
+ # 关闭pdf文档
100
+ pdf_doc.close()
101
+ return middle_json
102
+
103
+
104
+ if __name__ == "__main__":
105
+
106
+ output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|>![]('img_url')<|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
107
+
108
+ p_info = token_to_page_info(output)
109
+ # 将blocks 转换为json文本
110
+ import json
111
+
112
+ json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
113
+ print(json_str)
vendor/mineru/mineru/backend/vlm/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from base64 import b64decode
4
+
5
+ import httpx
6
+
7
+ _timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
8
+ _file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
9
+ _data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
10
+
11
+
12
+ def load_resource(uri: str) -> bytes:
13
+ if uri.startswith("http://") or uri.startswith("https://"):
14
+ response = httpx.get(uri, timeout=_timeout)
15
+ return response.content
16
+ if uri.startswith("file://"):
17
+ with open(uri[len("file://") :], "rb") as file:
18
+ return file.read()
19
+ if uri.lower().endswith(_file_exts):
20
+ with open(uri, "rb") as file:
21
+ return file.read()
22
+ if re.match(_data_uri_regex, uri):
23
+ return b64decode(uri.split(",")[1])
24
+ return b64decode(uri)
25
+
26
+
27
+ async def aio_load_resource(uri: str) -> bytes:
28
+ if uri.startswith("http://") or uri.startswith("https://"):
29
+ async with httpx.AsyncClient(timeout=_timeout) as client:
30
+ response = await client.get(uri)
31
+ return response.content
32
+ if uri.startswith("file://"):
33
+ with open(uri[len("file://") :], "rb") as file:
34
+ return file.read()
35
+ if uri.lower().endswith(_file_exts):
36
+ with open(uri, "rb") as file:
37
+ return file.read()
38
+ if re.match(_data_uri_regex, uri):
39
+ return b64decode(uri.split(",")[1])
40
+ return b64decode(uri)
vendor/mineru/mineru/backend/vlm/vlm_analyze.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import time
3
+
4
+ from loguru import logger
5
+
6
+ from ...data.data_reader_writer import DataWriter
7
+ from mineru.utils.pdf_image_tools import load_images_from_pdf
8
+ from .base_predictor import BasePredictor
9
+ from .predictor import get_predictor
10
+ from .token_to_middle_json import result_to_middle_json
11
+ from ...utils.models_download_utils import auto_download_and_get_model_root_path
12
+
13
+
14
+ class ModelSingleton:
15
+ _instance = None
16
+ _models = {}
17
+
18
+ def __new__(cls, *args, **kwargs):
19
+ if cls._instance is None:
20
+ cls._instance = super().__new__(cls)
21
+ return cls._instance
22
+
23
+ def get_model(
24
+ self,
25
+ backend: str,
26
+ model_path: str | None,
27
+ server_url: str | None,
28
+ **kwargs,
29
+ ) -> BasePredictor:
30
+ key = (backend, model_path, server_url)
31
+ if key not in self._models:
32
+ if backend in ['transformers', 'sglang-engine'] and not model_path:
33
+ model_path = auto_download_and_get_model_root_path("/","vlm")
34
+ self._models[key] = get_predictor(
35
+ backend=backend,
36
+ model_path=model_path,
37
+ server_url=server_url,
38
+ **kwargs,
39
+ )
40
+ return self._models[key]
41
+
42
+
43
+ def doc_analyze(
44
+ pdf_bytes,
45
+ image_writer: DataWriter | None,
46
+ predictor: BasePredictor | None = None,
47
+ backend="transformers",
48
+ model_path: str | None = None,
49
+ server_url: str | None = None,
50
+ **kwargs,
51
+ ):
52
+ if predictor is None:
53
+ predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
54
+
55
+ # load_images_start = time.time()
56
+ images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
57
+ images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
58
+ # load_images_time = round(time.time() - load_images_start, 2)
59
+ # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
60
+
61
+ # infer_start = time.time()
62
+ results = predictor.batch_predict(images=images_base64_list)
63
+ # infer_time = round(time.time() - infer_start, 2)
64
+ # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
65
+
66
+ middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
67
+ return middle_json, results
68
+
69
+
70
+ async def aio_doc_analyze(
71
+ pdf_bytes,
72
+ image_writer: DataWriter | None,
73
+ predictor: BasePredictor | None = None,
74
+ backend="transformers",
75
+ model_path: str | None = None,
76
+ server_url: str | None = None,
77
+ **kwargs,
78
+ ):
79
+ if predictor is None:
80
+ predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
81
+
82
+ # load_images_start = time.time()
83
+ images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
84
+ images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
85
+ # load_images_time = round(time.time() - load_images_start, 2)
86
+ # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
87
+
88
+ # infer_start = time.time()
89
+ results = await predictor.aio_batch_predict(images=images_base64_list)
90
+ # infer_time = round(time.time() - infer_start, 2)
91
+ # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
92
+ middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
93
+ return middle_json, results
vendor/mineru/mineru/backend/vlm/vlm_magic_model.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Literal
3
+
4
+ from loguru import logger
5
+
6
+ from mineru.utils.boxbase import bbox_distance, is_in
7
+ from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
8
+ from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
9
+ from mineru.utils.format_utils import convert_otsl_to_html
10
+
11
+
12
+ class MagicModel:
13
+ def __init__(self, token: str, width, height):
14
+ self.token = token
15
+
16
+ # 使用正则表达式查找所有块
17
+ pattern = (
18
+ r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
19
+ )
20
+ block_infos = re.findall(pattern, token, re.DOTALL)
21
+
22
+ blocks = []
23
+ self.all_spans = []
24
+ # 解析每个块
25
+ for index, block_info in enumerate(block_infos):
26
+ block_bbox = block_info[0].strip()
27
+ try:
28
+ x1, y1, x2, y2 = map(int, block_bbox.split())
29
+ x_1, y_1, x_2, y_2 = (
30
+ int(x1 * width / 1000),
31
+ int(y1 * height / 1000),
32
+ int(x2 * width / 1000),
33
+ int(y2 * height / 1000),
34
+ )
35
+ if x_2 < x_1:
36
+ x_1, x_2 = x_2, x_1
37
+ if y_2 < y_1:
38
+ y_1, y_2 = y_2, y_1
39
+ block_bbox = (x_1, y_1, x_2, y_2)
40
+ block_type = block_info[1].strip()
41
+ block_content = block_info[2].strip()
42
+
43
+ # print(f"坐标: {block_bbox}")
44
+ # print(f"类型: {block_type}")
45
+ # print(f"内容: {block_content}")
46
+ # print("-" * 50)
47
+ except Exception as e:
48
+ # 如果解析失败,可能是因为格式不正确,跳过这个块
49
+ logger.warning(f"Invalid block format: {block_info}, error: {e}")
50
+ continue
51
+
52
+ span_type = "unknown"
53
+ if block_type in [
54
+ "text",
55
+ "title",
56
+ "image_caption",
57
+ "image_footnote",
58
+ "table_caption",
59
+ "table_footnote",
60
+ "list",
61
+ "index",
62
+ ]:
63
+ span_type = ContentType.TEXT
64
+ elif block_type in ["image"]:
65
+ block_type = BlockType.IMAGE_BODY
66
+ span_type = ContentType.IMAGE
67
+ elif block_type in ["table"]:
68
+ block_type = BlockType.TABLE_BODY
69
+ span_type = ContentType.TABLE
70
+ elif block_type in ["equation"]:
71
+ block_type = BlockType.INTERLINE_EQUATION
72
+ span_type = ContentType.INTERLINE_EQUATION
73
+
74
+ if span_type in ["image", "table"]:
75
+ span = {
76
+ "bbox": block_bbox,
77
+ "type": span_type,
78
+ }
79
+ if span_type == ContentType.TABLE:
80
+ if "<fcel>" in block_content or "<ecel>" in block_content:
81
+ lines = block_content.split("\n\n")
82
+ new_lines = []
83
+ for line in lines:
84
+ if "<fcel>" in line or "<ecel>" in line:
85
+ line = convert_otsl_to_html(line)
86
+ new_lines.append(line)
87
+ span["html"] = "\n\n".join(new_lines)
88
+ else:
89
+ span["html"] = block_content
90
+ elif span_type in [ContentType.INTERLINE_EQUATION]:
91
+ span = {
92
+ "bbox": block_bbox,
93
+ "type": span_type,
94
+ "content": isolated_formula_clean(block_content),
95
+ }
96
+ else:
97
+ if block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
98
+ # 生成包含文本和公式的span列表
99
+ spans = []
100
+ last_end = 0
101
+
102
+ # 查找所有公式
103
+ for match in re.finditer(r'\\\((.+?)\\\)', block_content):
104
+ start, end = match.span()
105
+
106
+ # 添加公式前的文本
107
+ if start > last_end:
108
+ text_before = block_content[last_end:start]
109
+ if text_before.strip():
110
+ spans.append({
111
+ "bbox": block_bbox,
112
+ "type": ContentType.TEXT,
113
+ "content": text_before
114
+ })
115
+
116
+ # 添加公式(去除\(和\))
117
+ formula = match.group(1)
118
+ spans.append({
119
+ "bbox": block_bbox,
120
+ "type": ContentType.INLINE_EQUATION,
121
+ "content": formula.strip()
122
+ })
123
+
124
+ last_end = end
125
+
126
+ # 添加最后一个公式后的文本
127
+ if last_end < len(block_content):
128
+ text_after = block_content[last_end:]
129
+ if text_after.strip():
130
+ spans.append({
131
+ "bbox": block_bbox,
132
+ "type": ContentType.TEXT,
133
+ "content": text_after
134
+ })
135
+
136
+ span = spans
137
+ else:
138
+ span = {
139
+ "bbox": block_bbox,
140
+ "type": span_type,
141
+ "content": block_content,
142
+ }
143
+
144
+ if isinstance(span, dict) and "bbox" in span:
145
+ self.all_spans.append(span)
146
+ line = {
147
+ "bbox": block_bbox,
148
+ "spans": [span],
149
+ }
150
+ elif isinstance(span, list):
151
+ self.all_spans.extend(span)
152
+ line = {
153
+ "bbox": block_bbox,
154
+ "spans": span,
155
+ }
156
+ else:
157
+ raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
158
+
159
+ blocks.append(
160
+ {
161
+ "bbox": block_bbox,
162
+ "type": block_type,
163
+ "lines": [line],
164
+ "index": index,
165
+ }
166
+ )
167
+
168
+ self.image_blocks = []
169
+ self.table_blocks = []
170
+ self.interline_equation_blocks = []
171
+ self.text_blocks = []
172
+ self.title_blocks = []
173
+ for block in blocks:
174
+ if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
175
+ self.image_blocks.append(block)
176
+ elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
177
+ self.table_blocks.append(block)
178
+ elif block["type"] == BlockType.INTERLINE_EQUATION:
179
+ self.interline_equation_blocks.append(block)
180
+ elif block["type"] == BlockType.TEXT:
181
+ self.text_blocks.append(block)
182
+ elif block["type"] == BlockType.TITLE:
183
+ self.title_blocks.append(block)
184
+ else:
185
+ continue
186
+
187
+ def get_image_blocks(self):
188
+ return fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
189
+
190
+ def get_table_blocks(self):
191
+ return fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
192
+
193
+ def get_title_blocks(self):
194
+ return fix_title_blocks(self.title_blocks)
195
+
196
+ def get_text_blocks(self):
197
+ return fix_text_blocks(self.text_blocks)
198
+
199
+ def get_interline_equation_blocks(self):
200
+ return self.interline_equation_blocks
201
+
202
+ def get_all_spans(self):
203
+ return self.all_spans
204
+
205
+
206
+ def isolated_formula_clean(txt):
207
+ latex = txt[:]
208
+ if latex.startswith("\\["): latex = latex[2:]
209
+ if latex.endswith("\\]"): latex = latex[:-2]
210
+ latex = latex_fix(latex.strip())
211
+ return latex
212
+
213
+
214
+ def latex_fix(latex):
215
+ # valid pairs:
216
+ # \left\{ ... \right\}
217
+ # \left( ... \right)
218
+ # \left| ... \right|
219
+ # \left\| ... \right\|
220
+ # \left[ ... \right]
221
+
222
+ LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
223
+ RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
224
+ left_count = len(LEFT_COUNT_PATTERN.findall(latex)) # 不匹配\lefteqn等
225
+ right_count = len(RIGHT_COUNT_PATTERN.findall(latex)) # 不匹配\rightarrow
226
+
227
+ if left_count != right_count:
228
+ for _ in range(2):
229
+ # replace valid pairs
230
+ latex = re.sub(r'\\left\\\{', "{", latex) # \left\{
231
+ latex = re.sub(r"\\left\|", "|", latex) # \left|
232
+ latex = re.sub(r"\\left\\\|", "|", latex) # \left\|
233
+ latex = re.sub(r"\\left\(", "(", latex) # \left(
234
+ latex = re.sub(r"\\left\[", "[", latex) # \left[
235
+
236
+ latex = re.sub(r"\\right\\\}", "}", latex) # \right\}
237
+ latex = re.sub(r"\\right\|", "|", latex) # \right|
238
+ latex = re.sub(r"\\right\\\|", "|", latex) # \right\|
239
+ latex = re.sub(r"\\right\)", ")", latex) # \right)
240
+ latex = re.sub(r"\\right\]", "]", latex) # \right]
241
+ latex = re.sub(r"\\right\.", "", latex) # \right.
242
+
243
+ # replace invalid pairs first
244
+ latex = re.sub(r'\\left\{', "{", latex)
245
+ latex = re.sub(r'\\right\}', "}", latex) # \left{ ... \right}
246
+ latex = re.sub(r'\\left\\\(', "(", latex)
247
+ latex = re.sub(r'\\right\\\)', ")", latex) # \left\( ... \right\)
248
+ latex = re.sub(r'\\left\\\[', "[", latex)
249
+ latex = re.sub(r'\\right\\\]', "]", latex) # \left\[ ... \right\]
250
+
251
+ return latex
252
+
253
+
254
+ def __reduct_overlap(bboxes):
255
+ N = len(bboxes)
256
+ keep = [True] * N
257
+ for i in range(N):
258
+ for j in range(N):
259
+ if i == j:
260
+ continue
261
+ if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
262
+ keep[i] = False
263
+ return [bboxes[i] for i in range(N) if keep[i]]
264
+
265
+
266
+ def __tie_up_category_by_distance_v3(
267
+ blocks: list,
268
+ subject_block_type: str,
269
+ object_block_type: str,
270
+ ):
271
+ subjects = __reduct_overlap(
272
+ list(
273
+ map(
274
+ lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
275
+ filter(
276
+ lambda x: x["type"] == subject_block_type,
277
+ blocks,
278
+ ),
279
+ )
280
+ )
281
+ )
282
+ objects = __reduct_overlap(
283
+ list(
284
+ map(
285
+ lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
286
+ filter(
287
+ lambda x: x["type"] == object_block_type,
288
+ blocks,
289
+ ),
290
+ )
291
+ )
292
+ )
293
+
294
+ ret = []
295
+ N, M = len(subjects), len(objects)
296
+ subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
297
+ objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
298
+
299
+ OBJ_IDX_OFFSET = 10000
300
+ SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
301
+
302
+ all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
303
+ (i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
304
+ ]
305
+ seen_idx = set()
306
+ seen_sub_idx = set()
307
+
308
+ while N > len(seen_sub_idx):
309
+ candidates = []
310
+ for idx, kind, x0, y0 in all_boxes_with_idx:
311
+ if idx in seen_idx:
312
+ continue
313
+ candidates.append((idx, kind, x0, y0))
314
+
315
+ if len(candidates) == 0:
316
+ break
317
+ left_x = min([v[2] for v in candidates])
318
+ top_y = min([v[3] for v in candidates])
319
+
320
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
321
+
322
+ fst_idx, fst_kind, left_x, top_y = candidates[0]
323
+ candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
324
+ nxt = None
325
+
326
+ for i in range(1, len(candidates)):
327
+ if candidates[i][1] ^ fst_kind == 1:
328
+ nxt = candidates[i]
329
+ break
330
+ if nxt is None:
331
+ break
332
+
333
+ if fst_kind == SUB_BIT_KIND:
334
+ sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
335
+
336
+ else:
337
+ sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
338
+
339
+ pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
340
+ nearest_dis = float("inf")
341
+ for i in range(N):
342
+ if i in seen_idx or i == sub_idx:
343
+ continue
344
+ nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
345
+
346
+ if pair_dis >= 3 * nearest_dis:
347
+ seen_idx.add(sub_idx)
348
+ continue
349
+
350
+ seen_idx.add(sub_idx)
351
+ seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
352
+ seen_sub_idx.add(sub_idx)
353
+
354
+ ret.append(
355
+ {
356
+ "sub_bbox": {
357
+ "bbox": subjects[sub_idx]["bbox"],
358
+ "lines": subjects[sub_idx]["lines"],
359
+ "index": subjects[sub_idx]["index"],
360
+ },
361
+ "obj_bboxes": [
362
+ {"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
363
+ ],
364
+ "sub_idx": sub_idx,
365
+ }
366
+ )
367
+
368
+ for i in range(len(objects)):
369
+ j = i + OBJ_IDX_OFFSET
370
+ if j in seen_idx:
371
+ continue
372
+ seen_idx.add(j)
373
+ nearest_dis, nearest_sub_idx = float("inf"), -1
374
+ for k in range(len(subjects)):
375
+ dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
376
+ if dis < nearest_dis:
377
+ nearest_dis = dis
378
+ nearest_sub_idx = k
379
+
380
+ for k in range(len(subjects)):
381
+ if k != nearest_sub_idx:
382
+ continue
383
+ if k in seen_sub_idx:
384
+ for kk in range(len(ret)):
385
+ if ret[kk]["sub_idx"] == k:
386
+ ret[kk]["obj_bboxes"].append(
387
+ {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
388
+ )
389
+ break
390
+ else:
391
+ ret.append(
392
+ {
393
+ "sub_bbox": {
394
+ "bbox": subjects[k]["bbox"],
395
+ "lines": subjects[k]["lines"],
396
+ "index": subjects[k]["index"],
397
+ },
398
+ "obj_bboxes": [
399
+ {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
400
+ ],
401
+ "sub_idx": k,
402
+ }
403
+ )
404
+ seen_sub_idx.add(k)
405
+ seen_idx.add(k)
406
+
407
+ for i in range(len(subjects)):
408
+ if i in seen_sub_idx:
409
+ continue
410
+ ret.append(
411
+ {
412
+ "sub_bbox": {
413
+ "bbox": subjects[i]["bbox"],
414
+ "lines": subjects[i]["lines"],
415
+ "index": subjects[i]["index"],
416
+ },
417
+ "obj_bboxes": [],
418
+ "sub_idx": i,
419
+ }
420
+ )
421
+
422
+ return ret
423
+
424
+
425
+ def get_type_blocks(blocks, block_type: Literal["image", "table"]):
426
+ with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
427
+ with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
428
+ ret = []
429
+ for v in with_captions:
430
+ record = {
431
+ f"{block_type}_body": v["sub_bbox"],
432
+ f"{block_type}_caption_list": v["obj_bboxes"],
433
+ }
434
+ filter_idx = v["sub_idx"]
435
+ d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
436
+ record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
437
+ ret.append(record)
438
+ return ret
439
+
440
+
441
+ def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
442
+ need_fix_blocks = get_type_blocks(blocks, fix_type)
443
+ fixed_blocks = []
444
+ for block in need_fix_blocks:
445
+ body = block[f"{fix_type}_body"]
446
+ caption_list = block[f"{fix_type}_caption_list"]
447
+ footnote_list = block[f"{fix_type}_footnote_list"]
448
+
449
+ body["type"] = f"{fix_type}_body"
450
+ for caption in caption_list:
451
+ caption["type"] = f"{fix_type}_caption"
452
+ for footnote in footnote_list:
453
+ footnote["type"] = f"{fix_type}_footnote"
454
+
455
+ two_layer_block = {
456
+ "type": fix_type,
457
+ "bbox": body["bbox"],
458
+ "blocks": [
459
+ body,
460
+ ],
461
+ "index": body["index"],
462
+ }
463
+ two_layer_block["blocks"].extend([*caption_list, *footnote_list])
464
+
465
+ fixed_blocks.append(two_layer_block)
466
+
467
+ return fixed_blocks
468
+
469
+
470
+ def fix_title_blocks(blocks):
471
+ for block in blocks:
472
+ if block["type"] == BlockType.TITLE:
473
+ title_content = merge_para_with_text(block)
474
+ title_level = count_leading_hashes(title_content)
475
+ block['level'] = title_level
476
+ for line in block['lines']:
477
+ for span in line['spans']:
478
+ span['content'] = strip_leading_hashes(span['content'])
479
+ break
480
+ break
481
+ return blocks
482
+
483
+
484
+ def count_leading_hashes(text):
485
+ match = re.match(r'^(#+)', text)
486
+ return len(match.group(1)) if match else 0
487
+
488
+
489
+ def strip_leading_hashes(text):
490
+ # 去除开头的#和紧随其后的空格
491
+ return re.sub(r'^#+\s*', '', text)
492
+
493
+
494
+ def fix_text_blocks(blocks):
495
+ i = 0
496
+ while i < len(blocks):
497
+ block = blocks[i]
498
+ last_line = block["lines"][-1]if block["lines"] else None
499
+ if last_line:
500
+ last_span = last_line["spans"][-1] if last_line["spans"] else None
501
+ if last_span and last_span['content'].endswith('<|txt_contd|>'):
502
+ last_span['content'] = last_span['content'][:-len('<|txt_contd|>')]
503
+
504
+ # 查找下一个未被清空的块
505
+ next_idx = i + 1
506
+ while next_idx < len(blocks) and blocks[next_idx].get(SplitFlag.LINES_DELETED, False):
507
+ next_idx += 1
508
+
509
+ # 如果找到下一个有效块,则合并
510
+ if next_idx < len(blocks):
511
+ next_block = blocks[next_idx]
512
+ # 将下一个块的lines扩展到当前块的lines中
513
+ block["lines"].extend(next_block["lines"])
514
+ # 清空下一个块的lines
515
+ next_block["lines"] = []
516
+ # 在下一个块中添加标志
517
+ next_block[SplitFlag.LINES_DELETED] = True
518
+ # 不增加i,继续检查当前块(现在已包含下一个块的内容)
519
+ continue
520
+ i += 1
521
+ return blocks
vendor/mineru/mineru/backend/vlm/vlm_middle_json_mkcontent.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
4
+ from mineru.utils.enum_class import MakeMode, BlockType, ContentType
5
+
6
+
7
+ latex_delimiters_config = get_latex_delimiter_config()
8
+
9
+ default_delimiters = {
10
+ 'display': {'left': '$$', 'right': '$$'},
11
+ 'inline': {'left': '$', 'right': '$'}
12
+ }
13
+
14
+ delimiters = latex_delimiters_config if latex_delimiters_config else default_delimiters
15
+
16
+ display_left_delimiter = delimiters['display']['left']
17
+ display_right_delimiter = delimiters['display']['right']
18
+ inline_left_delimiter = delimiters['inline']['left']
19
+ inline_right_delimiter = delimiters['inline']['right']
20
+
21
+ def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
22
+ para_text = ''
23
+ for line in para_block['lines']:
24
+ for j, span in enumerate(line['spans']):
25
+ span_type = span['type']
26
+ content = ''
27
+ if span_type == ContentType.TEXT:
28
+ content = span['content']
29
+ elif span_type == ContentType.INLINE_EQUATION:
30
+ content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
31
+ elif span_type == ContentType.INTERLINE_EQUATION:
32
+ if formula_enable:
33
+ content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
34
+ else:
35
+ if span.get('image_path', ''):
36
+ content = f"![]({img_buket_path}/{span['image_path']})"
37
+ # content = content.strip()
38
+ if content:
39
+ if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
40
+ if j == len(line['spans']) - 1:
41
+ para_text += content
42
+ else:
43
+ para_text += f'{content} '
44
+ elif span_type == ContentType.INTERLINE_EQUATION:
45
+ para_text += content
46
+ return para_text
47
+
48
+ def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
49
+ page_markdown = []
50
+ for para_block in para_blocks:
51
+ para_text = ''
52
+ para_type = para_block['type']
53
+ if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
54
+ para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
55
+ elif para_type == BlockType.TITLE:
56
+ title_level = get_title_level(para_block)
57
+ para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
58
+ elif para_type == BlockType.IMAGE:
59
+ if make_mode == MakeMode.NLP_MD:
60
+ continue
61
+ elif make_mode == MakeMode.MM_MD:
62
+ # 检测是否存在图片脚注
63
+ has_image_footnote = any(block['type'] == BlockType.IMAGE_FOOTNOTE for block in para_block['blocks'])
64
+ # 如果存在图片脚注,则将图片脚注拼接到图片正文后面
65
+ if has_image_footnote:
66
+ for block in para_block['blocks']: # 1st.拼image_caption
67
+ if block['type'] == BlockType.IMAGE_CAPTION:
68
+ para_text += merge_para_with_text(block) + ' \n'
69
+ for block in para_block['blocks']: # 2nd.拼image_body
70
+ if block['type'] == BlockType.IMAGE_BODY:
71
+ for line in block['lines']:
72
+ for span in line['spans']:
73
+ if span['type'] == ContentType.IMAGE:
74
+ if span.get('image_path', ''):
75
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
76
+ for block in para_block['blocks']: # 3rd.拼image_footnote
77
+ if block['type'] == BlockType.IMAGE_FOOTNOTE:
78
+ para_text += ' \n' + merge_para_with_text(block)
79
+ else:
80
+ for block in para_block['blocks']: # 1st.拼image_body
81
+ if block['type'] == BlockType.IMAGE_BODY:
82
+ for line in block['lines']:
83
+ for span in line['spans']:
84
+ if span['type'] == ContentType.IMAGE:
85
+ if span.get('image_path', ''):
86
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
87
+ for block in para_block['blocks']: # 2nd.拼image_caption
88
+ if block['type'] == BlockType.IMAGE_CAPTION:
89
+ para_text += ' \n' + merge_para_with_text(block)
90
+
91
+ elif para_type == BlockType.TABLE:
92
+ if make_mode == MakeMode.NLP_MD:
93
+ continue
94
+ elif make_mode == MakeMode.MM_MD:
95
+ for block in para_block['blocks']: # 1st.拼table_caption
96
+ if block['type'] == BlockType.TABLE_CAPTION:
97
+ para_text += merge_para_with_text(block) + ' \n'
98
+ for block in para_block['blocks']: # 2nd.拼table_body
99
+ if block['type'] == BlockType.TABLE_BODY:
100
+ for line in block['lines']:
101
+ for span in line['spans']:
102
+ if span['type'] == ContentType.TABLE:
103
+ # if processed by table model
104
+ if table_enable:
105
+ if span.get('html', ''):
106
+ para_text += f"\n{span['html']}\n"
107
+ elif span.get('image_path', ''):
108
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
109
+ else:
110
+ if span.get('image_path', ''):
111
+ para_text += f"![]({img_buket_path}/{span['image_path']})"
112
+ for block in para_block['blocks']: # 3rd.拼table_footnote
113
+ if block['type'] == BlockType.TABLE_FOOTNOTE:
114
+ para_text += '\n' + merge_para_with_text(block) + ' '
115
+
116
+ if para_text.strip() == '':
117
+ continue
118
+ else:
119
+ # page_markdown.append(para_text.strip() + ' ')
120
+ page_markdown.append(para_text.strip())
121
+
122
+ return page_markdown
123
+
124
+
125
+
126
+
127
+
128
+ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
129
+ para_type = para_block['type']
130
+ para_content = {}
131
+ if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
132
+ para_content = {
133
+ 'type': ContentType.TEXT,
134
+ 'text': merge_para_with_text(para_block),
135
+ }
136
+ elif para_type == BlockType.TITLE:
137
+ title_level = get_title_level(para_block)
138
+ para_content = {
139
+ 'type': ContentType.TEXT,
140
+ 'text': merge_para_with_text(para_block),
141
+ }
142
+ if title_level != 0:
143
+ para_content['text_level'] = title_level
144
+ elif para_type == BlockType.INTERLINE_EQUATION:
145
+ para_content = {
146
+ 'type': ContentType.EQUATION,
147
+ 'text': merge_para_with_text(para_block),
148
+ 'text_format': 'latex',
149
+ }
150
+ elif para_type == BlockType.IMAGE:
151
+ para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
152
+ for block in para_block['blocks']:
153
+ if block['type'] == BlockType.IMAGE_BODY:
154
+ for line in block['lines']:
155
+ for span in line['spans']:
156
+ if span['type'] == ContentType.IMAGE:
157
+ if span.get('image_path', ''):
158
+ para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
159
+ if block['type'] == BlockType.IMAGE_CAPTION:
160
+ para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
161
+ if block['type'] == BlockType.IMAGE_FOOTNOTE:
162
+ para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
163
+ elif para_type == BlockType.TABLE:
164
+ para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
165
+ for block in para_block['blocks']:
166
+ if block['type'] == BlockType.TABLE_BODY:
167
+ for line in block['lines']:
168
+ for span in line['spans']:
169
+ if span['type'] == ContentType.TABLE:
170
+
171
+ if span.get('html', ''):
172
+ para_content[BlockType.TABLE_BODY] = f"{span['html']}"
173
+
174
+ if span.get('image_path', ''):
175
+ para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
176
+
177
+ if block['type'] == BlockType.TABLE_CAPTION:
178
+ para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
179
+ if block['type'] == BlockType.TABLE_FOOTNOTE:
180
+ para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
181
+
182
+ para_content['page_idx'] = page_idx
183
+
184
+ return para_content
185
+
186
+ def union_make(pdf_info_dict: list,
187
+ make_mode: str,
188
+ img_buket_path: str = '',
189
+ ):
190
+
191
+ formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
192
+ table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
193
+
194
+ output_content = []
195
+ for page_info in pdf_info_dict:
196
+ paras_of_layout = page_info.get('para_blocks')
197
+ page_idx = page_info.get('page_idx')
198
+ if not paras_of_layout:
199
+ continue
200
+ if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
201
+ page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
202
+ output_content.extend(page_markdown)
203
+ elif make_mode == MakeMode.CONTENT_LIST:
204
+ for para_block in paras_of_layout:
205
+ para_content = make_blocks_to_content_list(para_block, img_buket_path, page_idx)
206
+ output_content.append(para_content)
207
+
208
+ if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
209
+ return '\n\n'.join(output_content)
210
+ elif make_mode == MakeMode.CONTENT_LIST:
211
+ return output_content
212
+ return None
213
+
214
+
215
+ def get_title_level(block):
216
+ title_level = block.get('level', 1)
217
+ if title_level > 4:
218
+ title_level = 4
219
+ elif title_level < 1:
220
+ title_level = 0
221
+ return title_level
vendor/mineru/mineru/cli/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/cli/client.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import os
3
+ import click
4
+ from pathlib import Path
5
+ from loguru import logger
6
+
7
+ from mineru.utils.cli_parser import arg_parse
8
+ from mineru.utils.config_reader import get_device
9
+ from mineru.utils.model_utils import get_vram
10
+ from ..version import __version__
11
+ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
12
+
13
+ @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
14
+ @click.pass_context
15
+ @click.version_option(__version__,
16
+ '--version',
17
+ '-v',
18
+ help='display the version and exit')
19
+ @click.option(
20
+ '-p',
21
+ '--path',
22
+ 'input_path',
23
+ type=click.Path(exists=True),
24
+ required=True,
25
+ help='local filepath or directory. support pdf, png, jpg, jpeg files',
26
+ )
27
+ @click.option(
28
+ '-o',
29
+ '--output',
30
+ 'output_dir',
31
+ type=click.Path(),
32
+ required=True,
33
+ help='output local directory',
34
+ )
35
+ @click.option(
36
+ '-m',
37
+ '--method',
38
+ 'method',
39
+ type=click.Choice(['auto', 'txt', 'ocr']),
40
+ help="""the method for parsing pdf:
41
+ auto: Automatically determine the method based on the file type.
42
+ txt: Use text extraction method.
43
+ ocr: Use OCR method for image-based PDFs.
44
+ Without method specified, 'auto' will be used by default.
45
+ Adapted only for the case where the backend is set to "pipeline".""",
46
+ default='auto',
47
+ )
48
+ @click.option(
49
+ '-b',
50
+ '--backend',
51
+ 'backend',
52
+ type=click.Choice(['pipeline', 'vlm-transformers', 'vlm-sglang-engine', 'vlm-sglang-client']),
53
+ help="""the backend for parsing pdf:
54
+ pipeline: More general.
55
+ vlm-transformers: More general.
56
+ vlm-sglang-engine: Faster(engine).
57
+ vlm-sglang-client: Faster(client).
58
+ without method specified, pipeline will be used by default.""",
59
+ default='pipeline',
60
+ )
61
+ @click.option(
62
+ '-l',
63
+ '--lang',
64
+ 'lang',
65
+ type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka',
66
+ 'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']),
67
+ help="""
68
+ Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
69
+ Without languages specified, 'ch' will be used by default.
70
+ Adapted only for the case where the backend is set to "pipeline".
71
+ """,
72
+ default='ch',
73
+ )
74
+ @click.option(
75
+ '-u',
76
+ '--url',
77
+ 'server_url',
78
+ type=str,
79
+ help="""
80
+ When the backend is `sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`
81
+ """,
82
+ default=None,
83
+ )
84
+ @click.option(
85
+ '-s',
86
+ '--start',
87
+ 'start_page_id',
88
+ type=int,
89
+ help='The starting page for PDF parsing, beginning from 0.',
90
+ default=0,
91
+ )
92
+ @click.option(
93
+ '-e',
94
+ '--end',
95
+ 'end_page_id',
96
+ type=int,
97
+ help='The ending page for PDF parsing, beginning from 0.',
98
+ default=None,
99
+ )
100
+ @click.option(
101
+ '-f',
102
+ '--formula',
103
+ 'formula_enable',
104
+ type=bool,
105
+ help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
106
+ default=True,
107
+ )
108
+ @click.option(
109
+ '-t',
110
+ '--table',
111
+ 'table_enable',
112
+ type=bool,
113
+ help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
114
+ default=True,
115
+ )
116
+ @click.option(
117
+ '-d',
118
+ '--device',
119
+ 'device_mode',
120
+ type=str,
121
+ help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
122
+ default=None,
123
+ )
124
+ @click.option(
125
+ '--vram',
126
+ 'virtual_vram',
127
+ type=int,
128
+ help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
129
+ default=None,
130
+ )
131
+ @click.option(
132
+ '--source',
133
+ 'model_source',
134
+ type=click.Choice(['huggingface', 'modelscope', 'local']),
135
+ help="""
136
+ The source of the model repository. Default is 'huggingface'.
137
+ """,
138
+ default='huggingface',
139
+ )
140
+
141
+
142
+ def main(
143
+ ctx,
144
+ input_path, output_dir, method, backend, lang, server_url,
145
+ start_page_id, end_page_id, formula_enable, table_enable,
146
+ device_mode, virtual_vram, model_source, **kwargs
147
+ ):
148
+
149
+ kwargs.update(arg_parse(ctx))
150
+
151
+ if not backend.endswith('-client'):
152
+ def get_device_mode() -> str:
153
+ if device_mode is not None:
154
+ return device_mode
155
+ else:
156
+ return get_device()
157
+ if os.getenv('MINERU_DEVICE_MODE', None) is None:
158
+ os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
159
+
160
+ def get_virtual_vram_size() -> int:
161
+ if virtual_vram is not None:
162
+ return virtual_vram
163
+ if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
164
+ return round(get_vram(get_device_mode()))
165
+ return 1
166
+ if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
167
+ os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
168
+
169
+ if os.getenv('MINERU_MODEL_SOURCE', None) is None:
170
+ os.environ['MINERU_MODEL_SOURCE'] = model_source
171
+
172
+ os.makedirs(output_dir, exist_ok=True)
173
+
174
+ def parse_doc(path_list: list[Path]):
175
+ try:
176
+ file_name_list = []
177
+ pdf_bytes_list = []
178
+ lang_list = []
179
+ for path in path_list:
180
+ file_name = str(Path(path).stem)
181
+ pdf_bytes = read_fn(path)
182
+ file_name_list.append(file_name)
183
+ pdf_bytes_list.append(pdf_bytes)
184
+ lang_list.append(lang)
185
+ do_parse(
186
+ output_dir=output_dir,
187
+ pdf_file_names=file_name_list,
188
+ pdf_bytes_list=pdf_bytes_list,
189
+ p_lang_list=lang_list,
190
+ backend=backend,
191
+ parse_method=method,
192
+ formula_enable=formula_enable,
193
+ table_enable=table_enable,
194
+ server_url=server_url,
195
+ start_page_id=start_page_id,
196
+ end_page_id=end_page_id,
197
+ **kwargs,
198
+ )
199
+ except Exception as e:
200
+ logger.exception(e)
201
+
202
+ if os.path.isdir(input_path):
203
+ doc_path_list = []
204
+ for doc_path in Path(input_path).glob('*'):
205
+ if doc_path.suffix in pdf_suffixes + image_suffixes:
206
+ doc_path_list.append(doc_path)
207
+ parse_doc(doc_path_list)
208
+ else:
209
+ parse_doc([Path(input_path)])
210
+
211
+ if __name__ == '__main__':
212
+ main()
vendor/mineru/mineru/cli/common.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+ import io
3
+ import json
4
+ import os
5
+ import copy
6
+ from pathlib import Path
7
+
8
+ import pypdfium2 as pdfium
9
+ from loguru import logger
10
+
11
+ from mineru.data.data_reader_writer import FileBasedDataWriter
12
+ from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
13
+ from mineru.utils.enum_class import MakeMode
14
+ from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
15
+ from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
16
+ from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
17
+ from mineru.backend.vlm.vlm_analyze import aio_doc_analyze as aio_vlm_doc_analyze
18
+
19
+ pdf_suffixes = [".pdf"]
20
+ image_suffixes = [".png", ".jpeg", ".jpg", ".webp", ".gif"]
21
+
22
+
23
+ def read_fn(path):
24
+ if not isinstance(path, Path):
25
+ path = Path(path)
26
+ with open(str(path), "rb") as input_file:
27
+ file_bytes = input_file.read()
28
+ if path.suffix in image_suffixes:
29
+ return images_bytes_to_pdf_bytes(file_bytes)
30
+ elif path.suffix in pdf_suffixes:
31
+ return file_bytes
32
+ else:
33
+ raise Exception(f"Unknown file suffix: {path.suffix}")
34
+
35
+
36
+ def prepare_env(output_dir, pdf_file_name, parse_method):
37
+ local_md_dir = str(os.path.join(output_dir, pdf_file_name, parse_method))
38
+ local_image_dir = os.path.join(str(local_md_dir), "images")
39
+ os.makedirs(local_image_dir, exist_ok=True)
40
+ os.makedirs(local_md_dir, exist_ok=True)
41
+ return local_image_dir, local_md_dir
42
+
43
+
44
+ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page_id=None):
45
+
46
+ # 从字节数据加载PDF
47
+ pdf = pdfium.PdfDocument(pdf_bytes)
48
+
49
+ # 确定结束页
50
+ end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf) - 1
51
+ if end_page_id > len(pdf) - 1:
52
+ logger.warning("end_page_id is out of range, use pdf_docs length")
53
+ end_page_id = len(pdf) - 1
54
+
55
+ # 创建一个新的PDF文档
56
+ output_pdf = pdfium.PdfDocument.new()
57
+
58
+ # 选择要导入的页面索引
59
+ page_indices = list(range(start_page_id, end_page_id + 1))
60
+
61
+ # 从原PDF导入页面到新PDF
62
+ output_pdf.import_pages(pdf, page_indices)
63
+
64
+ # 将新PDF保存到内存缓冲区
65
+ output_buffer = io.BytesIO()
66
+ output_pdf.save(output_buffer)
67
+
68
+ # 获取字节数据
69
+ output_bytes = output_buffer.getvalue()
70
+
71
+ pdf.close() # 关闭原PDF文档以释放资源
72
+ output_pdf.close() # 关闭新PDF文档以释放资源
73
+
74
+ return output_bytes
75
+
76
+
77
+ def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
78
+ """准备处理PDF字节数据"""
79
+ result = []
80
+ for pdf_bytes in pdf_bytes_list:
81
+ new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
82
+ result.append(new_pdf_bytes)
83
+ return result
84
+
85
+
86
+ def _process_output(
87
+ pdf_info,
88
+ pdf_bytes,
89
+ pdf_file_name,
90
+ local_md_dir,
91
+ local_image_dir,
92
+ md_writer,
93
+ f_draw_layout_bbox,
94
+ f_draw_span_bbox,
95
+ f_dump_orig_pdf,
96
+ f_dump_md,
97
+ f_dump_content_list,
98
+ f_dump_middle_json,
99
+ f_dump_model_output,
100
+ f_make_md_mode,
101
+ middle_json,
102
+ model_output=None,
103
+ is_pipeline=True
104
+ ):
105
+ from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
106
+ """处理输出文件"""
107
+ if f_draw_layout_bbox:
108
+ draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
109
+
110
+ if f_draw_span_bbox:
111
+ draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
112
+
113
+ if f_dump_orig_pdf:
114
+ md_writer.write(
115
+ f"{pdf_file_name}_origin.pdf",
116
+ pdf_bytes,
117
+ )
118
+
119
+ image_dir = str(os.path.basename(local_image_dir))
120
+
121
+ if f_dump_md:
122
+ make_func = pipeline_union_make if is_pipeline else vlm_union_make
123
+ md_content_str = make_func(pdf_info, f_make_md_mode, image_dir)
124
+ md_writer.write_string(
125
+ f"{pdf_file_name}.md",
126
+ md_content_str,
127
+ )
128
+
129
+ if f_dump_content_list:
130
+ make_func = pipeline_union_make if is_pipeline else vlm_union_make
131
+ content_list = make_func(pdf_info, MakeMode.CONTENT_LIST, image_dir)
132
+ md_writer.write_string(
133
+ f"{pdf_file_name}_content_list.json",
134
+ json.dumps(content_list, ensure_ascii=False, indent=4),
135
+ )
136
+
137
+ if f_dump_middle_json:
138
+ md_writer.write_string(
139
+ f"{pdf_file_name}_middle.json",
140
+ json.dumps(middle_json, ensure_ascii=False, indent=4),
141
+ )
142
+
143
+ if f_dump_model_output:
144
+ if is_pipeline:
145
+ md_writer.write_string(
146
+ f"{pdf_file_name}_model.json",
147
+ json.dumps(model_output, ensure_ascii=False, indent=4),
148
+ )
149
+ else:
150
+ output_text = ("\n" + "-" * 50 + "\n").join(model_output)
151
+ md_writer.write_string(
152
+ f"{pdf_file_name}_model_output.txt",
153
+ output_text,
154
+ )
155
+
156
+ logger.info(f"local output dir is {local_md_dir}")
157
+
158
+
159
+ def _process_pipeline(
160
+ output_dir,
161
+ pdf_file_names,
162
+ pdf_bytes_list,
163
+ p_lang_list,
164
+ parse_method,
165
+ p_formula_enable,
166
+ p_table_enable,
167
+ f_draw_layout_bbox,
168
+ f_draw_span_bbox,
169
+ f_dump_md,
170
+ f_dump_middle_json,
171
+ f_dump_model_output,
172
+ f_dump_orig_pdf,
173
+ f_dump_content_list,
174
+ f_make_md_mode,
175
+ ):
176
+ """处理pipeline后端逻辑"""
177
+ from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
178
+ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
179
+
180
+ infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
181
+ pipeline_doc_analyze(
182
+ pdf_bytes_list, p_lang_list, parse_method=parse_method,
183
+ formula_enable=p_formula_enable, table_enable=p_table_enable
184
+ )
185
+ )
186
+
187
+ for idx, model_list in enumerate(infer_results):
188
+ model_json = copy.deepcopy(model_list)
189
+ pdf_file_name = pdf_file_names[idx]
190
+ local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
191
+ image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
192
+
193
+ images_list = all_image_lists[idx]
194
+ pdf_doc = all_pdf_docs[idx]
195
+ _lang = lang_list[idx]
196
+ _ocr_enable = ocr_enabled_list[idx]
197
+
198
+ middle_json = pipeline_result_to_middle_json(
199
+ model_list, images_list, pdf_doc, image_writer,
200
+ _lang, _ocr_enable, p_formula_enable
201
+ )
202
+
203
+ pdf_info = middle_json["pdf_info"]
204
+ pdf_bytes = pdf_bytes_list[idx]
205
+
206
+ _process_output(
207
+ pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
208
+ md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
209
+ f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
210
+ f_make_md_mode, middle_json, model_json, is_pipeline=True
211
+ )
212
+
213
+
214
+ async def _async_process_vlm(
215
+ output_dir,
216
+ pdf_file_names,
217
+ pdf_bytes_list,
218
+ backend,
219
+ f_draw_layout_bbox,
220
+ f_draw_span_bbox,
221
+ f_dump_md,
222
+ f_dump_middle_json,
223
+ f_dump_model_output,
224
+ f_dump_orig_pdf,
225
+ f_dump_content_list,
226
+ f_make_md_mode,
227
+ server_url=None,
228
+ **kwargs,
229
+ ):
230
+ """异步处理VLM后端逻辑"""
231
+ parse_method = "vlm"
232
+ f_draw_span_bbox = False
233
+ if not backend.endswith("client"):
234
+ server_url = None
235
+
236
+ for idx, pdf_bytes in enumerate(pdf_bytes_list):
237
+ pdf_file_name = pdf_file_names[idx]
238
+ local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
239
+ image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
240
+
241
+ middle_json, infer_result = await aio_vlm_doc_analyze(
242
+ pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
243
+ )
244
+
245
+ pdf_info = middle_json["pdf_info"]
246
+
247
+ _process_output(
248
+ pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
249
+ md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
250
+ f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
251
+ f_make_md_mode, middle_json, infer_result, is_pipeline=False
252
+ )
253
+
254
+
255
+ def _process_vlm(
256
+ output_dir,
257
+ pdf_file_names,
258
+ pdf_bytes_list,
259
+ backend,
260
+ f_draw_layout_bbox,
261
+ f_draw_span_bbox,
262
+ f_dump_md,
263
+ f_dump_middle_json,
264
+ f_dump_model_output,
265
+ f_dump_orig_pdf,
266
+ f_dump_content_list,
267
+ f_make_md_mode,
268
+ server_url=None,
269
+ **kwargs,
270
+ ):
271
+ """同步处理VLM后端逻辑"""
272
+ parse_method = "vlm"
273
+ f_draw_span_bbox = False
274
+ if not backend.endswith("client"):
275
+ server_url = None
276
+
277
+ for idx, pdf_bytes in enumerate(pdf_bytes_list):
278
+ pdf_file_name = pdf_file_names[idx]
279
+ local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
280
+ image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
281
+
282
+ middle_json, infer_result = vlm_doc_analyze(
283
+ pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
284
+ )
285
+
286
+ pdf_info = middle_json["pdf_info"]
287
+
288
+ _process_output(
289
+ pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
290
+ md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
291
+ f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
292
+ f_make_md_mode, middle_json, infer_result, is_pipeline=False
293
+ )
294
+
295
+
296
+ def do_parse(
297
+ output_dir,
298
+ pdf_file_names: list[str],
299
+ pdf_bytes_list: list[bytes],
300
+ p_lang_list: list[str],
301
+ backend="pipeline",
302
+ parse_method="auto",
303
+ formula_enable=True,
304
+ table_enable=True,
305
+ server_url=None,
306
+ f_draw_layout_bbox=True,
307
+ f_draw_span_bbox=True,
308
+ f_dump_md=True,
309
+ f_dump_middle_json=True,
310
+ f_dump_model_output=True,
311
+ f_dump_orig_pdf=True,
312
+ f_dump_content_list=True,
313
+ f_make_md_mode=MakeMode.MM_MD,
314
+ start_page_id=0,
315
+ end_page_id=None,
316
+ **kwargs,
317
+ ):
318
+ # 预处理PDF字节数据
319
+ pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
320
+
321
+ if backend == "pipeline":
322
+ _process_pipeline(
323
+ output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
324
+ parse_method, formula_enable, table_enable,
325
+ f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
326
+ f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
327
+ )
328
+ else:
329
+ if backend.startswith("vlm-"):
330
+ backend = backend[4:]
331
+
332
+ os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
333
+ os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
334
+
335
+ _process_vlm(
336
+ output_dir, pdf_file_names, pdf_bytes_list, backend,
337
+ f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
338
+ f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
339
+ server_url, **kwargs,
340
+ )
341
+
342
+
343
+ async def aio_do_parse(
344
+ output_dir,
345
+ pdf_file_names: list[str],
346
+ pdf_bytes_list: list[bytes],
347
+ p_lang_list: list[str],
348
+ backend="pipeline",
349
+ parse_method="auto",
350
+ formula_enable=True,
351
+ table_enable=True,
352
+ server_url=None,
353
+ f_draw_layout_bbox=True,
354
+ f_draw_span_bbox=True,
355
+ f_dump_md=True,
356
+ f_dump_middle_json=True,
357
+ f_dump_model_output=True,
358
+ f_dump_orig_pdf=True,
359
+ f_dump_content_list=True,
360
+ f_make_md_mode=MakeMode.MM_MD,
361
+ start_page_id=0,
362
+ end_page_id=None,
363
+ **kwargs,
364
+ ):
365
+ # 预处理PDF字节数据
366
+ pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
367
+
368
+ if backend == "pipeline":
369
+ # pipeline模式暂不支持异步,使用同步处理方式
370
+ _process_pipeline(
371
+ output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
372
+ parse_method, formula_enable, table_enable,
373
+ f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
374
+ f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
375
+ )
376
+ else:
377
+ if backend.startswith("vlm-"):
378
+ backend = backend[4:]
379
+
380
+ os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
381
+ os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
382
+
383
+ await _async_process_vlm(
384
+ output_dir, pdf_file_names, pdf_bytes_list, backend,
385
+ f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
386
+ f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
387
+ server_url, **kwargs,
388
+ )
389
+
390
+
391
+
392
+ if __name__ == "__main__":
393
+ # pdf_path = "../../demo/pdfs/demo3.pdf"
394
+ pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
395
+
396
+ try:
397
+ do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
398
+ end_page_id=10,
399
+ backend='vlm-huggingface'
400
+ # backend = 'pipeline'
401
+ )
402
+ except Exception as e:
403
+ logger.exception(e)
vendor/mineru/mineru/cli/fast_api.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import os
3
+ import uvicorn
4
+ import click
5
+ from pathlib import Path
6
+ from glob import glob
7
+ from fastapi import FastAPI, UploadFile, File, Form
8
+ from fastapi.middleware.gzip import GZipMiddleware
9
+ from fastapi.responses import JSONResponse
10
+ from typing import List, Optional
11
+ from loguru import logger
12
+ from base64 import b64encode
13
+
14
+ from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
15
+ from mineru.utils.cli_parser import arg_parse
16
+ from mineru.version import __version__
17
+
18
+ app = FastAPI()
19
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
20
+
21
+ def encode_image(image_path: str) -> str:
22
+ """Encode image using base64"""
23
+ with open(image_path, "rb") as f:
24
+ return b64encode(f.read()).decode()
25
+
26
+
27
+ def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
28
+ """从结果文件中读取推理结果"""
29
+ result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
30
+ if os.path.exists(result_file_path):
31
+ with open(result_file_path, "r", encoding="utf-8") as fp:
32
+ return fp.read()
33
+ return None
34
+
35
+
36
+ @app.post(path="/file_parse",)
37
+ async def parse_pdf(
38
+ files: List[UploadFile] = File(...),
39
+ output_dir: str = Form("./output"),
40
+ lang_list: List[str] = Form(["ch"]),
41
+ backend: str = Form("pipeline"),
42
+ parse_method: str = Form("auto"),
43
+ formula_enable: bool = Form(True),
44
+ table_enable: bool = Form(True),
45
+ server_url: Optional[str] = Form(None),
46
+ return_md: bool = Form(True),
47
+ return_middle_json: bool = Form(False),
48
+ return_model_output: bool = Form(False),
49
+ return_content_list: bool = Form(False),
50
+ return_images: bool = Form(False),
51
+ start_page_id: int = Form(0),
52
+ end_page_id: int = Form(99999),
53
+ ):
54
+
55
+ # 获取命令行配置参数
56
+ config = getattr(app.state, "config", {})
57
+
58
+ try:
59
+ # 创建唯一的输出目录
60
+ unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
61
+ os.makedirs(unique_dir, exist_ok=True)
62
+
63
+ # 处理上传的PDF文件
64
+ pdf_file_names = []
65
+ pdf_bytes_list = []
66
+
67
+ for file in files:
68
+ content = await file.read()
69
+ file_path = Path(file.filename)
70
+
71
+ # 如果是图像文件或PDF,使用read_fn处理
72
+ if file_path.suffix.lower() in pdf_suffixes + image_suffixes:
73
+ # 创建临时文件以便使用read_fn
74
+ temp_path = Path(unique_dir) / file_path.name
75
+ with open(temp_path, "wb") as f:
76
+ f.write(content)
77
+
78
+ try:
79
+ pdf_bytes = read_fn(temp_path)
80
+ pdf_bytes_list.append(pdf_bytes)
81
+ pdf_file_names.append(file_path.stem)
82
+ os.remove(temp_path) # 删除临时文件
83
+ except Exception as e:
84
+ return JSONResponse(
85
+ status_code=400,
86
+ content={"error": f"Failed to load file: {str(e)}"}
87
+ )
88
+ else:
89
+ return JSONResponse(
90
+ status_code=400,
91
+ content={"error": f"Unsupported file type: {file_path.suffix}"}
92
+ )
93
+
94
+
95
+ # 设置语言列表,确保与文件数量一致
96
+ actual_lang_list = lang_list
97
+ if len(actual_lang_list) != len(pdf_file_names):
98
+ # 如果语言列表长度不匹配,使用第一个语言或默认"ch"
99
+ actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
100
+
101
+ # 调用异步处理函数
102
+ await aio_do_parse(
103
+ output_dir=unique_dir,
104
+ pdf_file_names=pdf_file_names,
105
+ pdf_bytes_list=pdf_bytes_list,
106
+ p_lang_list=actual_lang_list,
107
+ backend=backend,
108
+ parse_method=parse_method,
109
+ formula_enable=formula_enable,
110
+ table_enable=table_enable,
111
+ server_url=server_url,
112
+ f_draw_layout_bbox=False,
113
+ f_draw_span_bbox=False,
114
+ f_dump_md=return_md,
115
+ f_dump_middle_json=return_middle_json,
116
+ f_dump_model_output=return_model_output,
117
+ f_dump_orig_pdf=False,
118
+ f_dump_content_list=return_content_list,
119
+ start_page_id=start_page_id,
120
+ end_page_id=end_page_id,
121
+ **config
122
+ )
123
+
124
+ # 构建结果路径
125
+ result_dict = {}
126
+ for pdf_name in pdf_file_names:
127
+ result_dict[pdf_name] = {}
128
+ data = result_dict[pdf_name]
129
+
130
+ if backend.startswith("pipeline"):
131
+ parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
132
+ else:
133
+ parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
134
+
135
+ if os.path.exists(parse_dir):
136
+ if return_md:
137
+ data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
138
+ if return_middle_json:
139
+ data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
140
+ if return_model_output:
141
+ if backend.startswith("pipeline"):
142
+ data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
143
+ else:
144
+ data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
145
+ if return_content_list:
146
+ data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
147
+ if return_images:
148
+ image_paths = glob(f"{parse_dir}/images/*.jpg")
149
+ data["images"] = {
150
+ os.path.basename(
151
+ image_path
152
+ ): f"data:image/jpeg;base64,{encode_image(image_path)}"
153
+ for image_path in image_paths
154
+ }
155
+ return JSONResponse(
156
+ status_code=200,
157
+ content={
158
+ "backend": backend,
159
+ "version": __version__,
160
+ "results": result_dict
161
+ }
162
+ )
163
+ except Exception as e:
164
+ logger.exception(e)
165
+ return JSONResponse(
166
+ status_code=500,
167
+ content={"error": f"Failed to process file: {str(e)}"}
168
+ )
169
+
170
+
171
+ @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
172
+ @click.pass_context
173
+ @click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
174
+ @click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
175
+ @click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
176
+ def main(ctx, host, port, reload, **kwargs):
177
+
178
+ kwargs.update(arg_parse(ctx))
179
+
180
+ # 将配置参数存储到应用状态中
181
+ app.state.config = kwargs
182
+
183
+ """启动MinerU FastAPI服务器的命令行入口"""
184
+ print(f"Start MinerU FastAPI Service: http://{host}:{port}")
185
+ print("The API documentation can be accessed at the following address:")
186
+ print(f"- Swagger UI: http://{host}:{port}/docs")
187
+ print(f"- ReDoc: http://{host}:{port}/redoc")
188
+
189
+ uvicorn.run(
190
+ "mineru.cli.fast_api:app",
191
+ host=host,
192
+ port=port,
193
+ reload=reload
194
+ )
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
vendor/mineru/mineru/cli/gradio_app.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+
3
+ import base64
4
+ import os
5
+ import re
6
+ import time
7
+ import zipfile
8
+ from pathlib import Path
9
+
10
+ import click
11
+ import gradio as gr
12
+ from gradio_pdf import PDF
13
+ from loguru import logger
14
+
15
+ from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
16
+ from mineru.utils.cli_parser import arg_parse
17
+ from mineru.utils.hash_utils import str_sha256
18
+
19
+
20
+ async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
21
+ os.makedirs(output_dir, exist_ok=True)
22
+
23
+ try:
24
+ file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
25
+ pdf_data = read_fn(doc_path)
26
+ if is_ocr:
27
+ parse_method = 'ocr'
28
+ else:
29
+ parse_method = 'auto'
30
+
31
+ if backend.startswith("vlm"):
32
+ parse_method = "vlm"
33
+
34
+ local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
35
+ await aio_do_parse(
36
+ output_dir=output_dir,
37
+ pdf_file_names=[file_name],
38
+ pdf_bytes_list=[pdf_data],
39
+ p_lang_list=[language],
40
+ parse_method=parse_method,
41
+ end_page_id=end_page_id,
42
+ formula_enable=formula_enable,
43
+ table_enable=table_enable,
44
+ backend=backend,
45
+ server_url=url,
46
+ )
47
+ return local_md_dir, file_name
48
+ except Exception as e:
49
+ logger.exception(e)
50
+ return None
51
+
52
+
53
+ def compress_directory_to_zip(directory_path, output_zip_path):
54
+ """压缩指定目录到一个 ZIP 文件。
55
+
56
+ :param directory_path: 要压缩的目录路径
57
+ :param output_zip_path: 输出的 ZIP 文件路径
58
+ """
59
+ try:
60
+ with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
61
+
62
+ # 遍历目录中的所有文件和子目录
63
+ for root, dirs, files in os.walk(directory_path):
64
+ for file in files:
65
+ # 构建完整的文件路径
66
+ file_path = os.path.join(root, file)
67
+ # 计算相对路径
68
+ arcname = os.path.relpath(file_path, directory_path)
69
+ # 添加文件到 ZIP 文件
70
+ zipf.write(file_path, arcname)
71
+ return 0
72
+ except Exception as e:
73
+ logger.exception(e)
74
+ return -1
75
+
76
+
77
+ def image_to_base64(image_path):
78
+ with open(image_path, 'rb') as image_file:
79
+ return base64.b64encode(image_file.read()).decode('utf-8')
80
+
81
+
82
+ def replace_image_with_base64(markdown_text, image_dir_path):
83
+ # 匹配Markdown中的图片标签
84
+ pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
85
+
86
+ # 替换图片链接
87
+ def replace(match):
88
+ relative_path = match.group(1)
89
+ full_path = os.path.join(image_dir_path, relative_path)
90
+ base64_image = image_to_base64(full_path)
91
+ return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
92
+
93
+ # 应用替换
94
+ return re.sub(pattern, replace, markdown_text)
95
+
96
+
97
+ async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
98
+ file_path = to_pdf(file_path)
99
+ # 获取识别的md文件以及压缩包文件路径
100
+ local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
101
+ archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
102
+ zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
103
+ if zip_archive_success == 0:
104
+ logger.info('Compression successful')
105
+ else:
106
+ logger.error('Compression failed')
107
+ md_path = os.path.join(local_md_dir, file_name + '.md')
108
+ with open(md_path, 'r', encoding='utf-8') as f:
109
+ txt_content = f.read()
110
+ md_content = replace_image_with_base64(txt_content, local_md_dir)
111
+ # 返回转换后的PDF路径
112
+ new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
113
+
114
+ return md_content, txt_content, archive_zip_path, new_pdf_path
115
+
116
+
117
+ latex_delimiters = [
118
+ {'left': '$$', 'right': '$$', 'display': True},
119
+ {'left': '$', 'right': '$', 'display': False},
120
+ {'left': '\\(', 'right': '\\)', 'display': False},
121
+ {'left': '\\[', 'right': '\\]', 'display': True},
122
+ ]
123
+
124
+ header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
125
+ with open(header_path, 'r') as header_file:
126
+ header = header_file.read()
127
+
128
+
129
+ latin_lang = [
130
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr', # noqa: E126
131
+ 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
132
+ 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
133
+ 'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
134
+ ]
135
+ arabic_lang = ['ar', 'fa', 'ug', 'ur']
136
+ cyrillic_lang = [
137
+ 'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
138
+ 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
139
+ ]
140
+ east_slavic_lang = ["ru", "be", "uk"]
141
+ devanagari_lang = [
142
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
143
+ 'sa', 'bgc'
144
+ ]
145
+ other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
146
+ add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
147
+
148
+ # all_lang = ['', 'auto']
149
+ all_lang = []
150
+ # all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
151
+ all_lang.extend([*other_lang, *add_lang])
152
+
153
+
154
+ def safe_stem(file_path):
155
+ stem = Path(file_path).stem
156
+ # 只保留字母、数字、下划线和点,其他字符替换为下划线
157
+ return re.sub(r'[^\w.]', '_', stem)
158
+
159
+
160
+ def to_pdf(file_path):
161
+
162
+ if file_path is None:
163
+ return None
164
+
165
+ pdf_bytes = read_fn(file_path)
166
+
167
+ # unique_filename = f'{uuid.uuid4()}.pdf'
168
+ unique_filename = f'{safe_stem(file_path)}.pdf'
169
+
170
+ # 构建完整的文件路径
171
+ tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
172
+
173
+ # 将字节数据写入文件
174
+ with open(tmp_file_path, 'wb') as tmp_pdf_file:
175
+ tmp_pdf_file.write(pdf_bytes)
176
+
177
+ return tmp_file_path
178
+
179
+
180
+ # 更新界面函数
181
+ def update_interface(backend_choice):
182
+ if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
183
+ return gr.update(visible=False), gr.update(visible=False)
184
+ elif backend_choice in ["vlm-sglang-client"]:
185
+ return gr.update(visible=True), gr.update(visible=False)
186
+ elif backend_choice in ["pipeline"]:
187
+ return gr.update(visible=False), gr.update(visible=True)
188
+ else:
189
+ pass
190
+
191
+
192
+ @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
193
+ @click.pass_context
194
+ @click.option(
195
+ '--enable-example',
196
+ 'example_enable',
197
+ type=bool,
198
+ help="Enable example files for input."
199
+ "The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
200
+ default=True,
201
+ )
202
+ @click.option(
203
+ '--enable-sglang-engine',
204
+ 'sglang_engine_enable',
205
+ type=bool,
206
+ help="Enable SgLang engine backend for faster processing.",
207
+ default=False,
208
+ )
209
+ @click.option(
210
+ '--enable-api',
211
+ 'api_enable',
212
+ type=bool,
213
+ help="Enable gradio API for serving the application.",
214
+ default=True,
215
+ )
216
+ @click.option(
217
+ '--max-convert-pages',
218
+ 'max_convert_pages',
219
+ type=int,
220
+ help="Set the maximum number of pages to convert from PDF to Markdown.",
221
+ default=1000,
222
+ )
223
+ @click.option(
224
+ '--server-name',
225
+ 'server_name',
226
+ type=str,
227
+ help="Set the server name for the Gradio app.",
228
+ default=None,
229
+ )
230
+ @click.option(
231
+ '--server-port',
232
+ 'server_port',
233
+ type=int,
234
+ help="Set the server port for the Gradio app.",
235
+ default=None,
236
+ )
237
+ def main(ctx,
238
+ example_enable, sglang_engine_enable, api_enable, max_convert_pages,
239
+ server_name, server_port, **kwargs
240
+ ):
241
+
242
+ kwargs.update(arg_parse(ctx))
243
+
244
+ if sglang_engine_enable:
245
+ try:
246
+ print("Start init SgLang engine...")
247
+ from mineru.backend.vlm.vlm_analyze import ModelSingleton
248
+ model_singleton = ModelSingleton()
249
+ predictor = model_singleton.get_model(
250
+ "sglang-engine",
251
+ None,
252
+ None,
253
+ **kwargs
254
+ )
255
+ print("SgLang engine init successfully.")
256
+ except Exception as e:
257
+ logger.exception(e)
258
+
259
+ suffixes = pdf_suffixes + image_suffixes
260
+ with gr.Blocks() as demo:
261
+ gr.HTML(header)
262
+ with gr.Row():
263
+ with gr.Column(variant='panel', scale=5):
264
+ with gr.Row():
265
+ input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
266
+ with gr.Row():
267
+ max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
268
+ with gr.Row():
269
+ if sglang_engine_enable:
270
+ drop_list = ["pipeline", "vlm-sglang-engine"]
271
+ preferred_option = "vlm-sglang-engine"
272
+ else:
273
+ drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
274
+ preferred_option = "pipeline"
275
+ backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
276
+ with gr.Row(visible=False) as client_options:
277
+ url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
278
+ with gr.Row(equal_height=True):
279
+ with gr.Column():
280
+ gr.Markdown("**Recognition Options:**")
281
+ formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
282
+ table_enable = gr.Checkbox(label='Enable table recognition', value=True)
283
+ with gr.Column(visible=False) as ocr_options:
284
+ language = gr.Dropdown(all_lang, label='Language', value='ch')
285
+ is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
286
+ with gr.Row():
287
+ change_bu = gr.Button('Convert')
288
+ clear_bu = gr.ClearButton(value='Clear')
289
+ pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
290
+ if example_enable:
291
+ example_root = os.path.join(os.getcwd(), 'examples')
292
+ if os.path.exists(example_root):
293
+ with gr.Accordion('Examples:'):
294
+ gr.Examples(
295
+ examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
296
+ _.endswith(tuple(suffixes))],
297
+ inputs=input_file
298
+ )
299
+
300
+ with gr.Column(variant='panel', scale=5):
301
+ output_file = gr.File(label='convert result', interactive=False)
302
+ with gr.Tabs():
303
+ with gr.Tab('Markdown rendering'):
304
+ md = gr.Markdown(label='Markdown rendering', height=1100, show_copy_button=True,
305
+ latex_delimiters=latex_delimiters,
306
+ line_breaks=True)
307
+ with gr.Tab('Markdown text'):
308
+ md_text = gr.TextArea(lines=45, show_copy_button=True)
309
+
310
+ # 添加事件处理
311
+ backend.change(
312
+ fn=update_interface,
313
+ inputs=[backend],
314
+ outputs=[client_options, ocr_options],
315
+ api_name=False
316
+ )
317
+ # 添加demo.load事件,在页面加载时触发一次界面更新
318
+ demo.load(
319
+ fn=update_interface,
320
+ inputs=[backend],
321
+ outputs=[client_options, ocr_options],
322
+ api_name=False
323
+ )
324
+ clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
325
+
326
+ if api_enable:
327
+ api_name = None
328
+ else:
329
+ api_name = False
330
+
331
+ input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
332
+ change_bu.click(
333
+ fn=to_markdown,
334
+ inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
335
+ outputs=[md, md_text, output_file, pdf_show],
336
+ api_name=api_name
337
+ )
338
+
339
+ demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable)
340
+
341
+
342
+ if __name__ == '__main__':
343
+ main()
vendor/mineru/mineru/cli/models_download.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import click
5
+ import requests
6
+ from loguru import logger
7
+
8
+ from mineru.utils.enum_class import ModelPath
9
+ from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
10
+
11
+
12
+ def download_json(url):
13
+ """下载JSON文件"""
14
+ response = requests.get(url)
15
+ response.raise_for_status()
16
+ return response.json()
17
+
18
+
19
+ def download_and_modify_json(url, local_filename, modifications):
20
+ """下载JSON并修改内容"""
21
+ if os.path.exists(local_filename):
22
+ data = json.load(open(local_filename))
23
+ config_version = data.get('config_version', '0.0.0')
24
+ if config_version < '1.3.0':
25
+ data = download_json(url)
26
+ else:
27
+ data = download_json(url)
28
+
29
+ # 修改内容
30
+ for key, value in modifications.items():
31
+ if key in data:
32
+ if isinstance(data[key], dict):
33
+ # 如果是字典,合并新值
34
+ data[key].update(value)
35
+ else:
36
+ # 否则直接替换
37
+ data[key] = value
38
+
39
+ # 保存修改后的内容
40
+ with open(local_filename, 'w', encoding='utf-8') as f:
41
+ json.dump(data, f, ensure_ascii=False, indent=4)
42
+
43
+
44
+ def configure_model(model_dir, model_type):
45
+ """配置模型"""
46
+ json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
47
+ config_file_name = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
48
+ home_dir = os.path.expanduser('~')
49
+ config_file = os.path.join(home_dir, config_file_name)
50
+
51
+ json_mods = {
52
+ 'models-dir': {
53
+ f'{model_type}': model_dir
54
+ }
55
+ }
56
+
57
+ download_and_modify_json(json_url, config_file, json_mods)
58
+ logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
59
+
60
+
61
+ def download_pipeline_models():
62
+ """下载Pipeline模型"""
63
+ model_paths = [
64
+ ModelPath.doclayout_yolo,
65
+ ModelPath.yolo_v8_mfd,
66
+ ModelPath.unimernet_small,
67
+ ModelPath.pytorch_paddle,
68
+ ModelPath.layout_reader,
69
+ ModelPath.slanet_plus
70
+ ]
71
+ download_finish_path = ""
72
+ for model_path in model_paths:
73
+ logger.info(f"Downloading model: {model_path}")
74
+ download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
75
+ logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
76
+ configure_model(download_finish_path, "pipeline")
77
+
78
+
79
+ def download_vlm_models():
80
+ """下载VLM模型"""
81
+ download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
82
+ logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
83
+ configure_model(download_finish_path, "vlm")
84
+
85
+
86
+ @click.command()
87
+ @click.option(
88
+ '-s',
89
+ '--source',
90
+ 'model_source',
91
+ type=click.Choice(['huggingface', 'modelscope']),
92
+ help="""
93
+ The source of the model repository.
94
+ """,
95
+ default=None,
96
+ )
97
+ @click.option(
98
+ '-m',
99
+ '--model_type',
100
+ 'model_type',
101
+ type=click.Choice(['pipeline', 'vlm', 'all']),
102
+ help="""
103
+ The type of the model to download.
104
+ """,
105
+ default=None,
106
+ )
107
+ def download_models(model_source, model_type):
108
+ """Download MinerU model files.
109
+
110
+ Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
111
+ """
112
+ # 如果未显式指定则交互式输入下载来源
113
+ if model_source is None:
114
+ model_source = click.prompt(
115
+ "Please select the model download source: ",
116
+ type=click.Choice(['huggingface', 'modelscope']),
117
+ default='huggingface'
118
+ )
119
+
120
+ if os.getenv('MINERU_MODEL_SOURCE', None) is None:
121
+ os.environ['MINERU_MODEL_SOURCE'] = model_source
122
+
123
+ # 如果未显式指定则交互式输入模型类型
124
+ if model_type is None:
125
+ model_type = click.prompt(
126
+ "Please select the model type to download: ",
127
+ type=click.Choice(['pipeline', 'vlm', 'all']),
128
+ default='all'
129
+ )
130
+
131
+ logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
132
+
133
+ try:
134
+ if model_type == 'pipeline':
135
+ download_pipeline_models()
136
+ elif model_type == 'vlm':
137
+ download_vlm_models()
138
+ elif model_type == 'all':
139
+ download_pipeline_models()
140
+ download_vlm_models()
141
+ else:
142
+ click.echo(f"Unsupported model type: {model_type}", err=True)
143
+ sys.exit(1)
144
+
145
+ except Exception as e:
146
+ logger.exception(f"An error occurred while downloading models: {str(e)}")
147
+ sys.exit(1)
148
+
149
+ if __name__ == '__main__':
150
+ download_models()
vendor/mineru/mineru/cli/vlm_sglang_server.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ..model.vlm_sglang_model.server import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
vendor/mineru/mineru/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/data/data_reader_writer/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import DataReader, DataWriter
2
+ from .dummy import DummyDataWriter
3
+ from .filebase import FileBasedDataReader, FileBasedDataWriter
4
+ from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
5
+ from .s3 import S3DataReader, S3DataWriter
6
+
7
+ __all__ = [
8
+ "DataReader",
9
+ "DataWriter",
10
+ "FileBasedDataReader",
11
+ "FileBasedDataWriter",
12
+ "S3DataReader",
13
+ "S3DataWriter",
14
+ "MultiBucketS3DataReader",
15
+ "MultiBucketS3DataWriter",
16
+ "DummyDataWriter",
17
+ ]
vendor/mineru/mineru/data/data_reader_writer/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC, abstractmethod
3
+
4
+
5
+ class DataReader(ABC):
6
+
7
+ def read(self, path: str) -> bytes:
8
+ """Read the file.
9
+
10
+ Args:
11
+ path (str): file path to read
12
+
13
+ Returns:
14
+ bytes: the content of the file
15
+ """
16
+ return self.read_at(path)
17
+
18
+ @abstractmethod
19
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
20
+ """Read the file at offset and limit.
21
+
22
+ Args:
23
+ path (str): the file path
24
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
25
+ limit (int, optional): the length of bytes want to read. Defaults to -1.
26
+
27
+ Returns:
28
+ bytes: the content of the file
29
+ """
30
+ pass
31
+
32
+
33
+ class DataWriter(ABC):
34
+ @abstractmethod
35
+ def write(self, path: str, data: bytes) -> None:
36
+ """Write the data to the file.
37
+
38
+ Args:
39
+ path (str): the target file where to write
40
+ data (bytes): the data want to write
41
+ """
42
+ pass
43
+
44
+ def write_string(self, path: str, data: str) -> None:
45
+ """Write the data to file, the data will be encoded to bytes.
46
+
47
+ Args:
48
+ path (str): the target file where to write
49
+ data (str): the data want to write
50
+ """
51
+
52
+ def safe_encode(data: str, method: str):
53
+ try:
54
+ bit_data = data.encode(encoding=method, errors='replace')
55
+ return bit_data, True
56
+ except: # noqa
57
+ return None, False
58
+
59
+ for method in ['utf-8', 'ascii']:
60
+ bit_data, flag = safe_encode(data, method)
61
+ if flag:
62
+ self.write(path, bit_data)
63
+ break
vendor/mineru/mineru/data/data_reader_writer/dummy.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import DataWriter
2
+
3
+
4
+ class DummyDataWriter(DataWriter):
5
+ def write(self, path: str, data: bytes) -> None:
6
+ """Dummy write method that does nothing."""
7
+ pass
8
+
9
+ def write_string(self, path: str, data: str) -> None:
10
+ """Dummy write_string method that does nothing."""
11
+ pass
vendor/mineru/mineru/data/data_reader_writer/filebase.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .base import DataReader, DataWriter
4
+
5
+
6
+ class FileBasedDataReader(DataReader):
7
+ def __init__(self, parent_dir: str = ''):
8
+ """Initialized with parent_dir.
9
+
10
+ Args:
11
+ parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
12
+ """
13
+ self._parent_dir = parent_dir
14
+
15
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
16
+ """Read at offset and limit.
17
+
18
+ Args:
19
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
20
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
21
+ limit (int, optional): the length of bytes want to read. Defaults to -1.
22
+
23
+ Returns:
24
+ bytes: the content of file
25
+ """
26
+ fn_path = path
27
+ if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
28
+ fn_path = os.path.join(self._parent_dir, path)
29
+
30
+ with open(fn_path, 'rb') as f:
31
+ f.seek(offset)
32
+ if limit == -1:
33
+ return f.read()
34
+ else:
35
+ return f.read(limit)
36
+
37
+
38
+ class FileBasedDataWriter(DataWriter):
39
+ def __init__(self, parent_dir: str = '') -> None:
40
+ """Initialized with parent_dir.
41
+
42
+ Args:
43
+ parent_dir (str, optional): the parent directory that may be used within methods. Defaults to ''.
44
+ """
45
+ self._parent_dir = parent_dir
46
+
47
+ def write(self, path: str, data: bytes) -> None:
48
+ """Write file with data.
49
+
50
+ Args:
51
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
52
+ data (bytes): the data want to write
53
+ """
54
+ fn_path = path
55
+ if not os.path.isabs(fn_path) and len(self._parent_dir) > 0:
56
+ fn_path = os.path.join(self._parent_dir, path)
57
+
58
+ if not os.path.exists(os.path.dirname(fn_path)) and os.path.dirname(fn_path) != "":
59
+ os.makedirs(os.path.dirname(fn_path), exist_ok=True)
60
+
61
+ with open(fn_path, 'wb') as f:
62
+ f.write(data)
vendor/mineru/mineru/data/data_reader_writer/multi_bucket_s3.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from ..utils.exceptions import InvalidConfig, InvalidParams
3
+ from .base import DataReader, DataWriter
4
+ from ..io.s3 import S3Reader, S3Writer
5
+ from ..utils.schemas import S3Config
6
+ from ..utils.path_utils import parse_s3_range_params, parse_s3path, remove_non_official_s3_args
7
+
8
+
9
+ class MultiS3Mixin:
10
+ def __init__(self, default_prefix: str, s3_configs: list[S3Config]):
11
+ """Initialized with multiple s3 configs.
12
+
13
+ Args:
14
+ default_prefix (str): the default prefix of the relative path. for example, {some_bucket}/{some_prefix} or {some_bucket}
15
+ s3_configs (list[S3Config]): list of s3 configs, the bucket_name must be unique in the list.
16
+
17
+ Raises:
18
+ InvalidConfig: default bucket config not in s3_configs.
19
+ InvalidConfig: bucket name not unique in s3_configs.
20
+ InvalidConfig: default bucket must be provided.
21
+ """
22
+ if len(default_prefix) == 0:
23
+ raise InvalidConfig('default_prefix must be provided')
24
+
25
+ arr = default_prefix.strip('/').split('/')
26
+ self.default_bucket = arr[0]
27
+ self.default_prefix = '/'.join(arr[1:])
28
+
29
+ found_default_bucket_config = False
30
+ for conf in s3_configs:
31
+ if conf.bucket_name == self.default_bucket:
32
+ found_default_bucket_config = True
33
+ break
34
+
35
+ if not found_default_bucket_config:
36
+ raise InvalidConfig(
37
+ f'default_bucket: {self.default_bucket} config must be provided in s3_configs: {s3_configs}'
38
+ )
39
+
40
+ uniq_bucket = set([conf.bucket_name for conf in s3_configs])
41
+ if len(uniq_bucket) != len(s3_configs):
42
+ raise InvalidConfig(
43
+ f'the bucket_name in s3_configs: {s3_configs} must be unique'
44
+ )
45
+
46
+ self.s3_configs = s3_configs
47
+ self._s3_clients_h: dict = {}
48
+
49
+
50
+ class MultiBucketS3DataReader(DataReader, MultiS3Mixin):
51
+ def read(self, path: str) -> bytes:
52
+ """Read the path from s3, select diffect bucket client for each request
53
+ based on the bucket, also support range read.
54
+
55
+ Args:
56
+ path (str): the s3 path of file, the path must be in the format of s3://bucket_name/path?offset,limit.
57
+ for example: s3://bucket_name/path?0,100.
58
+
59
+ Returns:
60
+ bytes: the content of s3 file.
61
+ """
62
+ may_range_params = parse_s3_range_params(path)
63
+ if may_range_params is None or 2 != len(may_range_params):
64
+ byte_start, byte_len = 0, -1
65
+ else:
66
+ byte_start, byte_len = int(may_range_params[0]), int(may_range_params[1])
67
+ path = remove_non_official_s3_args(path)
68
+ return self.read_at(path, byte_start, byte_len)
69
+
70
+ def __get_s3_client(self, bucket_name: str):
71
+ if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
72
+ raise InvalidParams(
73
+ f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
74
+ )
75
+ if bucket_name not in self._s3_clients_h:
76
+ conf = next(
77
+ filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
78
+ )
79
+ self._s3_clients_h[bucket_name] = S3Reader(
80
+ bucket_name,
81
+ conf.access_key,
82
+ conf.secret_key,
83
+ conf.endpoint_url,
84
+ conf.addressing_style,
85
+ )
86
+ return self._s3_clients_h[bucket_name]
87
+
88
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
89
+ """Read the file with offset and limit, select diffect bucket client
90
+ for each request based on the bucket.
91
+
92
+ Args:
93
+ path (str): the file path.
94
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
95
+ limit (int, optional): the number of bytes want to read. Defaults to -1 which means infinite.
96
+
97
+ Returns:
98
+ bytes: the file content.
99
+ """
100
+ if path.startswith('s3://'):
101
+ bucket_name, path = parse_s3path(path)
102
+ s3_reader = self.__get_s3_client(bucket_name)
103
+ else:
104
+ s3_reader = self.__get_s3_client(self.default_bucket)
105
+ if self.default_prefix:
106
+ path = self.default_prefix + '/' + path
107
+ return s3_reader.read_at(path, offset, limit)
108
+
109
+
110
+ class MultiBucketS3DataWriter(DataWriter, MultiS3Mixin):
111
+ def __get_s3_client(self, bucket_name: str):
112
+ if bucket_name not in set([conf.bucket_name for conf in self.s3_configs]):
113
+ raise InvalidParams(
114
+ f'bucket name: {bucket_name} not found in s3_configs: {self.s3_configs}'
115
+ )
116
+ if bucket_name not in self._s3_clients_h:
117
+ conf = next(
118
+ filter(lambda conf: conf.bucket_name == bucket_name, self.s3_configs)
119
+ )
120
+ self._s3_clients_h[bucket_name] = S3Writer(
121
+ bucket_name,
122
+ conf.access_key,
123
+ conf.secret_key,
124
+ conf.endpoint_url,
125
+ conf.addressing_style,
126
+ )
127
+ return self._s3_clients_h[bucket_name]
128
+
129
+ def write(self, path: str, data: bytes) -> None:
130
+ """Write file with data, also select diffect bucket client for each
131
+ request based on the bucket.
132
+
133
+ Args:
134
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
135
+ data (bytes): the data want to write.
136
+ """
137
+ if path.startswith('s3://'):
138
+ bucket_name, path = parse_s3path(path)
139
+ s3_writer = self.__get_s3_client(bucket_name)
140
+ else:
141
+ s3_writer = self.__get_s3_client(self.default_bucket)
142
+ if self.default_prefix:
143
+ path = self.default_prefix + '/' + path
144
+ return s3_writer.write(path, data)
vendor/mineru/mineru/data/data_reader_writer/s3.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .multi_bucket_s3 import MultiBucketS3DataReader, MultiBucketS3DataWriter
2
+ from ..utils.schemas import S3Config
3
+
4
+
5
+ class S3DataReader(MultiBucketS3DataReader):
6
+ def __init__(
7
+ self,
8
+ default_prefix_without_bucket: str,
9
+ bucket: str,
10
+ ak: str,
11
+ sk: str,
12
+ endpoint_url: str,
13
+ addressing_style: str = 'auto',
14
+ ):
15
+ """s3 reader client.
16
+
17
+ Args:
18
+ default_prefix_without_bucket: prefix that not contains bucket
19
+ bucket (str): bucket name
20
+ ak (str): access key
21
+ sk (str): secret key
22
+ endpoint_url (str): endpoint url of s3
23
+ addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
24
+ refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
25
+ """
26
+ super().__init__(
27
+ f'{bucket}/{default_prefix_without_bucket}',
28
+ [
29
+ S3Config(
30
+ bucket_name=bucket,
31
+ access_key=ak,
32
+ secret_key=sk,
33
+ endpoint_url=endpoint_url,
34
+ addressing_style=addressing_style,
35
+ )
36
+ ],
37
+ )
38
+
39
+
40
+ class S3DataWriter(MultiBucketS3DataWriter):
41
+ def __init__(
42
+ self,
43
+ default_prefix_without_bucket: str,
44
+ bucket: str,
45
+ ak: str,
46
+ sk: str,
47
+ endpoint_url: str,
48
+ addressing_style: str = 'auto',
49
+ ):
50
+ """s3 writer client.
51
+
52
+ Args:
53
+ default_prefix_without_bucket: prefix that not contains bucket
54
+ bucket (str): bucket name
55
+ ak (str): access key
56
+ sk (str): secret key
57
+ endpoint_url (str): endpoint url of s3
58
+ addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
59
+ refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
60
+ """
61
+ super().__init__(
62
+ f'{bucket}/{default_prefix_without_bucket}',
63
+ [
64
+ S3Config(
65
+ bucket_name=bucket,
66
+ access_key=ak,
67
+ secret_key=sk,
68
+ endpoint_url=endpoint_url,
69
+ addressing_style=addressing_style,
70
+ )
71
+ ],
72
+ )
vendor/mineru/mineru/data/io/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base import IOReader, IOWriter
3
+ from .http import HttpReader, HttpWriter
4
+ from .s3 import S3Reader, S3Writer
5
+
6
+ __all__ = ['IOReader', 'IOWriter', 'HttpReader', 'HttpWriter', 'S3Reader', 'S3Writer']
vendor/mineru/mineru/data/io/base.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class IOReader(ABC):
5
+ @abstractmethod
6
+ def read(self, path: str) -> bytes:
7
+ """Read the file.
8
+
9
+ Args:
10
+ path (str): file path to read
11
+
12
+ Returns:
13
+ bytes: the content of the file
14
+ """
15
+ pass
16
+
17
+ @abstractmethod
18
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
19
+ """Read at offset and limit.
20
+
21
+ Args:
22
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
23
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
24
+ limit (int, optional): the length of bytes want to read. Defaults to -1.
25
+
26
+ Returns:
27
+ bytes: the content of file
28
+ """
29
+ pass
30
+
31
+
32
+ class IOWriter(ABC):
33
+
34
+ @abstractmethod
35
+ def write(self, path: str, data: bytes) -> None:
36
+ """Write file with data.
37
+
38
+ Args:
39
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
40
+ data (bytes): the data want to write
41
+ """
42
+ pass
vendor/mineru/mineru/data/io/http.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import io
3
+
4
+ import requests
5
+
6
+ from .base import IOReader, IOWriter
7
+
8
+
9
+ class HttpReader(IOReader):
10
+
11
+ def read(self, url: str) -> bytes:
12
+ """Read the file.
13
+
14
+ Args:
15
+ path (str): file path to read
16
+
17
+ Returns:
18
+ bytes: the content of the file
19
+ """
20
+ return requests.get(url).content
21
+
22
+ def read_at(self, path: str, offset: int = 0, limit: int = -1) -> bytes:
23
+ """Not Implemented."""
24
+ raise NotImplementedError
25
+
26
+
27
+ class HttpWriter(IOWriter):
28
+ def write(self, url: str, data: bytes) -> None:
29
+ """Write file with data.
30
+
31
+ Args:
32
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
33
+ data (bytes): the data want to write
34
+ """
35
+ files = {'file': io.BytesIO(data)}
36
+ response = requests.post(url, files=files)
37
+ assert 300 > response.status_code and response.status_code > 199
vendor/mineru/mineru/data/io/s3.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ from botocore.config import Config
3
+
4
+ from ..io.base import IOReader, IOWriter
5
+
6
+
7
+ class S3Reader(IOReader):
8
+ def __init__(
9
+ self,
10
+ bucket: str,
11
+ ak: str,
12
+ sk: str,
13
+ endpoint_url: str,
14
+ addressing_style: str = 'auto',
15
+ ):
16
+ """s3 reader client.
17
+
18
+ Args:
19
+ bucket (str): bucket name
20
+ ak (str): access key
21
+ sk (str): secret key
22
+ endpoint_url (str): endpoint url of s3
23
+ addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
24
+ refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
25
+ """
26
+ self._bucket = bucket
27
+ self._ak = ak
28
+ self._sk = sk
29
+ self._s3_client = boto3.client(
30
+ service_name='s3',
31
+ aws_access_key_id=ak,
32
+ aws_secret_access_key=sk,
33
+ endpoint_url=endpoint_url,
34
+ config=Config(
35
+ s3={'addressing_style': addressing_style},
36
+ retries={'max_attempts': 5, 'mode': 'standard'},
37
+ ),
38
+ )
39
+
40
+ def read(self, key: str) -> bytes:
41
+ """Read the file.
42
+
43
+ Args:
44
+ path (str): file path to read
45
+
46
+ Returns:
47
+ bytes: the content of the file
48
+ """
49
+ return self.read_at(key)
50
+
51
+ def read_at(self, key: str, offset: int = 0, limit: int = -1) -> bytes:
52
+ """Read at offset and limit.
53
+
54
+ Args:
55
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
56
+ offset (int, optional): the number of bytes skipped. Defaults to 0.
57
+ limit (int, optional): the length of bytes want to read. Defaults to -1.
58
+
59
+ Returns:
60
+ bytes: the content of file
61
+ """
62
+ if limit > -1:
63
+ range_header = f'bytes={offset}-{offset+limit-1}'
64
+ res = self._s3_client.get_object(
65
+ Bucket=self._bucket, Key=key, Range=range_header
66
+ )
67
+ else:
68
+ res = self._s3_client.get_object(
69
+ Bucket=self._bucket, Key=key, Range=f'bytes={offset}-'
70
+ )
71
+ return res['Body'].read()
72
+
73
+
74
+ class S3Writer(IOWriter):
75
+ def __init__(
76
+ self,
77
+ bucket: str,
78
+ ak: str,
79
+ sk: str,
80
+ endpoint_url: str,
81
+ addressing_style: str = 'auto',
82
+ ):
83
+ """s3 reader client.
84
+
85
+ Args:
86
+ bucket (str): bucket name
87
+ ak (str): access key
88
+ sk (str): secret key
89
+ endpoint_url (str): endpoint url of s3
90
+ addressing_style (str, optional): Defaults to 'auto'. Other valid options here are 'path' and 'virtual'
91
+ refer to https://boto3.amazonaws.com/v1/documentation/api/1.9.42/guide/s3.html
92
+ """
93
+ self._bucket = bucket
94
+ self._ak = ak
95
+ self._sk = sk
96
+ self._s3_client = boto3.client(
97
+ service_name='s3',
98
+ aws_access_key_id=ak,
99
+ aws_secret_access_key=sk,
100
+ endpoint_url=endpoint_url,
101
+ config=Config(
102
+ s3={'addressing_style': addressing_style},
103
+ retries={'max_attempts': 5, 'mode': 'standard'},
104
+ ),
105
+ )
106
+
107
+ def write(self, key: str, data: bytes):
108
+ """Write file with data.
109
+
110
+ Args:
111
+ path (str): the path of file, if the path is relative path, it will be joined with parent_dir.
112
+ data (bytes): the data want to write
113
+ """
114
+ self._s3_client.put_object(Bucket=self._bucket, Key=key, Body=data)
vendor/mineru/mineru/data/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/data/utils/exceptions.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+
3
+ class FileNotExisted(Exception):
4
+
5
+ def __init__(self, path):
6
+ self.path = path
7
+
8
+ def __str__(self):
9
+ return f'File {self.path} does not exist.'
10
+
11
+
12
+ class InvalidConfig(Exception):
13
+ def __init__(self, msg):
14
+ self.msg = msg
15
+
16
+ def __str__(self):
17
+ return f'Invalid config: {self.msg}'
18
+
19
+
20
+ class InvalidParams(Exception):
21
+ def __init__(self, msg):
22
+ self.msg = msg
23
+
24
+ def __str__(self):
25
+ return f'Invalid params: {self.msg}'
26
+
27
+
28
+ class EmptyData(Exception):
29
+ def __init__(self, msg):
30
+ self.msg = msg
31
+
32
+ def __str__(self):
33
+ return f'Empty data: {self.msg}'
34
+
35
+ class CUDA_NOT_AVAILABLE(Exception):
36
+ def __init__(self, msg):
37
+ self.msg = msg
38
+
39
+ def __str__(self):
40
+ return f'CUDA not available: {self.msg}'
vendor/mineru/mineru/data/utils/path_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+
3
+
4
+ def remove_non_official_s3_args(s3path):
5
+ """
6
+ example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
7
+ """
8
+ arr = s3path.split("?")
9
+ return arr[0]
10
+
11
+ def parse_s3path(s3path: str):
12
+ # from s3pathlib import S3Path
13
+ # p = S3Path(remove_non_official_s3_args(s3path))
14
+ # return p.bucket, p.key
15
+ s3path = remove_non_official_s3_args(s3path).strip()
16
+ if s3path.startswith(('s3://', 's3a://')):
17
+ prefix, path = s3path.split('://', 1)
18
+ bucket_name, key = path.split('/', 1)
19
+ return bucket_name, key
20
+ elif s3path.startswith('/'):
21
+ raise ValueError("The provided path starts with '/'. This does not conform to a valid S3 path format.")
22
+ else:
23
+ raise ValueError("Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'.")
24
+
25
+
26
+ def parse_s3_range_params(s3path: str):
27
+ """
28
+ example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
29
+ """
30
+ arr = s3path.split("?bytes=")
31
+ if len(arr) == 1:
32
+ return None
33
+ return arr[1].split(",")
vendor/mineru/mineru/data/utils/schemas.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class S3Config(BaseModel):
7
+ """S3 config
8
+ """
9
+ bucket_name: str = Field(description='s3 bucket name', min_length=1)
10
+ access_key: str = Field(description='s3 access key', min_length=1)
11
+ secret_key: str = Field(description='s3 secret key', min_length=1)
12
+ endpoint_url: str = Field(description='s3 endpoint url', min_length=1)
13
+ addressing_style: str = Field(description='s3 addressing style', default='auto', min_length=1)
14
+
15
+
16
+ class PageInfo(BaseModel):
17
+ """The width and height of page
18
+ """
19
+ w: float = Field(description='the width of page')
20
+ h: float = Field(description='the height of page')
vendor/mineru/mineru/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.
vendor/mineru/mineru/model/layout/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Opendatalab. All rights reserved.