File size: 2,584 Bytes
d349925
5cdeca9
6e99860
5cdeca9
1d831e0
 
d349925
927f7a3
 
11d423c
 
 
1d831e0
 
 
d349925
 
11d423c
 
a40023d
d349925
1d831e0
a40023d
1d831e0
a40023d
1d831e0
 
 
 
 
 
76e9582
6e99860
1d831e0
 
 
 
 
 
 
 
 
4295e2c
a40023d
d349925
4295e2c
1d831e0
 
 
11d423c
d349925
 
 
 
1d831e0
 
 
 
d349925
4295e2c
1d831e0
d349925
 
 
a40023d
 
b3a1556
76e9582
11d423c
4295e2c
927f7a3
5cdeca9
4815dab
1d831e0
7acb493
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load the NASA-specific bi-encoder model and tokenizer
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
bi_model = AutoModel.from_pretrained(bi_encoder_model_name)

# Set up OpenAI client
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)

def encode_text(text):
    inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    outputs = bi_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()  # Ensure the output is 2D

def retrieve_relevant_context(user_input, context_texts):
    user_embedding = encode_text(user_input).reshape(1, -1)
    context_embeddings = np.array([encode_text(text) for text in context_texts])
    context_embeddings = context_embeddings.reshape(len(context_embeddings), -1)  # Flatten each embedding
    similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
    most_relevant_idx = np.argmax(similarities)
    return context_texts[most_relevant_idx]

def generate_response(user_input, relevant_context):
    combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
    
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "user", "content": combined_input}
        ],
        max_tokens=150,
        temperature=0.7,
        top_p=0.9,
        frequency_penalty=0.5,
        presence_penalty=0.0
    )
    return response.choices[0].message.content.strip()

def chatbot(user_input, context=""):
    context_texts = context.split("\n")
    relevant_context = retrieve_relevant_context(user_input, context_texts) if context else ""
    response = generate_response(user_input, relevant_context)
    return response

# Create the Gradio interface
iface = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your message here..."),
        gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...")
    ],
    outputs="text",
    title="Context-Aware Dynamic Response Chatbot",
    description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses. Enter context to get more refined and relevant responses."
)

# Launch the interface
iface.launch(share=True)