diginoron commited on
Commit
d95ee14
·
verified ·
1 Parent(s): 7279042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -50
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- import json
3
- import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
- from pinecone import Pinecone
8
 
9
- # دریافت توکن‌ها و متغیرهای محیطی
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
12
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
@@ -15,50 +14,40 @@ assert HF_TOKEN is not None, "❌ HF_TOKEN is missing!"
15
  assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
16
  assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
17
 
18
- # بارگذاری مدل embedding
19
- embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
20
-
21
- # بارگذاری مدل زبانی (مثل MT5 یا Gemma)
22
- model_name = "google/mt5-small"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
24
- language_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
25
-
26
- # اتصال به Pinecone
27
- pc = Pinecone(api_key=PINECONE_API_KEY)
28
- index = pc.Index(PINECONE_INDEX_NAME)
29
-
30
- # دریافت پاسخ اولیه از Pinecone
31
- def retrieve_answer(query, threshold=0.65, top_k=3):
32
- query_embedding = embedding_model.encode([query])[0]
33
- result = index.query(vector=query_embedding.tolist(), top_k=top_k, include_metadata=True)
34
-
35
- if result['matches'] and result['matches'][0]['score'] > threshold:
36
- return result['matches'][0]['metadata'].get('answer', "پاسخی یافت نشد.")
37
- else:
38
- return "متأسفم، پاسخی برای این سوال در پایگاه داده یافت نشد."
39
-
40
- # بازنویسی پاسخ به صورت طبیعی با مدل زبانی
41
- def rewrite_answer(question, raw_answer):
42
- input_text = f"پرسش: {question}\nپاسخ: {raw_answer}"
43
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
44
- with torch.no_grad():
45
- outputs = language_model.generate(**inputs, max_new_tokens=100)
46
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
47
-
48
- # رابط Gradio
49
- def chat_interface(question):
50
- raw_answer = retrieve_answer(question)
51
- if "متأسفم" in raw_answer:
52
- return raw_answer
53
- final_answer = rewrite_answer(question, raw_answer)
54
- return final_answer
55
-
56
- demo = gr.Interface(
57
- fn=chat_interface,
58
- inputs="text",
59
- outputs="text",
60
- title="چت‌بات تیام",
61
- description="پرسش خود را درباره خدمات دیجیتال مارکتینگ تیام بپرسید."
62
  )
63
 
64
- demo.launch()
 
1
  import os
2
+ import pinecone
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer
 
5
  import torch
6
+ import gradio as gr
7
 
8
+ # --- Load environment variables ---
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
10
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
11
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
 
14
  assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
15
  assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
16
 
17
+ # --- Initialize Pinecone ---
18
+ pinecone.init(api_key=PINECONE_API_KEY, environment="gcp-starter")
19
+ index = pinecone.Index(PINECONE_INDEX_NAME)
20
+
21
+ # --- Load models ---
22
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", token=HF_TOKEN)
23
+ tokenizer = T5Tokenizer.from_pretrained("google/mt5-small", token=HF_TOKEN)
24
+ model = T5ForConditionalGeneration.from_pretrained("google/mt5-small", token=HF_TOKEN)
25
+
26
+ def generate_answer(question):
27
+ # Embed the question
28
+ question_embedding = embedding_model.encode(question).tolist()
29
+
30
+ # Query Pinecone for similar content
31
+ response = index.query(vector=question_embedding, top_k=3, include_metadata=True)
32
+ contexts = [match['metadata']['text'] for match in response['matches'] if 'text' in match['metadata']]
33
+
34
+ # Concatenate context
35
+ context = "\n".join(contexts)
36
+ input_text = f"پرسش: {question}\nاطلاعات: {context}\nپاسخ:"
37
+
38
+ # Tokenize and generate
39
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
40
+ outputs = model.generate(**inputs, max_new_tokens=200)
41
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ return answer
43
+
44
+ # --- Gradio UI ---
45
+ iface = gr.Interface(
46
+ fn=generate_answer,
47
+ inputs=gr.Textbox(label="question"),
48
+ outputs=gr.Textbox(label="output"),
49
+ title="💬 چت‌بات هوشمند تیام",
50
+ description="سؤالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید."
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
+ iface.launch()