Spaces:
Sleeping
Sleeping
File size: 2,584 Bytes
d349925 5cdeca9 6e99860 5cdeca9 1d831e0 d349925 927f7a3 11d423c 1d831e0 d349925 11d423c a40023d d349925 1d831e0 a40023d 1d831e0 a40023d 1d831e0 76e9582 6e99860 1d831e0 4295e2c a40023d d349925 4295e2c 1d831e0 11d423c d349925 1d831e0 d349925 4295e2c 1d831e0 d349925 a40023d b3a1556 76e9582 11d423c 4295e2c 927f7a3 5cdeca9 4815dab 1d831e0 7acb493 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Load the NASA-specific bi-encoder model and tokenizer
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
# Set up OpenAI client
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)
def encode_text(text):
inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
outputs = bi_model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # Ensure the output is 2D
def retrieve_relevant_context(user_input, context_texts):
user_embedding = encode_text(user_input).reshape(1, -1)
context_embeddings = np.array([encode_text(text) for text in context_texts])
context_embeddings = context_embeddings.reshape(len(context_embeddings), -1) # Flatten each embedding
similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
most_relevant_idx = np.argmax(similarities)
return context_texts[most_relevant_idx]
def generate_response(user_input, relevant_context):
combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "user", "content": combined_input}
],
max_tokens=150,
temperature=0.7,
top_p=0.9,
frequency_penalty=0.5,
presence_penalty=0.0
)
return response.choices[0].message.content.strip()
def chatbot(user_input, context=""):
context_texts = context.split("\n")
relevant_context = retrieve_relevant_context(user_input, context_texts) if context else ""
response = generate_response(user_input, relevant_context)
return response
# Create the Gradio interface
iface = gr.Interface(
fn=chatbot,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your message here..."),
gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...")
],
outputs="text",
title="Context-Aware Dynamic Response Chatbot",
description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses. Enter context to get more refined and relevant responses."
)
# Launch the interface
iface.launch(share=True)
|