Spaces:
Running
Running
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import gradio as gr | |
# Load Personality_LM model and tokenizer | |
model = AutoModelForSequenceClassification.from_pretrained("KevSun/Personality_LM", ignore_mismatched_sizes=True) | |
tokenizer = AutoTokenizer.from_pretrained("KevSun/Personality_LM") | |
def analyze_personality(text): | |
"""Analyze personality traits from input text.""" | |
encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**encoded_input) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_scores = predictions[0].tolist() | |
trait_names = ["agreeableness", "openness", "conscientiousness", "extraversion", "neuroticism"] | |
personality_traits = {trait: score for trait, score in zip(trait_names, predicted_scores)} | |
return personality_traits | |
def adjust_response(response, traits): | |
"""Adjust chatbot response based on personality traits.""" | |
if traits["agreeableness"] > 0.5: | |
response = f"{response} 😊 I'm so glad we get along well!" | |
if traits["neuroticism"] > 0.5: | |
response += " But I'm feeling a bit worried about what might happen..." | |
if traits["extraversion"] > 0.5: | |
response += " Let's keep chatting! I love interacting with you." | |
return response | |
def respond(user_message, history, personality_text): | |
"""Generate chatbot response based on user input and personality.""" | |
traits = analyze_personality(personality_text) | |
base_response = f"Hi! You said: {user_message}" | |
final_response = adjust_response(base_response, traits) | |
history.append((user_message, final_response)) | |
return history, history | |
def personality_demo(): | |
"""Create the Gradio interface for the chatbot with personality training.""" | |
with gr.Blocks() as demo: | |
gr.Markdown("### Personality-Based Chatbot") | |
personality_textbox = gr.Textbox( | |
label="Define Personality Text (Use direct input if no file)", | |
placeholder="Type personality description or paste a sample text here." | |
) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="User Input", placeholder="Say something to the chatbot...") | |
clear = gr.Button("Clear Chat") | |
msg.submit(respond, [msg, chatbot, personality_textbox], [chatbot, chatbot]) | |
clear.click(lambda: ([], []), None, [chatbot, chatbot]) | |
return demo | |
if __name__ == "__main__": | |
demo = personality_demo() | |
demo.launch() | |