luohoa97 commited on
Commit
d1cb6f6
·
verified ·
1 Parent(s): 78be6a9

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +118 -0
chat.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ # Initialize model and tokenizer as global variables
8
+ model = None
9
+ tokenizer = None
10
+
11
+ # Dictionary to store user instructions for future responses
12
+ user_instructions = {}
13
+
14
+ # Dummy dataset class for user feedback
15
+ class FeedbackDataset(Dataset):
16
+ def __init__(self, input_texts, target_texts):
17
+ self.input_texts = input_texts
18
+ self.target_texts = target_texts
19
+
20
+ def __len__(self):
21
+ return len(self.input_texts)
22
+
23
+ def __getitem__(self, idx):
24
+ inputs = tokenizer.encode(self.input_texts[idx], return_tensors="pt").squeeze()
25
+ targets = tokenizer.encode(self.target_texts[idx], return_tensors="pt").squeeze()
26
+ return {"input_ids": inputs, "labels": targets}
27
+
28
+ def load_model(model_name_or_path):
29
+ global model, tokenizer
30
+
31
+ st.write(f"Loading model from {model_name_or_path}...")
32
+
33
+ # Load the tokenizer and model
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
36
+
37
+ st.success("Model loaded successfully!")
38
+
39
+ def generate_response(input_text):
40
+ # Ensure model and tokenizer are loaded
41
+ if model is None or tokenizer is None:
42
+ st.error("Model is not loaded. Please load a model first.")
43
+ return ""
44
+
45
+ # Check if there's a user-defined response
46
+ if input_text in user_instructions:
47
+ return user_instructions[input_text]
48
+
49
+ # Encode input text
50
+ inputs = tokenizer.encode(input_text, return_tensors="pt")
51
+
52
+ # Generate response using the model
53
+ with torch.no_grad():
54
+ outputs = model.generate(
55
+ inputs, max_length=100, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95
56
+ )
57
+
58
+ # Decode and return the response
59
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+ return response
61
+
62
+ def train_on_feedback(input_text, correct_response):
63
+ # Prepare dataset
64
+ dataset = FeedbackDataset([input_text], [correct_response])
65
+
66
+ # Training arguments
67
+ training_args = TrainingArguments(
68
+ output_dir="./feedback_model",
69
+ num_train_epochs=1,
70
+ per_device_train_batch_size=1,
71
+ learning_rate=1e-5,
72
+ logging_dir='./logs',
73
+ logging_steps=10,
74
+ save_steps=100
75
+ )
76
+
77
+ # Trainer for the feedback loop
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ train_dataset=dataset,
82
+ )
83
+
84
+ # Train model on the feedback
85
+ trainer.train()
86
+
87
+ def chat_interface():
88
+ st.title("🤖 Chat with AI")
89
+
90
+ # Input for model name or path
91
+ model_name_or_path = st.text_input("Enter model name or local path:", "gpt2")
92
+
93
+ # Button to load the model
94
+ if st.button("Load Model"):
95
+ load_model(model_name_or_path)
96
+
97
+ st.write("---")
98
+
99
+ # Chat input
100
+ input_text = st.text_input("You:")
101
+
102
+ if st.button("Send"):
103
+ response = generate_response(input_text)
104
+ st.write("AI:", response)
105
+
106
+ # Feedback section
107
+ feedback = st.radio("Was this response helpful?", ("Yes", "No"))
108
+
109
+ if feedback == "No":
110
+ correct_response = st.text_input("What should the AI have said?")
111
+ if st.button("Submit Feedback"):
112
+ # Train model on feedback
113
+ train_on_feedback(input_text, correct_response)
114
+ st.success("Feedback recorded. AI will improve based on this feedback.")
115
+
116
+ # Run chat interface
117
+ if __name__ == "__main__":
118
+ chat_interface()