File size: 4,064 Bytes
9d615c0
 
 
 
 
 
 
 
71b8ea3
 
9d615c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8cabdd
9d615c0
 
 
 
 
 
 
 
 
ce65deb
 
 
 
 
 
befaba8
ea855b6
ad2034b
 
 
 
 
ce65deb
9d615c0
 
 
 
 
ea855b6
7afd0f0
f15f01c
 
 
 
 
 
 
 
 
 
 
 
 
7afd0f0
 
 
ea855b6
 
9d615c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""This module contains functions for loading a ConversationalRetrievalChain"""

import logging

import wandb
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
## deprectated from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from prompts import load_chat_prompt


logger = logging.getLogger(__name__)


def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
    """Load a vector store from a Weights & Biases artifact
    Args:
        run (wandb.run): An active Weights & Biases run
        openai_api_key (str): The OpenAI API key to use for embedding
    Returns:
        Chroma: A chroma vector store object
    """
    # load vector store artifact
    vector_store_artifact_dir = wandb_run.use_artifact(
        wandb_run.config.vector_store_artifact, type="search_index"
    ).download()
    embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
    # load vector store
    vector_store = Chroma(
        embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
    )

    return vector_store


def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
    """Load a ConversationalQA chain from a config and a vector store
    Args:
        wandb_run (wandb.run): An active Weights & Biases run
        vector_store (Chroma): A Chroma vector store object
        openai_api_key (str): The OpenAI API key to use for embedding
    Returns:
        ConversationalRetrievalChain: A ConversationalRetrievalChain object
    """
    retriever = vector_store.as_retriever()
    llm = ChatOpenAI(
        openai_api_key=openai_api_key,
        model_name=wandb_run.config.model_name,
        temperature=wandb_run.config.chat_temperature,
        max_retries=wandb_run.config.max_fallback_retries,
    )
    chat_prompt_dir = wandb_run.use_artifact(
        wandb_run.config.chat_prompt_artifact, type="prompt"
    ).download()
    qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json")
    
    print ( '\\n===================\\nqa_prompt = ', qa_prompt)
    
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        combine_docs_chain_kwargs={"prompt": qa_prompt},
        return_source_documents=True,
    )
    return qa_chain


def get_answer(
    chain: ConversationalRetrievalChain,
    question: str,
    chat_history: list[tuple[str, str]],
    wandb_run: wandb.run
):
    """Get an answer from a ConversationalRetrievalChain
    Args:
        chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
        question (str): The question to ask
        chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
    Returns:
        str: The answer to the question
    """
    # Define logging configuration
    logging.basicConfig(filename='user_input.log', level=logging.INFO,
                    format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    
    # Log user question
    logging.info(f"User question: {question}")

    

    wandb.log({"question": question })
        
    # Log training progress
    
    
    result = chain(
        inputs={"question": question, "chat_history": chat_history},
        return_only_outputs=True,
    )
    response = f"Answer:\t{result['answer']}"

    print( "file name"+ wandb_run.config.log_file)
    
    f_name = wandb_run.config.log_file
    
    if isinstance(f_name, str) and f_name:
        f_name = pathlib.Path(f_name)

        with open(f_name, "w") as file1:
    # Writing data to a file
        file1.write("Hello \n")
        file1.writelines(L)
    
    #if f_name and f_name.is_file():
    #    ret = f_name.write("r"))

          # if f_name and f_name.is_file():
        ##template = json.load(f_name.open("r"))
    print("File writing complete."+"quest = "+question+" answer : "+ result['answer'])
    
    return response