pquintero commited on
Commit
b2a1e67
Β·
1 Parent(s): 845443f

validate cv

Browse files
Files changed (4) hide show
  1. constants.py +11 -1
  2. data/example-predictions-cv.csv +0 -0
  3. submit.py +7 -2
  4. validation.py +105 -6
constants.py CHANGED
@@ -36,7 +36,17 @@ REQUIRED_COLUMNS: list[str] = [
36
  "vh_protein_sequence",
37
  "vl_protein_sequence",
38
  ]
39
- ANTIBODY_NAMES = pd.read_csv("data/example-predictions.csv")["antibody_name"].tolist()
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Huggingface API
42
  TOKEN = os.environ.get("HF_TOKEN")
 
36
  "vh_protein_sequence",
37
  "vl_protein_sequence",
38
  ]
39
+ # Cross validation
40
+ CV_COLUMN = "hierarchical_cluster_IgG_isotype_stratified_fold"
41
+ # Example files
42
+ EXAMPLE_FILE_DICT = {
43
+ "GDPa1": "data/example-predictions.csv",
44
+ "GDPa1_CV": "data/example-predictions-cv.csv",
45
+ }
46
+ ANTIBODY_NAMES_DICT = {
47
+ "GDPa1": pd.read_csv(EXAMPLE_FILE_DICT["GDPa1"])["antibody_name"].tolist(),
48
+ "GDPa1_CV": pd.read_csv(EXAMPLE_FILE_DICT["GDPa1_CV"])["antibody_name"].tolist(),
49
+ }
50
 
51
  # Huggingface API
52
  TOKEN = os.environ.get("HF_TOKEN")
data/example-predictions-cv.csv ADDED
The diff for this file is too large to render. See raw diff
 
submit.py CHANGED
@@ -11,7 +11,12 @@ from constants import API, SUBMISSIONS_REPO
11
  from validation import validate_csv_file
12
 
13
 
14
- def make_submission(submitted_file: BinaryIO, user_state, anonymous_state):
 
 
 
 
 
15
  if user_state is None:
16
  raise gr.Error("You must submit your username to submit a file.")
17
 
@@ -34,7 +39,7 @@ def make_submission(submitted_file: BinaryIO, user_state, anonymous_state):
34
  with path_obj.open("rb") as f_in:
35
  file_content = f_in.read().decode("utf-8")
36
 
37
- validate_csv_file(file_content)
38
 
39
  # write to dataset
40
  filename = f"{submission_id}.json"
 
11
  from validation import validate_csv_file
12
 
13
 
14
+ def make_submission(
15
+ submitted_file: BinaryIO,
16
+ user_state,
17
+ anonymous_state,
18
+ submission_type: str = "GDPa1",
19
+ ):
20
  if user_state is None:
21
  raise gr.Error("You must submit your username to submit a file.")
22
 
 
39
  with path_obj.open("rb") as f_in:
40
  file_content = f_in.read().decode("utf-8")
41
 
42
+ validate_csv_file(file_content, submission_type)
43
 
44
  # write to dataset
45
  filename = f"{submission_id}.json"
validation.py CHANGED
@@ -4,8 +4,10 @@ import gradio as gr
4
  from constants import (
5
  REQUIRED_COLUMNS,
6
  MINIMAL_NUMBER_OF_ROWS,
7
- ANTIBODY_NAMES,
8
  ASSAY_LIST,
 
 
 
9
  )
10
 
11
 
@@ -46,7 +48,90 @@ def validate_csv_can_be_read(file_content: str) -> pd.DataFrame:
46
  raise gr.Error(f"❌ Unexpected error reading CSV file: {str(e)}")
47
 
48
 
49
- def validate_dataframe(df: pd.DataFrame) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
  Validate the DataFrame content and structure.
52
 
@@ -54,18 +139,23 @@ def validate_dataframe(df: pd.DataFrame) -> None:
54
  ----------
55
  df: pd.DataFrame
56
  The DataFrame to validate.
 
 
57
 
58
  Raises
59
  ------
60
  gr.Error: If validation fails
61
  """
 
 
 
62
  # Required columns should be present
63
  missing_columns = set(REQUIRED_COLUMNS) - set(df.columns)
64
  if missing_columns:
65
  raise gr.Error(f"❌ Missing required columns: {', '.join(missing_columns)}")
66
 
67
  # Should include at least 1 assay column
68
- assay_columns = [col for col in df.columns if col in ASSAY_LIST]
69
  if len(assay_columns) < 1:
70
  raise gr.Error(
71
  "❌ CSV should include at least one of the following assay columns: "
@@ -96,14 +186,21 @@ def validate_dataframe(df: pd.DataFrame) -> None:
96
  )
97
 
98
  # All antibody names should be recognizable
99
- unrecognized_antibodies = set(df["antibody_name"]) - set(ANTIBODY_NAMES)
 
 
100
  if unrecognized_antibodies:
101
  raise gr.Error(
102
  f"❌ Found unrecognized antibody names: {', '.join(unrecognized_antibodies)}"
103
  )
 
 
 
 
 
104
 
105
 
106
- def validate_csv_file(file_content: str) -> None:
107
  """
108
  Validate the uploaded CSV file.
109
 
@@ -111,10 +208,12 @@ def validate_csv_file(file_content: str) -> None:
111
  ----------
112
  file_content: str
113
  The content of the uploaded CSV file.
 
 
114
 
115
  Raises
116
  ------
117
  gr.Error: If validation fails
118
  """
119
  df = validate_csv_can_be_read(file_content)
120
- validate_dataframe(df)
 
4
  from constants import (
5
  REQUIRED_COLUMNS,
6
  MINIMAL_NUMBER_OF_ROWS,
 
7
  ASSAY_LIST,
8
+ CV_COLUMN,
9
+ EXAMPLE_FILE_DICT,
10
+ ANTIBODY_NAMES_DICT,
11
  )
12
 
13
 
 
48
  raise gr.Error(f"❌ Unexpected error reading CSV file: {str(e)}")
49
 
50
 
51
+ def validate_cv_submission(df: pd.DataFrame, submission_type: str = "GDPa1_CV") -> None:
52
+ """Validate cross-validation submission"""
53
+ # Must have CV_COLUMN for CV submissions
54
+ if CV_COLUMN not in df.columns:
55
+ raise gr.Error(f"❌ CV submissions must include a '{CV_COLUMN}' column")
56
+
57
+ # Load canonical fold assignments
58
+ expected_cv_df = pd.read_csv(EXAMPLE_FILE_DICT[submission_type])[
59
+ ["antibody_name", CV_COLUMN]
60
+ ]
61
+ antibody_check = expected_cv_df.merge(
62
+ df[["antibody_name", CV_COLUMN]],
63
+ on="antibody_name",
64
+ how="left",
65
+ suffixes=("_expected", "_submitted"),
66
+ )
67
+ # All antibodies should be present if using CV
68
+ missing_antibodies_mask = antibody_check[f"{CV_COLUMN}_submitted"].isna()
69
+ n_missing_antibodies = missing_antibodies_mask.sum()
70
+ if n_missing_antibodies > 0:
71
+ missing_antibodies = (
72
+ antibody_check[missing_antibodies_mask]["antibody_name"].head(5).tolist()
73
+ )
74
+ raise gr.Error(
75
+ f"❌ Missing predictions for {n_missing_antibodies} antibodies. Examples: {', '.join(missing_antibodies)}"
76
+ )
77
+ # CV fold assignments should match
78
+ fold_mismatches = antibody_check[
79
+ antibody_check[f"{CV_COLUMN}_expected"]
80
+ != antibody_check[f"{CV_COLUMN}_submitted"]
81
+ ]
82
+ if len(fold_mismatches) > 0:
83
+ examples = []
84
+ for _, row in fold_mismatches.head(3).iterrows():
85
+ examples.append(
86
+ f"{row['antibody_name']} (expected fold {row[f'{CV_COLUMN}_expected']}, got {row[f'{CV_COLUMN}_submitted']})"
87
+ )
88
+ raise gr.Error(
89
+ f"❌ Fold assignments don't match canonical CV folds: {'; '.join(examples)}"
90
+ )
91
+
92
+ # Merge on both columns for assay validation
93
+ merged_cv_df = expected_cv_df.merge(df, on=["antibody_name", CV_COLUMN], how="left")
94
+
95
+ # Check for missing assay predictions
96
+ assay_columns = get_assay_columns(merged_cv_df)
97
+ for assay_column in assay_columns:
98
+ missing_antibodies = merged_cv_df[merged_cv_df[assay_column].isna()][
99
+ "antibody_name"
100
+ ].unique()
101
+ if len(missing_antibodies) > 0:
102
+ raise gr.Error(
103
+ f"❌ Missing {assay_column} predictions for {len(missing_antibodies)} antibodies: {', '.join(missing_antibodies[:5])}"
104
+ )
105
+
106
+ # Step 5: Check that submission length matches expected
107
+ if len(merged_cv_df) != len(expected_cv_df):
108
+ raise gr.Error(
109
+ f"❌ Expected {len(expected_cv_df)} rows, got {len(merged_cv_df)}"
110
+ )
111
+
112
+
113
+ def validate_full_dataset_submission(df: pd.DataFrame) -> None:
114
+ """Validate full dataset submission"""
115
+ if CV_COLUMN in df.columns:
116
+ raise gr.Error(
117
+ f"❌ Your submission contains a '{CV_COLUMN}' column. "
118
+ "Please select 'Cross-Validation Predictions' if you want to submit CV results."
119
+ )
120
+
121
+ # All names should be unique (duplicates check from original validation)
122
+ n_duplicates = df["antibody_name"].duplicated().sum()
123
+ if n_duplicates > 0:
124
+ raise gr.Error(
125
+ f"❌ Standard submissions should have only one prediction per antibody. Found {n_duplicates} duplicates."
126
+ )
127
+
128
+
129
+ def get_assay_columns(df: pd.DataFrame) -> list[str]:
130
+ """Get all assay columns from the DataFrame"""
131
+ return [col for col in df.columns if col in ASSAY_LIST]
132
+
133
+
134
+ def validate_dataframe(df: pd.DataFrame, submission_type: str = "GDPa1") -> None:
135
  """
136
  Validate the DataFrame content and structure.
137
 
 
139
  ----------
140
  df: pd.DataFrame
141
  The DataFrame to validate.
142
+ submission_type: str
143
+ Type of submission: "GDPa1" or "GDPa1_CV"
144
 
145
  Raises
146
  ------
147
  gr.Error: If validation fails
148
  """
149
+ if submission_type not in EXAMPLE_FILE_DICT.keys():
150
+ raise ValueError(f"Invalid submission type: {submission_type}")
151
+
152
  # Required columns should be present
153
  missing_columns = set(REQUIRED_COLUMNS) - set(df.columns)
154
  if missing_columns:
155
  raise gr.Error(f"❌ Missing required columns: {', '.join(missing_columns)}")
156
 
157
  # Should include at least 1 assay column
158
+ assay_columns = get_assay_columns(df)
159
  if len(assay_columns) < 1:
160
  raise gr.Error(
161
  "❌ CSV should include at least one of the following assay columns: "
 
186
  )
187
 
188
  # All antibody names should be recognizable
189
+ unrecognized_antibodies = set(df["antibody_name"]) - set(
190
+ ANTIBODY_NAMES_DICT[submission_type]
191
+ )
192
  if unrecognized_antibodies:
193
  raise gr.Error(
194
  f"❌ Found unrecognized antibody names: {', '.join(unrecognized_antibodies)}"
195
  )
196
+ # Submission-type specific validation
197
+ if submission_type.endswith("_CV"):
198
+ validate_cv_submission(df, submission_type)
199
+ else: # full_dataset
200
+ validate_full_dataset_submission(df)
201
 
202
 
203
+ def validate_csv_file(file_content: str, submission_type: str = "GDPa1") -> None:
204
  """
205
  Validate the uploaded CSV file.
206
 
 
208
  ----------
209
  file_content: str
210
  The content of the uploaded CSV file.
211
+ submission_type: str
212
+ Type of submission: "standard" or "cv"
213
 
214
  Raises
215
  ------
216
  gr.Error: If validation fails
217
  """
218
  df = validate_csv_can_be_read(file_content)
219
+ validate_dataframe(df, submission_type)