diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml index 87db96381..a51a58f32 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml @@ -13,6 +13,7 @@ settings: checkpoint_saving_path: data/checkpoints train_dataset_path: ./data/lorem_ipsum_long.pbin test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} intervals: training_log_interval_in_steps: 1 checkpointing_interval_in_steps: 32 @@ -221,6 +222,7 @@ initialized_model: mean: 0.0 std: 0.02 num_layers: ${model_raw.config.n_layer} + multi_device_generator_policy: error fsdp_model: component_key: model diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml index b4982044c..8e44e38b8 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -223,6 +223,10 @@ initialized_model: mean: 0.0 std: 0.02 num_layers: ${model_raw.config.n_layer} + seed: 42 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE scheduled_pipeline: component_key: pipeline @@ -315,7 +319,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/docs/components/components.md b/docs/components/components.md index 22f45958a..81af02d34 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -17,6 +17,8 @@ |---------------|--------------------|----------------|---------------|---------------------|-------------| | model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place | +The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization. + ## Losses |Component type | Component Version | Implementation | Configuration | Component Interface | Description | diff --git a/pyproject.toml b/pyproject.toml index 9ea55a4d4..2ddc23935 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ line-length = 120 [tool.pytest.ini_options] addopts = "--cov=src --cov-report term --cov-report html" +#addopts = "-ra" # Enable this instead of line above for reliable VS Code test debugging (without coverage) [tool.coverage.run] branch = true diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 2f45a5f22..42a19b99a 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -7,8 +7,8 @@ from omegaconf import OmegaConf, Resolver from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator from torch.distributed.fsdp import ShardingStrategy -from transformers import GPT2TokenizerFast -from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast +from transformers import GPT2Tokenizer as GPT2TokenizerFast +from transformers import LlamaTokenizer as LlamaTokenizerFast from typing_extensions import deprecated from modalities.config.lookup_enum import LookupEnum diff --git a/src/modalities/conversion/gpt2/modeling_gpt2.py b/src/modalities/conversion/gpt2/modeling_gpt2.py index dec0bf64c..f6aa77ab1 100644 --- a/src/modalities/conversion/gpt2/modeling_gpt2.py +++ b/src/modalities/conversion/gpt2/modeling_gpt2.py @@ -40,7 +40,14 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from transformers.utils.generic import check_model_inputs + +try: + from transformers.utils.generic import check_model_inputs +except ImportError: + + def check_model_inputs(func: Callable) -> Callable: + return func + from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 70f595e67..2da4979c0 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -342,7 +342,6 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying. - seed: Optional[int] = None: The random seed for reproducibility. enforce_swiglu_hidden_dim_multiple_of (int): If specified, enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. @@ -370,7 +369,6 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool - seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 @model_validator(mode="after") @@ -837,7 +835,6 @@ def __init__( ffn_norm_config: LayerNormWrapperConfig, lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, - seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, ): """ @@ -862,7 +859,6 @@ def __init__( attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. - seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying. enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. @@ -873,7 +869,7 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) + super().__init__(weight_decay_groups=weight_decay_groups) self.sample_key = sample_key self.prediction_key = prediction_key self.sequence_length = sequence_length diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index ac3dca96b..f981f6117 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -26,16 +26,13 @@ class ActivationType(str, Enum): class NNModel(nn.Module): """NNModel class to define a base model.""" - def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None): + def __init__(self, weight_decay_groups: Optional[WeightDecayGroups] = None): """ Initializes an NNModel object. Args: - seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None. """ - if seed is not None: - torch.manual_seed(seed) self._weight_decay_groups = weight_decay_groups if weight_decay_groups is not None else {} super(NNModel, self).__init__() diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 142aef920..62933794d 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -615,7 +615,6 @@ def get_gpt2_model( lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, use_meta_device: Optional[bool] = False, - seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, ) -> GPT2LLM: config = dict( @@ -637,7 +636,6 @@ def get_gpt2_model( attention_norm_config=attention_norm_config, ffn_norm_config=ffn_norm_config, lm_head_norm_config=lm_head_norm_config, - seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, ) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 190311cb6..e8e3e7114 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,17 +1,26 @@ from typing import Optional +import torch import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator +from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated -from modalities.config.pydantic_if_types import PydanticModelInitializationIFType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType from modalities.nn.model_initialization.initialization_if import ModelInitializationIF -from modalities.nn.model_initialization.initialization_routines import InitializationRoutines +from modalities.nn.model_initialization.initialization_routines import ( + InitializationRoutines, + MultiDeviceGeneratorPolicy, +) from modalities.nn.model_initialization.parameter_name_filters import ( NAMED_PARAMETER_INIT_GROUPS, SupportWeightInitModels, WeightInitTypes, ) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) class ModelInitializerWrapperConfig(BaseModel): @@ -30,6 +39,9 @@ class ComposedModelInitializationConfig(BaseModel): std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None + seed: int | None = None + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN + device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -87,6 +99,24 @@ def initialize_in_place(self, model: nn.Module): class ComposedInitializationRoutines: + @staticmethod + def _warn_pp_topology_dependent_seed(device_mesh: Optional[DeviceMesh], seed: Optional[int]) -> None: + if seed is None or not has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + return + + if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: + return + + logger.warning( + "Seeded weight initialization is topology-dependent when pipeline parallelism is active. " + "Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, " + "so the same seed can produce different initialized weights for different PP configurations. " + "For topology-independent reproducibility, create and reuse a distributed checkpoint directly " + "after weight initialization." + ) + @staticmethod def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF: initializer_wrapper = ModelInitializerWrapper(model_initializers) @@ -98,8 +128,11 @@ def get_composed_model_initializer( weight_init_type: WeightInitTypes, mean: float, std: float | str, - hidden_dim: Optional[int] = None, - num_layers: int = None, + hidden_dim: int | None = None, + num_layers: int | None = None, + device_mesh: Optional[DeviceMesh] = None, + seed: int | None = None, + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -114,28 +147,53 @@ def get_composed_model_initializer( Defaults to None. num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only). Defaults to None. + device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization. + seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline + parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local + initialization, so the same seed does not guarantee identical initialized weights across different + PP topologies. + multi_device_generator_policy (MultiDeviceGeneratorPolicy, optional): Behavior when + initialization creates per-device RNG generators for more than one device in the same process. + Defaults to MultiDeviceGeneratorPolicy.WARN. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. """ + ComposedInitializationRoutines._warn_pp_topology_dependent_seed(device_mesh=device_mesh, seed=seed) + + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + assert device_mesh is not None + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + model_initializers = [] # plain plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN] plain_init = InitializationRoutines.get_plain_initialization( - mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes + mean=mean, + std=std, + hidden_dim=hidden_dim, + parameter_name_regexes=plain_parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) working_std = plain_init.std model_initializers.append(plain_init) if weight_init_type in [WeightInitTypes.SCALED, WeightInitTypes.SCALED_EMBED]: # scaled + assert num_layers is not None scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED] scaled_init = InitializationRoutines.get_scaled_initialization( mean=mean, std=working_std, num_layers=num_layers, parameter_name_regexes=scaled_parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) model_initializers.append(scaled_init) @@ -143,7 +201,10 @@ def get_composed_model_initializer( # scaled embed scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED] scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization( - mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes + mean=mean, + parameter_name_regexes=scaled_embed_parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) model_initializers.append(scaled_embed_init) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index deb6a2737..1f785f562 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,7 +1,10 @@ import math import re -from typing import Annotated, Optional +import warnings +from enum import Enum +from typing import Annotated +import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -9,11 +12,17 @@ from modalities.nn.model_initialization.parameter_name_filters import RegexFilter +class MultiDeviceGeneratorPolicy(str, Enum): + IGNORE = "ignore" + WARN = "warn" + ERROR = "error" + + class PlainInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" - hidden_dim: Optional[int] = None + hidden_dim: int | None = None @model_validator(mode="after") def check_std_and_hidden_dim(self): @@ -39,21 +48,49 @@ class ScaledEmbedInitializationConfig(BaseModel): class NamedParameterwiseNormalInitialization(ModelInitializationIF): - def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter): + def __init__( + self, + mean: float, + std: float, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, + ): self.mean = mean self.std = std self.parameter_name_regexes = parameter_name_regexes + self.seed = torch.initial_seed() if seed is None else seed + self.multi_device_generator_policy = multi_device_generator_policy + self._generators: dict[str, torch.Generator] = {} + + def _get_generator(self, parameter: torch.Tensor) -> torch.Generator: + device_key = str(parameter.device) + generator = self._generators.get(device_key) + if generator is None: + if len(self._generators) > 0: + message = ( + "NamedParameterwiseNormalInitialization created generators for multiple devices in one process " + f"(existing={list(self._generators.keys())}, new={device_key})." + ) + if self.multi_device_generator_policy == MultiDeviceGeneratorPolicy.ERROR: + raise RuntimeError(message) + if self.multi_device_generator_policy == MultiDeviceGeneratorPolicy.WARN: + warnings.warn(message, stacklevel=2) + generator = torch.Generator(device=parameter.device) + generator.manual_seed(self.seed) + self._generators[device_key] = generator + return generator def initialize_in_place(self, model: nn.Module): weight_regexes = self.parameter_name_regexes.weights - bias_regexes = self.parameter_name_regexes.biases + bias_regexes = self.parameter_name_regexes.biases or [] for parameter_name, p in model.named_parameters(): parameter_name = parameter_name.replace( "_orig_mod.", "" ) # remove FQN modification from torch.compile if present for weight_regex in weight_regexes: if re.fullmatch(weight_regex, parameter_name): - nn.init.normal_(p, mean=self.mean, std=self.std) + nn.init.normal_(p, mean=self.mean, std=self.std, generator=self._get_generator(p)) for bias_regex in bias_regexes: if re.fullmatch(bias_regex, parameter_name): nn.init.zeros_(p) @@ -62,7 +99,12 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None + mean: float, + std: float | str, + parameter_name_regexes: RegexFilter, + hidden_dim: int | None = None, + seed: int | None = None, + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -73,23 +115,37 @@ def get_plain_initialization( std (float): standard deviation of the normal distribution. If set to "auto", appropiate value selected as per plain initialization described in https://arxiv.org/abs/2312.16903 hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None. + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one + device-local RNG generator is created in the same process. """ - # auto: choose std automatically if std == "auto": if hidden_dim is None: raise ValueError("ERROR! weight_init.std = auto not implemented") # as per https://arxiv.org/abs/2312.16903 std = math.sqrt(2 / (5 * hidden_dim)) + assert isinstance(std, float) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes + mean=mean, + std=std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] + mean: float, + std: float, + num_layers: int, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -97,8 +153,11 @@ def get_scaled_initialization( mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with - parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + parameter_name_regexes (RegexFilter): List of parameter name regexes to which the initialization should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one + device-local RNG generator is created in the same process. Returns: WeightInitializationIF: Weight initialization object @@ -107,12 +166,21 @@ def get_scaled_initialization( scaled_std = std / math.sqrt(2 * num_layers) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes + mean=mean, + std=scaled_std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization( + mean: float, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, + ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -120,12 +188,19 @@ def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[st mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. + seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one + device-local RNG generator is created in the same process. Returns: WeightInitializationIF: Weight initialization object """ std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes + mean=mean, + std=std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml index 395c131a7..9a9c886d4 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml @@ -204,7 +204,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml index c784ae6bc..f31503a6f 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml @@ -270,7 +270,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml index eb6b5f490..a8f72ac2f 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml @@ -281,7 +281,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml index 2d0c8e2b5..579162709 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml @@ -216,7 +216,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml index b5378e05d..e85c6e93c 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml index 8af01a926..f6b553b44 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml index c88c80922..4f073ec28 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml index 5b687e2e4..2029c0323 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml @@ -300,7 +300,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml index cd822525c..33009818f 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -129,7 +129,11 @@ initialized_model: weight_init_type: scaled mean: 0.0 std: 0.02 + seed: 42 num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE model_raw: component_key: model diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml new file mode 100644 index 000000000..ad6ed5954 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml @@ -0,0 +1,186 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + seed: 42 + num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: Interleaved1F1B + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} + num_layers_per_stage: 4 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py new file mode 100644 index 000000000..6a3333ace --- /dev/null +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -0,0 +1,238 @@ +import logging +import multiprocessing as py_mp +import os +import traceback +from pathlib import Path +from typing import Any, cast + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel +from torch.distributed._tensor.placement_types import Replicate +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import monitor_child_processes + +working_dir = Path(os.path.dirname(__file__)) +tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" +working_dir = working_dir / "configs" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This e2e test requires 8 GPUs.", +) +class TestParallelSeedInitialization: + WORLD_SIZE = 8 + RDVZ_PORT = 24574 + + def test_parameters_follow_parallelism(self, tmp_path: Path): + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + self._seed_distribution_impl_wrapper, + args=(self.WORLD_SIZE, tmp_path, error_queue), + nprocs=self.WORLD_SIZE, + join=False, + ) + monitor_child_processes(manager, error_queue, proc_ctx) + + def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=TestParallelSeedInitialization.RDVZ_PORT, + ): + try: + self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue (seed distribution test).") + os._exit(1) + + def _seed_distribution_impl(self, world_size: int, tmp_path: Path): + # initialize components + class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType] + device_mesh: PydanticDeviceMeshIFType + + config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) + main_obj = Main(config_file_path, experiments_root_path=tmp_path) + components = cast( + ComponentsInstantiationModel, + main_obj.build_components(components_model_type=ComponentsInstantiationModel), + ) + model = cast( + Any, + components.initialized_model[0] + if isinstance(components.initialized_model, list) + else components.initialized_model, + ) + device_mesh = components.device_mesh + # for each pp stage get first transformer block's MLP weight parameter shards and full tensor + block_key = next(iter(model.transformer.h.keys())) + block = model.transformer.h[block_key] + placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape) + full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu() + payload = { + "tensor_full": full_weight, + "tensor_shard": block.mlp.W.weight.to_local().cpu(), + "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), + "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), + "dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD), + "block_key": block_key, + } + + gather_list = cast(list[dict[str, Any] | None] | None, [None] * world_size if dist.get_rank() == 0 else None) + dist.gather_object(payload, gather_list, dst=0) + + if dist.get_rank() == 0: + assert gather_list is not None + TestParallelSeedInitialization._assert_parameter_distribution(cast(list[dict[str, Any]], gather_list)) + dist.barrier() + + @staticmethod + def _assert_parameter_distribution(records: list[dict[str, Any]]): + combos: dict[tuple[int, int], list[dict[str, Any]]] = {} + for record in records: + key = (record["pp_rank"], record["tp_rank"]) + combos.setdefault(key, []).append(record) + + expected_combo_count = 4 + assert ( + len(combos) == expected_combo_count + ), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}" + + combo_tensors: dict[tuple[int, int], torch.Tensor] = {} + for (pp_rank, tp_rank), entries in combos.items(): + # check that full tensors are the same across data parallel processes + reference = entries[0]["tensor_full"] + seen_dp_ranks: set[int] = set() + for entry in entries: + dp_rank = entry["dp_shard_rank"] + assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}" + seen_dp_ranks.add(dp_rank) + assert torch.equal(reference, entry["tensor_full"]), ( + "Tensors within the same TP/PP combo must be identical across DP ranks; " + f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})" + ) + # concatenate all shards for this pp/tp combo + shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) + combo_tensors[(pp_rank, tp_rank)] = torch.cat( + [e["tensor_shard"] for e in shards], + dim=0, + ) + # check that tensor shards differ across different pp/tp combos + combo_items = list(combo_tensors.items()) + for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): + for other_key, other_tensor in combo_items[idx + 1 :]: + tensors_equal = torch.equal(base_tensor, other_tensor) + assert not tensors_equal, ( + "Distinct TP/PP combinations should initialize with different weights; " + f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})" + ) + + def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path: + temp_file_path = tmp_path / "pp_tp_sharding_config.yaml" + working_dir = Path(os.path.dirname(__file__)) + config_file_path = ( + working_dir + / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml" + ) + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + +@pytest.mark.skipif( + torch.cuda.device_count() < 1, + reason="This test requires at least 1 GPU.", +) +class TestSeededModelReproducibility: + RDVZ_PORT = 24575 + + def test_same_seed_same_weights(self, tmp_path: Path): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=0, + local_rank=0, + world_size=1, + rdvz_port=TestSeededModelReproducibility.RDVZ_PORT, + ): + self._same_seed_same_weights_impl(tmp_path=tmp_path) + + def _same_seed_same_weights_impl(self, tmp_path: Path): + class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + + config_file_path = self._get_tmp_seeded_config_path(tmp_path=tmp_path, seed=1234) + + main_obj_1 = Main(config_file_path, experiments_root_path=tmp_path) + components_1 = cast( + ComponentsInstantiationModel, + main_obj_1.build_components(components_model_type=ComponentsInstantiationModel), + ) + state_dict_1 = get_state_dict( + model=cast(Any, components_1.initialized_model), + optimizers=[], + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + )[0] + + main_obj_2 = Main(config_file_path, experiments_root_path=tmp_path) + components_2 = cast( + ComponentsInstantiationModel, + main_obj_2.build_components(components_model_type=ComponentsInstantiationModel), + ) + state_dict_2 = get_state_dict( + model=cast(Any, components_2.initialized_model), + optimizers=[], + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + )[0] + + assert set(state_dict_1.keys()) == set(state_dict_2.keys()), "State dict keys differ between initializations" + for key in state_dict_1: + tensor_1 = state_dict_1[key] + tensor_2 = state_dict_2[key] + assert isinstance(tensor_1, torch.Tensor), f"Expected Tensor in first state dict for key {key}" + assert isinstance(tensor_2, torch.Tensor), f"Expected Tensor in second state dict for key {key}" + assert torch.equal(tensor_1, tensor_2), f"Mismatch for parameter {key}" + + dist.barrier() + + def _get_tmp_seeded_config_path(self, tmp_path: Path, seed: int) -> Path: + temp_file_path = tmp_path / "seeded_reproducibility.yaml" + config_file_path = Path(os.path.dirname(__file__)) / "../checkpointing/fsdp2_gpt2_config.yaml" + + with open(config_file_path, "r") as file: + config_dict = yaml.safe_load(file.read()) + config_dict["initialized_model"]["config"]["model_initializer"]["config"]["seed"] = seed + + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path diff --git a/tests/nn/model_initialization/test_deferred_initialization.py b/tests/nn/model_initialization/test_deferred_initialization.py index 1c431abc4..eae9c0686 100644 --- a/tests/nn/model_initialization/test_deferred_initialization.py +++ b/tests/nn/model_initialization/test_deferred_initialization.py @@ -105,7 +105,6 @@ def _build_gpt2_model() -> GPT2LLM: ffn_norm_config=ln_cfg, lm_head_norm_config=ln_cfg, use_weight_tying=False, - seed=42, enforce_swiglu_hidden_dim_multiple_of=256, ) return model