yakine commited on
Commit
65871c9
·
verified ·
1 Parent(s): dbb92e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -87,63 +87,59 @@ generation_params = {
87
  def generate_synthetic_data(description, columns):
88
  formatted_prompt = format_prompt(description, columns)
89
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
90
- response = requests.post(API_URL, headers={"Authorization": f"Bearer {token}"}, json=payload)
91
- response_data = response.json()
 
 
 
 
92
 
93
  if 'error' in response_data:
94
- return f"Error: {response_data['error']}"
 
 
 
95
 
96
  return response_data[0]["generated_text"]
97
 
98
  def process_generated_data(csv_data, expected_columns):
99
  try:
100
- # Ensure the data is cleaned and correctly formatted
101
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
102
  data = StringIO(cleaned_data)
103
-
104
- # Read the CSV data
105
  df = pd.read_csv(data, delimiter=',')
106
 
107
- # Check if the DataFrame has the expected columns
108
  if set(df.columns) != set(expected_columns):
109
- return f"Unexpected columns in the generated data: {df.columns}"
110
 
111
  return df
112
  except pd.errors.ParserError as e:
113
- return f"Failed to parse CSV data: {e}"
114
 
115
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
116
- csv_data_all = ""
117
 
118
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
119
  generated_data = generate_synthetic_data(description, columns)
120
- if "Error" in generated_data:
121
- return generated_data # Return the error message
122
-
123
  df_synthetic = process_generated_data(generated_data, columns)
 
124
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
125
- csv_data_all += df_synthetic.to_csv(index=False, header=False)
126
- else:
127
- print("Skipping invalid generation.")
128
-
129
- if csv_data_all:
130
  return csv_data_all
131
  else:
132
- return "No valid data frames to concatenate."
133
 
134
  @app.post("/generate/")
135
  def generate_data(request: DataGenerationRequest):
136
  description = request.description.strip()
137
  columns = [col.strip() for col in request.columns]
138
- generated_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
139
-
140
- if isinstance(generated_data, str) and "Error" in generated_data:
141
- return JSONResponse(content={"error": generated_data}, status_code=500)
142
 
143
- # Create a streaming response to return the CSV data
144
- csv_buffer = StringIO(generated_data)
145
  return StreamingResponse(
146
- csv_buffer,
147
  media_type="text/csv",
148
  headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
149
  )
 
87
  def generate_synthetic_data(description, columns):
88
  formatted_prompt = format_prompt(description, columns)
89
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
90
+ response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
91
+
92
+ try:
93
+ response_data = response.json()
94
+ except ValueError:
95
+ raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
96
 
97
  if 'error' in response_data:
98
+ raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
99
+
100
+ if 'generated_text' not in response_data[0]:
101
+ raise HTTPException(status_code=500, detail="Unexpected API response format.")
102
 
103
  return response_data[0]["generated_text"]
104
 
105
  def process_generated_data(csv_data, expected_columns):
106
  try:
 
107
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
108
  data = StringIO(cleaned_data)
 
 
109
  df = pd.read_csv(data, delimiter=',')
110
 
 
111
  if set(df.columns) != set(expected_columns):
112
+ raise ValueError("Unexpected columns in the generated data.")
113
 
114
  return df
115
  except pd.errors.ParserError as e:
116
+ raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}")
117
 
118
  def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
119
+ csv_data_all = StringIO()
120
 
121
  for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
122
  generated_data = generate_synthetic_data(description, columns)
 
 
 
123
  df_synthetic = process_generated_data(generated_data, columns)
124
+
125
  if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
126
+ df_synthetic.to_csv(csv_data_all, index=False, header=False)
127
+
128
+ if csv_data_all.tell() > 0: # Check if there's any data in the buffer
129
+ csv_data_all.seek(0) # Rewind the buffer to the beginning
 
130
  return csv_data_all
131
  else:
132
+ raise HTTPException(status_code=500, detail="No valid data frames generated.")
133
 
134
  @app.post("/generate/")
135
  def generate_data(request: DataGenerationRequest):
136
  description = request.description.strip()
137
  columns = [col.strip() for col in request.columns]
138
+ csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
 
 
 
139
 
140
+ # Return the CSV data as a downloadable file
 
141
  return StreamingResponse(
142
+ csv_data,
143
  media_type="text/csv",
144
  headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
145
  )