diff --git a/csrc/config/config_factory.cpp b/csrc/config/config_factory.cpp index 09e21e93..f651374a 100644 --- a/csrc/config/config_factory.cpp +++ b/csrc/config/config_factory.cpp @@ -17,7 +17,7 @@ std::shared_ptr ConfigFactory::createConfig(const if (it != config_map.end()) { it->second(model_config); } else { - std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"}; + std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b", "baichuan"}; const std::string &model_type = model_config->get("model_type"); if (std::find(classic_models.begin(), classic_models.end(), model_type) == classic_models.end()) { throw std::invalid_argument("infinilm::config::ConfigFactory::createConfig: Unsupported model config type: " + model_type); diff --git a/csrc/models/baichuan/baichuan_for_causal_lm.cpp b/csrc/models/baichuan/baichuan_for_causal_lm.cpp new file mode 100644 index 00000000..9894952b --- /dev/null +++ b/csrc/models/baichuan/baichuan_for_causal_lm.cpp @@ -0,0 +1,50 @@ +#include "baichuan_for_causal_lm.hpp" +#include "../llama/llama_for_causal_lm.hpp" +#include "../models_registry.hpp" + +namespace infinilm::models::baichuan { + +std::shared_ptr create_baichuan_model_config( + std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("baichuan" != model_type) { + throw std::runtime_error( + "infinilm::models::baichuan::create_baichuan_model_config: model_type is not baichuan"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + + if (!config_json.contains("num_key_value_heads")) { + config_json["num_key_value_heads"] = model_config->get("num_attention_heads"); + } + + if (!config_json.contains("head_dim")) { + config_json["head_dim"] = model_config->get("hidden_size") + / model_config->get("num_attention_heads"); + } + + if (!config_json.contains("rope_theta")) { + config_json["rope_theta"] = 10000.0; + } + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = false; + } + + return model_config; +} + +} // namespace infinilm::models::baichuan + +namespace { + +#ifndef USE_CLASSIC_LLAMA + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + baichuan, + infinilm::models::llama::LlamaForCausalLM, + infinilm::models::baichuan::create_baichuan_model_config); + +#endif + +} // namespace diff --git a/csrc/models/baichuan/baichuan_for_causal_lm.hpp b/csrc/models/baichuan/baichuan_for_causal_lm.hpp new file mode 100644 index 00000000..d879534b --- /dev/null +++ b/csrc/models/baichuan/baichuan_for_causal_lm.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include + +namespace infinilm::models::baichuan { + +std::shared_ptr create_baichuan_model_config( + std::shared_ptr model_config); + +} // namespace infinilm::models::baichuan diff --git a/examples/test_infer.py b/examples/test_infer.py index abec5d00..c6ab1cff 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -6,6 +6,7 @@ from infinilm.distributed import DistConfig from infinilm.infer_engine import GenerationConfig, InferEngine import argparse +import json import sys import time import os @@ -22,6 +23,37 @@ _PAGED_KV_BLOCK_SIZE = 256 +def _get_baichuan_role_token_ids(model_path: str) -> tuple[int, int]: + user_token_id = 195 + assistant_token_id = 196 + generation_config_path = os.path.join(model_path, "generation_config.json") + if os.path.exists(generation_config_path): + with open(generation_config_path, "r") as f: + generation_config = json.load(f) + user_token_id = int(generation_config.get("user_token_id", user_token_id)) + assistant_token_id = int( + generation_config.get("assistant_token_id", assistant_token_id) + ) + return user_token_id, assistant_token_id + + +def _encode_baichuan_chat_prompts( + prompts: list[str], + tokenizer: AutoTokenizer, + model_path: str, + max_length: int, +) -> list[list[int]]: + user_token_id, assistant_token_id = _get_baichuan_role_token_ids(model_path) + max_content_length = max(0, max_length - 2) + input_ids_list = [] + for prompt in prompts: + content_ids = tokenizer.encode(prompt, add_special_tokens=False) + if len(content_ids) > max_content_length: + content_ids = content_ids[-max_content_length:] + input_ids_list.append([user_token_id, *content_ids, assistant_token_id]) + return input_ids_list + + def test( prompts: str | list[str], model_path, @@ -104,7 +136,10 @@ def test( updated_prompts.append(prompt) prompts = updated_prompts - if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + used_chat_template = ( + hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None + ) + if used_chat_template: input_contents = [ tokenizer.apply_chat_template( conversation=[{"role": "user", "content": prompt}], @@ -139,20 +174,14 @@ def test( else: raise ValueError(f"Unsupported multimodal model_type: {model.model_type}") else: - if hasattr(tokenizer, "batch_encode_plus"): - input_ids_list = tokenizer.batch_encode_plus(input_contents)["input_ids"] - elif hasattr(tokenizer, "_encode_plus"): - input_ids_list = tokenizer._encode_plus(input_contents)["input_ids"] - else: - input_ids_list = tokenizer(input_contents)[ - "input_ids" - ] # List: [[1, 1128, 526, 366, 29892]] - - # input_ids_list = tokenizer.batch_encode_plus(input_contents)[ - # "input_ids" - # ] # List: [[1, 1128, 526, 366, 29892]] - if version.parse(transformers.__version__) < version.parse("5.0.0"): - # Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used. + if model.model_type == "baichuan" and not used_chat_template: + input_ids_list = _encode_baichuan_chat_prompts( + prompts, + tokenizer, + model_path, + max_length=2048, + ) + elif version.parse(transformers.__version__) < version.parse("5.0.0"): input_ids_list = [ tokenizer.encode_plus( text, truncation=True, max_length=2048, add_special_tokens=True diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index c1d8b32f..6fdc0d91 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -47,4 +47,7 @@ def from_pretrained(model_path): cfg.model_type = "minicpmv" return cfg + elif config_dict["model_type"] == "baichuan": + return LlamaConfig(**config_dict) + raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 827923f3..db208a31 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -1,4 +1,5 @@ import os +import re from typing import Dict, Union import time import torch @@ -41,6 +42,48 @@ def parse_dtype(dtype_str: str): } +def _split_first_dim(tensor, sizes, name): + if tensor.dim() not in (1, 2): + raise ValueError(f"Cannot split {name} with shape {tensor.shape}") + return torch.split(tensor, sizes, dim=0) + + +def _remap_baichuan_weights(state_dict, hf_config): + hidden_size = hf_config.get("hidden_size", 4096) + num_heads = hf_config.get("num_attention_heads", 32) + per_head_dim = num_heads * (hidden_size // num_heads) + new_sd = {} + + for key, tensor in state_dict.items(): + wpack_match = re.match(r"(.*\.)W_pack\.(weight|bias)", key) + if not wpack_match: + new_sd[key] = tensor + continue + + prefix = wpack_match.group(1) + suffix = wpack_match.group(2) + q, k, v = _split_first_dim( + tensor, + [per_head_dim, per_head_dim, tensor.shape[0] - 2 * per_head_dim], + "W_pack", + ) + new_sd[f"{prefix}q_proj.{suffix}"] = q + new_sd[f"{prefix}k_proj.{suffix}"] = k + new_sd[f"{prefix}v_proj.{suffix}"] = v + return new_sd + + +def maybe_remap_weights(state_dict, model): + if not hasattr(model, "hf_config"): + return state_dict + + hf_config = model.hf_config + model_type = hf_config.get("model_type", "") + if model_type == "baichuan": + return _remap_baichuan_weights(state_dict, hf_config) + return state_dict + + def check_parameters(model_keys: list, already_loaded_keys: list): model_keys = set(model_keys) already_loaded_keys = set(already_loaded_keys) @@ -165,6 +208,7 @@ def load_model_state_dict_by_file( model_param = load_state_dict( file_path, device=torch_device, dtype=torch_dtype ) + model_param = maybe_remap_weights(model_param, model) already_loaded_keys.extend(model_param.keys()) # --------------------------------------------------------- # @@ -181,6 +225,8 @@ def load_model_state_dict_by_file( file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") + model_params = maybe_remap_weights(model_params, model) + model_param_infini = {} for key in model_params.keys(): model_param_infini[key] = infinicore.from_torch(