File size: 4,139 Bytes
170c5f1 c90ec2a 1ab421e fa1c843 eaf6f50 1ab421e c90ec2a 1ab421e 170c5f1 8628226 a686c13 24605c7 170c5f1 24605c7 8628226 170c5f1 a686c13 170c5f1 c90ec2a 170c5f1 021bce4 170c5f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from io import StringIO
from accelerate import Accelerator
from fastapi.middleware.cors import CORSMiddleware
import re
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")
# Load the GPT-2 tokenizer and model
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
# Load the Llama-3 model and tokenizer once during startup
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
model_llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype='float16',
device_map='auto',
token=hf_token
).to(device)
# Define your prompt template
prompt_template = """...""" # Your existing prompt template here
class DataGenerationRequest(BaseModel):
description: str
columns: list
def preprocess_user_prompt(user_prompt):
generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
return generated_text
def format_prompt(description, columns):
processed_description = preprocess_user_prompt(description)
prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
return prompt
generation_params = {
"top_p": 0.90,
"temperature": 0.8,
"max_new_tokens": 512,
}
def generate_synthetic_data(description, columns):
try:
# Prepare the input for the Llama model
formatted_prompt = format_prompt(description, columns)
# Tokenize the prompt with truncation enabled
inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)
# Generate synthetic data
with torch.no_grad():
outputs = model_llama.generate(
**inputs,
max_length=512,
top_p=generation_params["top_p"],
temperature=generation_params["temperature"],
num_return_sequences=1,
)
# Decode the generated output
generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
# Return the generated synthetic data
return generated_text
except Exception as e:
return f"Error: {e}"
def clean_generated_text(generated_text):
csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
if csv_match:
csv_text = csv_match.group(0)
else:
raise ValueError("No valid CSV data found in generated text.")
return csv_text
def process_generated_data(csv_data):
cleaned_data = clean_generated_text(csv_data)
data = StringIO(cleaned_data)
df = pd.read_csv(data)
return df
@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
description = request.description.strip()
columns = [col.strip() for col in request.columns]
generated_data = generate_synthetic_data(description, columns)
if "Error" in generated_data:
return JSONResponse(content={"error": generated_data}, status_code=500)
df_synthetic = process_generated_data(generated_data)
return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
@app.get("/")
def greet_json():
return {"Hello": "World!"}
|