Rohan Kataria
new change
3b75bce
raw
history blame
1.74 kB
import os
from typing import Any, Optional, Tuple
from langchain.chains import ConversationChain
from langchain.llms import HuggingFaceHub
from langchain.llms import OpenAI
from threading import Lock
def load_chain_openai(api_key: str):
os.environ["OPENAI_API_KEY"] = api_key
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
os.environ["OPENAI_API_KEY"] = ""
return chain
def load_chain_falcon(api_key: str):
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
llm = HuggingFaceHub(repo_id="tiiuae/falcon-7b-instruct", model_kwargs={"temperature": 0.9})
chain = ConversationChain(llm=llm)
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""
return chain
class ChatWrapper:
def __init__(self, chain_type: str, api_key: str = ''):
self.api_key = api_key
self.chain_type = chain_type
self.history = []
self.lock = Lock()
if self.api_key:
if chain_type == 'openai':
self.chain = load_chain_openai(self.api_key)
elif chain_type == 'falcon':
self.chain = load_chain_falcon(self.api_key)
else:
raise ValueError(f'Invalid chain_type: {chain_type}')
else:
self.chain = None
def __call__(self, inp: str):
self.lock.acquire()
try:
if self.chain is None:
self.history.append((inp, "Please add your API key to proceed."))
return self.history
output = self.chain.run(input=inp)
self.history.append((inp, output))
except Exception as e:
self.history.append((inp, f"An error occurred: {e}"))
self.lock.release()
return self.history