diff --git a/nemo_curator/stages/video/io/clip_writer.py b/nemo_curator/stages/video/io/clip_writer.py index eb62d17774..f829b8feee 100644 --- a/nemo_curator/stages/video/io/clip_writer.py +++ b/nemo_curator/stages/video/io/clip_writer.py @@ -131,6 +131,20 @@ def _write_json_data( ) -> None: write_json(data, dest, desc, source_video, verbose=self.verbose) + @staticmethod + def _cleanup_video_data(video: Video) -> None: + for clip in video.clips: + clip.buffer = None + clip.cosmos_embed1_embedding = None + for window in clip.windows: + window.mp4_bytes = None + window.llm_inputs.clear() + window.caption.clear() + window.enhanced_caption.clear() + window.webp_bytes = None + for clip in video.filtered_clips: + clip.buffer = None + def process(self, task: VideoTask) -> VideoTask: video: Video = task.data with ThreadPoolExecutor(max_workers=self.max_workers) as executor: @@ -171,15 +185,7 @@ def process(self, task: VideoTask) -> VideoTask: for future_v in futures_videos: future_v.result() # clean up intermediate data - for clip in video.clips: - clip.buffer = None - clip.cosmos_embed1_embedding = None - for window in clip.windows: - window.mp4_bytes = None - window.llm_inputs.clear() - window.caption.clear() - window.enhanced_caption.clear() - window.webp_bytes = None + self._cleanup_video_data(video) if self.verbose: logger.info(f"Video {video.input_path} has {len(video.clips)} clips and wrote to {self.output_path}") diff --git a/tests/stages/video/io/test_clip_writer.py b/tests/stages/video/io/test_clip_writer.py index ad2829c96c..48cb106861 100644 --- a/tests/stages/video/io/test_clip_writer.py +++ b/tests/stages/video/io/test_clip_writer.py @@ -700,6 +700,8 @@ def test_process_success(self, mock_executor_class: MagicMock): assert window.caption == {} assert window.enhanced_caption == {} assert window.webp_bytes is None + for clip in result.data.filtered_clips: + assert clip.buffer is None mock_logger.info.assert_called() @@ -847,6 +849,32 @@ def test_edge_cases_clip_with_errors(self): assert "error1" in data["errors"] assert "error2" in data["errors"] + @patch("nemo_curator.stages.video.io.clip_writer.ThreadPoolExecutor") + def test_filtered_clips_buffer_cleared_after_process(self, mock_executor_class: MagicMock): + """Filtered (e.g. motion-filtered) clips must have buffer cleared after process(). + + Motion-filtered clips skip ClipFrameExtractionStage (which normally clears buffer) + so ClipWriterStage is responsible for clearing their buffers to avoid bloating + the serialized task objects (~1.8GB tasks.pkl in benchmarks before this fix). + """ + mock_executor = MagicMock() + mock_executor_class.return_value.__enter__.return_value = mock_executor + mock_future = MagicMock() + mock_future.result.return_value = None + mock_executor.submit.return_value = mock_future + + self.stage.setup() + assert self.mock_filtered_clip.buffer is not None # confirm buffer is set before process + + with ( + patch.object(self.stage, "_write_clip_embedding_to_buffer"), + patch.object(self.stage, "_write_video_embeddings_to_parquet"), + patch.object(self.stage, "_write_video_metadata"), + ): + result = self.stage.process(self.mock_task) + + assert result.data.filtered_clips[0].buffer is None + def test_multiple_embedding_algorithms(self): """Test with different embedding algorithms.""" algorithms = ["cosmos-embed1-224p", "cosmos-embed1-336p", "cosmos-embed1-448p"]