diff --git a/csrc/config/config_factory.cpp b/csrc/config/config_factory.cpp index 09e21e93..2beaa29d 100644 --- a/csrc/config/config_factory.cpp +++ b/csrc/config/config_factory.cpp @@ -17,7 +17,7 @@ std::shared_ptr ConfigFactory::createConfig(const if (it != config_map.end()) { it->second(model_config); } else { - std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"}; + std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b", "glm4"}; const std::string &model_type = model_config->get("model_type"); if (std::find(classic_models.begin(), classic_models.end(), model_type) == classic_models.end()) { throw std::invalid_argument("infinilm::config::ConfigFactory::createConfig: Unsupported model config type: " + model_type); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8a94c441..f600d113 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -281,7 +281,13 @@ void RankWorker::thread_loop() { const std::string &model_type = model_config_->get("model_type"); const auto &model_map = models::get_causal_lm_model_map(); auto it = model_map.find(model_type); - if (it != model_map.end()) { + if (model_type == "glm4") { + model_ = InfinilmModelFactory::createModel( + model_config_, + rank_info_, + pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, + attention_backend_); + } else if (it != model_map.end()) { model_ = InfinilmModelFactory::createModel( model_config_, rank_info_.device, diff --git a/csrc/models/glm4/glm4_for_causal_lm.cpp b/csrc/models/glm4/glm4_for_causal_lm.cpp new file mode 100644 index 00000000..1bcbaf7d --- /dev/null +++ b/csrc/models/glm4/glm4_for_causal_lm.cpp @@ -0,0 +1,42 @@ +#include "glm4_for_causal_lm.hpp" +#include "../llama/llama_for_causal_lm.hpp" +#include "../models_registry.hpp" + +namespace infinilm::models::glm4 { + +std::shared_ptr create_glm4_model_config( + std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("glm4" != model_type) { + throw std::runtime_error( + "infinilm::models::glm4::create_glm4_model_config: model_type is not glm4"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + + if (!config_json.contains("head_dim")) { + config_json["head_dim"] = model_config->get("hidden_size") + / model_config->get("num_attention_heads"); + } + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = true; + } + + return model_config; +} + +} // namespace infinilm::models::glm4 + +namespace { + +#ifndef USE_CLASSIC_LLAMA + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + glm4, + infinilm::models::llama::LlamaForCausalLM, + infinilm::models::glm4::create_glm4_model_config); + +#endif + +} // namespace diff --git a/csrc/models/glm4/glm4_for_causal_lm.hpp b/csrc/models/glm4/glm4_for_causal_lm.hpp new file mode 100644 index 00000000..8a412be0 --- /dev/null +++ b/csrc/models/glm4/glm4_for_causal_lm.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include + +namespace infinilm::models::glm4 { + +std::shared_ptr create_glm4_model_config( + std::shared_ptr model_config); + +} // namespace infinilm::models::glm4 diff --git a/csrc/models/llama_legacy/llama_attention.cpp b/csrc/models/llama_legacy/llama_attention.cpp index a95bb74c..70afed6f 100644 --- a/csrc/models/llama_legacy/llama_attention.cpp +++ b/csrc/models/llama_legacy/llama_attention.cpp @@ -1,6 +1,7 @@ #include "llama_attention.hpp" #include "../../utils.hpp" +#include "llama_utils.hpp" #include "infinicore/nn/linear.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" @@ -42,6 +43,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, num_attention_heads_(config.num_attention_heads), num_key_value_heads_(config.num_key_value_heads), head_dim_(config.head_dim), + rotary_dim_(config.head_dim), kv_dim_(config.kv_dim()), use_bias_(config.attention_bias), use_output_bias_(config.attention_output_bias), @@ -90,6 +92,7 @@ LlamaAttention::LlamaAttention(std::shared_ptr mo num_attention_heads_(model_config->get("num_attention_heads")), num_key_value_heads_(model_config->get("num_key_value_heads")), head_dim_(model_config->get_head_dim()), + rotary_dim_(get_rotary_dim(model_config->get_head_dim(), model_config->get_or("partial_rotary_factor", 1.0))), kv_dim_(model_config->get_kv_dim()), use_bias_(model_config->get_or("attention_bias", true)), use_output_bias_(model_config->get_or("attention_output_bias", false)), @@ -204,8 +207,21 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta // 4. Apply RoPE to Q and K auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3}); - rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] - rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + q_rope->copy_from(q_reshaped); + + auto k_rope = infinicore::Tensor::empty({batch_size, seq_len, num_key_value_heads_, head_dim_}, k_reshaped->dtype(), k_reshaped->device()); + k_rope->copy_from(k_reshaped); + + if (rotary_dim_ == head_dim_) { + rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] + rotary_emb_->forward(k_rope, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + } else { + rotary_emb_->forward( + q_rope->narrow({{3, 0, rotary_dim_}}), + q_reshaped->narrow({{3, 0, rotary_dim_}}), + pos_ids_for_rope); + rotary_emb_->forward(k_rope->narrow({{3, 0, rotary_dim_}}), pos_ids_for_rope, true); + } infinilm::KVQuantUtils::quantize( k_reshaped, v_reshaped, @@ -217,7 +233,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta // Convert to [batch, n_head, seq_len, head_dim] for cache // Ensure contiguous after permute for F16 compatibility with cache operations q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] - auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] + auto k_permuted = k_rope->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] infinicore::Tensor k_total; // [bs, n_kv_head, max_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, max_seq_len, head_dim] @@ -330,8 +346,19 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd } // 4. Apply RoPE to Q and K - rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_q_head, head_dim] - rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + if (rotary_dim_ == head_dim_) { + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); // [seq_len, n_q_head, head_dim] + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [seq_len, n_kv_head, head_dim] + } else { + auto q_rope = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); + q_rope->copy_from(q_reshaped); + auto k_rope = infinicore::Tensor::empty({seq_len, num_key_value_heads_, head_dim_}, k_reshaped->dtype(), k_reshaped->device()); + k_rope->copy_from(k_reshaped); + rotary_emb_->forward(q_rope->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); + rotary_emb_->forward(k_rope->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); + q_reshaped = q_rope; + k_reshaped = k_rope; + } // 5. Prepare KV caches // Ensure contiguous after permute for F16 compatibility with cache operations diff --git a/csrc/models/llama_legacy/llama_attention.hpp b/csrc/models/llama_legacy/llama_attention.hpp index c1fb4871..046c6c4e 100644 --- a/csrc/models/llama_legacy/llama_attention.hpp +++ b/csrc/models/llama_legacy/llama_attention.hpp @@ -135,6 +135,7 @@ class LlamaAttention : public infinicore::nn::Module { size_t num_attention_heads_; size_t num_key_value_heads_; size_t head_dim_; + size_t rotary_dim_; size_t kv_dim_; bool use_bias_; // Bias for Q/K/V projections bool use_output_bias_; // Bias for output projection (o_proj) diff --git a/csrc/models/llama_legacy/llama_decoder_layer.cpp b/csrc/models/llama_legacy/llama_decoder_layer.cpp index 6ea5215e..6c61c54d 100644 --- a/csrc/models/llama_legacy/llama_decoder_layer.cpp +++ b/csrc/models/llama_legacy/llama_decoder_layer.cpp @@ -40,11 +40,18 @@ LlamaDecoderLayer::LlamaDecoderLayer(std::shared_ptrget_dtype()}; + use_glm4_post_norms_ = model_config_->get("model_type") == "glm4"; // Initialize layer normalization layers INFINICORE_NN_MODULE_INIT(input_layernorm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), dtype, device); INFINICORE_NN_MODULE_INIT(post_attention_layernorm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), dtype, device); + if (use_glm4_post_norms_) { + INFINICORE_NN_MODULE_INIT(post_self_attn_layernorm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), + dtype, device); + INFINICORE_NN_MODULE_INIT(post_mlp_layernorm, model_config_->get("hidden_size"), model_config_->get("rms_norm_eps"), + dtype, device); + } // Initialize attention and MLP modules INFINICORE_NN_MODULE_INIT(self_attn, model_config_, device, layer_idx, rank_info_, attention_backend); @@ -62,6 +69,20 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states, std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping) const { + if (use_glm4_post_norms_) { + hidden_states = forward_naive( + hidden_states, + position_ids, + kv_cache, + past_sequence_lengths, + total_sequence_lengths, + input_offsets, + cu_seqlens, + block_tables, + slot_mapping); + return std::make_tuple(hidden_states, residual); + } + // 1. Attention layer normalization input_layernorm_->forward_inplace(hidden_states, residual); @@ -78,4 +99,34 @@ LlamaDecoderLayer::forward(infinicore::Tensor &hidden_states, return std::make_tuple(hidden_states, residual); } +infinicore::Tensor LlamaDecoderLayer::forward_naive( + infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional cu_seqlens, + std::optional block_tables, + std::optional slot_mapping) const { + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = self_attn_->forward( + hidden_states, position_ids, kv_cache, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping); + + if (use_glm4_post_norms_) { + hidden_states = post_self_attn_layernorm_->forward(hidden_states); + } + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + if (use_glm4_post_norms_) { + hidden_states = post_mlp_layernorm_->forward(hidden_states); + } + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; +} + } // namespace infinilm::models::llama_legacy diff --git a/csrc/models/llama_legacy/llama_decoder_layer.hpp b/csrc/models/llama_legacy/llama_decoder_layer.hpp index 3ea152bf..a0f58c26 100644 --- a/csrc/models/llama_legacy/llama_decoder_layer.hpp +++ b/csrc/models/llama_legacy/llama_decoder_layer.hpp @@ -79,6 +79,17 @@ class LlamaDecoderLayer : public infinicore::nn::Module { std::optional block_tables, std::optional slot_mappin) const; + infinicore::Tensor forward_naive( + infinicore::Tensor &hidden_states, + const infinicore::Tensor &position_ids, + std::shared_ptr kv_cache, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional cu_seqlens, + std::optional block_tables, + std::optional slot_mapping) const; + /** * @brief Get the layer index */ @@ -94,6 +105,8 @@ class LlamaDecoderLayer : public infinicore::nn::Module { // Layer normalization INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_self_attn_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_mlp_layernorm); // Attention and MLP INFINICORE_NN_MODULE(LlamaAttention, self_attn); @@ -103,6 +116,7 @@ class LlamaDecoderLayer : public infinicore::nn::Module { private: size_t layer_idx_; // Layer index for cache management and debugging + bool use_glm4_post_norms_ = false; }; } // namespace infinilm::models::llama_legacy diff --git a/csrc/models/llama_legacy/llama_model.cpp b/csrc/models/llama_legacy/llama_model.cpp index 20724135..e31ab640 100644 --- a/csrc/models/llama_legacy/llama_model.cpp +++ b/csrc/models/llama_legacy/llama_model.cpp @@ -1,4 +1,5 @@ #include "llama_model.hpp" +#include "llama_utils.hpp" #include "infinicore/nn/embedding.hpp" #include "infinicore/nn/rmsnorm.hpp" #include "infinicore/nn/rope.hpp" @@ -6,6 +7,22 @@ #include namespace infinilm::models::llama_legacy { + +namespace { + +infinicore::nn::RoPE::Algo get_rope_algo(const std::string &model_type) { + if (model_type == "glm4") { + return infinicore::nn::RoPE::Algo::GPT_J; + } + return infinicore::nn::RoPE::Algo::GPT_NEOX; +} + +bool uses_naive_residual_path(const std::shared_ptr &model_config) { + return model_config != nullptr && model_config->get("model_type") == "glm4"; +} + +} // namespace + /** * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0). * @@ -45,7 +62,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, // Initialize Rotary Position Embeddings (shared across all layers) // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing INFINICORE_NN_MODULE_INIT(rotary_emb, config.head_dim, config.max_position_embeddings, - config.rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX, + config.rope_theta, get_rope_algo(config.model_type), dtype, device, config.rope_scaling); for (auto &layer : layers_) { @@ -78,8 +95,8 @@ LlamaModel::LlamaModel(std::shared_ptr model_conf dtype, device); // Initialize Rotary Position Embeddings (shared across all layers) // Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing - INFINICORE_NN_MODULE_INIT(rotary_emb, model_config_->get_head_dim(), model_config_->get("max_position_embeddings"), - model_config_->get("rope_theta"), infinicore::nn::RoPE::Algo::GPT_NEOX, + INFINICORE_NN_MODULE_INIT(rotary_emb, get_rotary_dim(model_config_->get_head_dim(), model_config_->get_or("partial_rotary_factor", 1.0)), model_config_->get("max_position_embeddings"), + model_config_->get("rope_theta"), get_rope_algo(model_config_->get("model_type")), dtype, device, model_config_->get_rope_scaling()); for (auto &layer : layers_) { @@ -117,6 +134,10 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, slot_mapping); } + if (uses_naive_residual_path(model_config_)) { + return norm_->forward(hidden_states); + } + norm_->forward_inplace(hidden_states, residual); return hidden_states; @@ -136,6 +157,11 @@ infinicore::Tensor LlamaModel::forward_embeds(const infinicore::Tensor &inputs_e for (size_t i = 0; i < num_layers; ++i) { layers_.at(i)->forward(hidden_states, residual, position_ids, kv_cache_, past_sequence_lengths, total_sequence_lengths, input_offsets, cu_seqlens, block_tables, slot_mapping); } + + if (uses_naive_residual_path(model_config_)) { + return norm_->forward(hidden_states); + } + norm_->forward_inplace(hidden_states, residual); return hidden_states; diff --git a/csrc/models/llama_legacy/llama_utils.hpp b/csrc/models/llama_legacy/llama_utils.hpp new file mode 100644 index 00000000..3c582e55 --- /dev/null +++ b/csrc/models/llama_legacy/llama_utils.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +namespace infinilm::models::llama_legacy { + +inline size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor) { + if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) { + return head_dim; + } + + size_t rotary_dim = static_cast(std::llround( + static_cast(head_dim) * partial_rotary_factor)); + rotary_dim = std::clamp(rotary_dim, static_cast(2), head_dim); + if (rotary_dim % 2 != 0) { + rotary_dim -= 1; + } + return std::max(rotary_dim, static_cast(2)); +} + +} // namespace infinilm::models::llama_legacy diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index c1d8b32f..6c0a8a0f 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -46,5 +46,7 @@ def from_pretrained(model_path): cfg = LlamaConfig(**config_dict) cfg.model_type = "minicpmv" return cfg + elif config_dict["model_type"] == "glm4": + return LlamaConfig(**config_dict) raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 827923f3..48142c03 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -1,4 +1,5 @@ import os +import re from typing import Dict, Union import time import torch @@ -41,6 +42,43 @@ def parse_dtype(dtype_str: str): } +def _split_first_dim(tensor, sizes, name): + if tensor.dim() not in (1, 2): + raise ValueError(f"Cannot split {name} with shape {tensor.shape}") + return torch.split(tensor, sizes, dim=0) + + +def _remap_glm4_weights(state_dict): + new_sd = {} + + for key, tensor in state_dict.items(): + if "gate_up_proj" not in key: + new_sd[key] = tensor + continue + + base_key = key.replace(".gate_up_proj.weight", "") + intermediate = tensor.shape[0] // 2 + gate, up = _split_first_dim( + tensor, + [intermediate, tensor.shape[0] - intermediate], + "gate_up_proj", + ) + new_sd[f"{base_key}.gate_proj.weight"] = gate + new_sd[f"{base_key}.up_proj.weight"] = up + return new_sd + + +def maybe_remap_weights(state_dict, model): + if not hasattr(model, "hf_config"): + return state_dict + + hf_config = model.hf_config + model_type = hf_config.get("model_type", "") + if model_type == "glm4": + return _remap_glm4_weights(state_dict) + return state_dict + + def check_parameters(model_keys: list, already_loaded_keys: list): model_keys = set(model_keys) already_loaded_keys = set(already_loaded_keys) @@ -165,6 +203,7 @@ def load_model_state_dict_by_file( model_param = load_state_dict( file_path, device=torch_device, dtype=torch_dtype ) + model_param = maybe_remap_weights(model_param, model) already_loaded_keys.extend(model_param.keys()) # --------------------------------------------------------- # @@ -181,6 +220,8 @@ def load_model_state_dict_by_file( file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") + model_params = maybe_remap_weights(model_params, model) + model_param_infini = {} for key in model_params.keys(): model_param_infini[key] = infinicore.from_torch(