diginoron commited on
Commit
a741062
·
verified ·
1 Parent(s): ff0c373

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -45
app.py CHANGED
@@ -1,58 +1,64 @@
 
1
  import gradio as gr
 
2
  from sentence_transformers import SentenceTransformer
3
- import pinecone
4
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
- import torch
6
- import os
7
 
8
- # Load secrets and environment variables
 
9
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
10
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
11
- HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
13
- # Step 1: Load embedding model and Pinecone
14
- embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
15
- pinecone.init(api_key=PINECONE_API_KEY)
16
- pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
 
 
 
 
 
 
 
17
  index = pc.Index(PINECONE_INDEX_NAME)
18
 
19
- # Step 2: Load GPT-2 language model
20
- model_name = "HooshvareLab/gpt2-fa"
21
- tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
22
- model = GPT2LMHeadModel.from_pretrained(model_name, use_auth_token=HF_TOKEN)
23
- model.eval()
24
-
25
- # Function: Embed input and search in Pinecone
26
- def retrieve_context(query, top_k=1):
27
- xq = embedding_model.encode(query).tolist()
28
- res = index.query(vector=xq, top_k=top_k, include_metadata=True)
29
- if res.matches:
30
- return res.matches[0].metadata['text']
31
- return ""
32
-
33
- # Function: Generate response using GPT-2
34
- def generate_response(query, context):
35
- prompt = f"پرسش: {query}\nپاسخ با توجه به اطلاعات زیر: {context}\nپاسخ:"
36
- input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
37
- output_ids = model.generate(input_ids, max_length=256, num_beams=4, no_repeat_ngram_size=2, early_stopping=True)
38
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
- return output.split("پاسخ:")[-1].strip()
40
-
41
- # Gradio interface
42
  def chat(query):
43
- context = retrieve_context(query)
44
- response = generate_response(query, context)
45
- return response
46
-
47
- # UI
48
- with gr.Blocks() as demo:
49
- gr.Markdown("## چت‌بات هوشمند تیام\nسوالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید.")
50
- with gr.Row():
51
- inp = gr.Textbox(label="question", placeholder="سوال خود را وارد کنید")
52
- out = gr.Textbox(label="output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  submit = gr.Button("Submit")
54
- submit.click(chat, inputs=inp, outputs=out)
 
55
 
56
- # Launch
57
  if __name__ == "__main__":
58
  demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from sentence_transformers import SentenceTransformer
5
+ from pinecone import Pinecone, ServerlessSpec
 
 
 
6
 
7
+ # --- Load environment variables ---
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")
9
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
10
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
 
11
 
12
+ assert HF_TOKEN is not None, "❌ HF_TOKEN is missing!"
13
+ assert PINECONE_API_KEY is not None, "❌ PINECONE_API_KEY is missing!"
14
+ assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
15
+
16
+ # --- Load models ---
17
+ embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
18
+ tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/gpt2-fa", token=HF_TOKEN)
19
+ model = AutoModelForCausalLM.from_pretrained("HooshvareLab/gpt2-fa", token=HF_TOKEN)
20
+
21
+ # --- Connect to Pinecone ---
22
+ pc = Pinecone(api_key=PINECONE_API_KEY)
23
  index = pc.Index(PINECONE_INDEX_NAME)
24
 
25
+ # --- Inference pipeline ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def chat(query):
27
+ # Embed user question
28
+ xq = embedder.encode(query).tolist()
29
+
30
+ # Search in Pinecone
31
+ res = index.query(vector=xq, top_k=1, include_metadata=True)
32
+ matches = res.get("matches", [])
33
+
34
+ if not matches:
35
+ return "پاسخی برای سوال شما پیدا نشد. لطفا تماس بگیرید."
36
+
37
+ # Retrieve matched content
38
+ context = matches[0]['metadata']['text']
39
+
40
+ # Prepare prompt
41
+ prompt = f"سوال: {query}\nپاسخ بر اساس اطلاعات زیر بده: {context}\nپاسخ:"
42
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
43
+
44
+ # Generate response
45
+ output_ids = model.generate(**inputs, max_new_tokens=100)
46
+ answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
47
+
48
+ # Post-process to remove prompt
49
+ if "پاسخ:" in answer:
50
+ answer = answer.split("پاسخ:", 1)[-1].strip()
51
+
52
+ return answer
53
+
54
+ # --- Gradio UI ---
55
+ with gr.Blocks(title="چت‌بات هوشمند تیام") as demo:
56
+ gr.Markdown("""## چت‌بات هوشمند تیام\nسوالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید""")
57
+ question = gr.Textbox(label="question", placeholder="سوال خود را وارد کنید")
58
+ output = gr.Textbox(label="output")
59
  submit = gr.Button("Submit")
60
+ submit.click(fn=chat, inputs=question, outputs=output)
61
+ gr.ClearButton([question, output])
62
 
 
63
  if __name__ == "__main__":
64
  demo.launch()