KevinHuSh
commited on
Commit
·
f666f56
1
Parent(s):
4873964
fix user login issue (#85)
Browse files- api/apps/user_app.py +47 -60
- api/db/__init__.py +0 -1
- api/db/db_models.py +1 -1
- api/db/init_data.py +2 -2
- api/db/services/user_service.py +10 -3
- api/settings.py +1 -1
- deepdoc/parser/pdf_parser.py +1 -1
- deepdoc/vision/layout_recognizer.py +1 -2
- deepdoc/vision/table_structure_recognizer.py +1 -2
- rag/app/manual.py +6 -0
- rag/app/naive.py +29 -6
- rag/svr/task_executor.py +2 -2
api/apps/user_app.py
CHANGED
|
@@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse
|
|
| 33 |
|
| 34 |
@manager.route('/login', methods=['POST', 'GET'])
|
| 35 |
def login():
|
| 36 |
-
userinfo = None
|
| 37 |
login_channel = "password"
|
| 38 |
-
if
|
| 39 |
-
login_channel = session["access_token_from"]
|
| 40 |
-
if session["access_token_from"] == "github":
|
| 41 |
-
userinfo = user_info_from_github(session["access_token"])
|
| 42 |
-
elif not request.json:
|
| 43 |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
| 44 |
retmsg='Unautherized!')
|
| 45 |
|
| 46 |
-
email = request.json.get('email'
|
| 47 |
users = UserService.query(email=email)
|
| 48 |
-
if not users:
|
| 49 |
-
if request.json is not None:
|
| 50 |
-
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
| 51 |
-
avatar = ""
|
| 52 |
-
try:
|
| 53 |
-
avatar = download_img(userinfo["avatar_url"])
|
| 54 |
-
except Exception as e:
|
| 55 |
-
stat_logger.exception(e)
|
| 56 |
-
user_id = get_uuid()
|
| 57 |
-
try:
|
| 58 |
-
users = user_register(user_id, {
|
| 59 |
-
"access_token": session["access_token"],
|
| 60 |
-
"email": userinfo["email"],
|
| 61 |
-
"avatar": avatar,
|
| 62 |
-
"nickname": userinfo["login"],
|
| 63 |
-
"login_channel": login_channel,
|
| 64 |
-
"last_login_time": get_format_time(),
|
| 65 |
-
"is_superuser": False,
|
| 66 |
-
})
|
| 67 |
-
if not users: raise Exception('Register user failure.')
|
| 68 |
-
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 69 |
-
user = users[0]
|
| 70 |
-
login_user(user)
|
| 71 |
-
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
| 72 |
-
except Exception as e:
|
| 73 |
-
rollback_user_registration(user_id)
|
| 74 |
-
stat_logger.exception(e)
|
| 75 |
-
return server_error_response(e)
|
| 76 |
-
elif not request.json:
|
| 77 |
-
login_user(users[0])
|
| 78 |
-
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
|
| 79 |
|
| 80 |
password = request.json.get('password')
|
| 81 |
try:
|
|
@@ -97,28 +62,50 @@ def login():
|
|
| 97 |
|
| 98 |
@manager.route('/github_callback', methods=['GET'])
|
| 99 |
def github_callback():
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
retmsg=res["error_description"])
|
| 111 |
-
|
| 112 |
-
if "user:email" not in res["scope"].split(","):
|
| 113 |
-
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
| 114 |
-
|
| 115 |
-
session["access_token"] = res["access_token"]
|
| 116 |
-
session["access_token_from"] = "github"
|
| 117 |
-
return redirect(url_for("user.login"), code=307)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
def user_info_from_github(access_token):
|
|
@@ -208,7 +195,7 @@ def user_register(user_id, user):
|
|
| 208 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
| 209 |
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
| 210 |
|
| 211 |
-
if not UserService.
|
| 212 |
TenantService.insert(**tenant)
|
| 213 |
UserTenantService.insert(**usr_tenant)
|
| 214 |
TenantLLMService.insert_many(tenant_llm)
|
|
|
|
| 33 |
|
| 34 |
@manager.route('/login', methods=['POST', 'GET'])
|
| 35 |
def login():
|
|
|
|
| 36 |
login_channel = "password"
|
| 37 |
+
if not request.json:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
| 39 |
retmsg='Unautherized!')
|
| 40 |
|
| 41 |
+
email = request.json.get('email', "")
|
| 42 |
users = UserService.query(email=email)
|
| 43 |
+
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
password = request.json.get('password')
|
| 46 |
try:
|
|
|
|
| 62 |
|
| 63 |
@manager.route('/github_callback', methods=['GET'])
|
| 64 |
def github_callback():
|
| 65 |
+
import requests
|
| 66 |
+
res = requests.post(GITHUB_OAUTH.get("url"), data={
|
| 67 |
+
"client_id": GITHUB_OAUTH.get("client_id"),
|
| 68 |
+
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
| 69 |
+
"code": request.args.get('code')
|
| 70 |
+
}, headers={"Accept": "application/json"})
|
| 71 |
+
res = res.json()
|
| 72 |
+
if "error" in res:
|
| 73 |
+
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
|
| 74 |
+
retmsg=res["error_description"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
if "user:email" not in res["scope"].split(","):
|
| 77 |
+
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
|
| 78 |
+
|
| 79 |
+
session["access_token"] = res["access_token"]
|
| 80 |
+
session["access_token_from"] = "github"
|
| 81 |
+
userinfo = user_info_from_github(session["access_token"])
|
| 82 |
+
users = UserService.query(email=userinfo["email"])
|
| 83 |
+
user_id = get_uuid()
|
| 84 |
+
if not users:
|
| 85 |
+
try:
|
| 86 |
+
try:
|
| 87 |
+
avatar = download_img(userinfo["avatar_url"])
|
| 88 |
+
except Exception as e:
|
| 89 |
+
stat_logger.exception(e)
|
| 90 |
+
avatar = ""
|
| 91 |
+
users = user_register(user_id, {
|
| 92 |
+
"access_token": session["access_token"],
|
| 93 |
+
"email": userinfo["email"],
|
| 94 |
+
"avatar": avatar,
|
| 95 |
+
"nickname": userinfo["login"],
|
| 96 |
+
"login_channel": "github",
|
| 97 |
+
"last_login_time": get_format_time(),
|
| 98 |
+
"is_superuser": False,
|
| 99 |
+
})
|
| 100 |
+
if not users: raise Exception('Register user failure.')
|
| 101 |
+
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 102 |
+
user = users[0]
|
| 103 |
+
login_user(user)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
rollback_user_registration(user_id)
|
| 106 |
+
stat_logger.exception(e)
|
| 107 |
+
|
| 108 |
+
return redirect("/knowledge")
|
| 109 |
|
| 110 |
|
| 111 |
def user_info_from_github(access_token):
|
|
|
|
| 195 |
for llm in LLMService.query(fid=LLM_FACTORY):
|
| 196 |
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
|
| 197 |
|
| 198 |
+
if not UserService.save(**user):return
|
| 199 |
TenantService.insert(**tenant)
|
| 200 |
UserTenantService.insert(**usr_tenant)
|
| 201 |
TenantLLMService.insert_many(tenant_llm)
|
api/db/__init__.py
CHANGED
|
@@ -69,7 +69,6 @@ class TaskStatus(StrEnum):
|
|
| 69 |
|
| 70 |
|
| 71 |
class ParserType(StrEnum):
|
| 72 |
-
GENERAL = "general"
|
| 73 |
PRESENTATION = "presentation"
|
| 74 |
LAWS = "laws"
|
| 75 |
MANUAL = "manual"
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
class ParserType(StrEnum):
|
|
|
|
| 72 |
PRESENTATION = "presentation"
|
| 73 |
LAWS = "laws"
|
| 74 |
MANUAL = "manual"
|
api/db/db_models.py
CHANGED
|
@@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel):
|
|
| 475 |
similarity_threshold = FloatField(default=0.2)
|
| 476 |
vector_similarity_weight = FloatField(default=0.3)
|
| 477 |
|
| 478 |
-
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.
|
| 479 |
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
| 480 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
| 481 |
|
|
|
|
| 475 |
similarity_threshold = FloatField(default=0.2)
|
| 476 |
vector_similarity_weight = FloatField(default=0.3)
|
| 477 |
|
| 478 |
+
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
|
| 479 |
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
| 480 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
| 481 |
|
api/db/init_data.py
CHANGED
|
@@ -30,7 +30,7 @@ def init_superuser():
|
|
| 30 |
"password": "admin",
|
| 31 |
"nickname": "admin",
|
| 32 |
"is_superuser": True,
|
| 33 |
-
"email": "
|
| 34 |
"creator": "system",
|
| 35 |
"status": "1",
|
| 36 |
}
|
|
@@ -61,7 +61,7 @@ def init_superuser():
|
|
| 61 |
TenantService.insert(**tenant)
|
| 62 |
UserTenantService.insert(**usr_tenant)
|
| 63 |
TenantLLMService.insert_many(tenant_llm)
|
| 64 |
-
print("【INFO】Super user initialized. \033[
|
| 65 |
|
| 66 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
| 67 |
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
|
|
|
| 30 |
"password": "admin",
|
| 31 |
"nickname": "admin",
|
| 32 |
"is_superuser": True,
|
| 33 |
+
"email": "admin@ragflow.io",
|
| 34 |
"creator": "system",
|
| 35 |
"status": "1",
|
| 36 |
}
|
|
|
|
| 61 |
TenantService.insert(**tenant)
|
| 62 |
UserTenantService.insert(**usr_tenant)
|
| 63 |
TenantLLMService.insert_many(tenant_llm)
|
| 64 |
+
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
|
| 65 |
|
| 66 |
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
| 67 |
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
api/db/services/user_service.py
CHANGED
|
@@ -13,6 +13,8 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
|
|
|
| 16 |
import peewee
|
| 17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 18 |
|
|
@@ -20,7 +22,7 @@ from api.db import UserTenantRole
|
|
| 20 |
from api.db.db_models import DB, UserTenant
|
| 21 |
from api.db.db_models import User, Tenant
|
| 22 |
from api.db.services.common_service import CommonService
|
| 23 |
-
from api.utils import get_uuid, get_format_time
|
| 24 |
from api.db import StatusEnum
|
| 25 |
|
| 26 |
|
|
@@ -53,6 +55,11 @@ class UserService(CommonService):
|
|
| 53 |
kwargs["id"] = get_uuid()
|
| 54 |
if "password" in kwargs:
|
| 55 |
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
obj = cls.model(**kwargs).save(force_insert=True)
|
| 57 |
return obj
|
| 58 |
|
|
@@ -66,10 +73,10 @@ class UserService(CommonService):
|
|
| 66 |
@classmethod
|
| 67 |
@DB.connection_context()
|
| 68 |
def update_user(cls, user_id, user_dict):
|
| 69 |
-
date_time = get_format_time()
|
| 70 |
with DB.atomic():
|
| 71 |
if user_dict:
|
| 72 |
-
user_dict["update_time"] =
|
|
|
|
| 73 |
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
| 74 |
|
| 75 |
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
import peewee
|
| 19 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 20 |
|
|
|
|
| 22 |
from api.db.db_models import DB, UserTenant
|
| 23 |
from api.db.db_models import User, Tenant
|
| 24 |
from api.db.services.common_service import CommonService
|
| 25 |
+
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
| 26 |
from api.db import StatusEnum
|
| 27 |
|
| 28 |
|
|
|
|
| 55 |
kwargs["id"] = get_uuid()
|
| 56 |
if "password" in kwargs:
|
| 57 |
kwargs["password"] = generate_password_hash(str(kwargs["password"]))
|
| 58 |
+
|
| 59 |
+
kwargs["create_time"] = current_timestamp()
|
| 60 |
+
kwargs["create_date"] = datetime_format(datetime.now())
|
| 61 |
+
kwargs["update_time"] = current_timestamp()
|
| 62 |
+
kwargs["update_date"] = datetime_format(datetime.now())
|
| 63 |
obj = cls.model(**kwargs).save(force_insert=True)
|
| 64 |
return obj
|
| 65 |
|
|
|
|
| 73 |
@classmethod
|
| 74 |
@DB.connection_context()
|
| 75 |
def update_user(cls, user_id, user_dict):
|
|
|
|
| 76 |
with DB.atomic():
|
| 77 |
if user_dict:
|
| 78 |
+
user_dict["update_time"] = current_timestamp()
|
| 79 |
+
user_dict["update_date"] = datetime_format(datetime.now())
|
| 80 |
cls.model.update(user_dict).where(cls.model.id == user_id).execute()
|
| 81 |
|
| 82 |
|
api/settings.py
CHANGED
|
@@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
|
|
| 76 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
| 77 |
|
| 78 |
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
| 79 |
-
PARSERS = LLM.get("parsers", "
|
| 80 |
|
| 81 |
# distribution
|
| 82 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
|
|
|
| 76 |
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
|
| 77 |
|
| 78 |
API_KEY = LLM.get("api_key", "infiniflow API Key")
|
| 79 |
+
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
|
| 80 |
|
| 81 |
# distribution
|
| 82 |
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -25,7 +25,7 @@ class HuParser:
|
|
| 25 |
def __init__(self):
|
| 26 |
self.ocr = OCR()
|
| 27 |
if not hasattr(self, "model_speciess"):
|
| 28 |
-
self.model_speciess = ParserType.
|
| 29 |
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
| 30 |
self.tbl_det = TableStructureRecognizer()
|
| 31 |
|
|
|
|
| 25 |
def __init__(self):
|
| 26 |
self.ocr = OCR()
|
| 27 |
if not hasattr(self, "model_speciess"):
|
| 28 |
+
self.model_speciess = ParserType.NAIVE.value
|
| 29 |
self.layouter = LayoutRecognizer("layout."+self.model_speciess)
|
| 30 |
self.tbl_det = TableStructureRecognizer()
|
| 31 |
|
deepdoc/vision/layout_recognizer.py
CHANGED
|
@@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer):
|
|
| 34 |
"Equation",
|
| 35 |
]
|
| 36 |
def __init__(self, domain):
|
| 37 |
-
super().__init__(self.labels, domain,
|
| 38 |
-
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 39 |
|
| 40 |
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
| 41 |
def __is_garbage(b):
|
|
|
|
| 34 |
"Equation",
|
| 35 |
]
|
| 36 |
def __init__(self, domain):
|
| 37 |
+
super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
|
|
| 38 |
|
| 39 |
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
| 40 |
def __is_garbage(b):
|
deepdoc/vision/table_structure_recognizer.py
CHANGED
|
@@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer):
|
|
| 33 |
]
|
| 34 |
|
| 35 |
def __init__(self):
|
| 36 |
-
super().__init__(self.labels, "tsr",
|
| 37 |
-
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
| 38 |
|
| 39 |
def __call__(self, images, thr=0.2):
|
| 40 |
tbls = super().__call__(images, thr)
|
|
|
|
| 33 |
]
|
| 34 |
|
| 35 |
def __init__(self):
|
| 36 |
+
super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
|
|
|
| 37 |
|
| 38 |
def __call__(self, images, thr=0.2):
|
| 39 |
tbls = super().__call__(images, thr)
|
rag/app/manual.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
| 1 |
import copy
|
| 2 |
import re
|
|
|
|
|
|
|
| 3 |
from rag.nlp import huqie, tokenize
|
| 4 |
from deepdoc.parser import PdfParser
|
| 5 |
from rag.utils import num_tokens_from_string
|
| 6 |
|
| 7 |
|
| 8 |
class Pdf(PdfParser):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def __call__(self, filename, binary=None, from_page=0,
|
| 10 |
to_page=100000, zoomin=3, callback=None):
|
| 11 |
self.__images__(
|
|
|
|
| 1 |
import copy
|
| 2 |
import re
|
| 3 |
+
|
| 4 |
+
from api.db import ParserType
|
| 5 |
from rag.nlp import huqie, tokenize
|
| 6 |
from deepdoc.parser import PdfParser
|
| 7 |
from rag.utils import num_tokens_from_string
|
| 8 |
|
| 9 |
|
| 10 |
class Pdf(PdfParser):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.model_speciess = ParserType.MANUAL.value
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
def __call__(self, filename, binary=None, from_page=0,
|
| 16 |
to_page=100000, zoomin=3, callback=None):
|
| 17 |
self.__images__(
|
rag/app/naive.py
CHANGED
|
@@ -30,11 +30,21 @@ class Pdf(PdfParser):
|
|
| 30 |
|
| 31 |
from timeit import default_timer as timer
|
| 32 |
start = timer()
|
|
|
|
| 33 |
self._layouts_rec(zoomin)
|
| 34 |
-
callback(0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
| 36 |
-
self._naive_vertical_merge()
|
| 37 |
-
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
|
| 38 |
|
| 39 |
|
| 40 |
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
|
@@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
| 44 |
Successive text will be sliced into pieces using 'delimiter'.
|
| 45 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
| 46 |
"""
|
|
|
|
|
|
|
| 47 |
doc = {
|
| 48 |
"docnm_kwd": filename,
|
| 49 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
| 50 |
}
|
| 51 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
|
|
|
| 52 |
pdf_parser = None
|
| 53 |
sections = []
|
| 54 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
|
@@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
| 58 |
callback(0.8, "Finish parsing.")
|
| 59 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
| 60 |
pdf_parser = Pdf()
|
| 61 |
-
sections = pdf_parser(filename if not binary else binary,
|
| 62 |
from_page=from_page, to_page=to_page, callback=callback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
| 64 |
callback(0.1, "Start to parse.")
|
| 65 |
txt = ""
|
|
@@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
|
| 79 |
|
| 80 |
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
| 81 |
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
| 82 |
-
|
| 83 |
-
res = []
|
| 84 |
# wrap up to es documents
|
| 85 |
for ck in cks:
|
| 86 |
print("--", ck)
|
|
|
|
| 30 |
|
| 31 |
from timeit import default_timer as timer
|
| 32 |
start = timer()
|
| 33 |
+
start = timer()
|
| 34 |
self._layouts_rec(zoomin)
|
| 35 |
+
callback(0.5, "Layout analysis finished.")
|
| 36 |
+
print("paddle layouts:", timer() - start)
|
| 37 |
+
self._table_transformer_job(zoomin)
|
| 38 |
+
callback(0.7, "Table analysis finished.")
|
| 39 |
+
self._text_merge()
|
| 40 |
+
self._concat_downward(concat_between_pages=False)
|
| 41 |
+
self._filter_forpages()
|
| 42 |
+
callback(0.77, "Text merging finished")
|
| 43 |
+
tbls = self._extract_table_figure(True, zoomin, False)
|
| 44 |
+
|
| 45 |
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
|
| 46 |
+
#self._naive_vertical_merge()
|
| 47 |
+
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
|
| 48 |
|
| 49 |
|
| 50 |
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
|
|
|
| 54 |
Successive text will be sliced into pieces using 'delimiter'.
|
| 55 |
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
|
| 56 |
"""
|
| 57 |
+
|
| 58 |
+
eng = lang.lower() == "english"#is_english(cks)
|
| 59 |
doc = {
|
| 60 |
"docnm_kwd": filename,
|
| 61 |
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
| 62 |
}
|
| 63 |
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
|
| 64 |
+
res = []
|
| 65 |
pdf_parser = None
|
| 66 |
sections = []
|
| 67 |
if re.search(r"\.docx?$", filename, re.IGNORECASE):
|
|
|
|
| 71 |
callback(0.8, "Finish parsing.")
|
| 72 |
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
| 73 |
pdf_parser = Pdf()
|
| 74 |
+
sections, tbls = pdf_parser(filename if not binary else binary,
|
| 75 |
from_page=from_page, to_page=to_page, callback=callback)
|
| 76 |
+
# add tables
|
| 77 |
+
for img, rows in tbls:
|
| 78 |
+
bs = 10
|
| 79 |
+
de = ";" if eng else ";"
|
| 80 |
+
for i in range(0, len(rows), bs):
|
| 81 |
+
d = copy.deepcopy(doc)
|
| 82 |
+
r = de.join(rows[i:i + bs])
|
| 83 |
+
r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
|
| 84 |
+
tokenize(d, r, eng)
|
| 85 |
+
d["image"] = img
|
| 86 |
+
res.append(d)
|
| 87 |
elif re.search(r"\.txt$", filename, re.IGNORECASE):
|
| 88 |
callback(0.1, "Start to parse.")
|
| 89 |
txt = ""
|
|
|
|
| 103 |
|
| 104 |
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
|
| 105 |
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
|
| 106 |
+
|
|
|
|
| 107 |
# wrap up to es documents
|
| 108 |
for ck in cks:
|
| 109 |
print("--", ck)
|
rag/svr/task_executor.py
CHANGED
|
@@ -37,7 +37,7 @@ from rag.nlp import search
|
|
| 37 |
from io import BytesIO
|
| 38 |
import pandas as pd
|
| 39 |
|
| 40 |
-
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
|
| 41 |
|
| 42 |
from api.db import LLMType, ParserType
|
| 43 |
from api.db.services.document_service import DocumentService
|
|
@@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory
|
|
| 48 |
BATCH_SIZE = 64
|
| 49 |
|
| 50 |
FACTORY = {
|
| 51 |
-
ParserType.
|
| 52 |
ParserType.PAPER.value: paper,
|
| 53 |
ParserType.BOOK.value: book,
|
| 54 |
ParserType.PRESENTATION.value: presentation,
|
|
|
|
| 37 |
from io import BytesIO
|
| 38 |
import pandas as pd
|
| 39 |
|
| 40 |
+
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
|
| 41 |
|
| 42 |
from api.db import LLMType, ParserType
|
| 43 |
from api.db.services.document_service import DocumentService
|
|
|
|
| 48 |
BATCH_SIZE = 64
|
| 49 |
|
| 50 |
FACTORY = {
|
| 51 |
+
ParserType.NAIVE.value: naive,
|
| 52 |
ParserType.PAPER.value: paper,
|
| 53 |
ParserType.BOOK.value: book,
|
| 54 |
ParserType.PRESENTATION.value: presentation,
|