Kieran Gookey commited on
Commit
6ad144b
·
1 Parent(s): 277b244

Set a different embedding model

Browse files
Files changed (1) hide show
  1. app.py +111 -69
app.py CHANGED
@@ -10,104 +10,146 @@ from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
10
 
11
  inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
12
 
13
- embed_model_name = st.text_input(
14
- 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
15
 
16
- llm_model_name = st.text_input(
17
- 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  query = st.text_input(
20
- 'Query', "What is the price of the product?")
 
21
 
22
  html_file = st.file_uploader("Upload a html file", type=["html"])
23
 
24
- if st.button('Start Pipeline'):
25
- if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
26
- st.write('Running Pipeline')
27
- llm = HuggingFaceInferenceAPI(
28
- model_name=llm_model_name, token=inference_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- embed_model = HuggingFaceInferenceAPIEmbedding(
31
- model_name=embed_model_name,
32
- token=inference_api_key,
33
- model_kwargs={"device": ""},
34
- encode_kwargs={"normalize_embeddings": True},
35
- )
36
 
37
- service_context = ServiceContext.from_defaults(
38
- embed_model=embed_model, llm=llm)
39
 
40
- stringio = StringIO(html_file.getvalue().decode("utf-8"))
41
- string_data = stringio.read()
42
- with st.expander("Uploaded HTML"):
43
- st.write(string_data)
44
 
45
- document_id = str(uuid.uuid4())
46
 
47
- document = Document(text=string_data)
48
- document.metadata["id"] = document_id
49
- documents = [document]
50
 
51
- filters = MetadataFilters(
52
- filters=[ExactMatchFilter(key="id", value=document_id)])
53
 
54
- index = VectorStoreIndex.from_documents(
55
- documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
56
 
57
- retriever = index.as_retriever()
58
 
59
- ranked_nodes = retriever.retrieve(
60
- query)
61
 
62
- with st.expander("Ranked Nodes"):
63
- for node in ranked_nodes:
64
- st.write(node.node.get_content(), "-> Score:", node.score)
65
 
66
- query_engine = index.as_query_engine(
67
- filters=filters, service_context=service_context)
68
 
69
- response = query_engine.query(query)
70
 
71
- st.write(response.response)
72
 
73
- st.write(response.source_nodes)
74
 
75
- else:
76
- st.error('Please fill in all the fields')
77
- else:
78
- st.write('Press start to begin')
79
 
80
- # if html_file is not None:
81
- # stringio = StringIO(html_file.getvalue().decode("utf-8"))
82
- # string_data = stringio.read()
83
- # with st.expander("Uploaded HTML"):
84
- # st.write(string_data)
85
 
86
- # document_id = str(uuid.uuid4())
87
 
88
- # document = Document(text=string_data)
89
- # document.metadata["id"] = document_id
90
- # documents = [document]
91
 
92
- # filters = MetadataFilters(
93
- # filters=[ExactMatchFilter(key="id", value=document_id)])
94
 
95
- # index = VectorStoreIndex.from_documents(
96
- # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
97
 
98
- # retriever = index.as_retriever()
99
 
100
- # ranked_nodes = retriever.retrieve(
101
- # "Get me all the information about the product")
102
 
103
- # with st.expander("Ranked Nodes"):
104
- # for node in ranked_nodes:
105
- # st.write(node.node.get_content(), "-> Score:", node.score)
106
 
107
- # query_engine = index.as_query_engine(
108
- # filters=filters, service_context=service_context)
109
 
110
- # response = query_engine.query(
111
- # "Get me all the information about the product")
112
 
113
- # st.write(response)
 
10
 
11
  inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
12
 
13
+ # embed_model_name = st.text_input(
14
+ # 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
15
 
16
+ # llm_model_name = st.text_input(
17
+ # 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
18
+
19
+ embed_model_name = "jinaai/jina-embedding-s-en-v1"
20
+ llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
21
+
22
+ llm = HuggingFaceInferenceAPI(
23
+ model_name=llm_model_name, token=inference_api_key)
24
+
25
+ embed_model = HuggingFaceInferenceAPIEmbedding(
26
+ model_name=embed_model_name,
27
+ token=inference_api_key,
28
+ model_kwargs={"device": ""},
29
+ encode_kwargs={"normalize_embeddings": True},
30
+ )
31
+
32
+ service_context = ServiceContext.from_defaults(
33
+ embed_model=embed_model, llm=llm)
34
 
35
  query = st.text_input(
36
+ 'Query', "What is the price of the product?"
37
+ )
38
 
39
  html_file = st.file_uploader("Upload a html file", type=["html"])
40
 
41
+ if html_file is not None:
42
+ stringio = StringIO(html_file.getvalue().decode("utf-8"))
43
+ string_data = stringio.read()
44
+ with st.expander("Uploaded HTML"):
45
+ st.write(string_data)
46
+
47
+ document_id = str(uuid.uuid4())
48
+
49
+ document = Document(text=string_data)
50
+ document.metadata["id"] = document_id
51
+ documents = [document]
52
+
53
+ filters = MetadataFilters(
54
+ filters=[ExactMatchFilter(key="id", value=document_id)])
55
+
56
+ index = VectorStoreIndex.from_documents(
57
+ documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
58
+
59
+ query_engine = index.as_query_engine(
60
+ filters=filters, service_context=service_context)
61
+
62
+ response = query_engine.query(query)
63
+
64
+ st.write(response.response)
65
+
66
+ # if st.button('Start Pipeline'):
67
+ # if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
68
+ # st.write('Running Pipeline')
69
+ # llm = HuggingFaceInferenceAPI(
70
+ # model_name=llm_model_name, token=inference_api_key)
71
 
72
+ # embed_model = HuggingFaceInferenceAPIEmbedding(
73
+ # model_name=embed_model_name,
74
+ # token=inference_api_key,
75
+ # model_kwargs={"device": ""},
76
+ # encode_kwargs={"normalize_embeddings": True},
77
+ # )
78
 
79
+ # service_context = ServiceContext.from_defaults(
80
+ # embed_model=embed_model, llm=llm)
81
 
82
+ # stringio = StringIO(html_file.getvalue().decode("utf-8"))
83
+ # string_data = stringio.read()
84
+ # with st.expander("Uploaded HTML"):
85
+ # st.write(string_data)
86
 
87
+ # document_id = str(uuid.uuid4())
88
 
89
+ # document = Document(text=string_data)
90
+ # document.metadata["id"] = document_id
91
+ # documents = [document]
92
 
93
+ # filters = MetadataFilters(
94
+ # filters=[ExactMatchFilter(key="id", value=document_id)])
95
 
96
+ # index = VectorStoreIndex.from_documents(
97
+ # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
98
 
99
+ # retriever = index.as_retriever()
100
 
101
+ # ranked_nodes = retriever.retrieve(
102
+ # query)
103
 
104
+ # with st.expander("Ranked Nodes"):
105
+ # for node in ranked_nodes:
106
+ # st.write(node.node.get_content(), "-> Score:", node.score)
107
 
108
+ # query_engine = index.as_query_engine(
109
+ # filters=filters, service_context=service_context)
110
 
111
+ # response = query_engine.query(query)
112
 
113
+ # st.write(response.response)
114
 
115
+ # st.write(response.source_nodes)
116
 
117
+ # else:
118
+ # st.error('Please fill in all the fields')
119
+ # else:
120
+ # st.write('Press start to begin')
121
 
122
+ # # if html_file is not None:
123
+ # # stringio = StringIO(html_file.getvalue().decode("utf-8"))
124
+ # # string_data = stringio.read()
125
+ # # with st.expander("Uploaded HTML"):
126
+ # # st.write(string_data)
127
 
128
+ # # document_id = str(uuid.uuid4())
129
 
130
+ # # document = Document(text=string_data)
131
+ # # document.metadata["id"] = document_id
132
+ # # documents = [document]
133
 
134
+ # # filters = MetadataFilters(
135
+ # # filters=[ExactMatchFilter(key="id", value=document_id)])
136
 
137
+ # # index = VectorStoreIndex.from_documents(
138
+ # # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
139
 
140
+ # # retriever = index.as_retriever()
141
 
142
+ # # ranked_nodes = retriever.retrieve(
143
+ # # "Get me all the information about the product")
144
 
145
+ # # with st.expander("Ranked Nodes"):
146
+ # # for node in ranked_nodes:
147
+ # # st.write(node.node.get_content(), "-> Score:", node.score)
148
 
149
+ # # query_engine = index.as_query_engine(
150
+ # # filters=filters, service_context=service_context)
151
 
152
+ # # response = query_engine.query(
153
+ # # "Get me all the information about the product")
154
 
155
+ # # st.write(response)