diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..905ad289b 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -120,7 +120,15 @@ class RotaryTransform(QueryKeyValueTransform): XFormers implementation and removed in this implementation.# """ - def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000): + def __init__( + self, + n_embd: int, + n_head: int, + seq_length_dim: int = -2, + base_freq: int = 10000, + max_position_embeddings: int | None = None, + rope_scaling: dict[str, object] | None = None, + ): """ Initializes the RotaryTransform object. @@ -136,16 +144,128 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq self.dim_model = n_embd // n_head self.seq_length_dim = seq_length_dim self.base_freq = base_freq + self.max_position_embeddings = max_position_embeddings + + self.rope_scaling = rope_scaling + self.attention_scaling = 1.0 self.reset_parameters() + def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]: + """Compute YaRN inverse frequencies and the attention scaling factor.""" + if self.rope_scaling is None: + raise ValueError("YaRN requires a rope_scaling config.") + if self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set.") + + original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings") + if ( + original_max_position_embeddings is None + or not isinstance(original_max_position_embeddings, int) + or original_max_position_embeddings <= 0 + ): + raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") + + factor = self.rope_scaling.get("factor") + if factor is None: + factor = self.max_position_embeddings / original_max_position_embeddings + if not isinstance(factor, (int, float)) or factor < 1.0: + raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") + factor_float = float(factor) + + attention_factor = self.rope_scaling.get("attention_factor") + mscale = self.rope_scaling.get("mscale") + mscale_all_dim = self.rope_scaling.get("mscale_all_dim") + beta_fast_raw = self.rope_scaling.get("beta_fast") + beta_slow_raw = self.rope_scaling.get("beta_slow") + beta_fast = float(beta_fast_raw) if isinstance(beta_fast_raw, (int, float)) else 32.0 + beta_slow = float(beta_slow_raw) if isinstance(beta_slow_raw, (int, float)) else 1.0 + truncate = self.rope_scaling.get("truncate", True) + + def get_mscale(scale: float, mscale: float = 1.0) -> float: + """Return the YaRN mscale coefficient for a given scaling factor.""" + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + if attention_factor is None: + if isinstance(mscale, (int, float)) and isinstance(mscale_all_dim, (int, float)): + attention_factor = float( + get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim)) + ) + else: + attention_factor = get_mscale(factor_float) + elif not isinstance(attention_factor, (int, float)) or attention_factor <= 0: + raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0") + + def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float: + """Map a target number of rotations to a rotary dimension index.""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range( + low_rot: float, + high_rot: float, + dim: int, + base: int, + max_position_embeddings: int, + truncate: bool, + ) -> tuple[float, float]: + """Compute the lower and upper rotary-dimension correction bounds for YaRN.""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor: + """Create a clamped linear ramp used to blend interpolation and extrapolation.""" + if min_value == max_value: + max_value += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + dim = self.dim_model + base = self.base_freq + + pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor_float * pos_freqs) + + low, high = find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + bool(truncate), + ) + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, float(attention_factor) + def reset_parameters(self): # If previously initialized on or moved to a device, reuse that device. # Otherwise, use the default device of the current environment. - device = self.inv_freq.device if hasattr(self, "inv_freq") else None - inv_freq = 1.0 / ( - self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) - ) + device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None + + rope_type = "default" + if self.rope_scaling is not None: + rope_type = str(self.rope_scaling.get("rope_type", "default")) + + if rope_type == "yarn": + inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device) + else: + inv_freq = 1.0 / ( + self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = None @@ -172,15 +292,21 @@ def _update_cos_sin_tables(self, x): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: + if ( + seq_len != self._seq_len_cached + or self._cos_cached is None + or self._sin_cached is None + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): self._seq_len_cached = seq_len t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) emb = torch.cat((freqs, freqs), dim=-1).to( x.device ) # here, we combine the two matrices (not zipping them). - self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) - self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype) + self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype) return self._cos_cached, self._sin_cached @@ -295,6 +421,46 @@ class RotaryTransformConfig(BaseModel): n_head: Annotated[int, Field(strict=True, ge=0)] seq_length_dim: Annotated[int, Field(strict=True)] base_freq: Annotated[int, Field(strict=True, ge=10000)] + max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + rope_scaling: Optional[dict[str, object]] = None + + @model_validator(mode="after") + def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig": + """Validate and normalize rope_scaling, including YaRN-specific constraints.""" + if self.rope_scaling is None: + return self + + if not isinstance(self.rope_scaling, dict): + raise ValueError("rope_scaling must be a dictionary") + + rope_scaling = dict(self.rope_scaling) + if "type" in rope_scaling and "rope_type" not in rope_scaling: + rope_scaling["rope_type"] = rope_scaling["type"] + + rope_type = rope_scaling.get("rope_type", "default") + if rope_type not in {"default", "yarn"}: + raise ValueError( + f"Unsupported rope_scaling.rope_type '{rope_type}'. Supported values are 'default' and 'yarn'." + ) + + if rope_type == "yarn": + if self.max_position_embeddings is None: + raise ValueError("YaRN requires max_position_embeddings to be set") + + original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings") + if ( + original_max_position_embeddings is None + or not isinstance(original_max_position_embeddings, int) + or original_max_position_embeddings <= 0 + ): + raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer") + + factor = rope_scaling.get("factor") + if factor is not None and (not isinstance(factor, (int, float)) or factor < 1.0): + raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0") + + self.rope_scaling = rope_scaling + return self @validator("type_hint", pre=True, always=True) def parse_sharding_strategy_by_name(cls, name): diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index d3ccd46c2..25abc686b 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -20,14 +20,15 @@ from tests.utility import find_free_port -def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path: +def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path: """Patches the original configuration file to set a custom activation type.""" with original_config_path.open("r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) config_dict["model_raw"]["config"]["activation_type"] = activation_type - tmp_file_path = tmp_dir / original_config_path.name + file_suffix = f"_{file_tag}" if file_tag else "" + tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}" with tmp_file_path.open("w", encoding="utf-8") as f: yaml.safe_dump(config_dict, f) @@ -103,12 +104,16 @@ def _test_tp_sharding_impl( ): # Seed before FSDP2 instantiation torch.manual_seed(42) - fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir) + fsdp2_path = patch_config_file( + fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2" + ) fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path) # Seed again before TP instantiation to match torch.manual_seed(42) - tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir) + tp_path = patch_config_file( + tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp" + ) tp_model, tp_mesh = self._get_components(tp_path, tmp_path) # Ensure models use the correct MLP diff --git a/tests/test_rotary_qkv_transform.py b/tests/test_rotary_qkv_transform.py index fa82715b1..9da3bd652 100644 --- a/tests/test_rotary_qkv_transform.py +++ b/tests/test_rotary_qkv_transform.py @@ -1,3 +1,4 @@ +import pytest import torch from modalities.models.gpt2.gpt2_model import RotaryTransform @@ -41,3 +42,82 @@ def test_rotary_transform(): comp_rot_h = torch.cat([-comp_h_2, comp_h_1], dim=-1) comp_rot_expected = comp * cos_m_theta + comp_rot_h * sin_m_theta assert torch.equal(comp_rot_expected, comp_rot) + + +def _apply_rotary(x: torch.Tensor, cos_cached: torch.Tensor, sin_cached: torch.Tensor) -> torch.Tensor: + cos_local = cos_cached[:, :, : x.shape[-2], :] + sin_local = sin_cached[:, :, : x.shape[-2], :] + x1, x2 = x.chunk(2, dim=-1) + x_rot = torch.cat((-x2, x1), dim=-1) + return (x * cos_local) + (x_rot * sin_local) + + +def _assert_yarn_outputs_match_reference( + rotary_transform: RotaryTransform, + q: torch.Tensor, + k: torch.Tensor, + q_rot: torch.Tensor, + k_rot: torch.Tensor, + seq_length: int, +) -> None: + t = torch.arange(seq_length, device=q.device, dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, rotary_transform.inv_freq.to(q.dtype)) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + sin = (emb.sin() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype) + + q_expected = _apply_rotary(q, cos, sin) + k_expected = _apply_rotary(k, cos, sin) + + assert torch.allclose(q_rot, q_expected, atol=1e-5, rtol=1e-5) + assert torch.allclose(k_rot, k_expected, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize( + "rope_scaling", + [ + { + "rope_type": "yarn", + "factor": 2.0, + "beta_fast": 32, + "beta_slow": 1, + "original_max_position_embeddings": 4, + }, + { + "rope_type": "yarn", + "beta_fast": 32, + "beta_slow": 1, + "original_max_position_embeddings": 4, + }, + ], +) +def test_rotary_transform_yarn_matches_reference(rope_scaling: dict): + bs = 1 + n_heads = 2 + embedding_dim = 8 + seq_length = 8 + head_dim = embedding_dim // n_heads + + q = torch.randn(bs, n_heads, seq_length, head_dim) + k = torch.randn(bs, n_heads, seq_length, head_dim) + v = torch.randn(bs, n_heads, seq_length, head_dim) + + rotary_transform = RotaryTransform( + n_embd=embedding_dim, + n_head=n_heads, + base_freq=10000, + max_position_embeddings=seq_length, + rope_scaling=rope_scaling, + ) + + q_rot, k_rot, v_rot = rotary_transform(q=q, k=k, v=v) + assert torch.equal(v, v_rot) + + _assert_yarn_outputs_match_reference( + rotary_transform=rotary_transform, + q=q, + k=k, + q_rot=q_rot, + k_rot=k_rot, + seq_length=seq_length, + )