diff --git a/finetune_csv/finetune_base_model.py b/finetune_csv/finetune_base_model.py index d21c22db..915d81a8 100644 --- a/finetune_csv/finetune_base_model.py +++ b/finetune_csv/finetune_base_model.py @@ -121,8 +121,11 @@ def __getitem__(self, idx): x = window_data[self.feature_list].values.astype(np.float32) x_stamp = window_data[self.time_feature_list].values.astype(np.float32) - - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) + + # Compute normalization stats only on the lookback portion to prevent + # future data leakage into the prediction window statistics. + past_x = x[:self.lookback_window] + x_mean, x_std = np.mean(past_x, axis=0), np.std(past_x, axis=0) x = (x - x_mean) / (x_std + 1e-5) x = np.clip(x, -self.clip, self.clip)