File size: 4,139 Bytes
170c5f1
 
 
 
 
 
 
 
c90ec2a
1ab421e
fa1c843
eaf6f50
 
 
 
1ab421e
 
c90ec2a
1ab421e
 
 
 
170c5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
8628226
a686c13
24605c7
170c5f1
24605c7
8628226
170c5f1
 
a686c13
170c5f1
 
c90ec2a
170c5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
021bce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170c5f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import pandas as pd
import os
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
from io import StringIO
from accelerate import Accelerator
from fastapi.middleware.cors import CORSMiddleware
import re


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Access the Hugging Face API token from environment variables
hf_token = os.getenv('HF_API_TOKEN')

if not hf_token:
    raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.")

# Load the GPT-2 tokenizer and model
tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2')
model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')

# Create a pipeline for text generation using GPT-2
text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)

# Load the Llama-3 model and tokenizer once during startup
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
model_llama = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype='float16',
    device_map='auto',
    token=hf_token
).to(device) 

# Define your prompt template
prompt_template = """..."""  # Your existing prompt template here

class DataGenerationRequest(BaseModel):
    description: str
    columns: list

def preprocess_user_prompt(user_prompt):
    generated_text = text_generator(user_prompt, max_length=60, num_return_sequences=1, truncation=True)[0]["generated_text"]
    return generated_text

def format_prompt(description, columns):
    processed_description = preprocess_user_prompt(description)
    prompt = prompt_template.format(description=processed_description, columns=",".join(columns))
    return prompt

generation_params = {
    "top_p": 0.90,
    "temperature": 0.8,
    "max_new_tokens": 512,
}

def generate_synthetic_data(description, columns):
    try:
        # Prepare the input for the Llama model
        formatted_prompt = format_prompt(description, columns)

        # Tokenize the prompt with truncation enabled
        inputs = tokenizer_llama(formatted_prompt, return_tensors="pt", truncation=True, max_length=512).to(model_llama.device)

        # Generate synthetic data
        with torch.no_grad():
            outputs = model_llama.generate(
                **inputs,
                max_length=512,
                top_p=generation_params["top_p"],
                temperature=generation_params["temperature"],
                num_return_sequences=1,
            )

        # Decode the generated output
        generated_text = tokenizer_llama.decode(outputs[0], skip_special_tokens=True)
        
        # Return the generated synthetic data
        return generated_text
    except Exception as e:
        return f"Error: {e}"
        
def clean_generated_text(generated_text):
    csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
    
    if csv_match:
        csv_text = csv_match.group(0)
    else:
        raise ValueError("No valid CSV data found in generated text.")
    
    return csv_text

def process_generated_data(csv_data):
    cleaned_data = clean_generated_text(csv_data)
    
    data = StringIO(cleaned_data)
    df = pd.read_csv(data)
    
    return df

@app.post("/generate/")
def generate_data(request: DataGenerationRequest):
    description = request.description.strip()
    columns = [col.strip() for col in request.columns]
    generated_data = generate_synthetic_data(description, columns)
    
    if "Error" in generated_data:
        return JSONResponse(content={"error": generated_data}, status_code=500)
    
    df_synthetic = process_generated_data(generated_data)
    return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})

@app.get("/")
def greet_json():
    return {"Hello": "World!"}