best / app.py
yakine's picture
Update app.py
7af38c7 verified
raw
history blame
4.76 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import requests
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline
from io import StringIO
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import HfFolder
from tqdm import tqdm
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
prompt_template = """\
You are an expert in generating synthetic data for machine learning models.
Your task is to generate a synthetic tabular dataset based on the description provided below.
Description: {description}
The dataset should include the following columns: {columns}
Please provide the data in CSV format.
"""
tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
def preprocess_user_prompt(user_prompt):
generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"]
return generated_text
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 512,
"return_full_text": False,
"use_cache": False
}
def generate_synthetic_data(description, columns):
formatted_prompt = format_prompt(description, columns)
payload = {"inputs": formatted_prompt, "parameters": generation_params}
try:
response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload)
response.raise_for_status()
data = response.json()
if 'generated_text' in data[0]:
return data[0]['generated_text']
else:
raise ValueError("Invalid response format from Hugging Face API.")
except (requests.RequestException, ValueError) as e:
print(f"Error during API request or response processing: {e}")
return ""
def process_generated_data(csv_data, expected_columns):
try:
cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n')
data = StringIO(cleaned_data)
df = pd.read_csv(data, delimiter=',')
if set(df.columns) != set(expected_columns):
print(f"Unexpected columns in the generated data: {df.columns}")
return None
return df
except pd.errors.ParserError as e:
print(f"Failed to parse CSV data: {e}")
return None
def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100):
data_frames = []
for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"):
generated_data = generate_synthetic_data(description, columns)
df_synthetic = process_generated_data(generated_data, columns)
if df_synthetic is not None and not df_synthetic.empty:
data_frames.append(df_synthetic)
else:
print("Skipping invalid generation.")
if data_frames:
return pd.concat(data_frames, ignore_index=True)
else:
print("No valid data frames to concatenate.")
return pd.DataFrame(columns=columns)
class DataGenerationRequest(BaseModel):
description: str
columns: list[str]
@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
description = request.description.strip()
columns = [col.strip() for col in request.columns]
csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100)
if csv_data.empty:
return JSONResponse(content={"error": "No valid data generated"}, status_code=500)
csv_buffer = StringIO()
csv_data.to_csv(csv_buffer, index=False)
csv_buffer.seek(0)
return StreamingResponse(
csv_buffer,
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=generated_data.csv"}
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}