diginoron commited on
Commit
7279042
·
verified ·
1 Parent(s): 012badc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -1,61 +1,64 @@
1
  import os
2
  import json
3
- import torch
4
  import gradio as gr
5
  from sentence_transformers import SentenceTransformer
 
 
6
  from pinecone import Pinecone
7
- from transformers import T5Tokenizer, T5ForConditionalGeneration
8
 
9
- # بارگذاری توکن‌ها از محیط امن
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
12
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
13
 
 
 
 
 
14
  # بارگذاری مدل embedding
15
  embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
16
 
 
 
 
 
 
17
  # اتصال به Pinecone
18
  pc = Pinecone(api_key=PINECONE_API_KEY)
19
  index = pc.Index(PINECONE_INDEX_NAME)
20
 
21
- # بارگذاری مدل زبانی MT5
22
- tokenizer = T5Tokenizer.from_pretrained("google/mt5-small", token=HF_TOKEN)
23
- language_model = T5ForConditionalGeneration.from_pretrained("google/mt5-small", token=HF_TOKEN)
24
-
25
- # تابع جستجو در Pinecone
26
  def retrieve_answer(query, threshold=0.65, top_k=3):
27
  query_embedding = embedding_model.encode([query])[0]
28
  result = index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
29
 
30
  if result['matches'] and result['matches'][0]['score'] > threshold:
31
- print(f"📊 Similarity: {result['matches'][0]['score']:.3f}")
32
- metadata = result['matches'][0]['metadata']
33
- return metadata.get('answer', 'پاسخی یافت نشد.')
34
  else:
35
- return "متأسفم، پاسخ دقیقی برای این سوال نداریم. لطفاً با ما تماس بگیرید."
36
-
37
- # تابع تولید پاسخ طبیعی با MT5
38
- def generate_natural_answer(question, raw_answer):
39
- prompt = f"پرسش: {question}\nپاسخ دقیق: {raw_answer}\nپاسخ طبیعی:"
40
- inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(language_model.device)
41
 
 
 
 
 
42
  with torch.no_grad():
43
- outputs = language_model.generate(**inputs, max_new_tokens=128, do_sample=False)
44
-
45
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
 
47
- # اتصال همه‌چیز در رابط Gradio
48
- def chat_interface(user_question):
49
- raw_answer = retrieve_answer(user_question)
50
- return generate_natural_answer(user_question, raw_answer)
51
-
52
  # رابط Gradio
 
 
 
 
 
 
 
53
  demo = gr.Interface(
54
  fn=chat_interface,
55
  inputs="text",
56
  outputs="text",
57
  title="چت‌بات تیام",
58
- description="سؤالات خود را از آژانس دیجیتال مارکتینگ تیام بپرسید."
59
  )
60
 
61
  demo.launch()
 
1
  import os
2
  import json
 
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import torch
7
  from pinecone import Pinecone
 
8
 
9
+ # دریافت توکن‌ها و متغیرهای محیطی
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
12
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
13
 
14
+ assert HF_TOKEN is not None, "❌ HF_TOKEN is missing!"
15
+ assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
16
+ assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
17
+
18
  # بارگذاری مدل embedding
19
  embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
20
 
21
+ # بارگذاری مدل زبانی (مثل MT5 یا Gemma)
22
+ model_name = "google/mt5-small"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
24
+ language_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
25
+
26
  # اتصال به Pinecone
27
  pc = Pinecone(api_key=PINECONE_API_KEY)
28
  index = pc.Index(PINECONE_INDEX_NAME)
29
 
30
+ # دریافت پاسخ اولیه از Pinecone
 
 
 
 
31
  def retrieve_answer(query, threshold=0.65, top_k=3):
32
  query_embedding = embedding_model.encode([query])[0]
33
  result = index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
34
 
35
  if result['matches'] and result['matches'][0]['score'] > threshold:
36
+ return result['matches'][0]['metadata'].get('answer', "پاسخی یافت نشد.")
 
 
37
  else:
38
+ return "متأسفم، پاسخی برای این سوال در پایگاه داده یافت نشد."
 
 
 
 
 
39
 
40
+ # بازنویسی پاسخ به صورت طبیعی با مدل زبانی
41
+ def rewrite_answer(question, raw_answer):
42
+ input_text = f"پرسش: {question}\nپاسخ: {raw_answer}"
43
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
44
  with torch.no_grad():
45
+ outputs = language_model.generate(**inputs, max_new_tokens=100)
 
46
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
 
 
 
 
 
 
48
  # رابط Gradio
49
+ def chat_interface(question):
50
+ raw_answer = retrieve_answer(question)
51
+ if "متأسفم" in raw_answer:
52
+ return raw_answer
53
+ final_answer = rewrite_answer(question, raw_answer)
54
+ return final_answer
55
+
56
  demo = gr.Interface(
57
  fn=chat_interface,
58
  inputs="text",
59
  outputs="text",
60
  title="چت‌بات تیام",
61
+ description="پرسش خود را درباره خدمات دیجیتال مارکتینگ تیام بپرسید."
62
  )
63
 
64
  demo.launch()