fixed bug for lc_serve
Browse files- app_modules/init.py +2 -2
 - app_modules/llm_loader.py +4 -2
 - server.py +1 -1
 
    	
        app_modules/init.py
    CHANGED
    
    | 
         @@ -23,7 +23,7 @@ load_dotenv(found_dotenv, override=False) 
     | 
|
| 23 | 
         
             
            init_settings()
         
     | 
| 24 | 
         | 
| 25 | 
         | 
| 26 | 
         
            -
            def app_init():
         
     | 
| 27 | 
         
             
                # https://github.com/huggingface/transformers/issues/17611
         
     | 
| 28 | 
         
             
                os.environ["CURL_CA_BUNDLE"] = ""
         
     | 
| 29 | 
         | 
| 
         @@ -69,7 +69,7 @@ def app_init(): 
     | 
|
| 69 | 
         
             
                print(f"Completed in {end - start:.3f}s")
         
     | 
| 70 | 
         | 
| 71 | 
         
             
                start = timer()
         
     | 
| 72 | 
         
            -
                llm_loader = LLMLoader(llm_model_type)
         
     | 
| 73 | 
         
             
                llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
         
     | 
| 74 | 
         
             
                qa_chain = QAChain(vectorstore, llm_loader)
         
     | 
| 75 | 
         
             
                end = timer()
         
     | 
| 
         | 
|
| 23 | 
         
             
            init_settings()
         
     | 
| 24 | 
         | 
| 25 | 
         | 
| 26 | 
         
            +
            def app_init(lc_serve: bool = False):
         
     | 
| 27 | 
         
             
                # https://github.com/huggingface/transformers/issues/17611
         
     | 
| 28 | 
         
             
                os.environ["CURL_CA_BUNDLE"] = ""
         
     | 
| 29 | 
         | 
| 
         | 
|
| 69 | 
         
             
                print(f"Completed in {end - start:.3f}s")
         
     | 
| 70 | 
         | 
| 71 | 
         
             
                start = timer()
         
     | 
| 72 | 
         
            +
                llm_loader = LLMLoader(llm_model_type, lc_serve)
         
     | 
| 73 | 
         
             
                llm_loader.init(n_threds=n_threds, hf_pipeline_device_type=hf_pipeline_device_type)
         
     | 
| 74 | 
         
             
                qa_chain = QAChain(vectorstore, llm_loader)
         
     | 
| 75 | 
         
             
                end = timer()
         
     | 
    	
        app_modules/llm_loader.py
    CHANGED
    
    | 
         @@ -90,10 +90,12 @@ class LLMLoader: 
     | 
|
| 90 | 
         
             
                streamer: any
         
     | 
| 91 | 
         
             
                max_tokens_limit: int
         
     | 
| 92 | 
         | 
| 93 | 
         
            -
                def __init__( 
     | 
| 
         | 
|
| 
         | 
|
| 94 | 
         
             
                    self.llm_model_type = llm_model_type
         
     | 
| 95 | 
         
             
                    self.llm = None
         
     | 
| 96 | 
         
            -
                    self.streamer = None
         
     | 
| 97 | 
         
             
                    self.max_tokens_limit = max_tokens_limit
         
     | 
| 98 | 
         
             
                    self.search_kwargs = {"k": 4}
         
     | 
| 99 | 
         | 
| 
         | 
|
| 90 | 
         
             
                streamer: any
         
     | 
| 91 | 
         
             
                max_tokens_limit: int
         
     | 
| 92 | 
         | 
| 93 | 
         
            +
                def __init__(
         
     | 
| 94 | 
         
            +
                    self, llm_model_type, max_tokens_limit: int = 2048, lc_serve: bool = False
         
     | 
| 95 | 
         
            +
                ):
         
     | 
| 96 | 
         
             
                    self.llm_model_type = llm_model_type
         
     | 
| 97 | 
         
             
                    self.llm = None
         
     | 
| 98 | 
         
            +
                    self.streamer = None if lc_serve else TextIteratorStreamer("")
         
     | 
| 99 | 
         
             
                    self.max_tokens_limit = max_tokens_limit
         
     | 
| 100 | 
         
             
                    self.search_kwargs = {"k": 4}
         
     | 
| 101 | 
         | 
    	
        server.py
    CHANGED
    
    | 
         @@ -11,7 +11,7 @@ from app_modules.init import app_init 
     | 
|
| 11 | 
         
             
            from app_modules.llm_chat_chain import ChatChain
         
     | 
| 12 | 
         
             
            from app_modules.utils import print_llm_response
         
     | 
| 13 | 
         | 
| 14 | 
         
            -
            llm_loader, qa_chain = app_init()
         
     | 
| 15 | 
         | 
| 16 | 
         
             
            chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
         
     | 
| 17 | 
         | 
| 
         | 
|
| 11 | 
         
             
            from app_modules.llm_chat_chain import ChatChain
         
     | 
| 12 | 
         
             
            from app_modules.utils import print_llm_response
         
     | 
| 13 | 
         | 
| 14 | 
         
            +
            llm_loader, qa_chain = app_init(True)
         
     | 
| 15 | 
         | 
| 16 | 
         
             
            chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
         
     | 
| 17 | 
         |