MoraxCheng commited on
Commit
7c1376e
·
1 Parent(s): 9fff9fd

Refactor scoring matrix visualization to simplify font and figure size adjustments, ensuring consistent label formatting and improved clarity.

Browse files
Files changed (1) hide show
  1. app.py +13 -51
app.py CHANGED
@@ -145,21 +145,9 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
145
  filtered_scores=filtered_scores[filtered_scores.position.isin(range(mutation_range_start,mutation_range_end+1))]
146
  piv=filtered_scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
147
 
148
- # Calculate dynamic font size based on matrix dimensions
149
  mutation_range_len = mutation_range_end - mutation_range_start + 1
150
 
151
- # Adjust font size based on number of positions
152
- if mutation_range_len > 30:
153
- fontsize = 8
154
- elif mutation_range_len > 20:
155
- fontsize = 10
156
- elif mutation_range_len > 15:
157
- fontsize = 12
158
- elif mutation_range_len > 10:
159
- fontsize = 14
160
- else:
161
- fontsize = 16
162
-
163
  # Save CSV file
164
  csv_path = 'fitness_scoring_substitution_matrix_{}_{}.csv'.format(unique_id, image_index)
165
 
@@ -189,22 +177,8 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
189
  csv_df.to_csv(csv_path, index=False)
190
 
191
  # Continue with visualization
192
- # Adjust figure size based on content
193
- if mutation_range_len <= 10:
194
- fig_width = max(12, len(AA_vocab) * 0.6)
195
- fig_height = max(8, mutation_range_len * 0.8)
196
- elif mutation_range_len <= 20:
197
- fig_width = max(14, len(AA_vocab) * 0.5)
198
- fig_height = max(10, mutation_range_len * 0.6)
199
- else:
200
- fig_width = max(16, len(AA_vocab) * 0.4)
201
- fig_height = max(12, mutation_range_len * 0.5)
202
-
203
- # Limit maximum size
204
- fig_width = min(fig_width, 30)
205
- fig_height = min(fig_height, 40)
206
-
207
- _, ax = plt.subplots(figsize=(fig_width, fig_height))
208
  scores_dict = {}
209
  valid_mutant_set=set(filtered_scores.mutant)
210
  ax.tick_params(bottom=True, top=True, left=True, right=True)
@@ -221,42 +195,30 @@ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_st
221
  scores_dict[mutant] = float(score_value)
222
  else:
223
  scores_dict[mutant]=0.0
224
- # Format labels based on available space
225
- if fontsize <= 10:
226
- # For small fonts, show only score
227
- labels = (np.asarray(["{:.2f}".format(value) for _, value in scores_dict.items() ])).reshape(mutation_range_len,len(AA_vocab))
228
- else:
229
- # For larger fonts, show mutation and score
230
- labels = (np.asarray(["{} \n{:.3f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(mutation_range_len,len(AA_vocab))
231
 
232
  heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
233
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
234
  else:
235
  heat = sns.heatmap(piv,cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
236
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
237
- # Adjust label sizes proportionally
238
- cbar_label_size = max(10, fontsize * 1.2)
239
- title_size = max(14, fontsize * 1.5)
240
- axis_label_size = max(12, fontsize * 1.3)
241
-
242
- heat.figure.axes[-1].yaxis.label.set_size(fontsize=cbar_label_size)
243
- heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=title_size, pad=20)
244
- heat.set_ylabel("Sequence position", fontsize = axis_label_size)
245
- heat.set_xlabel("Amino Acid mutation", fontsize = axis_label_size)
246
 
247
  # Set y-axis labels (positions)
248
- tick_label_size = max(8, fontsize * 0.8)
249
  yticklabels = [str(pos)+' ('+sequence[pos-1]+')' for pos in range(mutation_range_start,mutation_range_end+1)]
250
- heat.set_yticklabels(yticklabels, fontsize=tick_label_size, rotation=0)
251
 
252
  # Set x-axis labels (amino acids) - ensuring correct number
253
- heat.set_xticklabels(list(AA_vocab), fontsize=tick_label_size)
254
  try:
255
  plt.tight_layout()
256
  image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index)
257
- # Increase DPI for better quality when font is small
258
- dpi = 150 if fontsize <= 10 else 100
259
- plt.savefig(image_path, dpi=dpi, bbox_inches='tight')
260
  return image_path, csv_path
261
  finally:
262
  plt.close('all') # Ensure all figures are closed
 
145
  filtered_scores=filtered_scores[filtered_scores.position.isin(range(mutation_range_start,mutation_range_end+1))]
146
  piv=filtered_scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
147
 
148
+ # Calculate mutation range length
149
  mutation_range_len = mutation_range_end - mutation_range_start + 1
150
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # Save CSV file
152
  csv_path = 'fitness_scoring_substitution_matrix_{}_{}.csv'.format(unique_id, image_index)
153
 
 
177
  csv_df.to_csv(csv_path, index=False)
178
 
179
  # Continue with visualization
180
+ # Use large fixed width for clarity, height scales with positions (as in reference)
181
+ fig, ax = plt.subplots(figsize=(50, mutation_range_len))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  scores_dict = {}
183
  valid_mutant_set=set(filtered_scores.mutant)
184
  ax.tick_params(bottom=True, top=True, left=True, right=True)
 
195
  scores_dict[mutant] = float(score_value)
196
  else:
197
  scores_dict[mutant]=0.0
198
+ # Format labels as in reference - always show mutation and score with 4 decimal places
199
+ labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(mutation_range_len,len(AA_vocab))
 
 
 
 
 
200
 
201
  heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
202
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
203
  else:
204
  heat = sns.heatmap(piv,cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
205
  cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
206
+ # Use label sizes from reference
207
+ heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
208
+ heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
209
+ heat.set_ylabel("Sequence position", fontsize = fontsize*2)
210
+ heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
 
 
 
 
211
 
212
  # Set y-axis labels (positions)
 
213
  yticklabels = [str(pos)+' ('+sequence[pos-1]+')' for pos in range(mutation_range_start,mutation_range_end+1)]
214
+ heat.set_yticklabels(yticklabels, fontsize=fontsize, rotation=0)
215
 
216
  # Set x-axis labels (amino acids) - ensuring correct number
217
+ heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
218
  try:
219
  plt.tight_layout()
220
  image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index)
221
+ plt.savefig(image_path, dpi=100)
 
 
222
  return image_path, csv_path
223
  finally:
224
  plt.close('all') # Ensure all figures are closed