Skip to content

Return loss values as float values not PyTorch objects #332

@srinify

Description

@srinify

Duplicate of sdv-dev/SDV#1811, close both at same time.

Problem Description

After fitting a model, calling loss_values returns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.

Screenshot 2024-02-22 at 4 20 36 PM

This means that plotting these values requires an extra step of extracting the values using apply(), which adds unnecessary friction I feel.

Screenshot 2024-02-22 at 4 58 56 PM

Expected behavior

Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:

loss_df = ctgan.loss_values
loss_df.plot(x='Epoch', y=['Generator Loss', 'Discriminator Loss'])

Additional context

Relevant code is here:

https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions