tiyam-chatbot / app.py
diginoron's picture
Update app.py
ed6acf3 verified
raw
history blame
2.19 kB
import os
import pinecone
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import gradio as gr
# Load environment variables
HF_TOKEN = os.environ.get("HF_TOKEN")
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY")
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME")
assert HF_TOKEN is not None, "❌ HF token is missing!"
assert PINECONE_API_KEY is not None, "❌ Pinecone API key is missing!"
assert PINECONE_INDEX_NAME is not None, "❌ Pinecone index name is missing!"
# Load embedding model
embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", use_auth_token=HF_TOKEN)
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small", token=HF_TOKEN)
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small", token=HF_TOKEN)
# Initialize Pinecone client
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(PINECONE_INDEX_NAME)
def query_index(question):
# Embed question
question_embedding = embedder.encode(question).tolist()
# Query Pinecone
results = index.query(vector=question_embedding, top_k=1, include_metadata=True)
if results.matches:
retrieved_text = results.matches[0].metadata.get("text", "")
else:
retrieved_text = "متاسفم، پاسخ مناسبی پیدا نکردم."
# Generate answer
input_text = f"پرسش: {question} \n پاسخ بر اساس دانش: {retrieved_text}"
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True).input_ids
output_ids = model.generate(input_ids, max_length=100)
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return answer
# Gradio UI
iface = gr.Interface(
fn=query_index,
inputs=gr.Textbox(label="question", placeholder="سوال خود را وارد کنید"),
outputs=gr.Textbox(label="output"),
title="چت‌بات هوشمند تیام",
description="سوالات خود درباره خدمات دیجیتال مارکتینگ تیام را بپرسید."
)
iface.launch(server_name="0.0.0.0", server_port=7860)