isayahc commited on
Commit
1a34160
·
verified ·
1 Parent(s): 5d7a6c0

changed embedding model

Browse files
Files changed (2) hide show
  1. app.py +90 -23
  2. requirements.txt +257 -0
app.py CHANGED
@@ -14,10 +14,22 @@ from langchain.vectorstores import Chroma
14
  from langchain.chains import RetrievalQA
15
 
16
  from langchain.document_loaders import PyPDFLoader
 
 
 
 
 
17
 
18
 
19
  from langchain.embeddings import HuggingFaceHubEmbeddings, OpenAIEmbeddings
20
 
 
 
 
 
 
 
 
21
 
22
  text_splitter = CharacterTextSplitter(chunk_size=350, chunk_overlap=0)
23
 
@@ -27,6 +39,12 @@ flan_ul2 = OpenAI()
27
  global qa
28
 
29
  # embeddings = HuggingFaceHubEmbeddings()
 
 
 
 
 
 
30
 
31
 
32
 
@@ -34,48 +52,97 @@ global qa
34
  def loading_pdf():
35
  return "Loading..."
36
  def pdf_changes(pdf_doc):
37
- embeddings = OpenAIEmbeddings()
 
 
 
 
 
 
 
38
  loader = PyPDFLoader(pdf_doc.name)
39
  documents = loader.load()
40
  texts = text_splitter.split_documents(documents)
41
  db = Chroma.from_documents(texts, embeddings)
42
  retriever = db.as_retriever()
 
 
43
 
44
- prompt_template = """You have been given a pdf or pdfs. You must search these pdfs.
45
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
46
- Only answer the question.
47
 
48
- {context}
49
 
50
- Question: {query}
51
- Answer:"""
52
- PROMPT = PromptTemplate(
53
- template=prompt_template, input_variables=["context", "question"]
54
- )
55
- chain_type_kwargs = {"prompt": PROMPT}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  global qa
57
- qa = RetrievalQA.from_chain_type(
58
- llm=flan_ul2,
59
- chain_type="stuff",
60
- retriever=retriever,
61
- return_source_documents=True,
62
- chain_type_kwargs=chain_type_kwargs,
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  return "Ready"
65
 
66
  def add_text(history, text):
67
  history = history + [(text, None)]
68
  return history, ""
69
 
 
 
 
 
 
70
  def bot(history):
71
- response = infer(history[-1][0])
72
- history[-1][1] = response['result']
73
- return history
 
74
 
75
- def infer(question):
76
 
77
  query = question
78
- result = qa({"query": query})
 
 
79
 
80
  return result
81
 
 
14
  from langchain.chains import RetrievalQA
15
 
16
  from langchain.document_loaders import PyPDFLoader
17
+ from langchain.memory import VectorStoreRetrieverMemory
18
+
19
+ from langchain.chains import RetrievalQAWithSourcesChain
20
+ from langchain.memory import ConversationBufferMemory
21
+ from langchain.embeddings import CohereEmbeddings
22
 
23
 
24
  from langchain.embeddings import HuggingFaceHubEmbeddings, OpenAIEmbeddings
25
 
26
+ import dotenv
27
+
28
+ import os
29
+
30
+ dotenv.load_dotenv()
31
+
32
+
33
 
34
  text_splitter = CharacterTextSplitter(chunk_size=350, chunk_overlap=0)
35
 
 
39
  global qa
40
 
41
  # embeddings = HuggingFaceHubEmbeddings()
42
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
43
+ embeddings = CohereEmbeddings(
44
+ model="embed-english-light-v3.0",
45
+ cohere_api_key=COHERE_API_KEY
46
+ )
47
+
48
 
49
 
50
 
 
52
  def loading_pdf():
53
  return "Loading..."
54
  def pdf_changes(pdf_doc):
55
+ # embeddings = OpenAIEmbeddings()
56
+ # embeddings = HuggingFaceHubEmbeddings()
57
+
58
+ embeddings = CohereEmbeddings(
59
+ model="embed-english-light-v3.0",
60
+ # cohere_api_key=COHERE_API_KEY
61
+ )
62
+
63
  loader = PyPDFLoader(pdf_doc.name)
64
  documents = loader.load()
65
  texts = text_splitter.split_documents(documents)
66
  db = Chroma.from_documents(texts, embeddings)
67
  retriever = db.as_retriever()
68
+ # memory = VectorStoreRetrieverMemory(retriever=retriever)
69
+ memory = ConversationBufferMemory(memory_key="chat_history", input_key="human_input")
70
 
71
+ # prompt_template = """You have been given a pdf or pdfs. You must search these pdfs.
72
+ # If you don't know the answer, just say that you don't know, don't try to make up an answer.
73
+ # Only answer the question.
74
 
 
75
 
76
+ # Question: {query}
77
+ # Answer:"""
78
+ # PROMPT = PromptTemplate(
79
+ # template=prompt_template, input_variables=["context", "question"]
80
+ # )
81
+ # template = """You are a chatbot having a conversation with a human.\n\nGiven the following extracted parts of a long document and a question, create a final answer.\n\n{context}\n\n{chat_history}\nHuman: {human_input}\nChatbot:"""
82
+ template = """
83
+ You are the friendly documentation buddy Arti, who helps the Human in using RAY, the open-source unified framework for scaling AI and Python applications.\
84
+ Use the following context (delimited by <ctx></ctx>) and the chat history (delimited by <hs></hs>) to answer the question :
85
+ ------
86
+ <ctx>
87
+ {context}
88
+ </ctx>
89
+ ------
90
+ <hs>
91
+ {history}
92
+ </hs>
93
+ ------
94
+ {question}
95
+ Answer:
96
+ """
97
+ prompt = PromptTemplate(input_variables=["chat_history", "human_input", "context"], template=template)
98
+ chain_type_kwargs = {"prompt": prompt}
99
  global qa
100
+ # qa = RetrievalQA.from_chain_type(
101
+ # llm=flan_ul2,
102
+ # memory=memory,
103
+ # chain_type="stuff",
104
+ # retriever=retriever,
105
+ # return_source_documents=True,
106
+ # chain_type_kwargs=chain_type_kwargs,
107
+ # )
108
+
109
+ prompt = PromptTemplate(
110
+ input_variables=["history", "context", "question"],
111
+ template=template,
112
+ )
113
+ memory = ConversationBufferMemory(memory_key="history", input_key="question")
114
+
115
+ qa = RetrievalQAWithSourcesChain.from_chain_type(llm=flan_ul2, retriever=retriever, return_source_documents=True, verbose=True, chain_type_kwargs={
116
+ "verbose": True,
117
+ "memory": memory,
118
+ "prompt": prompt,
119
+ "document_variable_name": "context"
120
+ }
121
+ )
122
+
123
  return "Ready"
124
 
125
  def add_text(history, text):
126
  history = history + [(text, None)]
127
  return history, ""
128
 
129
+ # def bot(history):
130
+ # response = infer(history[-1][0])
131
+ # history[-1][1] = response['result']
132
+ # return history
133
+
134
  def bot(history):
135
+ response = infer(history[-1][0], history)
136
+ sources = [doc.metadata.get("source") for doc in response['source_documents']]
137
+ src_list = '\n'.join(sources)
138
+ print_this = response['answer'] + "\n\n\n Sources: \n\n\n" + src_list
139
 
140
+ def infer(question, history):
141
 
142
  query = question
143
+ # result = qa({"query": query, "context":""})
144
+ # result = qa({"query": query, })
145
+ result = qa({"query": query, "history": history, "question": question})
146
 
147
  return result
148
 
requirements.txt CHANGED
@@ -500,3 +500,260 @@ xkit==0.0.0
500
  yarl==1.9.2
501
  yaspin==3.0.1
502
  zipp==1.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  yarl==1.9.2
501
  yaspin==3.0.1
502
  zipp==1.0.0
503
+ aiofiles==23.2.1
504
+ aiohttp==3.9.1
505
+ aiosignal==1.3.1
506
+ aiostream==0.5.2
507
+ altair==5.1.2
508
+ annotated-types==0.5.0
509
+ anyio==3.7.1
510
+ appdirs==1.4.4
511
+ argcomplete==1.8.1
512
+ astor==0.8.1
513
+ asttokens==2.4.1
514
+ async-timeout==4.0.3
515
+ asyncer==0.0.2
516
+ attrs==23.1.0
517
+ auth0-python==4.4.2
518
+ Babel==2.8.0
519
+ backoff==2.2.1
520
+ beautiful-date==2.2.1
521
+ beautifulsoup4==4.12.2
522
+ bidict==0.22.1
523
+ blessed==1.20.0
524
+ blinker==1.4
525
+ Brotli==1.0.9
526
+ CacheControl==0.12.10
527
+ cachetools==5.3.1
528
+ cachy==0.3.0
529
+ certifi==2023.7.22
530
+ cffi==1.15.1
531
+ chardet==4.0.0
532
+ charset-normalizer==3.2.0
533
+ cleo==0.8.1
534
+ click==8.1.7
535
+ clikit==0.6.2
536
+ cohere==4.37
537
+ colorama==0.4.4
538
+ comm==0.2.0
539
+ command-not-found==0.3
540
+ contourpy==1.2.0
541
+ crashtest==0.3.1
542
+ cryptography==41.0.3
543
+ cycler==0.12.1
544
+ dataclasses-json==0.5.14
545
+ dbus-python==1.2.18
546
+ debugpy==1.8.0
547
+ decorator==5.1.1
548
+ Deprecated==1.2.14
549
+ distlib==0.3.4
550
+ distro==1.7.0
551
+ distro-info==1.1+ubuntu0.1
552
+ exceptiongroup==1.1.3
553
+ executing==2.0.1
554
+ fastapi==0.104.1
555
+ fastapi-socketio==0.0.10
556
+ fastavro==1.9.1
557
+ ffmpy==0.3.1
558
+ filelock==3.6.0
559
+ filetype==1.2.0
560
+ fonttools==4.44.3
561
+ frozenlist==1.4.0
562
+ fsspec==2023.10.0
563
+ gcsa==2.1.0
564
+ gdown==4.7.1
565
+ git-python==1.0.3
566
+ gitdb==4.0.11
567
+ GitPython==3.1.40
568
+ google-api-core==2.11.1
569
+ google-api-python-client==2.99.0
570
+ google-auth==2.23.0
571
+ google-auth-httplib2==0.1.1
572
+ google-auth-oauthlib==0.8.0
573
+ googleapis-common-protos==1.60.0
574
+ gradio==4.4.1
575
+ gradio_client==0.7.0
576
+ graphviz==0.14.2
577
+ greenlet==2.0.2
578
+ grpcio==1.58.0
579
+ gyp==0.1
580
+ h11==0.14.0
581
+ html2image==2.0.4.3
582
+ html5lib==1.1
583
+ httpcore==0.18.0
584
+ httplib2==0.20.2
585
+ httpx==0.25.0
586
+ huggingface-hub==0.19.4
587
+ idna==3.4
588
+ importlib-metadata==6.8.0
589
+ importlib-resources==6.1.1
590
+ inquirer==3.1.4
591
+ ipykernel==6.26.0
592
+ ipython==8.18.0
593
+ jedi==0.19.1
594
+ jeepney==0.7.1
595
+ Jinja2==3.1.2
596
+ joblib==1.3.2
597
+ jsonschema==4.19.2
598
+ jsonschema-specifications==2023.7.1
599
+ jupyter_client==8.6.0
600
+ jupyter_core==5.5.0
601
+ keyring==21.8.0
602
+ kiwisolver==1.4.5
603
+ langchain==0.0.281
604
+ langsmith==0.0.33
605
+ launchpadlib==1.10.16
606
+ Lazify==0.4.0
607
+ lazr.restfulclient==0.14.4
608
+ lazr.uri==1.0.6
609
+ litellm==0.13.2
610
+ livereload==2.6.3
611
+ llama-index==0.9.13
612
+ lockfile==0.12.2
613
+ Markdown==3.3.6
614
+ markdown-it-py==3.0.0
615
+ MarkupSafe==2.0.1
616
+ marshmallow==3.20.1
617
+ matplotlib==3.8.2
618
+ matplotlib-inline==0.1.6
619
+ mdurl==0.1.2
620
+ mkdocs==1.1.2
621
+ more-itertools==8.10.0
622
+ msgpack==1.0.3
623
+ multidict==6.0.4
624
+ mutagen==1.45.1
625
+ mypy-extensions==1.0.0
626
+ nest-asyncio==1.5.8
627
+ netifaces==0.11.0
628
+ nltk==3.8.1
629
+ nodeenv==1.8.0
630
+ numexpr==2.8.5
631
+ numpy==1.25.2
632
+ oauthlib==3.2.0
633
+ open-interpreter==0.1.15
634
+ openai==1.3.8
635
+ openapi-schema-pydantic==1.2.4
636
+ opentelemetry-api==1.20.0
637
+ opentelemetry-exporter-otlp==1.20.0
638
+ opentelemetry-exporter-otlp-proto-common==1.20.0
639
+ opentelemetry-exporter-otlp-proto-grpc==1.20.0
640
+ opentelemetry-exporter-otlp-proto-http==1.20.0
641
+ opentelemetry-instrumentation==0.40b0
642
+ opentelemetry-proto==1.20.0
643
+ opentelemetry-sdk==1.20.0
644
+ opentelemetry-semantic-conventions==0.41b0
645
+ orjson==3.9.10
646
+ packaging==20.9
647
+ pandas==2.1.3
648
+ parso==0.8.3
649
+ pastel==0.2.1
650
+ pexpect==4.8.0
651
+ Pillow==10.1.0
652
+ pipdeptree==2.2.0
653
+ pkginfo==1.8.2
654
+ platformdirs==2.5.1
655
+ poetry==1.1.12
656
+ poetry-core==1.0.7
657
+ prisma==0.10.0
658
+ prompt-toolkit==3.0.41
659
+ protobuf==4.24.3
660
+ psutil==5.9.6
661
+ ptyprocess==0.7.0
662
+ pure-eval==0.2.2
663
+ pyarrow==14.0.1
664
+ pyasn1==0.5.0
665
+ pyasn1-modules==0.3.0
666
+ pycparser==2.21
667
+ pycryptodomex==3.11.0
668
+ pydantic==2.5.2
669
+ pydantic_core==2.14.5
670
+ pydeck==0.8.1b0
671
+ pydub==0.25.1
672
+ PyGithub==2.1.1
673
+ Pygments==2.16.1
674
+ PyGObject==3.42.1
675
+ pyinotify==0.9.6
676
+ PyJWT==2.8.0
677
+ pylev==1.2.0
678
+ PyNaCl==1.5.0
679
+ pyOpenSSL==23.2.0
680
+ pyparsing==2.4.7
681
+ PySocks==1.7.1
682
+ python-apt==2.4.0+ubuntu2
683
+ python-dateutil==2.8.2
684
+ python-dotenv==1.0.0
685
+ python-editor==1.0.4
686
+ python-engineio==4.7.0
687
+ python-graphql-client==0.4.3
688
+ python-multipart==0.0.6
689
+ python-socketio==5.9.0
690
+ pytz==2022.1
691
+ pytz-deprecation-shim==0.1.0.post0
692
+ pyxattr==0.7.2
693
+ PyYAML==6.0.1
694
+ pyzmq==25.1.1
695
+ readchar==4.0.5
696
+ referencing==0.30.2
697
+ regex==2023.10.3
698
+ requests==2.31.0
699
+ requests-oauthlib==1.3.1
700
+ requests-toolbelt==0.9.1
701
+ rich==13.6.0
702
+ rpds-py==0.12.0
703
+ rsa==4.9
704
+ screen-resolution-extra==0.0.0
705
+ SecretStorage==3.3.1
706
+ semantic-version==2.10.0
707
+ shellingham==1.4.0
708
+ six==1.16.0
709
+ smmap==5.0.1
710
+ sniffio==1.3.0
711
+ soupsieve==2.5
712
+ speedtest-cli==2.1.3
713
+ SQLAlchemy==2.0.20
714
+ stack-data==0.6.3
715
+ starlette==0.27.0
716
+ streamlit==1.28.1
717
+ syncer==2.0.3
718
+ systemd-python==234
719
+ tenacity==8.2.3
720
+ termcolor==2.3.0
721
+ tiktoken==0.4.0
722
+ tokenizers==0.15.0
723
+ tokentrim==0.1.13
724
+ toml==0.10.2
725
+ tomli==2.0.1
726
+ tomlkit==0.12.0
727
+ toolz==0.12.0
728
+ tornado==6.3.3
729
+ tqdm==4.66.1
730
+ traitlets==5.13.0
731
+ typer==0.9.0
732
+ typing-inspect==0.9.0
733
+ typing_extensions==4.8.0
734
+ tzdata==2023.3
735
+ tzlocal==4.3.1
736
+ ubuntu-advantage-tools==8001
737
+ ubuntu-drivers-common==0.0.0
738
+ ufw==0.36.1
739
+ unattended-upgrades==0.1
740
+ uptrace==1.20.0
741
+ uritemplate==4.1.1
742
+ urllib3==2.0.4
743
+ userpath==1.8.0
744
+ uvicorn==0.23.2
745
+ validators==0.22.0
746
+ virtualenv==20.13.0+ds
747
+ wadllib==1.3.6
748
+ watchdog==3.0.0
749
+ watchfiles==0.20.0
750
+ wcwidth==0.2.12
751
+ webencodings==0.5.1
752
+ websocket-client==1.6.4
753
+ websockets==11.0.3
754
+ wget==3.2
755
+ wrapt==1.15.0
756
+ xkit==0.0.0
757
+ yarl==1.9.2
758
+ yaspin==3.0.1
759
+ zipp==1.0.0