yakine commited on
Commit
be8d2d5
·
verified ·
1 Parent(s): 4b17ebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -24,25 +24,40 @@ hf_token = os.getenv('HF_API_TOKEN')
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
27
- # Load GPT-2 model and tokenizer
28
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
29
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
30
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
31
 
32
- # Prompt template
33
  prompt_template = """\
34
  You are an expert in generating synthetic data for machine learning models.
 
35
  Your task is to generate a synthetic tabular dataset based on the description provided below.
 
36
  Description: {description}
 
37
  The dataset should include the following columns: {columns}
 
38
  Please provide the data in CSV format.
39
- """
40
 
41
- # Set up the Mixtral model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
43
 
44
  def preprocess_user_prompt(user_prompt):
45
- # Generate a structured prompt based on the user input
46
  generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
47
  return generated_text
48
 
@@ -64,28 +79,32 @@ generation_params = {
64
  def generate_synthetic_data(description, columns):
65
  formatted_prompt = format_prompt(description, columns)
66
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
67
- response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
68
-
69
- if response.status_code != 200:
70
- raise HTTPException(status_code=response.status_code, detail="Error from Hugging Face API")
71
-
72
- return response.json()[0]["generated_text"]
 
 
 
 
 
73
 
74
  def process_generated_data(csv_data, expected_columns):
75
  try:
76
- # Ensure the data is cleaned and correctly formatted
77
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
78
- data = StringIO(cleaned_data)
79
 
80
- # Read the CSV data
81
- df = pd.read_csv(data, delimiter=',')
82
 
83
- # Check if the DataFrame has the expected columns
84
- if set(df.columns) != set(expected_columns):
85
- print(f"Unexpected columns in the generated data: {df.columns}")
86
- return None
87
 
88
  return df
 
89
  except pd.errors.ParserError as e:
90
  print(f"Failed to parse CSV data: {e}")
91
  return None
@@ -101,7 +120,7 @@ def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_
101
  data_frames.append(df_synthetic)
102
  else:
103
  print("Skipping invalid generation.")
104
-
105
  if data_frames:
106
  return pd.concat(data_frames, ignore_index=True)
107
  else:
@@ -133,6 +152,4 @@ def generate_data(request: DataGenerationRequest):
133
 
134
  @app.get("/")
135
  def greet_json():
136
- return {"Hello": "World!"}
137
-
138
-
 
24
  if not hf_token:
25
  raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
26
 
 
27
  tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
28
  model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
29
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
30
 
 
31
  prompt_template = """\
32
  You are an expert in generating synthetic data for machine learning models.
33
+
34
  Your task is to generate a synthetic tabular dataset based on the description provided below.
35
+
36
  Description: {description}
37
+
38
  The dataset should include the following columns: {columns}
39
+
40
  Please provide the data in CSV format.
 
41
 
42
+ Example Description:
43
+ Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
44
+
45
+ Example Output:
46
+ Size,Location,Number of Bedrooms,Price
47
+ 1200,Suburban,3,250000
48
+ 900,Urban,2,200000
49
+ 1500,Rural,4,300000
50
+ ...
51
+
52
+ Description:
53
+ {description}
54
+ Columns:
55
+ {columns}
56
+ Output: """
57
+
58
  tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
59
 
60
  def preprocess_user_prompt(user_prompt):
 
61
  generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
62
  return generated_text
63
 
 
79
  def generate_synthetic_data(description, columns):
80
  formatted_prompt = format_prompt(description, columns)
81
  payload = {"inputs": formatted_prompt, "parameters": generation_params}
82
+ try:
83
+ response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
84
+ response.raise_for_status()
85
+ data = response.json()
86
+ if 'generated_text' in data[0]:
87
+ return data[0]['generated_text']
88
+ else:
89
+ raise ValueError("Invalid response format from Hugging Face API.")
90
+ except (requests.RequestException, ValueError) as e:
91
+ print(f"Error during API request or response processing: {e}")
92
+ return ""
93
 
94
  def process_generated_data(csv_data, expected_columns):
95
  try:
96
+ # Replace inconsistent line endings
97
  cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
 
98
 
99
+ # Check for common CSV formatting issues and apply corrections
100
+ cleaned_data = cleaned_data.strip().replace('|', ',').replace(' ', ' ').replace(' ,', ',')
101
 
102
+ # Load the cleaned data into a DataFrame
103
+ data = StringIO(cleaned_data)
104
+ df = pd.read_csv(data, delimiter=',')
 
105
 
106
  return df
107
+
108
  except pd.errors.ParserError as e:
109
  print(f"Failed to parse CSV data: {e}")
110
  return None
 
120
  data_frames.append(df_synthetic)
121
  else:
122
  print("Skipping invalid generation.")
123
+
124
  if data_frames:
125
  return pd.concat(data_frames, ignore_index=True)
126
  else:
 
152
 
153
  @app.get("/")
154
  def greet_json():
155
+ return {"Hello": "World!"}