Limited, imbalanced, or missing data in tabular datasets can lead to poor model performance and biased predictions.
TabGAN provides a solution by generating synthetic tabular data that maintains the statistical properties and relationships of the original dataset.
In this example, we will demonstrate how to use TabGAN to generate high-quality synthetic data using different generators (GAN, Diffusion, or LLM-based).
Creating Random Input Data
First, we create random input data:
from tabgan.sampler import OriginalGenerator, GANGenerator
import pandas as pd
import numpy as np
train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD"))
target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y"))
test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
print("Training Data:")
print(train.head())
print("\nTarget Data:")
print(target.head())
print("\nTest Data:")
print(test.head())
Output:
Training Data:
A B C D
0 29 135 29 113
1 88 77 28 126
2 135 31 37 137
3 25 138 16 79
4 97 141 120 55
Target Data:
Y
0 0
1 1
2 1
3 0
4 0
Test Data:
A B C D
0 77 98 56 15
1 80 16 38 7
2 63 87 57 50
3 17 11 64 27
4 9 29 85 75
Generating Synthetic Data with OriginalGenerator
Next, we use the OriginalGenerator
to generate synthetic data:
new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test)
print("Training Data:")
print(new_train1.head())
print("\nTarget Data:")
print(new_target1.head())
Output:
Training Data:
A B C D
0 38 46 34 69
1 38 46 34 69
2 38 46 34 69
3 38 46 34 69
4 38 46 34 69
Target Data:
0 1
1 1
2 1
3 1
4 1
Name: Y, dtype: int64
As we can see, TabGAN can generate high-quality synthetic tabular data that maintains the statistical properties and relationships of the original dataset.
The generate_data_pipe
method takes in the following parameters:
train_df
: The training dataframe.target
: The target variable for the training dataset.test_df
: The testing dataframe that is is used as a target distribution – the newly generated train dataframe should be similar to it in terms of statistical properties.deep_copy
: A boolean indicating whether to make a deep copy of the input dataframes. Default isTrue
.only_adversarial
: A boolean indicating whether to only perform adversarial filtering on the training dataframe. Default isFalse
.use_adversarial
: A boolean indicating whether to perform adversarial filtering on the generated data. Default isTrue
.only_generated_data
: A boolean indicating whether to return only the generated data. Default isFalse
.
Generating Synthetic Data with GANGenerator
Alternatively, we can use the GANGenerator
to generate synthetic data:
new_train2, new_target2 = GANGenerator(
gen_params={
"batch_size": 500, # Process data in batches of 500 samples at a time
"epochs": 10, # Train for a maximum of 10 epochs
"patience": 5 # Stop early if there is no improvement for 5 epochs
}
).generate_data_pipe(train, target, test)
print("Training Data:")
print(new_train2.head())
print("\nTarget Data:")
print(new_target2.head())
Output:
Fitting CTGAN transformers for each column: 0%| | 0/5 [00:00<?, ?it/s]
Training CTGAN, epochs:: 0%| | 0/10 [00:00<?, ?it/s]
Training Data:
A B C D
0 72 33 87 83
1 80 16 81 84
2 95 36 92 89
3 88 39 68 91
4 3 0 74 98
Target Data:
0 1
1 0
2 1
3 0
4 0
Name: Y, dtype: int64
The GANGenerator
takes in the following parameters:
gen_params
: A dictionary of parameters for the GAN training process. Default isNone
.
gen_x_times
: A float indicating how much data to generate. Default is1.1
.
cat_cols
: A list of categorical columns. Default isNone
.
bot_filter_quantile
: A float indicating the bottom quantile for post-processing filtering. Default is0.001
.
top_filter_quantile
: A float indicating the top quantile for post-processing filtering. Default is0.999
.
is_post_process
: A boolean indicating whether to perform post-processing filtering. Default isTrue
.
adversarial_model_params
: A dictionary of parameters for the adversarial filtering model. Default isNone
.
pregeneration_frac
: A float indicating the fraction of data to generate before post-processing filtering. Default is2
.