diff --git a/xtuner/v1/datasets/config.py b/xtuner/v1/datasets/config.py index 20047e8a8..a3b68902c 100644 --- a/xtuner/v1/datasets/config.py +++ b/xtuner/v1/datasets/config.py @@ -440,9 +440,12 @@ def build( ) elif self.group_by_length: assert shuffle, "Currently only shuffling is supported for LengthGroupedSampler." - assert isinstance(dataset, (ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset)), ( - "Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset or _LegacySoftPackDataset, " - f"but got {type(dataset)}" + assert isinstance( + dataset, + (ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset, MLLMPretrainHybridPackDataset), + ), ( + "Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset, _LegacySoftPackDataset, " + f"HardPackDataset, or MLLMPretrainHybridPackDataset, but got {type(dataset)}" ) sampler = LengthGroupedSampler( dataset=dataset, dp_mesh=dp_mesh, global_batch_size=global_batch_size, seed=seed diff --git a/xtuner/v1/datasets/packing.py b/xtuner/v1/datasets/packing.py index 9d3d8945e..c672eee03 100644 --- a/xtuner/v1/datasets/packing.py +++ b/xtuner/v1/datasets/packing.py @@ -4,11 +4,12 @@ import os import random import tempfile +from collections.abc import Sequence from concurrent.futures import ProcessPoolExecutor from functools import cached_property, partial from multiprocessing import shared_memory from pathlib import Path -from typing import Sized +from typing import Sized, cast import numpy as np import torch @@ -16,6 +17,7 @@ from datasets import Dataset, concatenate_datasets from torch import distributed as dist from torch.utils.data import ConcatDataset +from torch.utils.data import Dataset as TorchDataset from tqdm import tqdm from xtuner.v1.utils import get_logger, is_local_rank0 @@ -309,7 +311,7 @@ def get_pack_infos_by_expand_soft_split( class ExpandSoftPackDataset(_LegacySoftPackDataset): def __init__( self, - datasets: list[JsonlDataset], + datasets: Sequence[JsonlDataset], pack_max_length: int = 2048, global_pack: bool = False, pack_extra_buffer_size: int = 1000, @@ -642,7 +644,7 @@ def get_state_dict(self): def load_state_dict(self, state_dict): ... -class MLLMPretrainHybridPackDataset(_LegacySoftPackDataset): +class MLLMPretrainHybridPackDataset(TorchDataset): def __init__( self, datasets: list[JsonlDataset], @@ -653,17 +655,12 @@ def __init__( pack_extra_buffer_size: int = 1000, # for ExpandSoftPackDataset pack_chunk_size: int = 10000, # for ExpandSoftPackDataset ): - self.pack_extra_buffer_size = pack_extra_buffer_size - self.pack_workers = pack_workers - self.torch_random_generator = torch.Generator() - self.pack_chunk_size = pack_chunk_size - if seed is not None: - self.torch_random_generator.manual_seed(seed) - logger.info(f"Using {self.pack_workers} pack workers for packing datasets.") - self.seed = seed - self.global_pack = global_pack self.pack_max_length = pack_max_length + self.global_pack = global_pack + self.pack_workers = pack_workers + self.pack_extra_buffer_size = pack_extra_buffer_size + self.pack_chunk_size = pack_chunk_size hard_pack_groups = [] soft_pack_groups = [] @@ -673,100 +670,81 @@ def __init__( elif isinstance(dset, JsonlDataset): hard_pack_groups.append(dset) - if global_pack: - hard_pack_datasets: list[Sized] = [] - if len(hard_pack_groups) > 0: - num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in hard_pack_groups]))] - hard_pack_datasets = [ConcatDataset(hard_pack_groups)] - - pack_infos_list = [] - for i, dataset in enumerate(hard_pack_datasets): - _infos = self.get_hard_pack_infos(dataset, i, num_tokens[i]) - pack_infos_list.extend(_infos) - hard_pack_len = len(pack_infos_list) - - soft_pack_datasets: list[Sized] = [] - if len(soft_pack_groups) > 0: - num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in soft_pack_groups]))] - proxy_attn_flops = [ - ndarray_to_mmap(np.concatenate([dset.proxy_attn_flops for dset in soft_pack_groups])) - ] - - soft_pack_datasets = [ConcatDataset(soft_pack_groups)] - for i, dataset in enumerate(soft_pack_datasets): - _infos = self.get_soft_pack_infos(dataset, i, num_tokens[i], proxy_attn_flops[i]) - pack_infos_list.extend(_infos) - pack_infos = Dataset.from_list(pack_infos_list) + dataset_list: list[HardPackDataset | ExpandSoftPackDataset] = [] - else: - raise NotImplementedError + if hard_pack_groups: + hard_pack_dataset = HardPackDataset( + datasets=hard_pack_groups, + pack_max_length=pack_max_length, + global_pack=global_pack, + seed=seed, + pack_workers=pack_workers, + ) + dataset_list.append(hard_pack_dataset) - self.hard_pack_datasets = hard_pack_datasets - self.datasets = soft_pack_datasets - self.hard_pack_len = hard_pack_len - self.pack_infos = pack_infos + if soft_pack_groups: + soft_pack_dataset = ExpandSoftPackDataset( + datasets=soft_pack_groups, + pack_max_length=pack_max_length, + global_pack=global_pack, + pack_extra_buffer_size=pack_extra_buffer_size, + pack_chunk_size=pack_chunk_size, + pack_workers=pack_workers, + seed=seed, + ) + dataset_list.append(soft_pack_dataset) - def get_hard_pack_item(self, item: int): - info = self.pack_infos[item] - dataset_id = info["dataset_id"] - ds = self.hard_pack_datasets[dataset_id] + assert dataset_list, "No datasets provided for packing." + self.datasets: ConcatDataset[HardPackDataset | ExpandSoftPackDataset] = ConcatDataset(dataset_list) - indices = info["indices"] - s_off = info["start_offset"] - e_off = info["end_offset"] + @cached_property + def longest(self): + longest_list = [] + for dataset in self.datasets.datasets: + longest_list.extend(cast(HardPackDataset | ExpandSoftPackDataset, dataset).longest) + return longest_list - packed_list: list[dict] = [] + def __getitem__(self, item: int): + return self.datasets[item] - for i in range(len(indices)): - idx = indices[i] - sample = ds[idx] - ids = sample["input_ids"] - labs = sample.get("labels", None) + def __len__(self) -> int: + return len(self.datasets) - st = 0 if i != 0 else s_off - ed = len(ids) if i != len(indices) - 1 else e_off + def get_state_dict(self): + return { + "pack_max_length": self.pack_max_length, + "seed": self.seed, + "global_pack": self.global_pack, + "pack_extra_buffer_size": self.pack_extra_buffer_size, + "pack_chunk_size": self.pack_chunk_size, + } - packed_list.append( - { - "input_ids": ids[st:ed], - "labels": labs[st:ed] if labs is not None else None, - "num_tokens": ed - st, - } + def load_state_dict(self, state_dict): + if self.seed != state_dict["seed"]: + raise ValueError( + f"Cannot load state dict with different seed . Origin: {state_dict['seed']}, New: {self.seed}" ) - assert (total_num_tokens := sum(i["num_tokens"] for i in packed_list)) == self.pack_max_length, ( - f"Internal Error! Found size: {total_num_tokens} mismatch after hard packing." - ) - return packed_list - - def __getitem__(self, item: int): - if item < self.hard_pack_len: - return self.get_hard_pack_item(item) - else: - return super().__getitem__(item) - def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray): - # shuffled indices - inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist() + if self.pack_max_length != state_dict["pack_max_length"]: + raise ValueError( + "Cannot load state dict with different pack_max_length " + f". Origin: {state_dict['pack_max_length']}, New: {self.pack_max_length}" + ) - pack_infos_list = get_pack_infos_by_hard_split( - inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers - ) - return pack_infos_list + if self.global_pack != state_dict["global_pack"]: + raise ValueError( + "Cannot load state dict with different global_pack " + f". Origin: {state_dict['global_pack']}, New: {self.global_pack}" + ) - def get_soft_pack_infos( - self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray, proxy_attn_flops: np.ndarray - ): - # shuffled indices - inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist() + if self.pack_extra_buffer_size != state_dict["pack_extra_buffer_size"]: + raise ValueError( + "Cannot load state dict with different pack_extra_buffer_size " + f". Origin: {state_dict['pack_extra_buffer_size']}, New: {self.pack_extra_buffer_size}" + ) - pack_infos_list = get_pack_infos_by_expand_soft_split( - inds, - dataset_id, - num_tokens, - proxy_attn_flops, - pack_max_length=self.pack_max_length, - pack_workers=self.pack_workers, - pack_chunk_size=self.pack_chunk_size, - pack_extra_buffer_size=self.pack_extra_buffer_size, - ) - return pack_infos_list + if self.pack_chunk_size != state_dict["pack_chunk_size"]: + raise ValueError( + "Cannot load state dict with different pack_chunk_size " + f". Origin: {state_dict['pack_chunk_size']}, New: {self.pack_chunk_size}" + ) diff --git a/xtuner/v1/datasets/sampler.py b/xtuner/v1/datasets/sampler.py index a907f3e1e..bb9495aa4 100644 --- a/xtuner/v1/datasets/sampler.py +++ b/xtuner/v1/datasets/sampler.py @@ -12,7 +12,7 @@ from xtuner.v1.utils import get_logger from .jsonl import JsonlDataset -from .packing import _LegacySoftPackDataset +from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset try: @@ -54,7 +54,7 @@ class ParallelSampler(Sampler): def __init__( self, - dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset, + dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset | MLLMPretrainHybridPackDataset, global_batch_size: int, dp_mesh: DeviceMesh | None = None, shuffle: bool = True, @@ -178,7 +178,7 @@ class LengthGroupedSampler(Sampler): def __init__( self, - dataset: _LegacySoftPackDataset, + dataset: _LegacySoftPackDataset | MLLMPretrainHybridPackDataset, global_batch_size: int, dp_mesh: DeviceMesh | None = None, seed: Optional[int] = None,