Ask-FashionDB / src /use_llm.py
traopia
llm
54ab71a
raw
history blame
1.61 kB
from huggingface_hub import InferenceClient
import os
# Use a Hugging Face inference endpoint like "google/gemma-1.1-7b-it"
# You must have access to this model (either public or via token)
HUGGINGFACE_API_TOKEN = os.getenv("HF_TOKEN") # Add this in your HF Space's secret settings
DEFAULT_MODEL = "google/gemma-1.1-7b-it"
client = InferenceClient(DEFAULT_MODEL, token=HUGGINGFACE_API_TOKEN)
def send_chat_prompt(prompt: str, model: str, system_prompt: str) -> str:
full_prompt = f"<|start_of_turn|>system\n{system_prompt}<|end_of_turn|>\n" \
f"<|start_of_turn|>user\n{prompt}<|end_of_turn|>\n" \
f"<|start_of_turn|>assistant\n"
response = client.text_generation(
prompt=full_prompt,
max_new_tokens=500,
temperature=0.5,
stop_sequences=["<|end_of_turn|>"]
)
return response.strip()
def main_generate(prompt, model=DEFAULT_MODEL, system_prompt="You are a helpful assistant that generates SPARQL queries."):
response = send_chat_prompt(prompt, model, system_prompt)
response = response.replace('```', '').replace('json', '').strip()
return response
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "thenlper/gte-large" # embedding model
client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
def get_embeddings(texts):
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
response = client.text_to_vector(text)
# response is usually a list of floats (the embedding vector)
embeddings.append(response)
return embeddings