learning4 commited on
Commit
bd6ab92
·
verified ·
1 Parent(s): 9905320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -205
app.py CHANGED
@@ -1,205 +1,208 @@
1
- import logging
2
- from flask import Flask, request, render_template, send_file
3
- import pandas as pd
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
5
- import torch
6
- import os
7
- from datetime import datetime
8
- from datasets import load_dataset
9
- from huggingface_hub import login
10
-
11
- # Load Hugging Face token from environment variable
12
- HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
13
-
14
- # Authenticate with Hugging Face
15
- if HUGGING_FACE_TOKEN:
16
- login(token=HUGGING_FACE_TOKEN)
17
- else:
18
- raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
19
-
20
- # Initialize the Flask application
21
- app = Flask(__name__)
22
-
23
- # Set up the device (CUDA or CPU)
24
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
-
26
- # Optional: Set up logging for debugging
27
- logging.basicConfig(level=logging.DEBUG)
28
-
29
- # Define a function to classify user persona based on the selected model
30
- def classify_persona(text, model, tokenizer):
31
- inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
32
- outputs = model(**inputs)
33
- logits = outputs.logits
34
-
35
- # Convert logits to probabilities
36
- probabilities = torch.nn.functional.softmax(logits, dim=1)
37
-
38
- # Print logits and probabilities for debugging
39
- print(f"Logits: {logits}")
40
- print(f"Probabilities: {probabilities}")
41
-
42
- # Get the predicted classes
43
- predictions = torch.argmax(probabilities, dim=1)
44
-
45
- persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
46
-
47
- # If there are multiple predictions, return the first one (or handle them as needed)
48
- predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
49
-
50
- # For now, let's assume you want the first prediction
51
- return predicted_personas[0]
52
-
53
- # Define the function to determine if a message is polarized
54
- def is_polarized(message):
55
- # If message is a list, join it into a single string
56
- if isinstance(message, list):
57
- message = ' '.join(message)
58
-
59
- polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
60
- return any(keyword in message.lower() for keyword in polarized_keywords)
61
-
62
-
63
- # Define the function to generate AI-based nudges using the selected transformer model
64
- def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
65
- # Ensure min_length is less than or equal to max_length
66
- min_length = min(min_length, max_length)
67
-
68
- if model_type == "seq2seq":
69
- prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
70
- inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
71
- generated_ids = model.generate(
72
- inputs['input_ids'],
73
- max_length=max_length,
74
- min_length=min_length,
75
- temperature=temperature,
76
- top_p=top_p,
77
- repetition_penalty=repetition_penalty,
78
- do_sample=True,
79
- num_beams=4,
80
- early_stopping=True
81
- )
82
- nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
83
- elif model_type == "causal":
84
- prompt = f"{message} [AI Nudge]:"
85
- inputs = tokenizer(prompt, return_tensors='pt').to(device)
86
- generated_ids = model.generate(
87
- inputs['input_ids'],
88
- max_length=max_length,
89
- min_length=min_length,
90
- temperature=temperature,
91
- top_p=top_p,
92
- repetition_penalty=repetition_penalty,
93
- do_sample=True,
94
- )
95
- nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
96
- else:
97
- nudge = "This model is not suitable for generating text."
98
-
99
- return nudge
100
-
101
-
102
- @app.route('/', methods=['GET', 'POST'])
103
- def home():
104
- logging.debug("Home route accessed.")
105
- if request.method == 'POST':
106
- logging.debug("POST request received.")
107
- try:
108
- # Get the model names from the form
109
- persona_model_name = request.form.get('persona_model_name', 'roberta-base')
110
- nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
111
- logging.debug(f"Selected persona model: {persona_model_name}")
112
- logging.debug(f"Selected nudge model: {nudge_model_name}")
113
-
114
- # Load persona classification model
115
- persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
116
- persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)
117
-
118
- # Load nudge generation model
119
- if "bart" in nudge_model_name or "t5" in nudge_model_name:
120
- model_type = "seq2seq"
121
- nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
122
- elif "gpt2" in nudge_model_name:
123
- model_type = "causal"
124
- nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
125
- else:
126
- logging.error("Unsupported model selected.")
127
- return "Selected model is not supported for text generation tasks.", 400
128
-
129
- nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
130
- logging.debug("Models and tokenizers loaded.")
131
-
132
- use_online_dataset = request.form.get('use_online_dataset') == 'yes'
133
-
134
- if use_online_dataset:
135
- # Attempt to load the specified online dataset
136
- dataset_name = request.form.get('dataset_name')
137
- logging.debug(f"Selected online dataset: {dataset_name}")
138
-
139
- if dataset_name == 'personachat':
140
- # Use AlekseyKorshuk/persona-chat if 'personachat' is selected
141
- dataset_name = 'AlekseyKorshuk/persona-chat'
142
-
143
- dataset = load_dataset(dataset_name)
144
- df = pd.DataFrame(dataset['train']) # Use the training split for processing
145
- df = df.rename(columns=lambda x: x.strip().lower())
146
- df = df[['utterances', 'personality']] # Modify this according to the dataset structure
147
- df.columns = ['topic', 'post_reply'] # Standardize column names for processing
148
-
149
- else:
150
- uploaded_file = request.files['file']
151
- if uploaded_file.filename != '':
152
- logging.debug(f"File uploaded: {uploaded_file.filename}")
153
-
154
- df = pd.read_csv(uploaded_file)
155
- df.columns = df.columns.str.strip().str.lower()
156
-
157
- if 'post_reply' not in df.columns:
158
- logging.error("Required column 'post_reply' is missing in the CSV.")
159
- return "The uploaded CSV file must contain 'post_reply' column.", 400
160
-
161
- augmented_rows = []
162
- for _, row in df.iterrows():
163
- if 'user_persona' not in row or pd.isna(row['user_persona']):
164
- # Classify user persona if not provided
165
- row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
166
- augmented_rows.append(row.to_dict())
167
-
168
- if is_polarized(row['post_reply']):
169
- nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
170
- augmented_rows.append({
171
- 'topic': row['topic'],
172
- 'user_persona': 'AI Nudge',
173
- 'post_reply': nudge
174
- })
175
-
176
- augmented_df = pd.DataFrame(augmented_rows)
177
- logging.debug("Processing completed.")
178
-
179
- # Generate the output filename
180
- persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
181
- nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
182
- current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
183
- output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
184
-
185
- # Instead of saving to a directory, create the CSV in memory
186
- csv_buffer = io.BytesIO()
187
- augmented_df.to_csv(csv_buffer, index=False)
188
- csv_buffer.seek(0) # Reset buffer position to the start
189
-
190
- # Directly send the file for download without saving to a specific folder
191
- return send_file(
192
- csv_buffer,
193
- as_attachment=True,
194
- download_name=output_filename,
195
- mimetype='text/csv'
196
- )
197
- except Exception as e:
198
- logging.error(f"Error processing the request: {e}", exc_info=True)
199
- return "There was an error processing your request.", 500
200
-
201
- logging.debug("Rendering index.html")
202
- return render_template('index.html')
203
-
204
- if __name__ == '__main__':
205
- app.run(debug=True)
 
 
 
 
1
+ import logging
2
+ from flask import Flask, request, render_template, send_file
3
+ import pandas as pd
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
5
+ import torch
6
+ import os
7
+ from datetime import datetime
8
+ from datasets import load_dataset
9
+ from huggingface_hub import login
10
+
11
+
12
+ # Load Hugging Face token from environment variable
13
+ HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
14
+
15
+ # Authenticate with Hugging Face
16
+ if HUGGING_FACE_TOKEN:
17
+ login(token=HUGGING_FACE_TOKEN)
18
+ else:
19
+ raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.")
20
+
21
+
22
+
23
+ # Initialize the Flask application
24
+ app = Flask(__name__)
25
+
26
+ # Set up the device (CUDA or CPU)
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+
29
+ # Optional: Set up logging for debugging
30
+ logging.basicConfig(level=logging.DEBUG)
31
+
32
+ # Define a function to classify user persona based on the selected model
33
+ def classify_persona(text, model, tokenizer):
34
+ inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+
38
+ # Convert logits to probabilities
39
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
40
+
41
+ # Print logits and probabilities for debugging
42
+ print(f"Logits: {logits}")
43
+ print(f"Probabilities: {probabilities}")
44
+
45
+ # Get the predicted classes
46
+ predictions = torch.argmax(probabilities, dim=1)
47
+
48
+ persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'}
49
+
50
+ # If there are multiple predictions, return the first one (or handle them as needed)
51
+ predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions]
52
+
53
+ # For now, let's assume you want the first prediction
54
+ return predicted_personas[0]
55
+
56
+ # Define the function to determine if a message is polarized
57
+ def is_polarized(message):
58
+ # If message is a list, join it into a single string
59
+ if isinstance(message, list):
60
+ message = ' '.join(message)
61
+
62
+ polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"]
63
+ return any(keyword in message.lower() for keyword in polarized_keywords)
64
+
65
+
66
+ # Define the function to generate AI-based nudges using the selected transformer model
67
+ def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1):
68
+ # Ensure min_length is less than or equal to max_length
69
+ min_length = min(min_length, max_length)
70
+
71
+ if model_type == "seq2seq":
72
+ prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}"
73
+ inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device)
74
+ generated_ids = model.generate(
75
+ inputs['input_ids'],
76
+ max_length=max_length,
77
+ min_length=min_length,
78
+ temperature=temperature,
79
+ top_p=top_p,
80
+ repetition_penalty=repetition_penalty,
81
+ do_sample=True,
82
+ num_beams=4,
83
+ early_stopping=True
84
+ )
85
+ nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
86
+ elif model_type == "causal":
87
+ prompt = f"{message} [AI Nudge]:"
88
+ inputs = tokenizer(prompt, return_tensors='pt').to(device)
89
+ generated_ids = model.generate(
90
+ inputs['input_ids'],
91
+ max_length=max_length,
92
+ min_length=min_length,
93
+ temperature=temperature,
94
+ top_p=top_p,
95
+ repetition_penalty=repetition_penalty,
96
+ do_sample=True,
97
+ )
98
+ nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
99
+ else:
100
+ nudge = "This model is not suitable for generating text."
101
+
102
+ return nudge
103
+
104
+
105
+ @app.route('/', methods=['GET', 'POST'])
106
+ def home():
107
+ logging.debug("Home route accessed.")
108
+ if request.method == 'POST':
109
+ logging.debug("POST request received.")
110
+ try:
111
+ # Get the model names from the form
112
+ persona_model_name = request.form.get('persona_model_name', 'roberta-base')
113
+ nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn')
114
+ logging.debug(f"Selected persona model: {persona_model_name}")
115
+ logging.debug(f"Selected nudge model: {nudge_model_name}")
116
+
117
+ # Load persona classification model
118
+ persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device)
119
+ persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name)
120
+
121
+ # Load nudge generation model
122
+ if "bart" in nudge_model_name or "t5" in nudge_model_name:
123
+ model_type = "seq2seq"
124
+ nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device)
125
+ elif "gpt2" in nudge_model_name:
126
+ model_type = "causal"
127
+ nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device)
128
+ else:
129
+ logging.error("Unsupported model selected.")
130
+ return "Selected model is not supported for text generation tasks.", 400
131
+
132
+ nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name)
133
+ logging.debug("Models and tokenizers loaded.")
134
+
135
+ use_online_dataset = request.form.get('use_online_dataset') == 'yes'
136
+
137
+ if use_online_dataset:
138
+ # Attempt to load the specified online dataset
139
+ dataset_name = request.form.get('dataset_name')
140
+ logging.debug(f"Selected online dataset: {dataset_name}")
141
+
142
+ if dataset_name == 'personachat':
143
+ # Use AlekseyKorshuk/persona-chat if 'personachat' is selected
144
+ dataset_name = 'AlekseyKorshuk/persona-chat'
145
+
146
+ dataset = load_dataset(dataset_name)
147
+ df = pd.DataFrame(dataset['train']) # Use the training split for processing
148
+ df = df.rename(columns=lambda x: x.strip().lower())
149
+ df = df[['utterances', 'personality']] # Modify this according to the dataset structure
150
+ df.columns = ['topic', 'post_reply'] # Standardize column names for processing
151
+
152
+ else:
153
+ uploaded_file = request.files['file']
154
+ if uploaded_file.filename != '':
155
+ logging.debug(f"File uploaded: {uploaded_file.filename}")
156
+
157
+ df = pd.read_csv(uploaded_file)
158
+ df.columns = df.columns.str.strip().str.lower()
159
+
160
+ if 'post_reply' not in df.columns:
161
+ logging.error("Required column 'post_reply' is missing in the CSV.")
162
+ return "The uploaded CSV file must contain 'post_reply' column.", 400
163
+
164
+ augmented_rows = []
165
+ for _, row in df.iterrows():
166
+ if 'user_persona' not in row or pd.isna(row['user_persona']):
167
+ # Classify user persona if not provided
168
+ row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer)
169
+ augmented_rows.append(row.to_dict())
170
+
171
+ if is_polarized(row['post_reply']):
172
+ nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type)
173
+ augmented_rows.append({
174
+ 'topic': row['topic'],
175
+ 'user_persona': 'AI Nudge',
176
+ 'post_reply': nudge
177
+ })
178
+
179
+ augmented_df = pd.DataFrame(augmented_rows)
180
+ logging.debug("Processing completed.")
181
+
182
+ # Generate the output filename
183
+ persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_')
184
+ nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_')
185
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
186
+ output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv"
187
+
188
+ # Instead of saving to a directory, create the CSV in memory
189
+ csv_buffer = io.BytesIO()
190
+ augmented_df.to_csv(csv_buffer, index=False)
191
+ csv_buffer.seek(0) # Reset buffer position to the start
192
+
193
+ # Directly send the file for download without saving to a specific folder
194
+ return send_file(
195
+ csv_buffer,
196
+ as_attachment=True,
197
+ download_name=output_filename,
198
+ mimetype='text/csv'
199
+ )
200
+ except Exception as e:
201
+ logging.error(f"Error processing the request: {e}", exc_info=True)
202
+ return "There was an error processing your request.", 500
203
+
204
+ logging.debug("Rendering index.html")
205
+ return render_template('index.html')
206
+
207
+ if __name__ == '__main__':
208
+ app.run(debug=True)