Update app.py
Browse files
app.py
CHANGED
@@ -24,25 +24,40 @@ hf_token = os.getenv('HF_API_TOKEN')
|
|
24 |
if not hf_token:
|
25 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
26 |
|
27 |
-
# Load GPT-2 model and tokenizer
|
28 |
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
|
29 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
30 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
31 |
|
32 |
-
# Prompt template
|
33 |
prompt_template = """\
|
34 |
You are an expert in generating synthetic data for machine learning models.
|
|
|
35 |
Your task is to generate a synthetic tabular dataset based on the description provided below.
|
|
|
36 |
Description: {description}
|
|
|
37 |
The dataset should include the following columns: {columns}
|
|
|
38 |
Please provide the data in CSV format.
|
39 |
-
"""
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
43 |
|
44 |
def preprocess_user_prompt(user_prompt):
|
45 |
-
# Generate a structured prompt based on the user input
|
46 |
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
47 |
return generated_text
|
48 |
|
@@ -64,28 +79,32 @@ generation_params = {
|
|
64 |
def generate_synthetic_data(description, columns):
|
65 |
formatted_prompt = format_prompt(description, columns)
|
66 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def process_generated_data(csv_data, expected_columns):
|
75 |
try:
|
76 |
-
#
|
77 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
78 |
-
data = StringIO(cleaned_data)
|
79 |
|
80 |
-
#
|
81 |
-
|
82 |
|
83 |
-
#
|
84 |
-
|
85 |
-
|
86 |
-
return None
|
87 |
|
88 |
return df
|
|
|
89 |
except pd.errors.ParserError as e:
|
90 |
print(f"Failed to parse CSV data: {e}")
|
91 |
return None
|
@@ -101,7 +120,7 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
|
|
101 |
data_frames.append(df_synthetic)
|
102 |
else:
|
103 |
print("Skipping invalid generation.")
|
104 |
-
|
105 |
if data_frames:
|
106 |
return pd.concat(data_frames, ignore_index=True)
|
107 |
else:
|
@@ -133,6 +152,4 @@ def generate_data(request: DataGenerationRequest):
|
|
133 |
|
134 |
@app.get("/")
|
135 |
def greet_json():
|
136 |
-
return {"Hello": "World!"}
|
137 |
-
|
138 |
-
|
|
|
24 |
if not hf_token:
|
25 |
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
|
26 |
|
|
|
27 |
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
|
28 |
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
|
29 |
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
|
30 |
|
|
|
31 |
prompt_template = """\
|
32 |
You are an expert in generating synthetic data for machine learning models.
|
33 |
+
|
34 |
Your task is to generate a synthetic tabular dataset based on the description provided below.
|
35 |
+
|
36 |
Description: {description}
|
37 |
+
|
38 |
The dataset should include the following columns: {columns}
|
39 |
+
|
40 |
Please provide the data in CSV format.
|
|
|
41 |
|
42 |
+
Example Description:
|
43 |
+
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
|
44 |
+
|
45 |
+
Example Output:
|
46 |
+
Size,Location,Number of Bedrooms,Price
|
47 |
+
1200,Suburban,3,250000
|
48 |
+
900,Urban,2,200000
|
49 |
+
1500,Rural,4,300000
|
50 |
+
...
|
51 |
+
|
52 |
+
Description:
|
53 |
+
{description}
|
54 |
+
Columns:
|
55 |
+
{columns}
|
56 |
+
Output: """
|
57 |
+
|
58 |
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
|
59 |
|
60 |
def preprocess_user_prompt(user_prompt):
|
|
|
61 |
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
|
62 |
return generated_text
|
63 |
|
|
|
79 |
def generate_synthetic_data(description, columns):
|
80 |
formatted_prompt = format_prompt(description, columns)
|
81 |
payload = {"inputs": formatted_prompt, "parameters": generation_params}
|
82 |
+
try:
|
83 |
+
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
|
84 |
+
response.raise_for_status()
|
85 |
+
data = response.json()
|
86 |
+
if 'generated_text' in data[0]:
|
87 |
+
return data[0]['generated_text']
|
88 |
+
else:
|
89 |
+
raise ValueError("Invalid response format from Hugging Face API.")
|
90 |
+
except (requests.RequestException, ValueError) as e:
|
91 |
+
print(f"Error during API request or response processing: {e}")
|
92 |
+
return ""
|
93 |
|
94 |
def process_generated_data(csv_data, expected_columns):
|
95 |
try:
|
96 |
+
# Replace inconsistent line endings
|
97 |
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
|
|
|
98 |
|
99 |
+
# Check for common CSV formatting issues and apply corrections
|
100 |
+
cleaned_data = cleaned_data.strip().replace('|', ',').replace(' ', ' ').replace(' ,', ',')
|
101 |
|
102 |
+
# Load the cleaned data into a DataFrame
|
103 |
+
data = StringIO(cleaned_data)
|
104 |
+
df = pd.read_csv(data, delimiter=',')
|
|
|
105 |
|
106 |
return df
|
107 |
+
|
108 |
except pd.errors.ParserError as e:
|
109 |
print(f"Failed to parse CSV data: {e}")
|
110 |
return None
|
|
|
120 |
data_frames.append(df_synthetic)
|
121 |
else:
|
122 |
print("Skipping invalid generation.")
|
123 |
+
|
124 |
if data_frames:
|
125 |
return pd.concat(data_frames, ignore_index=True)
|
126 |
else:
|
|
|
152 |
|
153 |
@app.get("/")
|
154 |
def greet_json():
|
155 |
+
return {"Hello": "World!"}
|
|
|
|