@@ -66,20 +66,29 @@ def classes(self) -> np.ndarray:
6666
6767 def construct_head (self ) -> nn .Sequential :
6868 """Constructs a simple classifier head."""
69+ modules : list [nn .Module ] = []
6970 if self .n_layers == 0 :
70- return nn .Sequential (nn .Linear (self .embed_dim , self .out_dim ))
71- modules = [
72- nn .Linear (self .embed_dim , self .hidden_dim ),
73- nn .ReLU (),
74- ]
75- for _ in range (self .n_layers - 1 ):
76- modules .extend ([nn .Linear (self .hidden_dim , self .hidden_dim ), nn .ReLU ()])
77- modules .extend ([nn .Linear (self .hidden_dim , self .out_dim )])
78-
79- for module in modules :
80- if isinstance (module , nn .Linear ):
81- nn .init .kaiming_uniform_ (module .weight )
71+ modules .append (nn .Linear (self .embed_dim , self .out_dim ))
72+ else :
73+ # If we have a hidden layer, we should first project to hidden_dim
74+ modules = [
75+ nn .Linear (self .embed_dim , self .hidden_dim ),
76+ nn .ReLU (),
77+ ]
78+ for _ in range (self .n_layers - 1 ):
79+ modules .extend ([nn .Linear (self .hidden_dim , self .hidden_dim ), nn .ReLU ()])
80+ # We always have a layer mapping from hidden to out.
81+ modules .append (nn .Linear (self .hidden_dim , self .out_dim ))
82+
83+ linear_modules = [module for module in modules if isinstance (module , nn .Linear )]
84+ if linear_modules :
85+ * initial , last = linear_modules
86+ for module in initial :
87+ nn .init .kaiming_uniform_ (module .weight , nonlinearity = "relu" )
8288 nn .init .zeros_ (module .bias )
89+ # Final layer does not kaiming
90+ nn .init .xavier_uniform_ (last .weight )
91+ nn .init .zeros_ (last .bias )
8392
8493 return nn .Sequential (* modules )
8594
0 commit comments