diginoron commited on
Commit
ccf44c8
·
verified ·
1 Parent(s): 342ff5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -48
app.py CHANGED
@@ -1,57 +1,58 @@
1
- import os
2
- import pinecone
3
  from sentence_transformers import SentenceTransformer
4
- from transformers import T5Tokenizer, T5ForConditionalGeneration
 
5
  import torch
6
- import gradio as gr
7
 
8
- # Load environment variables
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
11
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
 
12
 
13
- assert HF_TOKEN is not None, "❌ HF token is missing!"
14
- assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
15
- assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
16
-
17
- # Load embedding model
18
- embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
19
-
20
- # Load tokenizer and model
21
- tokenizer = T5Tokenizer.from_pretrained("google/mt5-small", token=HF_TOKEN)
22
- model = T5ForConditionalGeneration.from_pretrained("google/mt5-small", token=HF_TOKEN)
23
-
24
- # Initialize Pinecone client
25
  pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
26
  index = pc.Index(PINECONE_INDEX_NAME)
27
 
28
- def query_index(question):
29
- # Embed question
30
- question_embedding = embedder.encode(question).tolist()
31
-
32
- # Query Pinecone
33
- results = index.query(vector=question_embedding, top_k=1, include_metadata=True)
34
-
35
- if results.matches:
36
- retrieved_text = results.matches[0].metadata.get("text", "")
37
- else:
38
- retrieved_text = "متاسفم، پاسخ مناسبی پیدا نکردم."
39
-
40
- # Generate answer
41
- input_text = f"پرسش: {question} \n پاسخ بر اساس دانش: {retrieved_text}"
42
- input_ids = tokenizer(input_text, return_tensors="pt", truncation=True).input_ids
43
- output_ids = model.generate(input_ids, max_length=100)
44
- answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
-
46
- return answer
47
-
48
- # Gradio UI
49
- iface = gr.Interface(
50
- fn=query_index,
51
- inputs=gr.Textbox(label="question", placeholder="سوال خود را وارد کنید"),
52
- outputs=gr.Textbox(label="output"),
53
- title="چت‌بات هوشمند تیام",
54
- description="سوالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید."
55
- )
56
-
57
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
2
  from sentence_transformers import SentenceTransformer
3
+ import pinecone
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
  import torch
6
+ import os
7
 
8
+ # Load secrets and environment variables
 
9
  PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
10
  PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
13
+ # Step 1: Load embedding model and Pinecone
14
+ embedding_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
15
+ pinecone.init(api_key=PINECONE_API_KEY)
 
 
 
 
 
 
 
 
 
16
  pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
17
  index = pc.Index(PINECONE_INDEX_NAME)
18
 
19
+ # Step 2: Load GPT-2 language model
20
+ model_name = "HooshvareLab/gpt2-fa"
21
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
22
+ model = GPT2LMHeadModel.from_pretrained(model_name, use_auth_token=HF_TOKEN)
23
+ model.eval()
24
+
25
+ # Function: Embed input and search in Pinecone
26
+ def retrieve_context(query, top_k=1):
27
+ xq = embedding_model.encode(query).tolist()
28
+ res = index.query(vector=xq, top_k=top_k, include_metadata=True)
29
+ if res.matches:
30
+ return res.matches[0].metadata['text']
31
+ return ""
32
+
33
+ # Function: Generate response using GPT-2
34
+ def generate_response(query, context):
35
+ prompt = f"پرسش: {query}\nپاسخ با توجه به اطلاعات زیر: {context}\nپاسخ:"
36
+ input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=512)
37
+ output_ids = model.generate(input_ids, max_length=256, num_beams=4, no_repeat_ngram_size=2, early_stopping=True)
38
+ output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
+ return output.split("پاسخ:")[-1].strip()
40
+
41
+ # Gradio interface
42
+ def chat(query):
43
+ context = retrieve_context(query)
44
+ response = generate_response(query, context)
45
+ return response
46
+
47
+ # UI
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown("## چت‌بات هوشمند تیام\nسوالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید.")
50
+ with gr.Row():
51
+ inp = gr.Textbox(label="question", placeholder="سوال خود را وارد کنید")
52
+ out = gr.Textbox(label="output")
53
+ submit = gr.Button("Submit")
54
+ submit.click(chat, inputs=inp, outputs=out)
55
+
56
+ # Launch
57
+ if __name__ == "__main__":
58
+ demo.launch()