Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -6,6 +6,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 6 |
from fastapi.templating import Jinja2Templates
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
import os
|
|
|
|
| 9 |
# Load environment variables
|
| 10 |
load_dotenv()
|
| 11 |
from gliner import GLiNER
|
|
@@ -19,14 +20,19 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 19 |
# Load models
|
| 20 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
| 21 |
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
| 22 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
| 23 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
| 24 |
|
| 25 |
init_qdrant_collection()
|
| 26 |
|
| 27 |
def extract_entities(text):
|
|
|
|
|
|
|
| 28 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def validate_answer(user_query, retrieved_answer):
|
| 32 |
prompt = f"""
|
|
|
|
| 6 |
from fastapi.templating import Jinja2Templates
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
import os
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
# Load environment variables
|
| 11 |
load_dotenv()
|
| 12 |
from gliner import GLiNER
|
|
|
|
| 20 |
# Load models
|
| 21 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
| 22 |
os.makedirs(cache_dir, exist_ok=True)
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", cache_dir=cache_dir) # Replace with appropriate tokenizer for GLiNER
|
| 24 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
| 25 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
| 26 |
|
| 27 |
init_qdrant_collection()
|
| 28 |
|
| 29 |
def extract_entities(text):
|
| 30 |
+
# Tokenize the input text first
|
| 31 |
+
inputs = tokenizer(text, return_tensors="pt") # Assuming PyTorch backend
|
| 32 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
| 33 |
+
|
| 34 |
+
# Predict entities
|
| 35 |
+
return gliner_model.predict_entities(inputs['input_ids'], labels)
|
| 36 |
|
| 37 |
def validate_answer(user_query, retrieved_answer):
|
| 38 |
prompt = f"""
|