|
from fastapi import FastAPI |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
import pandas as pd |
|
import os |
|
import torch |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from io import StringIO |
|
from tqdm import tqdm |
|
import accelerate |
|
from accelerate import init_empty_weights, disk_offload |
|
|
|
app = FastAPI() |
|
|
|
|
|
hf_token = os.getenv('HF_API_TOKEN') |
|
|
|
if not hf_token: |
|
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") |
|
|
|
|
|
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') |
|
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') |
|
|
|
|
|
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) |
|
|
|
|
|
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token) |
|
model_llama = AutoModelForCausalLM.from_pretrained( |
|
"meta-llama/Meta-Llama-3.1-8B", |
|
torch_dtype='auto', |
|
device_map='auto', |
|
token=hf_token |
|
) |
|
|
|
|
|
prompt_template = """\ |
|
You are an expert in generating synthetic data for machine learning models. |
|
Your task is to generate a synthetic tabular dataset based on the description provided below. |
|
Description: {description} |
|
The dataset should include the following columns: {columns} |
|
Please provide the data in CSV format with a minimum of 100 rows per generation. |
|
Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. |
|
Example Description: |
|
Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' |
|
Example Output: |
|
Size,Location,Number of Bedrooms,Price |
|
1200,Suburban,3,250000 |
|
900,Urban,2,200000 |
|
1500,Rural,4,300000 |
|
... |
|
Description: |
|
{description} |
|
Columns: |
|
{columns} |
|
Output: """ |
|
|
|
class DataGenerationRequest(BaseModel): |
|
description: str |
|
columns: list |
|
|
|
def preprocess_user_prompt(user_prompt): |
|
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"] |
|
return generated_text |
|
|
|
def format_prompt(description, columns): |
|
processed_description = preprocess_user_prompt(description) |
|
prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) |
|
return prompt |
|
|
|
generation_params = { |
|
"top_p": 0.90, |
|
"temperature": 0.8, |
|
"max_new_tokens": 512, |
|
} |
|
|
|
def generate_synthetic_data(description, columns): |
|
try: |
|
|
|
formatted_prompt = format_prompt(description, columns) |
|
|
|
|
|
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model_llama.generate( |
|
**inputs, |
|
max_length=512, |
|
top_p=generation_params["top_p"], |
|
temperature=generation_params["temperature"], |
|
num_return_sequences=1, |
|
) |
|
|
|
|
|
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return generated_text |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
@app.post("/generate/") |
|
def generate_data(request: DataGenerationRequest): |
|
description = request.description.strip() |
|
columns = [col.strip() for col in request.columns] |
|
generated_data = generate_synthetic_data(description, columns) |
|
|
|
if "Error" in generated_data: |
|
return JSONResponse(content={"error": generated_data}, status_code=500) |
|
|
|
|
|
df_synthetic = process_generated_data(generated_data) |
|
return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")}) |
|
|
|
def process_generated_data(csv_data): |
|
data = StringIO(csv_data) |
|
df = pd.read_csv(data) |
|
return df |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|