Skip to content

Commit

Permalink
Context metadata adjusted for the transformed datetime sdtype (#2127)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Jul 17, 2024
1 parent 2abacb1 commit fa01804
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def _fit_context_model(self, transformed):
context[constant_column] = 0
context_metadata.add_column(constant_column, sdtype='numerical')

for column in self.context_columns:
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_metadata.columns[column]['sdtype'] == 'datetime':
if pd.api.types.is_numeric_dtype(context[column]):
context_metadata.update_column(column, sdtype='numerical')

self._context_synthesizer = GaussianCopulaSynthesizer(
context_metadata,
enforce_min_max_values=self._context_synthesizer.enforce_min_max_values,
Expand Down
34 changes: 34 additions & 0 deletions tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,40 @@ def test_init_error_sequence_key_in_context():
PARSynthesizer(metadata, context_columns=['A'])


def test_par_with_datetime_context():
"""Test PARSynthesizer with a datetime as a context column"""
# Setup
data = pd.DataFrame(
data={
'user_id': ['ID_00'] * 5 + ['ID_01'] * 5,
'birthdate': ['1995-05-06'] * 5 + ['1982-01-21'] * 5,
'timestamp': ['2023-06-21', '2023-06-22', '2023-06-23', '2023-06-24', '2023-06-25'] * 2,
'heartrate': [67, 66, 68, 65, 64, 80, 82, 91, 88, 84],
}
)

metadata = SingleTableMetadata.load_from_dict({
'columns': {
'user_id': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{2}'},
'birthdate': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'timestamp': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'heartrate': {'sdtype': 'numerical'},
},
'sequence_key': 'user_id',
'sequence_index': 'timestamp',
})

# Run
synth = PARSynthesizer(metadata, epochs=50, verbose=True, context_columns=['birthdate'])

synth.fit(data)
sample = synth.sample(num_sequences=1)
expected_birthdate = pd.Series(['1984-02-23'] * 5, name='birthdate')

# Assert
pd.testing.assert_series_equal(sample['birthdate'], expected_birthdate)


def test_par_categorical_column_represented_by_floats():
"""Test to see if categorical columns work fine with float representation."""
# Setup
Expand Down
44 changes: 43 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,22 +444,64 @@ def test__fit_context_model_with_context_columns(self, gaussian_copula_mock):
'columns': {'gender': {'sdtype': 'categorical'}, 'name': {'sdtype': 'id'}}
})
par._context_synthesizer = initial_synthesizer
par._get_context_metadata = Mock(return_value=context_metadata)

# Run
par._fit_context_model(data)

# Assert
gaussian_copula_mock.assert_called_with(
context_metadata,
enforce_min_max_values=initial_synthesizer.enforce_min_max_values,
enforce_rounding=initial_synthesizer.enforce_rounding,
)
fitted_data = gaussian_copula_mock().fit.mock_calls[0][1][0]
expected_fitted_data = pd.DataFrame({
'name': ['Doe', 'Jane', 'John'],
'gender': ['M', 'F', 'M'],
})
pd.testing.assert_frame_equal(fitted_data.sort_values(by='name'), expected_fitted_data)

@patch('sdv.sequential.par.GaussianCopulaSynthesizer')
def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_mock):
"""Test that the method fits a synthesizer to the context columns.
If there are context columns, the method should create a new DataFrame that groups
the data by the sequence_key and only contains the context columns. Then a synthesizer
should be fit to this new data.
"""
# Setup
metadata = self.get_metadata()
data = self.get_data()
data['time'] = pd.to_datetime(data['time'])
data['time'] = data['time'].apply(lambda x: x.timestamp())
par = PARSynthesizer(metadata, context_columns=['time'])
initial_synthesizer = Mock()
context_metadata = SingleTableMetadata.load_from_dict({
'columns': {'time': {'sdtype': 'datetime'}, 'name': {'sdtype': 'id'}}
})
par._context_synthesizer = initial_synthesizer
par._get_context_metadata = Mock()
par._get_context_metadata.return_value = context_metadata

# Run
par._fit_context_model(data)

converted_context_metadata = SingleTableMetadata.load_from_dict({
'columns': {'time': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}}
})

# Assert
gaussian_copula_mock.assert_called_with(
context_metadata,
enforce_min_max_values=initial_synthesizer.enforce_min_max_values,
enforce_rounding=initial_synthesizer.enforce_rounding,
)
assert converted_context_metadata.columns == context_metadata.columns
fitted_data = gaussian_copula_mock().fit.mock_calls[0][1][0]
expected_fitted_data = pd.DataFrame({
'name': ['Doe', 'Jane', 'John'],
'gender': ['M', 'F', 'M'],
'time': [1.578010e09, 1.577837e09, 1.577923e09],
})
pd.testing.assert_frame_equal(fitted_data.sort_values(by='name'), expected_fitted_data)

Expand Down

0 comments on commit fa01804

Please sign in to comment.