Update app.py
Browse files
app.py
CHANGED
@@ -87,63 +87,59 @@ generation_params = {
|
|
87 |
def generate_synthetic_data(description, columns):
|
88 |
formatted_prompt = format_prompt(description, columns)
|
89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
90 |
-
response = requests.post(API_URL, headers={"Authorization": f"Bearer {
|
91 |
-
|
|
|
|
|
|
|
|
|
92 |
|
93 |
if 'error' in response_data:
|
94 |
-
|
|
|
|
|
|
|
95 |
|
96 |
return response_data[0]["generated_text"]
|
97 |
|
98 |
def process_generated_data(csv_data, expected_columns):
|
99 |
try:
|
100 |
-
# Ensure the data is cleaned and correctly formatted
|
101 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
102 |
data = StringIO(cleaned_data)
|
103 |
-
|
104 |
-
# Read the CSV data
|
105 |
df = pd.read_csv(data, delimiter=',')
|
106 |
|
107 |
-
# Check if the DataFrame has the expected columns
|
108 |
if set(df.columns) != set(expected_columns):
|
109 |
-
|
110 |
|
111 |
return df
|
112 |
except pd.errors.ParserError as e:
|
113 |
-
|
114 |
|
115 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
116 |
-
csv_data_all =
|
117 |
|
118 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
119 |
generated_data = generate_synthetic_data(description, columns)
|
120 |
-
if "Error" in generated_data:
|
121 |
-
return generated_data # Return the error message
|
122 |
-
|
123 |
df_synthetic = process_generated_data(generated_data, columns)
|
|
|
124 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
if csv_data_all:
|
130 |
return csv_data_all
|
131 |
else:
|
132 |
-
|
133 |
|
134 |
@app.post("/generate/")
|
135 |
def generate_data(request: DataGenerationRequest):
|
136 |
description = request.description.strip()
|
137 |
columns = [col.strip() for col in request.columns]
|
138 |
-
|
139 |
-
|
140 |
-
if isinstance(generated_data, str) and "Error" in generated_data:
|
141 |
-
return JSONResponse(content={"error": generated_data}, status_code=500)
|
142 |
|
143 |
-
#
|
144 |
-
csv_buffer = StringIO(generated_data)
|
145 |
return StreamingResponse(
|
146 |
-
|
147 |
media_type="text/csv",
|
148 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
149 |
)
|
|
|
87 |
def generate_synthetic_data(description, columns):
|
88 |
formatted_prompt = format_prompt(description, columns)
|
89 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
90 |
+
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
|
91 |
+
|
92 |
+
try:
|
93 |
+
response_data = response.json()
|
94 |
+
except ValueError:
|
95 |
+
raise HTTPException(status_code=500, detail="Failed to parse response from the API.")
|
96 |
|
97 |
if 'error' in response_data:
|
98 |
+
raise HTTPException(status_code=500, detail=f"API Error: {response_data['error']}")
|
99 |
+
|
100 |
+
if 'generated_text' not in response_data[0]:
|
101 |
+
raise HTTPException(status_code=500, detail="Unexpected API response format.")
|
102 |
|
103 |
return response_data[0]["generated_text"]
|
104 |
|
105 |
def process_generated_data(csv_data, expected_columns):
|
106 |
try:
|
|
|
107 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
108 |
data = StringIO(cleaned_data)
|
|
|
|
|
109 |
df = pd.read_csv(data, delimiter=',')
|
110 |
|
|
|
111 |
if set(df.columns) != set(expected_columns):
|
112 |
+
raise ValueError("Unexpected columns in the generated data.")
|
113 |
|
114 |
return df
|
115 |
except pd.errors.ParserError as e:
|
116 |
+
raise HTTPException(status_code=500, detail=f"Failed to parse CSV data: {e}")
|
117 |
|
118 |
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
|
119 |
+
csv_data_all = StringIO()
|
120 |
|
121 |
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
|
122 |
generated_data = generate_synthetic_data(description, columns)
|
|
|
|
|
|
|
123 |
df_synthetic = process_generated_data(generated_data, columns)
|
124 |
+
|
125 |
if isinstance(df_synthetic, pd.DataFrame) and not df_synthetic.empty:
|
126 |
+
df_synthetic.to_csv(csv_data_all, index=False, header=False)
|
127 |
+
|
128 |
+
if csv_data_all.tell() > 0: # Check if there's any data in the buffer
|
129 |
+
csv_data_all.seek(0) # Rewind the buffer to the beginning
|
|
|
130 |
return csv_data_all
|
131 |
else:
|
132 |
+
raise HTTPException(status_code=500, detail="No valid data frames generated.")
|
133 |
|
134 |
@app.post("/generate/")
|
135 |
def generate_data(request: DataGenerationRequest):
|
136 |
description = request.description.strip()
|
137 |
columns = [col.strip() for col in request.columns]
|
138 |
+
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
|
|
|
|
|
|
|
139 |
|
140 |
+
# Return the CSV data as a downloadable file
|
|
|
141 |
return StreamingResponse(
|
142 |
+
csv_data,
|
143 |
media_type="text/csv",
|
144 |
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
|
145 |
)
|