Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,61 +1,64 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
import torch
|
4 |
import gradio as gr
|
5 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
6 |
from pinecone import Pinecone
|
7 |
-
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
8 |
|
9 |
-
#
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
|
12 |
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
|
13 |
|
|
|
|
|
|
|
|
|
14 |
# بارگذاری مدل embedding
|
15 |
embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
17 |
# اتصال به Pinecone
|
18 |
pc = Pinecone(api_key=PINECONE_API_KEY)
|
19 |
index = pc.Index(PINECONE_INDEX_NAME)
|
20 |
|
21 |
-
#
|
22 |
-
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small", token=HF_TOKEN)
|
23 |
-
language_model = T5ForConditionalGeneration.from_pretrained("google/mt5-small", token=HF_TOKEN)
|
24 |
-
|
25 |
-
# تابع جستجو در Pinecone
|
26 |
def retrieve_answer(query, threshold=0.65, top_k=3):
|
27 |
query_embedding = embedding_model.encode([query])[0]
|
28 |
result = index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
|
29 |
|
30 |
if result['matches'] and result['matches'][0]['score'] > threshold:
|
31 |
-
|
32 |
-
metadata = result['matches'][0]['metadata']
|
33 |
-
return metadata.get('answer', 'پاسخی یافت نشد.')
|
34 |
else:
|
35 |
-
return "متأسفم،
|
36 |
-
|
37 |
-
# تابع تولید پاسخ طبیعی با MT5
|
38 |
-
def generate_natural_answer(question, raw_answer):
|
39 |
-
prompt = f"پرسش: {question}\nپاسخ دقیق: {raw_answer}\nپاسخ طبیعی:"
|
40 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(language_model.device)
|
41 |
|
|
|
|
|
|
|
|
|
42 |
with torch.no_grad():
|
43 |
-
outputs = language_model.generate(**inputs, max_new_tokens=
|
44 |
-
|
45 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
46 |
|
47 |
-
# اتصال همهچیز در رابط Gradio
|
48 |
-
def chat_interface(user_question):
|
49 |
-
raw_answer = retrieve_answer(user_question)
|
50 |
-
return generate_natural_answer(user_question, raw_answer)
|
51 |
-
|
52 |
# رابط Gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
demo = gr.Interface(
|
54 |
fn=chat_interface,
|
55 |
inputs="text",
|
56 |
outputs="text",
|
57 |
title="چتبات تیام",
|
58 |
-
description="
|
59 |
)
|
60 |
|
61 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import json
|
|
|
3 |
import gradio as gr
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
+
import torch
|
7 |
from pinecone import Pinecone
|
|
|
8 |
|
9 |
+
# دریافت توکنها و متغیرهای محیطی
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
|
12 |
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
|
13 |
|
14 |
+
assert HF_TOKEN is not None, "❌ HF_TOKEN is missing!"
|
15 |
+
assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
|
16 |
+
assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
|
17 |
+
|
18 |
# بارگذاری مدل embedding
|
19 |
embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
|
20 |
|
21 |
+
# بارگذاری مدل زبانی (مثل MT5 یا Gemma)
|
22 |
+
model_name = "google/mt5-small"
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
|
24 |
+
language_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
|
25 |
+
|
26 |
# اتصال به Pinecone
|
27 |
pc = Pinecone(api_key=PINECONE_API_KEY)
|
28 |
index = pc.Index(PINECONE_INDEX_NAME)
|
29 |
|
30 |
+
# دریافت پاسخ اولیه از Pinecone
|
|
|
|
|
|
|
|
|
31 |
def retrieve_answer(query, threshold=0.65, top_k=3):
|
32 |
query_embedding = embedding_model.encode([query])[0]
|
33 |
result = index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
|
34 |
|
35 |
if result['matches'] and result['matches'][0]['score'] > threshold:
|
36 |
+
return result['matches'][0]['metadata'].get('answer', "پاسخی یافت نشد.")
|
|
|
|
|
37 |
else:
|
38 |
+
return "متأسفم، پاسخی برای این سوال در پایگاه داده یافت نشد."
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
# بازنویسی پاسخ به صورت طبیعی با مدل زبانی
|
41 |
+
def rewrite_answer(question, raw_answer):
|
42 |
+
input_text = f"پرسش: {question}\nپاسخ: {raw_answer}"
|
43 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
|
44 |
with torch.no_grad():
|
45 |
+
outputs = language_model.generate(**inputs, max_new_tokens=100)
|
|
|
46 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
# رابط Gradio
|
49 |
+
def chat_interface(question):
|
50 |
+
raw_answer = retrieve_answer(question)
|
51 |
+
if "متأسفم" in raw_answer:
|
52 |
+
return raw_answer
|
53 |
+
final_answer = rewrite_answer(question, raw_answer)
|
54 |
+
return final_answer
|
55 |
+
|
56 |
demo = gr.Interface(
|
57 |
fn=chat_interface,
|
58 |
inputs="text",
|
59 |
outputs="text",
|
60 |
title="چتبات تیام",
|
61 |
+
description="پرسش خود را درباره خدمات دیجیتال مارکتینگ تیام بپرسید."
|
62 |
)
|
63 |
|
64 |
demo.launch()
|