Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5f9f50e
fix: Initialize different weights across TP ranks
rrutmann Dec 8, 2025
8c8c5ab
feat: Consider pp rank for model seed
rrutmann Dec 9, 2025
ab3daa0
fix: Only consider PP rank for seeding
rrutmann Dec 10, 2025
62a1743
test: Add test for different parameters on tp/pp ranks
rrutmann Dec 12, 2025
00a595b
test: Check for equal parameters across data parallel processes
rrutmann Dec 12, 2025
bf06da7
feat: Integrate seeding to model initialization
rrutmann Dec 19, 2025
b137701
refactor: Move seeding logic to model initialization component
rrutmann Dec 19, 2025
bff99f3
chore: Add seed and device_mesh to ComposedModelInitializationConfig
rrutmann Dec 19, 2025
98ff9db
test: Adapt test to latest changes
rrutmann Dec 19, 2025
2e248ed
chore: Remove old code
rrutmann Dec 19, 2025
093fa33
chore: Merge branch 'main' into seed
rrutmann May 4, 2026
5a9e89e
fix: Use local-generator weight init
rrutmann May 5, 2026
13e7a82
refactor: Do not set seed in NNModel
rrutmann May 5, 2026
dc11bbb
docs: Add documentation and warning for topology-dependent weight ini…
rrutmann May 5, 2026
999cb65
fix: Fix transformers version mismatch
rrutmann May 5, 2026
b02275f
test: Fix test by removing dependency on global RNG state for seed=None
rrutmann May 5, 2026
ddfbe47
test: Adapt test to latest changes in main
rrutmann May 5, 2026
76762d9
chore: Use consistent typing for optional parameters
rrutmann May 5, 2026
dea2eef
chore: Remove outdated seed parameter
rrutmann May 5, 2026
adf11f0
fix: Use correct type for parameter_name_regexes
rrutmann May 7, 2026
4cf0032
test: Add option for reliable vscode debugging
rrutmann May 7, 2026
7541df2
test: Add test for seeded model reproducibility
rrutmann May 7, 2026
ede150e
chore: Change order of model initialization
rrutmann May 7, 2026
67bc596
feat: Add multi_device_generator_policy for handling seeding with mul…
rrutmann May 7, 2026
5172fc4
refactor: Use enum for multi_device_generator_policy
rrutmann May 8, 2026
326823e
chore: Update model seed initialization
rrutmann May 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config_files/training/config_lorem_ipsum_long_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ settings:
checkpoint_saving_path: data/checkpoints
train_dataset_path: ./data/lorem_ipsum_long.pbin
test_dataset_path: ./data/lorem_ipsum.pbin
experiments_root_path: ${modalities_env:experiments_root_path}
intervals:
training_log_interval_in_steps: 1
checkpointing_interval_in_steps: 32
Expand Down Expand Up @@ -221,6 +222,7 @@ initialized_model:
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}
multi_device_generator_policy: error

fsdp_model:
component_key: model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ initialized_model:
mean: 0.0
std: 0.02
num_layers: ${model_raw.config.n_layer}
seed: 42
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE

scheduled_pipeline:
component_key: pipeline
Expand Down Expand Up @@ -315,7 +319,6 @@ model_raw:
component_key: model
variant_key: gpt2
config:
seed: 42
use_meta_device: true
use_weight_tying: false
sample_key: ${settings.referencing_keys.sample_key}
Expand Down
2 changes: 2 additions & 0 deletions docs/components/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place |

The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization.

## Losses

|Component type | Component Version | Implementation | Configuration | Component Interface | Description |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ line-length = 120

[tool.pytest.ini_options]
addopts = "--cov=src --cov-report term --cov-report html"
#addopts = "-ra" # Enable this instead of line above for reliable VS Code test debugging (without coverage)

[tool.coverage.run]
branch = true
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from omegaconf import OmegaConf, Resolver
from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator
from torch.distributed.fsdp import ShardingStrategy
from transformers import GPT2TokenizerFast
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from transformers import GPT2Tokenizer as GPT2TokenizerFast
from transformers import LlamaTokenizer as LlamaTokenizerFast
from typing_extensions import deprecated

from modalities.config.lookup_enum import LookupEnum
Expand Down
9 changes: 8 additions & 1 deletion src/modalities/conversion/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.generic import check_model_inputs

try:
from transformers.utils.generic import check_model_inputs
except ImportError:

def check_model_inputs(func: Callable) -> Callable:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this removed in transormers?
If it is part of a legacy API I think we should also remove this on our end.
What do you think @BlueCrescent? I think you added it, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function was removed in transformers version 5.2. In our pyproject.yaml we specify the requirement "transformers>=4.57.4,<5.0.0", so I used an unsupported transformers version here. Should we remove it just to be on the safe side?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think, we should tackle the transformers 5.0.0+ support soon anyways.

return func


from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config

Expand Down
6 changes: 1 addition & 5 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ class GPT2LLMConfig(BaseModel):
ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network.
lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head.
use_weight_tying (bool): Whether to use weight tying.
seed: Optional[int] = None: The random seed for reproducibility.
enforce_swiglu_hidden_dim_multiple_of (int): If specified, enforces the hidden dimension
in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the
activation_type is SwiGLU. Defaults to 256.
Expand Down Expand Up @@ -370,7 +369,6 @@ class GPT2LLMConfig(BaseModel):
ffn_norm_config: LayerNormWrapperConfig
lm_head_norm_config: LayerNormWrapperConfig
use_weight_tying: bool
seed: Optional[int] = None
enforce_swiglu_hidden_dim_multiple_of: int = 256

@model_validator(mode="after")
Expand Down Expand Up @@ -837,7 +835,6 @@ def __init__(
ffn_norm_config: LayerNormWrapperConfig,
lm_head_norm_config: LayerNormWrapperConfig,
use_weight_tying: bool,
seed: Optional[int] = None,
enforce_swiglu_hidden_dim_multiple_of: int = 256,
):
"""
Expand All @@ -862,7 +859,6 @@ def __init__(
attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module.
ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module.
lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module.
seed (int, optional): The random seed. Defaults to None.
use_weight_tying (bool): Whether to use weight tying.
enforce_swiglu_hidden_dim_multiple_of (int): Enforces
the hidden dimension in the SwiGLU layer to be a multiple of this value.
Expand All @@ -873,7 +869,7 @@ def __init__(
"embedding": [".wte", ".wpe"],
"layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"],
}
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
super().__init__(weight_decay_groups=weight_decay_groups)
self.sample_key = sample_key
self.prediction_key = prediction_key
self.sequence_length = sequence_length
Expand Down
5 changes: 1 addition & 4 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@ class ActivationType(str, Enum):
class NNModel(nn.Module):
"""NNModel class to define a base model."""

def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
def __init__(self, weight_decay_groups: Optional[WeightDecayGroups] = None):
"""
Initializes an NNModel object.

Args:
seed (int, optional): The seed value for random number generation. Defaults to None.
weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.
"""
if seed is not None:
torch.manual_seed(seed)
self._weight_decay_groups = weight_decay_groups if weight_decay_groups is not None else {}
super(NNModel, self).__init__()

Expand Down
2 changes: 0 additions & 2 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def get_gpt2_model(
lm_head_norm_config: LayerNormWrapperConfig,
use_weight_tying: bool,
use_meta_device: Optional[bool] = False,
seed: Optional[int] = None,
enforce_swiglu_hidden_dim_multiple_of: int = 256,
) -> GPT2LLM:
config = dict(
Expand All @@ -637,7 +636,6 @@ def get_gpt2_model(
attention_norm_config=attention_norm_config,
ffn_norm_config=ffn_norm_config,
lm_head_norm_config=lm_head_norm_config,
seed=seed,
use_weight_tying=use_weight_tying,
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
)
Expand Down
73 changes: 67 additions & 6 deletions src/modalities/nn/model_initialization/composed_initialization.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from typing import Optional

import torch
import torch.nn as nn
from pydantic import BaseModel, ConfigDict, Field, model_validator
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import Annotated

from modalities.config.pydantic_if_types import PydanticModelInitializationIFType
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.nn.model_initialization.initialization_routines import InitializationRoutines
from modalities.nn.model_initialization.initialization_routines import (
InitializationRoutines,
MultiDeviceGeneratorPolicy,
)
from modalities.nn.model_initialization.parameter_name_filters import (
NAMED_PARAMETER_INIT_GROUPS,
SupportWeightInitModels,
WeightInitTypes,
)
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method
from modalities.utils.logger_utils import get_logger

logger = get_logger(__name__)


class ModelInitializerWrapperConfig(BaseModel):
Expand All @@ -30,6 +39,9 @@ class ComposedModelInitializationConfig(BaseModel):
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
seed: int | None = None
multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN
device_mesh: Optional[PydanticDeviceMeshIFType] = None

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
Expand Down Expand Up @@ -87,6 +99,24 @@ def initialize_in_place(self, model: nn.Module):


class ComposedInitializationRoutines:
@staticmethod
def _warn_pp_topology_dependent_seed(device_mesh: Optional[DeviceMesh], seed: Optional[int]) -> None:
if seed is None or not has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
return

if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return

logger.warning(
"Seeded weight initialization is topology-dependent when pipeline parallelism is active. "
"Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, "
"so the same seed can produce different initialized weights for different PP configurations. "
"For topology-independent reproducibility, create and reuse a distributed checkpoint directly "
"after weight initialization."
)

@staticmethod
def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF:
initializer_wrapper = ModelInitializerWrapper(model_initializers)
Expand All @@ -98,8 +128,11 @@ def get_composed_model_initializer(
weight_init_type: WeightInitTypes,
mean: float,
std: float | str,
hidden_dim: Optional[int] = None,
num_layers: int = None,
hidden_dim: int | None = None,
num_layers: int | None = None,
device_mesh: Optional[DeviceMesh] = None,
seed: int | None = None,
multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN,
) -> ModelInitializationIF:
"""This initialization allows to intialize a model with plain, scaled or scaled_embed initialization.
Note that plain initialization is always performed in the beginning. In case of scaled_embed,
Expand All @@ -114,36 +147,64 @@ def get_composed_model_initializer(
Defaults to None.
num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
Defaults to None.
device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization.
seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline
parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local
initialization, so the same seed does not guarantee identical initialized weights across different
PP topologies.
multi_device_generator_policy (MultiDeviceGeneratorPolicy, optional): Behavior when
initialization creates per-device RNG generators for more than one device in the same process.
Defaults to MultiDeviceGeneratorPolicy.WARN.

Returns:
ModelInitializationIF: The Weight Initializer performing the initialization as specified.
"""
ComposedInitializationRoutines._warn_pp_topology_dependent_seed(device_mesh=device_mesh, seed=seed)

# Set different random seed for each PP rank to ensure diversity
if seed is not None and has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
assert device_mesh is not None
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
Comment thread
le1nux marked this conversation as resolved.

model_initializers = []

# plain
plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN]
plain_init = InitializationRoutines.get_plain_initialization(
mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes
mean=mean,
std=std,
hidden_dim=hidden_dim,
parameter_name_regexes=plain_parameter_name_regexes,
seed=seed,
multi_device_generator_policy=multi_device_generator_policy,
)
working_std = plain_init.std
model_initializers.append(plain_init)

if weight_init_type in [WeightInitTypes.SCALED, WeightInitTypes.SCALED_EMBED]:
# scaled
assert num_layers is not None
scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED]
scaled_init = InitializationRoutines.get_scaled_initialization(
mean=mean,
std=working_std,
num_layers=num_layers,
parameter_name_regexes=scaled_parameter_name_regexes,
seed=seed,
multi_device_generator_policy=multi_device_generator_policy,
)
model_initializers.append(scaled_init)

if weight_init_type == WeightInitTypes.SCALED_EMBED:
# scaled embed
scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED]
scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization(
mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes
mean=mean,
parameter_name_regexes=scaled_embed_parameter_name_regexes,
seed=seed,
multi_device_generator_policy=multi_device_generator_policy,
)
model_initializers.append(scaled_embed_init)

Expand Down
Loading
Loading