Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions xtuner/v1/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,12 @@ def build(
)
elif self.group_by_length:
assert shuffle, "Currently only shuffling is supported for LengthGroupedSampler."
assert isinstance(dataset, (ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset)), (
"Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset or _LegacySoftPackDataset, "
f"but got {type(dataset)}"
assert isinstance(
dataset,
(ExpandSoftPackDataset, _LegacySoftPackDataset, HardPackDataset, MLLMPretrainHybridPackDataset),
), (
"Internal Error, LengthGroupedSampler requires ExpandSoftPackDataset, _LegacySoftPackDataset, "
f"HardPackDataset, or MLLMPretrainHybridPackDataset, but got {type(dataset)}"
)
sampler = LengthGroupedSampler(
dataset=dataset, dp_mesh=dp_mesh, global_batch_size=global_batch_size, seed=seed
Expand Down
170 changes: 74 additions & 96 deletions xtuner/v1/datasets/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
import os
import random
import tempfile
from collections.abc import Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import cached_property, partial
from multiprocessing import shared_memory
from pathlib import Path
from typing import Sized
from typing import Sized, cast

import numpy as np
import torch
import xxhash
from datasets import Dataset, concatenate_datasets
from torch import distributed as dist
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm

from xtuner.v1.utils import get_logger, is_local_rank0
Expand Down Expand Up @@ -309,7 +311,7 @@ def get_pack_infos_by_expand_soft_split(
class ExpandSoftPackDataset(_LegacySoftPackDataset):
def __init__(
self,
datasets: list[JsonlDataset],
datasets: Sequence[JsonlDataset],
pack_max_length: int = 2048,
global_pack: bool = False,
pack_extra_buffer_size: int = 1000,
Expand Down Expand Up @@ -642,7 +644,7 @@ def get_state_dict(self):
def load_state_dict(self, state_dict): ...


class MLLMPretrainHybridPackDataset(_LegacySoftPackDataset):
class MLLMPretrainHybridPackDataset(TorchDataset):
def __init__(
self,
datasets: list[JsonlDataset],
Expand All @@ -653,17 +655,12 @@ def __init__(
pack_extra_buffer_size: int = 1000, # for ExpandSoftPackDataset
pack_chunk_size: int = 10000, # for ExpandSoftPackDataset
):
self.pack_extra_buffer_size = pack_extra_buffer_size
self.pack_workers = pack_workers
self.torch_random_generator = torch.Generator()
self.pack_chunk_size = pack_chunk_size
if seed is not None:
self.torch_random_generator.manual_seed(seed)
logger.info(f"Using {self.pack_workers} pack workers for packing datasets.")

self.seed = seed
self.global_pack = global_pack
self.pack_max_length = pack_max_length
self.global_pack = global_pack
self.pack_workers = pack_workers
self.pack_extra_buffer_size = pack_extra_buffer_size
self.pack_chunk_size = pack_chunk_size

hard_pack_groups = []
soft_pack_groups = []
Expand All @@ -673,100 +670,81 @@ def __init__(
elif isinstance(dset, JsonlDataset):
hard_pack_groups.append(dset)

if global_pack:
hard_pack_datasets: list[Sized] = []
if len(hard_pack_groups) > 0:
num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in hard_pack_groups]))]
hard_pack_datasets = [ConcatDataset(hard_pack_groups)]

pack_infos_list = []
for i, dataset in enumerate(hard_pack_datasets):
_infos = self.get_hard_pack_infos(dataset, i, num_tokens[i])
pack_infos_list.extend(_infos)
hard_pack_len = len(pack_infos_list)

soft_pack_datasets: list[Sized] = []
if len(soft_pack_groups) > 0:
num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in soft_pack_groups]))]
proxy_attn_flops = [
ndarray_to_mmap(np.concatenate([dset.proxy_attn_flops for dset in soft_pack_groups]))
]

soft_pack_datasets = [ConcatDataset(soft_pack_groups)]
for i, dataset in enumerate(soft_pack_datasets):
_infos = self.get_soft_pack_infos(dataset, i, num_tokens[i], proxy_attn_flops[i])
pack_infos_list.extend(_infos)
pack_infos = Dataset.from_list(pack_infos_list)
dataset_list: list[HardPackDataset | ExpandSoftPackDataset] = []

else:
raise NotImplementedError
if hard_pack_groups:
hard_pack_dataset = HardPackDataset(
datasets=hard_pack_groups,
pack_max_length=pack_max_length,
global_pack=global_pack,
seed=seed,
pack_workers=pack_workers,
)
dataset_list.append(hard_pack_dataset)

self.hard_pack_datasets = hard_pack_datasets
self.datasets = soft_pack_datasets
self.hard_pack_len = hard_pack_len
self.pack_infos = pack_infos
if soft_pack_groups:
soft_pack_dataset = ExpandSoftPackDataset(
datasets=soft_pack_groups,
pack_max_length=pack_max_length,
global_pack=global_pack,
pack_extra_buffer_size=pack_extra_buffer_size,
pack_chunk_size=pack_chunk_size,
pack_workers=pack_workers,
seed=seed,
)
dataset_list.append(soft_pack_dataset)

def get_hard_pack_item(self, item: int):
info = self.pack_infos[item]
dataset_id = info["dataset_id"]
ds = self.hard_pack_datasets[dataset_id]
assert dataset_list, "No datasets provided for packing."
self.datasets: ConcatDataset[HardPackDataset | ExpandSoftPackDataset] = ConcatDataset(dataset_list)

indices = info["indices"]
s_off = info["start_offset"]
e_off = info["end_offset"]
@cached_property
def longest(self):
longest_list = []
for dataset in self.datasets.datasets:
longest_list.extend(cast(HardPackDataset | ExpandSoftPackDataset, dataset).longest)
return longest_list

packed_list: list[dict] = []
def __getitem__(self, item: int):
return self.datasets[item]

for i in range(len(indices)):
idx = indices[i]
sample = ds[idx]
ids = sample["input_ids"]
labs = sample.get("labels", None)
def __len__(self) -> int:
return len(self.datasets)

st = 0 if i != 0 else s_off
ed = len(ids) if i != len(indices) - 1 else e_off
def get_state_dict(self):
return {
"pack_max_length": self.pack_max_length,
"seed": self.seed,
"global_pack": self.global_pack,
"pack_extra_buffer_size": self.pack_extra_buffer_size,
"pack_chunk_size": self.pack_chunk_size,
}

packed_list.append(
{
"input_ids": ids[st:ed],
"labels": labs[st:ed] if labs is not None else None,
"num_tokens": ed - st,
}
def load_state_dict(self, state_dict):
if self.seed != state_dict["seed"]:
raise ValueError(
f"Cannot load state dict with different seed . Origin: {state_dict['seed']}, New: {self.seed}"
)
assert (total_num_tokens := sum(i["num_tokens"] for i in packed_list)) == self.pack_max_length, (
f"Internal Error! Found size: {total_num_tokens} mismatch after hard packing."
)
return packed_list

def __getitem__(self, item: int):
if item < self.hard_pack_len:
return self.get_hard_pack_item(item)
else:
return super().__getitem__(item)

def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray):
# shuffled indices
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
if self.pack_max_length != state_dict["pack_max_length"]:
raise ValueError(
"Cannot load state dict with different pack_max_length "
f". Origin: {state_dict['pack_max_length']}, New: {self.pack_max_length}"
)

pack_infos_list = get_pack_infos_by_hard_split(
inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers
)
return pack_infos_list
if self.global_pack != state_dict["global_pack"]:
raise ValueError(
"Cannot load state dict with different global_pack "
f". Origin: {state_dict['global_pack']}, New: {self.global_pack}"
)

def get_soft_pack_infos(
self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray, proxy_attn_flops: np.ndarray
):
# shuffled indices
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
if self.pack_extra_buffer_size != state_dict["pack_extra_buffer_size"]:
raise ValueError(
"Cannot load state dict with different pack_extra_buffer_size "
f". Origin: {state_dict['pack_extra_buffer_size']}, New: {self.pack_extra_buffer_size}"
)

pack_infos_list = get_pack_infos_by_expand_soft_split(
inds,
dataset_id,
num_tokens,
proxy_attn_flops,
pack_max_length=self.pack_max_length,
pack_workers=self.pack_workers,
pack_chunk_size=self.pack_chunk_size,
pack_extra_buffer_size=self.pack_extra_buffer_size,
)
return pack_infos_list
if self.pack_chunk_size != state_dict["pack_chunk_size"]:
raise ValueError(
"Cannot load state dict with different pack_chunk_size "
f". Origin: {state_dict['pack_chunk_size']}, New: {self.pack_chunk_size}"
)
6 changes: 3 additions & 3 deletions xtuner/v1/datasets/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xtuner.v1.utils import get_logger

from .jsonl import JsonlDataset
from .packing import _LegacySoftPackDataset
from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset


try:
Expand Down Expand Up @@ -54,7 +54,7 @@ class ParallelSampler(Sampler):

def __init__(
self,
dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset,
dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset | MLLMPretrainHybridPackDataset,
global_batch_size: int,
dp_mesh: DeviceMesh | None = None,
shuffle: bool = True,
Expand Down Expand Up @@ -178,7 +178,7 @@ class LengthGroupedSampler(Sampler):

def __init__(
self,
dataset: _LegacySoftPackDataset,
dataset: _LegacySoftPackDataset | MLLMPretrainHybridPackDataset,
global_batch_size: int,
dp_mesh: DeviceMesh | None = None,
seed: Optional[int] = None,
Expand Down
Loading