Conversation
a4c4d5b to
96a6321
Compare
fealho
left a comment
There was a problem hiding this comment.
In general I think this looks good. @pvk-developer @amontanez24 what do you think?
| ctgan.sample(1, 'discrete', "d") | ||
|
|
||
|
|
||
| def test_ctgan_data_transformer_params(): |
There was a problem hiding this comment.
I think you should also add a performance test, something simple just to make sure that our results are not worse than before because of this change.
There was a problem hiding this comment.
I'm not sure about this one, do you think about a performance test of the gaussian mixture model or CTGAN ? In terms of speed or accuracy ?
There was a problem hiding this comment.
Accuracy for CTGAN. Basically, just a test to make sure the changes don't break the code. So something like changing your continuous column to be a normal distribution, instead of random, then sample from the model (after you fit) and make sure the samples loosely follow a normal distribution.
ctgan/synthesizers/ctgan.py
Outdated
|
|
||
| def fit(self, train_data, discrete_columns=tuple(), epochs=None): | ||
| def fit(self, train_data, discrete_columns=tuple(), epochs=None, | ||
| data_transformer_params={}): |
There was a problem hiding this comment.
The data_transformer_params should be moved to the __init__ and be asigned as self.data_transformer_params. (Use deepcopy if needed).
ctgan/data_transformer.py
Outdated
| def _fit_continuous(self, column_name, raw_column_data): | ||
| """Train Bayesian GMM for continuous column.""" | ||
| if self._max_gm_samples <= raw_column_data.shape[0]: | ||
| raw_column_data = np.random.choice(raw_column_data, |
There was a problem hiding this comment.
I think that when it comes to this kind of line breaking this indentation is better:
raw_column_data = np.random.choice(
raw_column_data,
size=self._max_gm_samples,
replace=False
)|
@fealho @pvk-developer |
|
@npatki not sure what you want to do with this? |
|
Meanwhile the library code has changed so the PR should be updated. For example, the Also, I wonder if ClusterBasedNormalizer could not be optionally replaced by a power transform, which might be faster (although it might impact the quality of the generated data), see sdv-dev/RDT#613 |
This PR solve issue #7, it allows two things:
DataTransformerthroughCTGANSynthesizer.fit(and so to change other parameters asmax_clusters).