from fastapi import FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel import pandas as pd import os import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline from io import StringIO from tqdm import tqdm import accelerate from accelerate import init_empty_weights, disk_offload app = FastAPI() # Access the Hugging Face API token from environment variables hf_token = os.getenv('HF_API_TOKEN') if not hf_token: raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") # Load the GPT-2 tokenizer and model tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') # Create a pipeline for text generation using GPT-2 text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) # Load the Llama-3 model and tokenizer once during startup tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B", token=hf_token) model_llama = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B", torch_dtype='auto', device_map='auto', token=hf_token ) # Define your prompt template prompt_template = """\ You are an expert in generating synthetic data for machine learning models. Your task is to generate a synthetic tabular dataset based on the description provided below. Description: {description} The dataset should include the following columns: {columns} Please provide the data in CSV format with a minimum of 100 rows per generation. Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned. Example Description: Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' Example Output: Size,Location,Number of Bedrooms,Price 1200,Suburban,3,250000 900,Urban,2,200000 1500,Rural,4,300000 ... Description: {description} Columns: {columns} Output: """ class DataGenerationRequest(BaseModel): description: str columns: list def preprocess_user_prompt(user_prompt): generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"] return generated_text def format_prompt(description, columns): processed_description = preprocess_user_prompt(description) prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) return prompt generation_params = { "top_p": 0.90, "temperature": 0.8, "max_new_tokens": 512, } def generate_synthetic_data(description, columns): try: # Prepare the input for the Llama model formatted_prompt = format_prompt(description, columns) # Tokenize the prompt with truncation enabled inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device) # Generate synthetic data with torch.no_grad(): outputs = model_llama.generate( **inputs, max_length=512, top_p=generation_params["top_p"], temperature=generation_params["temperature"], num_return_sequences=1, ) # Decode the generated output generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True) # Return the generated synthetic data return generated_text except Exception as e: return f"Error: {e}" @app.post("/generate/") def generate_data(request: DataGenerationRequest): description = request.description.strip() columns = [col.strip() for col in request.columns] generated_data = generate_synthetic_data(description, columns) if "Error" in generated_data: return JSONResponse(content={"error": generated_data}, status_code=500) # Process the generated CSV data into a DataFrame df_synthetic = process_generated_data(generated_data) return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")}) def process_generated_data(csv_data): data = StringIO(csv_data) df = pd.read_csv(data) return df @app.get("/") def greet_json(): return {"Hello": "World!"}