Duplicate of sdv-dev/SDV#1811, close both at same time.
Problem Description
After fitting a model, calling loss_values returns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.
This means that plotting these values requires an extra step of extracting the values using apply(), which adds unnecessary friction I feel.
Expected behavior
Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:
loss_df = ctgan.loss_values
loss_df.plot(x='Epoch', y=['Generator Loss', 'Discriminator Loss'])
Additional context
Relevant code is here:
https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426
Duplicate of sdv-dev/SDV#1811, close both at same time.
Problem Description
After fitting a model, calling
loss_valuesreturns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.This means that plotting these values requires an extra step of extracting the values using
apply(), which adds unnecessary friction I feel.Expected behavior
Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:
Additional context
Relevant code is here:
https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426