Generating Synthetic Tabular Data with TabGAN

Generating Synthetic Tabular Data with TabGAN

Limited, imbalanced, or missing data in tabular datasets can lead to poor model performance and biased predictions.

TabGAN provides a solution by generating synthetic tabular data that maintains the statistical properties and relationships of the original dataset.

In this example, we will demonstrate how to use TabGAN to generate high-quality synthetic data using different generators (GAN, Diffusion, or LLM-based).

Creating Random Input Data

First, we create random input data:

from tabgan.sampler import OriginalGenerator, GANGenerator
import pandas as pd
import numpy as np

train = pd.DataFrame(np.random.randint(-10, 150, size=(150, 4)), columns=list("ABCD"))
target = pd.DataFrame(np.random.randint(0, 2, size=(150, 1)), columns=list("Y"))
test = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))

print("Training Data:")
print(train.head())

print("\nTarget Data:")
print(target.head())

print("\nTest Data:")
print(test.head())

Output:

Training Data:
     A    B    C    D
0   29  135   29  113
1   88   77   28  126
2  135   31   37  137
3   25  138   16   79
4   97  141  120   55

Target Data:
   Y
0  0
1  1
2  1
3  0
4  0

Test Data:
    A   B   C   D
0  77  98  56  15
1  80  16  38   7
2  63  87  57  50
3  17  11  64  27
4   9  29  85  75

Generating Synthetic Data with OriginalGenerator

Next, we use the OriginalGenerator to generate synthetic data:

new_train1, new_target1 = OriginalGenerator().generate_data_pipe(train, target, test)

print("Training Data:")
print(new_train1.head())

print("\nTarget Data:")
print(new_target1.head())

Output:

Training Data:
    A   B   C   D
0  38  46  34  69
1  38  46  34  69
2  38  46  34  69
3  38  46  34  69
4  38  46  34  69

Target Data:
0    1
1    1
2    1
3    1
4    1
Name: Y, dtype: int64

As we can see, TabGAN can generate high-quality synthetic tabular data that maintains the statistical properties and relationships of the original dataset.

The generate_data_pipe method takes in the following parameters:

  • train_df: The training dataframe.
  • target: The target variable for the training dataset.
  • test_df: The testing dataframe that is is used as a target distribution – the newly generated train dataframe should be similar to it in terms of statistical properties.
  • deep_copy: A boolean indicating whether to make a deep copy of the input dataframes. Default is True.
  • only_adversarial: A boolean indicating whether to only perform adversarial filtering on the training dataframe. Default is False.
  • use_adversarial: A boolean indicating whether to perform adversarial filtering on the generated data. Default is True.
  • only_generated_data: A boolean indicating whether to return only the generated data. Default is False.

Generating Synthetic Data with GANGenerator

Alternatively, we can use the GANGenerator to generate synthetic data:

new_train2, new_target2 = GANGenerator(
    gen_params={
        "batch_size": 500,  # Process data in batches of 500 samples at a time
        "epochs": 10,       # Train for a maximum of 10 epochs
        "patience": 5       # Stop early if there is no improvement for 5 epochs
}
).generate_data_pipe(train, target, test)

print("Training Data:")
print(new_train2.head())

print("\nTarget Data:")
print(new_target2.head())

Output:

Fitting CTGAN transformers for each column:   0%|          | 0/5 [00:00<?, ?it/s]

Training CTGAN, epochs::   0%|          | 0/10 [00:00<?, ?it/s]

Training Data:
    A   B   C   D
0  72  33  87  83
1  80  16  81  84
2  95  36  92  89
3  88  39  68  91
4   3   0  74  98

Target Data:
0    1
1    0
2    1
3    0
4    0
Name: Y, dtype: int64

The GANGenerator takes in the following parameters:

  • gen_params: A dictionary of parameters for the GAN training process. Default is None.
  • gen_x_times: A float indicating how much data to generate. Default is 1.1.
  • cat_cols: A list of categorical columns. Default is None.
  • bot_filter_quantile: A float indicating the bottom quantile for post-processing filtering. Default is 0.001.
  • top_filter_quantile: A float indicating the top quantile for post-processing filtering. Default is 0.999.
  • is_post_process: A boolean indicating whether to perform post-processing filtering. Default is True.
  • adversarial_model_params: A dictionary of parameters for the adversarial filtering model. Default is None.
  • pregeneration_frac: A float indicating the fraction of data to generate before post-processing filtering. Default is 2.

Link to TabGan.

Search

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top

Work with Khuyen Tran

Work with Khuyen Tran