Skip to content
Open

YaRN #445

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 174 additions & 8 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,15 @@ class RotaryTransform(QueryKeyValueTransform):
XFormers implementation and removed in this implementation.#
"""

def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq: int = 10000):
def __init__(
self,
n_embd: int,
n_head: int,
seq_length_dim: int = -2,
base_freq: int = 10000,
max_position_embeddings: int | None = None,
rope_scaling: dict[str, object] | None = None,
):
"""
Initializes the RotaryTransform object.

Expand All @@ -136,16 +144,128 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq
self.dim_model = n_embd // n_head
self.seq_length_dim = seq_length_dim
self.base_freq = base_freq
self.max_position_embeddings = max_position_embeddings

self.rope_scaling = rope_scaling
self.attention_scaling = 1.0

self.reset_parameters()

def _compute_yarn_parameters(self, device: torch.device | None) -> tuple[torch.Tensor, float]:
"""Compute YaRN inverse frequencies and the attention scaling factor."""
if self.rope_scaling is None:
raise ValueError("YaRN requires a rope_scaling config.")
if self.max_position_embeddings is None:
raise ValueError("YaRN requires max_position_embeddings to be set.")

original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings")
if (
original_max_position_embeddings is None
or not isinstance(original_max_position_embeddings, int)
or original_max_position_embeddings <= 0
):
raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer")

factor = self.rope_scaling.get("factor")
if factor is None:
factor = self.max_position_embeddings / original_max_position_embeddings
if not isinstance(factor, (int, float)) or factor < 1.0:
raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0")
factor_float = float(factor)

attention_factor = self.rope_scaling.get("attention_factor")
mscale = self.rope_scaling.get("mscale")
mscale_all_dim = self.rope_scaling.get("mscale_all_dim")
beta_fast_raw = self.rope_scaling.get("beta_fast")
beta_slow_raw = self.rope_scaling.get("beta_slow")
beta_fast = float(beta_fast_raw) if isinstance(beta_fast_raw, (int, float)) else 32.0
beta_slow = float(beta_slow_raw) if isinstance(beta_slow_raw, (int, float)) else 1.0
truncate = self.rope_scaling.get("truncate", True)

def get_mscale(scale: float, mscale: float = 1.0) -> float:
"""Return the YaRN mscale coefficient for a given scaling factor."""
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

if attention_factor is None:
if isinstance(mscale, (int, float)) and isinstance(mscale_all_dim, (int, float)):
attention_factor = float(
get_mscale(factor_float, float(mscale)) / get_mscale(factor_float, float(mscale_all_dim))
)
else:
attention_factor = get_mscale(factor_float)
elif not isinstance(attention_factor, (int, float)) or attention_factor <= 0:
raise ValueError("YaRN requires rope_scaling.attention_factor to be a float > 0")

def find_correction_dim(num_rotations: float, dim: int, base: int, max_position_embeddings: int) -> float:
"""Map a target number of rotations to a rotary dimension index."""
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))

def find_correction_range(
low_rot: float,
high_rot: float,
dim: int,
base: int,
max_position_embeddings: int,
truncate: bool,
) -> tuple[float, float]:
"""Compute the lower and upper rotary-dimension correction bounds for YaRN."""
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)

def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Tensor:
"""Create a clamped linear ramp used to blend interpolation and extrapolation."""
if min_value == max_value:
max_value += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

dim = self.dim_model
base = self.base_freq

pos_freqs = base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor_float * pos_freqs)

low, high = find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
bool(truncate),
)
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)

return inv_freq, float(attention_factor)

def reset_parameters(self):
# If previously initialized on or moved to a device, reuse that device.
# Otherwise, use the default device of the current environment.
device = self.inv_freq.device if hasattr(self, "inv_freq") else None
inv_freq = 1.0 / (
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
)
device = self.inv_freq.device if hasattr(self, "inv_freq") and isinstance(self.inv_freq, torch.Tensor) else None

rope_type = "default"
if self.rope_scaling is not None:
rope_type = str(self.rope_scaling.get("rope_type", "default"))

if rope_type == "yarn":
inv_freq, self.attention_scaling = self._compute_yarn_parameters(device=device)
else:
inv_freq = 1.0 / (
self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model)
)
self.attention_scaling = 1.0

self.register_buffer("inv_freq", inv_freq)

self._seq_len_cached = None
Expand All @@ -172,15 +292,21 @@ def _update_cos_sin_tables(self, x):

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
if (
seq_len != self._seq_len_cached
or self._cos_cached is None
or self._sin_cached is None
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(
x.device
) # here, we combine the two matrices (not zipping them).
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
self._cos_cached = (emb.cos() * self.attention_scaling)[None, None, :, :].to(x.dtype)
self._sin_cached = (emb.sin() * self.attention_scaling)[None, None, :, :].to(x.dtype)

return self._cos_cached, self._sin_cached

Expand Down Expand Up @@ -295,6 +421,46 @@ class RotaryTransformConfig(BaseModel):
n_head: Annotated[int, Field(strict=True, ge=0)]
seq_length_dim: Annotated[int, Field(strict=True)]
base_freq: Annotated[int, Field(strict=True, ge=10000)]
max_position_embeddings: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
rope_scaling: Optional[dict[str, object]] = None

@model_validator(mode="after")
def validate_rope_scaling(self) -> "AttentionConfig.QueryKeyValueTransformConfig.RotaryTransformConfig":
"""Validate and normalize rope_scaling, including YaRN-specific constraints."""
if self.rope_scaling is None:
return self

if not isinstance(self.rope_scaling, dict):
raise ValueError("rope_scaling must be a dictionary")

rope_scaling = dict(self.rope_scaling)
if "type" in rope_scaling and "rope_type" not in rope_scaling:
rope_scaling["rope_type"] = rope_scaling["type"]

rope_type = rope_scaling.get("rope_type", "default")
if rope_type not in {"default", "yarn"}:
raise ValueError(
f"Unsupported rope_scaling.rope_type '{rope_type}'. Supported values are 'default' and 'yarn'."
)

if rope_type == "yarn":
if self.max_position_embeddings is None:
raise ValueError("YaRN requires max_position_embeddings to be set")

original_max_position_embeddings = rope_scaling.get("original_max_position_embeddings")
if (
original_max_position_embeddings is None
or not isinstance(original_max_position_embeddings, int)
or original_max_position_embeddings <= 0
):
raise ValueError("YaRN requires original_max_position_embeddings to be a positive integer")

factor = rope_scaling.get("factor")
if factor is not None and (not isinstance(factor, (int, float)) or factor < 1.0):
raise ValueError("YaRN requires rope_scaling.factor to be a float >= 1.0")

self.rope_scaling = rope_scaling
return self

@validator("type_hint", pre=True, always=True)
def parse_sharding_strategy_by_name(cls, name):
Expand Down
13 changes: 9 additions & 4 deletions tests/fsdp2_parallelization/test_tensor_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
from tests.utility import find_free_port


def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path) -> Path:
def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: Path, file_tag: str = "") -> Path:
"""Patches the original configuration file to set a custom activation type."""
with original_config_path.open("r", encoding="utf-8") as f:
config_dict = yaml.safe_load(f)

config_dict["model_raw"]["config"]["activation_type"] = activation_type

tmp_file_path = tmp_dir / original_config_path.name
file_suffix = f"_{file_tag}" if file_tag else ""
tmp_file_path = tmp_dir / f"{original_config_path.stem}{file_suffix}{original_config_path.suffix}"
with tmp_file_path.open("w", encoding="utf-8") as f:
yaml.safe_dump(config_dict, f)

Expand Down Expand Up @@ -103,12 +104,16 @@ def _test_tp_sharding_impl(
):
# Seed before FSDP2 instantiation
torch.manual_seed(42)
fsdp2_path = patch_config_file(fsdp2_config_path, activation_type, tmp_config_dir)
fsdp2_path = patch_config_file(
fsdp2_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_fsdp2"
)
fsdp2_model, fsdp2_mesh = self._get_components(fsdp2_path, tmp_path)

# Seed again before TP instantiation to match
torch.manual_seed(42)
tp_path = patch_config_file(tp_config_path, activation_type, tmp_config_dir)
tp_path = patch_config_file(
tp_config_path, activation_type, tmp_config_dir, file_tag=f"{activation_type}_rank{process_id}_tp"
)
tp_model, tp_mesh = self._get_components(tp_path, tmp_path)

# Ensure models use the correct MLP
Expand Down
80 changes: 80 additions & 0 deletions tests/test_rotary_qkv_transform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

from modalities.models.gpt2.gpt2_model import RotaryTransform
Expand Down Expand Up @@ -41,3 +42,82 @@ def test_rotary_transform():
comp_rot_h = torch.cat([-comp_h_2, comp_h_1], dim=-1)
comp_rot_expected = comp * cos_m_theta + comp_rot_h * sin_m_theta
assert torch.equal(comp_rot_expected, comp_rot)


def _apply_rotary(x: torch.Tensor, cos_cached: torch.Tensor, sin_cached: torch.Tensor) -> torch.Tensor:
cos_local = cos_cached[:, :, : x.shape[-2], :]
sin_local = sin_cached[:, :, : x.shape[-2], :]
x1, x2 = x.chunk(2, dim=-1)
x_rot = torch.cat((-x2, x1), dim=-1)
return (x * cos_local) + (x_rot * sin_local)


def _assert_yarn_outputs_match_reference(
rotary_transform: RotaryTransform,
q: torch.Tensor,
k: torch.Tensor,
q_rot: torch.Tensor,
k_rot: torch.Tensor,
seq_length: int,
) -> None:
t = torch.arange(seq_length, device=q.device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, rotary_transform.inv_freq.to(q.dtype))
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype)
sin = (emb.sin() * rotary_transform.attention_scaling)[None, None, :, :].to(q.dtype)

q_expected = _apply_rotary(q, cos, sin)
k_expected = _apply_rotary(k, cos, sin)

assert torch.allclose(q_rot, q_expected, atol=1e-5, rtol=1e-5)
assert torch.allclose(k_rot, k_expected, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize(
"rope_scaling",
[
{
"rope_type": "yarn",
"factor": 2.0,
"beta_fast": 32,
"beta_slow": 1,
"original_max_position_embeddings": 4,
},
{
"rope_type": "yarn",
"beta_fast": 32,
"beta_slow": 1,
"original_max_position_embeddings": 4,
},
],
)
def test_rotary_transform_yarn_matches_reference(rope_scaling: dict):
bs = 1
n_heads = 2
embedding_dim = 8
seq_length = 8
head_dim = embedding_dim // n_heads

q = torch.randn(bs, n_heads, seq_length, head_dim)
k = torch.randn(bs, n_heads, seq_length, head_dim)
v = torch.randn(bs, n_heads, seq_length, head_dim)

rotary_transform = RotaryTransform(
n_embd=embedding_dim,
n_head=n_heads,
base_freq=10000,
max_position_embeddings=seq_length,
rope_scaling=rope_scaling,
)

q_rot, k_rot, v_rot = rotary_transform(q=q, k=k, v=v)
assert torch.equal(v, v_rot)

_assert_yarn_outputs_match_reference(
rotary_transform=rotary_transform,
q=q,
k=k,
q_rot=q_rot,
k_rot=k_rot,
seq_length=seq_length,
)
Loading