File size: 2,966 Bytes
f0e607c
416c352
ccf2073
c4efeea
ccf2073
c4efeea
bf2110c
c4efeea
416c352
c4efeea
ccf2073
c4efeea
416c352
c4efeea
 
bf2110c
c4efeea
 
bf2110c
c4efeea
 
 
ccf2073
c4efeea
 
 
 
 
 
 
f0e607c
ccf2073
c4efeea
 
 
bf2110c
c4efeea
bf2110c
c4efeea
 
 
f0e607c
c4efeea
 
 
 
 
 
f0e607c
c4efeea
f0e607c
c4efeea
be18c9a
c4efeea
f0e607c
be18c9a
ccf2073
be18c9a
 
 
 
f0e607c
c4efeea
f0e607c
be18c9a
c4efeea
be18c9a
c4efeea
f0e607c
c4efeea
f0e607c
 
 
c4efeea
f0e607c
cdd8645
f0e607c
 
 
 
 
bf2110c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import os
import json
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient, hf_hub_download

# πŸ”Ή Hugging Face Credentials
HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')  # Store your token as an environment variable for security

# πŸ”Ή FAISS Index Path
FAISS_PATH = "asa_faiss.index"

# πŸ”Ή Load Sentence Transformer for Embeddings
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# πŸ”Ή Load FAISS Index from Hugging Face
faiss_local_path = hf_hub_download(HF_REPO, "asa_faiss.index", token=HF_TOKEN)
faiss_index = faiss.read_index(faiss_local_path)

# πŸ”Ή Initialize Hugging Face Model Client
client = InferenceClient(model=HF_REPO, token=HF_TOKEN)

# πŸ”Ή Retrieve Relevant FAISS Context
def retrieve_relevant_context(user_query, top_k=3):
    query_embedding = embedder.encode([user_query], convert_to_tensor=True).cpu().numpy()
    distances, indices = faiss_index.search(query_embedding, top_k)

    retrieved_texts = []
    for idx in indices[0]:  # Extract top_k results
        if idx != -1:  # Ensure valid index
            retrieved_texts.append(f"Example: {idx} β†’ {idx}")  # Customize how retrieved data appears

    return "\n".join(retrieved_texts) if retrieved_texts else "No relevant data found."

# πŸ”Ή Format Model Prompt with FAISS Guidance
def format_prompt(user_input, system_prompt, history):
    retrieved_context = retrieve_relevant_context(user_input)

    faiss_instruction = (
        "Use the following example responses as a guide for formatting and writing style:\n"
        f"{retrieved_context}\n\n"
        "### Instruction:\n"
        f"{user_input}\n\n### Response:"
    )

    return faiss_instruction

# πŸ”Ή Chatbot Response Function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    full_prompt = format_prompt(message, system_message, history)

    response = client.text_generation(
        full_prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    # βœ… Extract only model-generated response
    cleaned_response = response.split("### Response:")[-1].strip()
    
    history.append((message, cleaned_response))  # βœ… Update chat history
    
    yield cleaned_response  # βœ… Output the response

# πŸ”Ή Gradio Chat Interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI trained to follow FAISS-based writing styles.", label="System message"),
        gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()