Skip to content

Commit 8f46692

Browse files
authored
fix: if layers == 0, layers were not initialized (#306)
* fix: if layers == 0, layers were not initialized * typing * redo initialization * remove check for linear
1 parent 86f062e commit 8f46692

1 file changed

Lines changed: 21 additions & 12 deletions

File tree

model2vec/train/classifier.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,29 @@ def classes(self) -> np.ndarray:
6666

6767
def construct_head(self) -> nn.Sequential:
6868
"""Constructs a simple classifier head."""
69+
modules: list[nn.Module] = []
6970
if self.n_layers == 0:
70-
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
71-
modules = [
72-
nn.Linear(self.embed_dim, self.hidden_dim),
73-
nn.ReLU(),
74-
]
75-
for _ in range(self.n_layers - 1):
76-
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
77-
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])
78-
79-
for module in modules:
80-
if isinstance(module, nn.Linear):
81-
nn.init.kaiming_uniform_(module.weight)
71+
modules.append(nn.Linear(self.embed_dim, self.out_dim))
72+
else:
73+
# If we have a hidden layer, we should first project to hidden_dim
74+
modules = [
75+
nn.Linear(self.embed_dim, self.hidden_dim),
76+
nn.ReLU(),
77+
]
78+
for _ in range(self.n_layers - 1):
79+
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
80+
# We always have a layer mapping from hidden to out.
81+
modules.append(nn.Linear(self.hidden_dim, self.out_dim))
82+
83+
linear_modules = [module for module in modules if isinstance(module, nn.Linear)]
84+
if linear_modules:
85+
*initial, last = linear_modules
86+
for module in initial:
87+
nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
8288
nn.init.zeros_(module.bias)
89+
# Final layer does not kaiming
90+
nn.init.xavier_uniform_(last.weight)
91+
nn.init.zeros_(last.bias)
8392

8493
return nn.Sequential(*modules)
8594

0 commit comments

Comments
 (0)