diff --git a/gigl/src/common/models/pyg/nn/models/feature_embedding.py b/gigl/src/common/models/pyg/nn/models/feature_embedding.py index 26bc3d645..f98d04021 100644 --- a/gigl/src/common/models/pyg/nn/models/feature_embedding.py +++ b/gigl/src/common/models/pyg/nn/models/feature_embedding.py @@ -68,13 +68,12 @@ def __init__( self.__padding_idx = padding_idx self.__feature_padding_value_map = feature_padding_value_map - # whether to add 1 to the whole tensor, so all elements in tensor is >= 0 for nn.Embedding input - # Since tft.compute_and_apply_vocabulary will be 0 based and use -1 as OOV padding - self.__plus_one = False - self.__oov_idx: Optional[int] = None assert oov_idx is None or oov_idx >= -1, "oov_idx has to be >= -1" - if oov_idx and oov_idx == -1: - self.__plus_one = True + # Per-feature flag: shift indices by +1 when int_domain.min == -1, which is the + # single-OOV-bucket case where tft.compute_and_apply_vocabulary uses -1 for OOV. + # When num_oov_buckets > 1, TFT assigns OOV indices starting at vocab_size (min == 0), + # so no shift is needed — applying +1 would push the max OOV index out of bounds. + self.__feature_plus_one: dict[str, bool] = {} for feature_name, emb_dim in features_to_embed.items(): feat_dim = get_feature_len_from_fixed_len_feature( @@ -99,6 +98,7 @@ def __init__( "If int_domain.min_value is -1, oov_idx must also be -1" ) vocab_size = feat_schema.int_domain.max - feat_schema.int_domain.min + 1 + self.__feature_plus_one[feature_name] = feat_schema.int_domain.min == -1 feature_padding_idx: Optional[int] if ( @@ -113,7 +113,7 @@ def __init__( ].index(feature_padding_value) else: feature_padding_idx = self.__padding_idx - if self.__plus_one and feature_padding_idx is not None: + if self.__feature_plus_one[feature_name] and feature_padding_idx is not None: feature_padding_idx = feature_padding_idx + 1 self.__feature_embedding_layers[feature_name] = nn.Embedding( num_embeddings=vocab_size, @@ -133,7 +133,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_to_emb = filter_features( feature_schema=self.__feature_schema, feature_names=[feature], x=x ).long() # embedding layer takes LongTensor - if self.__plus_one: + if self.__feature_plus_one[feature]: x_to_emb = x_to_emb + 1 emb_layer = self.__feature_embedding_layers[feature] emb = emb_layer(x_to_emb)