diff --git a/CMakeLists.txt b/CMakeLists.txt index df636b27..bb879821 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,3 +204,29 @@ link_infini_train_exe(test_precision_check) add_executable(test_lora test/lora/test_lora.cc) link_infini_train_exe(test_lora) +add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc) +link_infini_train_exe(test_lr_scheduler) + +add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc) +link_infini_train_exe(test_constant_lr) + +add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc) +link_infini_train_exe(test_step_lr) + +add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc) +link_infini_train_exe(test_linear_lr) + +add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc) +link_infini_train_exe(test_lambda_lr) + +add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc) +link_infini_train_exe(test_sequential_lr) + +add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc) +link_infini_train_exe(test_chained_lr) + +add_executable(test_training_lr_scheduler test/lr_scheduler/test_training_lr_scheduler.cc) +link_infini_train_exe(test_training_lr_scheduler) + +add_executable(test_lr_scheduler_validation test/lr_scheduler/test_lr_scheduler_validation.cc) +link_infini_train_exe(test_lr_scheduler_validation) diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 8e28af52..740fe2a1 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -13,6 +13,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -54,8 +55,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-4, "Peak learning rate."); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -98,6 +105,8 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; // const std::unordered_map kModelToConfigs = { @@ -118,6 +127,8 @@ const std::unordered_map kStrToModelType = { DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -310,6 +321,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -353,6 +374,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; // model->Train(); if (pp_world_size == 1) { @@ -396,6 +418,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -405,6 +430,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -420,11 +448,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index acc20ac4..83b3c5c4 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -11,6 +11,7 @@ #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" +#include "infini_train/include/lr_scheduler.h" #include "infini_train/include/nn/lora/lora_utils.h" #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" @@ -53,8 +54,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run"); DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization -DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); +DEFINE_double(learning_rate, 1e-5, "Peak learning rate."); DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); +// lr scheduler +DEFINE_double(min_lr, 0.0, "Minimum learning rate."); +DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root"); +DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations."); +DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup."); +DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration)."); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -93,11 +100,15 @@ constexpr char kDeviceCPU[] = "cpu"; constexpr char kDeviceCUDA[] = "cuda"; constexpr char kDtypeFP32[] = "float32"; constexpr char kDtypeBF16[] = "bfloat16"; +const std::unordered_set kSupportedLRDecayStyles + = {"none", "constant", "linear", "cosine", "inverse-square-root"}; } // namespace DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); }); DEFINE_validator(device, [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; }); +DEFINE_validator(lr_decay_style, + [](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); }); void Train(const nn::parallel::Rank &rank) { using namespace nn::parallel; @@ -282,6 +293,16 @@ void Train(const nn::parallel::Rank &rank) { optimizer = optimizer_creator(params_to_optimize); } + const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration; + TrainingLRSchedulerConfig sched_config; + sched_config.lr = static_cast(FLAGS_learning_rate); + sched_config.min_lr = static_cast(FLAGS_min_lr); + sched_config.lr_decay_style = FLAGS_lr_decay_style; + sched_config.lr_decay_iters = lr_decay_iters; + sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters; + sched_config.lr_warmup_init = static_cast(FLAGS_lr_warmup_init); + auto scheduler = CreateLRScheduler(optimizer, sched_config); + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -322,6 +343,7 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().SetTag("Step_" + std::to_string(step)); #endif + const float current_lr = scheduler ? scheduler->GetLR() : static_cast(FLAGS_learning_rate); float lossf = 0.0f; if (pp_world_size == 1) { // model->Train(); @@ -365,6 +387,9 @@ void Train(const nn::parallel::Rank &rank) { } optimizer->Step(); + if (scheduler) { + scheduler->Step(); + } } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -374,6 +399,9 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); + if (scheduler) { + scheduler->Step(); + } } if (ddp_world_size > 1) { @@ -389,11 +417,10 @@ void Train(const nn::parallel::Rank &rank) { if (rank.IsLastRank()) { size_t used_mb = 0, reserved_mb = 0; std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", - step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps, + used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { diff --git a/infini_train/include/lr_scheduler.h b/infini_train/include/lr_scheduler.h new file mode 100644 index 00000000..4b4428cf --- /dev/null +++ b/infini_train/include/lr_scheduler.h @@ -0,0 +1,175 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infini_train { + +class Optimizer; + +using StateValue = std::variant>; +using StateDict = std::unordered_map; + +struct TrainingLRSchedulerConfig { + std::string lr_decay_style = "constant"; + float lr = 0.0f; + float min_lr = 0.0f; + int64_t lr_decay_iters = 1; + int64_t lr_warmup_iters = 0; + float lr_warmup_init = 0.0f; +}; + +class LRScheduler { +public: + template static std::shared_ptr Create(Args &&...args) { + auto scheduler = std::make_shared(std::forward(args)...); + scheduler->InitialStep(); + return scheduler; + } + + explicit LRScheduler(std::shared_ptr optimizer, int64_t last_step = -1); + virtual ~LRScheduler() = default; + + LRScheduler(const LRScheduler &) = delete; + LRScheduler &operator=(const LRScheduler &) = delete; + + virtual void Step(); + virtual void Step(int64_t epoch); + virtual void InitialStep(); + + float GetLR() const; + float BaseLR() const; + int64_t LastStep() const; + + void ResetStep(int64_t step = -1); + virtual StateDict State() const; + virtual void LoadState(const StateDict &state); + + bool SharesOptimizerWith(const std::shared_ptr &opt) const; + +protected: + virtual float GetClosedFormLR() const = 0; + virtual float GetChainedFormLR() const; + void ApplyLR(float lr); + + std::shared_ptr optimizer_; + int64_t last_step_; + float recover_lr_; + float base_lr_; + bool is_initial_ = false; +}; + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config); + +namespace lr_schedulers { + +class ConstantLR : public LRScheduler { +public: + ConstantLR(std::shared_ptr optimizer, float factor = 1.0f / 3.0f, int total_iters = 5, + int64_t last_step = -1); + ~ConstantLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float factor_; + const int64_t total_iters_; +}; + +class StepLR : public LRScheduler { +public: + StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1); + ~StepLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const int64_t step_size_; + const float gamma_; +}; + +class LinearLR : public LRScheduler { +public: + LinearLR(std::shared_ptr optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f, + int64_t total_iters = 5, int64_t last_step = -1); + ~LinearLR() override = default; + +protected: + float GetChainedFormLR() const override; + float GetClosedFormLR() const override; + +private: + const float start_factor_; + const float end_factor_; + const int64_t total_iters_; +}; + +class LambdaLR : public LRScheduler { +public: + using LambdaFunc = std::function; + + LambdaLR(std::shared_ptr optimizer, LambdaFunc lr_lambda, int64_t last_step = -1); + ~LambdaLR() override = default; + +protected: + float GetClosedFormLR() const override; + +private: + const LambdaFunc lr_lambda_; +}; + +class SequentialLR : public LRScheduler { +public: + SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step = -1); + ~SequentialLR() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + // FIXME: SequentialLR should not have a closed-form LR, but we need to implement this pure virtual function. + float GetClosedFormLR() const override { return base_lr_; } + void UndoChildInitialSteps(); + +private: + std::vector> schedulers_; + std::vector milestones_; +}; + +class ChainedScheduler : public LRScheduler { +public: + ChainedScheduler(std::shared_ptr optimizer, std::vector> schedulers, + int64_t last_step = -1); + ~ChainedScheduler() override = default; + + void Step() override; + void InitialStep() override; + + StateDict State() const override; + void LoadState(const StateDict &state) override; + +protected: + // FIXME: ChainedScheduler should not have a closed-form LR, but we need to implement this pure virtual function. + float GetClosedFormLR() const override { return base_lr_; } + +private: + std::vector> schedulers_; +}; + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..18947ec7 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -34,6 +34,9 @@ class DistributedOptimizer final : public infini_train::Optimizer { void StartParamSync(bool force_sync = false); void FinishParamSync(bool skip_next_bucket_dispatch = false); + virtual void set_learning_rate(float lr) override; + virtual float learning_rate() const override; + private: void BuildShardParamsAndBindGrads(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..942fd92a 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -15,14 +15,25 @@ using OptimizerCreator = std::function(const std::vec class Optimizer { public: - explicit Optimizer(const std::vector> ¶ms); + explicit Optimizer(const std::vector> ¶ms, float learning_rate = 0.0f); virtual void ZeroGrad(bool set_to_none = true); virtual void Step() = 0; + virtual void set_learning_rate(float lr); + + virtual float learning_rate() const; + + float initial_learning_rate() const; + + void set_initial_learning_rate(float lr); + protected: std::vector> params_; + float learning_rate_ = 0.0f; + float initial_learning_rate_ = 0.0f; + bool initial_lr_set_ = false; }; namespace optimizers { @@ -37,9 +48,6 @@ class SGD : public Optimizer { return std::make_shared(params, learning_rate); }; } - -private: - const float learning_rate_ = 0.0; }; class Adam : public Optimizer { @@ -58,7 +66,6 @@ class Adam : public Optimizer { private: int64_t t_; - const float learning_rate_; const float beta1_; const float beta2_; const float eps_; diff --git a/infini_train/src/lr_scheduler.cc b/infini_train/src/lr_scheduler.cc new file mode 100644 index 00000000..6e4e88e0 --- /dev/null +++ b/infini_train/src/lr_scheduler.cc @@ -0,0 +1,362 @@ +#include "infini_train/include/lr_scheduler.h" + +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" + +namespace infini_train { + +std::shared_ptr CreateLRScheduler(std::shared_ptr optimizer, + const TrainingLRSchedulerConfig &config) { + if (config.lr_decay_style == "none") { + return nullptr; + } + + CHECK(optimizer) << "CreateLRScheduler: optimizer must not be null."; + const float max_lr = config.lr != 0.0f ? config.lr : optimizer->learning_rate(); + CHECK_GT(max_lr, 0.0f) << "CreateLRScheduler: max_lr must be > 0."; + CHECK_GE(config.lr_warmup_init, 0.0f) << "CreateLRScheduler: lr_warmup_init must be >= 0."; + CHECK_GE(config.min_lr, 0.0f) << "CreateLRScheduler: min_lr must be >= 0."; + CHECK_GE(max_lr, config.min_lr) << "CreateLRScheduler: max_lr must be >= min_lr."; + CHECK_LE(config.lr_warmup_init, max_lr) << "CreateLRScheduler: lr_warmup_init must be <= max_lr."; + CHECK_GE(config.lr_warmup_iters, 0) << "CreateLRScheduler: lr_warmup_iters must be >= 0."; + CHECK_GT(config.lr_decay_iters, 0) << "CreateLRScheduler: lr_decay_iters must be > 0."; + CHECK_LT(config.lr_warmup_iters, config.lr_decay_iters) + << "CreateLRScheduler: lr_warmup_iters must be < lr_decay_iters."; + CHECK(config.lr_decay_style == "constant" || config.lr_decay_style == "linear" || config.lr_decay_style == "cosine" + || config.lr_decay_style == "inverse-square-root") + << "CreateLRScheduler: unsupported lr_decay_style: " << config.lr_decay_style; + + std::shared_ptr main_scheduler; + const int64_t decay_iters_after_warmup = config.lr_decay_iters - config.lr_warmup_iters; + if (config.lr_decay_style == "constant") { + main_scheduler = LRScheduler::Create(optimizer, [](int64_t) { return 1.0f; }); + } else if (config.lr_decay_style == "linear") { + main_scheduler = LRScheduler::Create(optimizer, 1.0f, config.min_lr / max_lr, + decay_iters_after_warmup); + } else if (config.lr_decay_style == "cosine") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, decay_iters_after_warmup](int64_t step) { + if (step > decay_iters_after_warmup) { + return min_lr / max_lr; + } + const float decay_ratio = static_cast(step) / static_cast(decay_iters_after_warmup); + CHECK_GE(decay_ratio, 0.0f) << "CreateLRScheduler: decay " + "ratio must be >= 0."; + CHECK_LE(decay_ratio, 1.0f) << "CreateLRScheduler: decay " + "ratio must be <= 1."; + const float coeff = 0.5f * (std::cos(std::numbers::pi_v * decay_ratio) + 1.0f); + return (min_lr + coeff * (max_lr - min_lr)) / max_lr; + }); + } else if (config.lr_decay_style == "inverse-square-root") { + main_scheduler = LRScheduler::Create( + optimizer, [max_lr, min_lr = config.min_lr, lr_warmup_iters = config.lr_warmup_iters, + lr_decay_iters = config.lr_decay_iters](int64_t step) { + const int64_t global_step = step + lr_warmup_iters; + if (global_step > lr_decay_iters) { + return min_lr / max_lr; + } + const auto warmup = static_cast(std::max(lr_warmup_iters, 1)); + const auto current = static_cast(std::max(global_step, 1)); + return std::max(min_lr, max_lr * std::sqrt(warmup) / std::sqrt(current)) / max_lr; + }); + } + + CHECK(main_scheduler) << "CreateLRScheduler: failed to create scheduler."; + if (config.lr_warmup_iters == 0) { + return main_scheduler; + } + + auto warmup_scheduler = LRScheduler::Create( + optimizer, + [lr_warmup_init = config.lr_warmup_init, max_lr, lr_warmup_iters = config.lr_warmup_iters](int64_t step) { + const float warmup_ratio = static_cast(step) / static_cast(lr_warmup_iters); + return (lr_warmup_init + (max_lr - lr_warmup_init) * warmup_ratio) / max_lr; + }); + return LRScheduler::Create( + std::move(optimizer), std::vector>{warmup_scheduler, main_scheduler}, + std::vector{config.lr_warmup_iters}); +} + +LRScheduler::LRScheduler(std::shared_ptr optimizer, int64_t last_step) + : optimizer_(std::move(optimizer)), last_step_(last_step), base_lr_(0.0f) { + CHECK(optimizer_) << "LRScheduler: optimizer must not be null."; + optimizer_->set_initial_learning_rate(optimizer_->learning_rate()); + base_lr_ = optimizer_->initial_learning_rate(); +} + +void LRScheduler::Step() { + ++last_step_; + ApplyLR(GetChainedFormLR()); +} + +void LRScheduler::Step(int64_t epoch) { + last_step_ = epoch; + ApplyLR(GetClosedFormLR()); +} + +void LRScheduler::InitialStep() { + is_initial_ = true; + Step(); + is_initial_ = false; +} + +void LRScheduler::ApplyLR(float lr) { optimizer_->set_learning_rate(lr); } + +float LRScheduler::GetChainedFormLR() const { return GetClosedFormLR(); } + +float LRScheduler::GetLR() const { return optimizer_->learning_rate(); } + +float LRScheduler::BaseLR() const { return base_lr_; } + +int64_t LRScheduler::LastStep() const { return last_step_; } + +bool LRScheduler::SharesOptimizerWith(const std::shared_ptr &opt) const { return optimizer_ == opt; } + +void LRScheduler::ResetStep(int64_t step) { last_step_ = step; } + +StateDict LRScheduler::State() const { + return { + {"last_step", last_step_}, + {"recover_lr", optimizer_->learning_rate()}, + {"base_lr", base_lr_}, + }; +} + +void LRScheduler::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + optimizer_->set_learning_rate(recover_lr_); +} + +// Concrete LR Schedulers + +namespace lr_schedulers { + +// --- ConstantLR --- + +ConstantLR::ConstantLR(std::shared_ptr optimizer, float factor, int total_iters, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), factor_(factor), total_iters_(total_iters) { + CHECK_GE(factor_, 0.0f) << "ConstantLR: factor must be >= 0."; + CHECK_LE(factor_, 1.0f) << "ConstantLR: factor must be <= 1."; +} + +float ConstantLR::GetClosedFormLR() const { return last_step_ < total_iters_ ? base_lr_ * factor_ : base_lr_; } + +float ConstantLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * factor_; + } else if (last_step_ < total_iters_) { + return lr; + } else if (last_step_ == total_iters_) { + return lr / factor_; + } + return lr; +} + +// --- StepLR --- + +StepLR::StepLR(std::shared_ptr optimizer, int64_t step_size, float gamma, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { + CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; + CHECK_GT(gamma_, 0.0f) << "StepLR: gamma must be > 0."; +} + +float StepLR::GetClosedFormLR() const { + return base_lr_ + * static_cast(std::pow(static_cast(gamma_), static_cast(last_step_ / step_size_))); +} + +float StepLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0 || (last_step_ % step_size_) != 0) { + return lr; + } + return lr * gamma_; +} + +LinearLR::LinearLR(std::shared_ptr optimizer, float start_factor, float end_factor, int64_t total_iters, + int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), start_factor_(start_factor), end_factor_(end_factor), + total_iters_(total_iters) { + CHECK_GT(start_factor_, 0.0f) << "LinearLR: start_factor must be > 0."; + CHECK_LE(start_factor_, 1.0f) << "LinearLR: start_factor must be <= 1."; + CHECK_GE(end_factor_, 0.0f) << "LinearLR: end_factor must be >= 0."; + CHECK_LE(end_factor_, 1.0f) << "LinearLR: end_factor must be <= 1."; + CHECK_GT(total_iters_, 0) << "LinearLR: total_iters must be > 0."; +} + +float LinearLR::GetClosedFormLR() const { + if (last_step_ >= total_iters_) { + return base_lr_ * end_factor_; + } + return base_lr_ + * (start_factor_ + + (end_factor_ - start_factor_) * static_cast(last_step_) / static_cast(total_iters_)); +} + +float LinearLR::GetChainedFormLR() const { + const float lr = optimizer_->learning_rate(); + if (last_step_ == 0) { + return lr * start_factor_; + } + if (last_step_ > total_iters_ || is_initial_) { + return lr; + } + if (last_step_ == total_iters_) { + const float prev_factor + = start_factor_ + + (end_factor_ - start_factor_) * static_cast(total_iters_ - 1) / static_cast(total_iters_); + return lr * (end_factor_ / prev_factor); + } + + const float numerator = end_factor_ - start_factor_; + const float denominator + = start_factor_ * static_cast(total_iters_) + static_cast(last_step_ - 1) * numerator; + return lr * (1.0f + numerator / denominator); +} + +LambdaLR::LambdaLR(std::shared_ptr optimizer, std::function lr_lambda, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), lr_lambda_(std::move(lr_lambda)) { + CHECK(lr_lambda_) << "LambdaLR: lr_lambda must not be null."; +} + +float LambdaLR::GetClosedFormLR() const { return base_lr_ * lr_lambda_(last_step_); } + +SequentialLR::SequentialLR(std::shared_ptr optimizer, std::vector> schedulers, + std::vector milestones, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)), + milestones_(std::move(milestones)) { + CHECK(!schedulers_.empty()) << "SequentialLR requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "SequentialLR: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "SequentialLR: scheduler at index " << i << " must share the same optimizer."; + } + + CHECK_EQ(milestones_.size(), schedulers_.size() - 1) + << "SequentialLR: milestones count must be schedulers count - 1."; + + for (size_t i = 1; i < milestones_.size(); ++i) { + CHECK_GT(milestones_[i], milestones_[i - 1]) << "Milestones must be strictly increasing."; + } +} + +void SequentialLR::InitialStep() { + + optimizer_->set_learning_rate(schedulers_[0]->BaseLR()); + + UndoChildInitialSteps(); + + ++last_step_; + schedulers_[0]->InitialStep(); +} + +void SequentialLR::UndoChildInitialSteps() { + for (auto &sched : schedulers_) { + if (auto nested = std::dynamic_pointer_cast(sched)) { + nested->UndoChildInitialSteps(); + } + sched->ResetStep(sched->LastStep() - 1); + } +} + +void SequentialLR::Step() { + ++last_step_; + size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); + + auto &scheduler = schedulers_[idx]; + + if (idx > 0 && milestones_[idx - 1] == last_step_) { + scheduler->Step(0); + } else { + scheduler->Step(); + } +} + +StateDict SequentialLR::State() const { + StateDict state; + state["last_step"] = last_step_; + state["recover_lr"] = optimizer_->learning_rate(); + state["base_lr"] = base_lr_; + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void SequentialLR::LoadState(const StateDict &state) { + last_step_ = std::get(state.at("last_step")); + recover_lr_ = std::get(state.at("recover_lr")); + base_lr_ = std::get(state.at("base_lr")); + + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } + optimizer_->set_learning_rate(recover_lr_); +} + +ChainedScheduler::ChainedScheduler(std::shared_ptr optimizer, + std::vector> schedulers, int64_t last_step) + : LRScheduler(std::move(optimizer), last_step), schedulers_(std::move(schedulers)) { + CHECK(!schedulers_.empty()) << "ChainedScheduler requires at least one scheduler."; + + for (size_t i = 0; i < schedulers_.size(); ++i) { + CHECK(schedulers_[i]) << "ChainedScheduler: scheduler at index " << i << " must not be null."; + CHECK(schedulers_[i]->SharesOptimizerWith(optimizer_)) + << "ChainedScheduler: scheduler at index " << i << " must share the same optimizer."; + } +} + +void ChainedScheduler::InitialStep() { last_step_ = 0; } + +void ChainedScheduler::Step() { + ++last_step_; + for (auto &sched : schedulers_) { sched->Step(); } +} + +StateDict ChainedScheduler::State() const { + StateDict state = LRScheduler::State(); + for (size_t i = 0; i < schedulers_.size(); ++i) { + auto sub_state = schedulers_[i]->State(); + for (const auto &[key, value] : sub_state) { state["scheduler_" + std::to_string(i) + "." + key] = value; } + } + return state; +} + +void ChainedScheduler::LoadState(const StateDict &state) { + LRScheduler::LoadState(state); + for (size_t i = 0; i < schedulers_.size(); ++i) { + StateDict sub_state; + std::string prefix = "scheduler_" + std::to_string(i) + "."; + for (const auto &[key, value] : state) { + if (key.substr(0, prefix.size()) == prefix) { + sub_state[key.substr(prefix.size())] = value; + } + } + if (!sub_state.empty()) { + schedulers_[i]->LoadState(sub_state); + } + } +} + +} // namespace lr_schedulers +} // namespace infini_train diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..2531ca60 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -114,6 +114,20 @@ void DistributedOptimizer::ZeroGrad(bool set_to_none) { } } +void DistributedOptimizer::set_learning_rate(float lr) { + Optimizer::set_learning_rate(lr); + if (base_optimizer_) { + base_optimizer_->set_learning_rate(lr); + } +} + +float DistributedOptimizer::learning_rate() const { + if (base_optimizer_) { + return base_optimizer_->learning_rate(); + } + return Optimizer::learning_rate(); +} + void DistributedOptimizer::Step() { // 1. Ensure grads are synced FinishGradSync(); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 2c9b218a..30c8caeb 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -8,16 +8,32 @@ #include "infini_train/include/tensor.h" namespace infini_train { -Optimizer::Optimizer(const std::vector> ¶ms) : params_(params) {} +Optimizer::Optimizer(const std::vector> ¶ms, float learning_rate) + : params_(params), learning_rate_(learning_rate) {} void Optimizer::ZeroGrad(bool set_to_none) { for (auto param : params_) { param->ZeroGrad(set_to_none); } } +void Optimizer::set_learning_rate(float lr) { learning_rate_ = lr; } + +float Optimizer::learning_rate() const { return learning_rate_; } + +float Optimizer::initial_learning_rate() const { + CHECK(initial_lr_set_) << "Optimizer: initial_learning_rate not set. " + "Use with an LRScheduler first."; + return initial_learning_rate_; +} + +void Optimizer::set_initial_learning_rate(float lr) { + if (!initial_lr_set_) { + initial_learning_rate_ = lr; + initial_lr_set_ = true; + } +} namespace optimizers { -SGD::SGD(const std::vector> ¶ms, float learning_rate) - : Optimizer(params), learning_rate_(learning_rate) {} +SGD::SGD(const std::vector> ¶ms, float learning_rate) : Optimizer(params, learning_rate) {} void SGD::Step() { for (auto param : params_) { @@ -33,7 +49,7 @@ void SGD::Step() { } Adam::Adam(const std::vector> ¶ms, float learning_rate, float beta1, float beta2, float eps) - : Optimizer(params), t_(0), learning_rate_(learning_rate), beta1_(beta1), beta2_(beta2), eps_(eps) { + : Optimizer(params, learning_rate), t_(0), beta1_(beta1), beta2_(beta2), eps_(eps) { for (const auto ¶m : params_) { m_.emplace_back(std::make_shared(param->Dims(), param->Dtype(), param->GetDevice())); diff --git a/scripts/test_config.json b/scripts/test_config.json index 54332f70..81282289 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -304,6 +304,182 @@ } ] }, + { + "tag": "lr_scheduler", + "tests": [ + { + "id": "3_none_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + }, + { + "id": "4_constant_tp4", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 0 + } + }, + { + "id": "5_linear_tp4_sp_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "6_cosine_pp8", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 8, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "7_inverse_sqrt_pp4_vpp2", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_cosine_all_parallel_distopt", + "args": { + "dtype": "float32", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "cosine", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "3_bfloat16_linear", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "linear", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 0 + } + }, + { + "id": "4_bfloat16_inverse_sqrt_tp4_distopt", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "use_distributed_optimizer": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "inverse-square-root", + "lr_warmup_iters": 2, + "lr_warmup_init": 0.0, + "lr_decay_iters": 10 + } + }, + { + "id": "5_bfloat16_constant_tp4_sp", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 4, + "sequence_parallel": true, + "learning_rate": 0.00001, + "min_lr": 0.000001, + "lr_decay_style": "constant", + "lr_warmup_iters": 0, + "lr_decay_iters": 10 + } + }, + { + "id": "8_bfloat16_none_all_parallel", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 8, + "num_iteration": 10, + "batch_size": 40, + "total_batch_size": 5120, + "tensor_parallel": 2, + "sequence_parallel": true, + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2, + "learning_rate": 0.00001, + "lr_decay_style": "none" + } + } + ] + }, { "tag": "lora", "tests": [ diff --git a/test/lr_scheduler/test_chained_lr.cc b/test/lr_scheduler/test_chained_lr.cc new file mode 100644 index 00000000..b22d0eea --- /dev/null +++ b/test/lr_scheduler/test_chained_lr.cc @@ -0,0 +1,166 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestSingleScheduler() { + std::cout << "[TC1] TestSingleScheduler" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + auto sched + = LRScheduler::Create(opt, /*schedulers=*/std::vector>{step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + sched->Step(); // step=1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); +} + +// TC2: StepLR + LambdaLR +void TestMultiplicativeChain() { + std::cout << "[TC2] TestMultiplicativeChain" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.1f * step; }); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{step_lr, lambda_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); +} + +// TC3: ConstantLR + StepLR +void TestConstantPlusStep() { + std::cout << "[TC3] TestConstantPlusStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{constant, step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.01f, kEps); +} + +// TC4: ConstantLR + StepLR (with extra unused scheduler) +void TestConstantPlusStepDLC() { + std::cout << "[TC4] TestConstantPlusStepDLC" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + auto Lambda = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.1f * step; }); + + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{constant, step_lr}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); + + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.02f, kEps); +} + +// TC5: State/LoadState +void TestStateRoundTrip() { + std::cout << "[TC5] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto step_lr = std::make_shared(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = std::make_shared(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{step_lr, lambda_lr}); + + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto step_lr2 = std::make_shared(opt2, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr2 = std::make_shared(opt2, /*lr_lambda=*/[](int64_t step) { return 1.0f - 0.05f * step; }); + auto sched2 = LRScheduler::Create( + opt2, /*schedulers=*/std::vector>{step_lr2, lambda_lr2}); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +// TC6: resume consistency (load state at step K, then step N-K, should match directly stepping to N) +void TestResumeConsistency() { + std::cout << "[TC6] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + auto lambda_fn = [](int64_t step) { return 1.0f - 0.05f * step; }; + + auto make_sched = [&](std::shared_ptr opt) { + auto step_lr = std::make_shared(opt, /*step_size=*/2, /*gamma=*/0.5f); + auto lambda_lr = std::make_shared(opt, /*lr_lambda=*/lambda_fn); + return LRScheduler::Create( + opt, /*schedulers=*/std::vector>{step_lr, lambda_lr}); + }; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ChainedScheduler Tests ===" << std::endl; + TestSingleScheduler(); + TestMultiplicativeChain(); + TestConstantPlusStep(); + TestConstantPlusStepDLC(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_constant_lr.cc b/test/lr_scheduler/test_constant_lr.cc new file mode 100644 index 00000000..5ac9dea6 --- /dev/null +++ b/test/lr_scheduler/test_constant_lr.cc @@ -0,0 +1,123 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestInitialState() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f); +} + +void TestFirstStepAppliesFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + sched->Step(); // last_step_ = 0 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f); + ASSERT_TRUE(sched->LastStep() == 1); +} + +void TestWithinTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + for (int i = 0; i < 2; ++i) { sched->Step(); } + // last_step_ = 2, still < 3 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); +} + +void TestBeyondTotalIters() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + for (int i = 0; i < 10; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/3); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), expected[i]); + } +} + +void TestStateRoundTrip() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/5); + for (int i = 0; i < 3; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto sched2 = LRScheduler::Create(opt2, /*factor=*/0.5f, /*total_iters=*/5); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->learning_rate(), sched->GetLR()); +} + +void TestResumeConsistency() { + constexpr int kN = 8; + constexpr int kK = 3; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*factor=*/0.5f, /*total_iters=*/5); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*factor=*/0.5f, /*total_iters=*/5); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*factor=*/0.5f, /*total_iters=*/5); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_EQ(sched_b->GetLR(), sched_ref->GetLR()); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*factor=*/0.5f, /*total_iters=*/5); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*factor=*/0.5f, /*total_iters=*/5); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== ConstantLR Tests ===" << std::endl; + TestInitialState(); + TestFirstStepAppliesFactor(); + TestWithinTotalIters(); + TestBeyondTotalIters(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + TestChainableAndClosedFormConsistency(); + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_helpers.h b/test/lr_scheduler/test_helpers.h new file mode 100644 index 00000000..f3fb1d23 --- /dev/null +++ b/test/lr_scheduler/test_helpers.h @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace { + +constexpr float kEps = 1e-6f; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatNear(float a, float b, float eps = kEps) { return std::fabs(a - b) < eps; } + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) Check(FloatNear((a), (b)), #a " == " #b, __LINE__) +#define ASSERT_FLOAT_NEAR(a, b, eps) Check(FloatNear((a), (b), (eps)), #a " ≈ " #b, __LINE__) + +} // namespace diff --git a/test/lr_scheduler/test_lambda_lr.cc b/test/lr_scheduler/test_lambda_lr.cc new file mode 100644 index 00000000..fddb7a1b --- /dev/null +++ b/test/lr_scheduler/test_lambda_lr.cc @@ -0,0 +1,103 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestIdentityLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t) { return 1.0f; }); + // Step() → last_step_=0, lr = 0.1 * 1.0 = 0.1 + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); + ASSERT_FLOAT_NEAR(opt->learning_rate(), kBaseLR, kEps); +} + +void TestLinearDecayLambda() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/[](int64_t step) { return 1.0f - step * 0.1f; }); + // step=0, lambda(0)=1.0, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); // step=1, lambda(1)=0.9, lr=0.09 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.09f, kEps); + + sched->Step(); // step=2, lambda(2)=0.8, lr=0.08 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.08f, kEps); + + sched->Step(); // step=3, lambda(3)=0.7, lr=0.07 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.07f, kEps); +} + +void TestPyTorchAlignment() { + // PyTorch: LambdaLR(opt, lr_lambda=lambda epoch: 0.95**epoch) + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create( + opt, /*lr_lambda=*/[](int64_t step) { return static_cast(std::pow(0.95, step)); }); + // step=0, lr = 0.1 * 0.95^0 = 0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + std::vector expected = {0.095f, 0.09025f, 0.0857375f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/lambda_fn); + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto sched2 = LRScheduler::Create(opt2, /*lr_lambda=*/lambda_fn); // same lambda + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +void TestResumeConsistency() { + auto lambda_fn = [](int64_t step) { return 1.0f - step * 0.05f; }; + constexpr int kN = 10, kK = 4; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*lr_lambda=*/lambda_fn); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*lr_lambda=*/lambda_fn); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*lr_lambda=*/lambda_fn); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== LambdaLR Tests ===" << std::endl; + TestIdentityLambda(); + TestLinearDecayLambda(); + TestPyTorchAlignment(); + TestStateRoundTrip(); + TestResumeConsistency(); + std::cout << "======================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_linear_lr.cc b/test/lr_scheduler/test_linear_lr.cc new file mode 100644 index 00000000..cf3a5e74 --- /dev/null +++ b/test/lr_scheduler/test_linear_lr.cc @@ -0,0 +1,94 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestFirstStepFromZero() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.02f); +} + +void TestMidpointLR() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + for (int i = 0; i < 3; ++i) { sched->Step(); } + // last_step_=3 -> 0.1*(0.2 + 0.8*3/5) = 0.068 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.068f); +} + +void TestWarmupEnd() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + for (int i = 0; i < 5; ++i) { sched->Step(); } + // last_step_ >= total_iters -> base_lr * end_factor + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); +} + +void TestBeyondWarmup() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + for (int i = 0; i < 20; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); +} + +void TestCustomStartFactor() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.25f, /*end_factor=*/1.0f, + /*total_iters=*/4); + sched->Step(); // last_step_=1, lr=0.1*(0.25+0.75*1/4)=0.04375 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.04375f, 1e-6f); + sched->Step(); // last_step_=2, lr=0.1*(0.25+0.75*2/4)=0.0625 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0625f, 1e-6f); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.036f, 0.052f, 0.068f, 0.084f, 0.1f, 0.1f, 0.1f}; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.2f, /*end_factor=*/1.0f, /*total_iters=*/5); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); + } +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*start_factor=*/0.2f, /*end_factor=*/1.0f, + /*total_iters=*/5); + + for (int epoch = 1; epoch <= 10; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-6f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Linear Tests ===" << std::endl; + TestFirstStepFromZero(); + TestMidpointLR(); + TestWarmupEnd(); + TestBeyondWarmup(); + TestCustomStartFactor(); + TestPyTorchAlignment(); + TestChainableAndClosedFormConsistency(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_lr_scheduler.cc b/test/lr_scheduler/test_lr_scheduler.cc new file mode 100644 index 00000000..3c3e8e03 --- /dev/null +++ b/test/lr_scheduler/test_lr_scheduler.cc @@ -0,0 +1,178 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/lr_scheduler.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; + +namespace { + +constexpr float kBaseLR = 0.1f; +constexpr float kEps = 1e-7f; + +class IdentityScheduler : public LRScheduler { +public: + IdentityScheduler(std::shared_ptr optimizer, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step) {} + ~IdentityScheduler() override = default; + +protected: + float GetClosedFormLR() const override { return base_lr_; } +}; + +class LinearDecayScheduler : public LRScheduler { +public: + LinearDecayScheduler(std::shared_ptr optimizer, int64_t total_steps, int64_t last_step = -1) + : LRScheduler(std::move(optimizer), last_step), total_steps_(total_steps) {} + +protected: + float GetClosedFormLR() const override { + if (last_step_ >= total_steps_) { + return 0.0f; + } + return base_lr_ * (1.0f - static_cast(last_step_) / static_cast(total_steps_)); + } + +private: + int64_t total_steps_; +}; + +std::shared_ptr MakeDummyOptimizer(float lr) { + std::vector> empty_params; + return std::make_shared(empty_params, lr); +} + +bool FloatEq(float a, float b) { return std::fabs(a - b) < kEps; } + +int g_fail_count = 0; + +void Check(bool cond, const char *expr, int line) { + if (!cond) { + std::cerr << "FAIL [line " << line << "]: " << expr << std::endl; + ++g_fail_count; + } +} + +#define ASSERT_TRUE(cond) Check((cond), #cond, __LINE__) +#define ASSERT_FLOAT_EQ(a, b) Check(FloatEq((a), (b)), #a " == " #b, __LINE__) + +// T1: Init +void TestInitialState() { + std::cout << "[T1] TestInitialState" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt); + + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_TRUE(sched->LastStep() == 0); + ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR); +} + +// T2: SingleStep +void TestSingleStep() { + std::cout << "[T2] TestSingleStep" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt); + + sched->Step(); + + ASSERT_TRUE(sched->LastStep() == 1); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR); +} + +// T3: ComputeLR +void TestLinearDecay() { + std::cout << "[T3] TestLinearDecay" << std::endl; + constexpr int64_t kTotalSteps = 10; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*total_steps=*/kTotalSteps); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + ASSERT_FLOAT_EQ(opt->learning_rate(), kBaseLR); + + sched->Step(); // last_step = 1 -> 0.09 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.09f); + ASSERT_FLOAT_EQ(opt->learning_rate(), 0.09f); + + for (int i = 0; i < 4; ++i) { sched->Step(); } // last_step = 5 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.05f); + ASSERT_FLOAT_EQ(opt->learning_rate(), 0.05f); +} + +// T4: State → LoadState +void TestStateRoundTrip() { + std::cout << "[T4] TestStateRoundTrip" << std::endl; + constexpr int64_t kTotalSteps = 20; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*total_steps=*/kTotalSteps); + + for (int i = 0; i < 7; ++i) { sched->Step(); } + + StateDict saved = sched->State(); + + ASSERT_TRUE(saved.count("last_step") == 1); + ASSERT_TRUE(saved.count("recover_lr") == 1); + ASSERT_TRUE(saved.count("base_lr") == 1); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto sched2 = LRScheduler::Create(opt2, /*total_steps=*/kTotalSteps); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == 7); + ASSERT_FLOAT_EQ(sched2->GetLR(), sched->GetLR()); + ASSERT_FLOAT_EQ(opt2->learning_rate(), sched->GetLR()); +} + +// T5: resume Step +void TestResumeAndContinue() { + std::cout << "[T5] TestResumeAndContinue" << std::endl; + constexpr int64_t kTotalSteps = 20; + + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = LRScheduler::Create(opt_ref, /*total_steps=*/kTotalSteps); + for (int i = 0; i < 10; ++i) { sched_ref->Step(); } + float lr_at_10 = sched_ref->GetLR(); + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = LRScheduler::Create(opt_a, /*total_steps=*/kTotalSteps); + for (int i = 0; i < 5; ++i) { sched_a->Step(); } + StateDict checkpoint = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = LRScheduler::Create(opt_b, /*total_steps=*/kTotalSteps); + sched_b->LoadState(checkpoint); + for (int i = 0; i < 5; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_EQ(sched_b->GetLR(), lr_at_10); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +} // namespace + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + std::cout << "========================================" << std::endl; + std::cout << " LRScheduler Base Class Tests" << std::endl; + std::cout << "========================================" << std::endl; + + TestInitialState(); + TestSingleStep(); + TestLinearDecay(); + TestStateRoundTrip(); + TestResumeAndContinue(); + + std::cout << "========================================" << std::endl; + if (g_fail_count == 0) { + std::cout << " All Tests PASSED" << std::endl; + } else { + std::cout << " " << g_fail_count << " test(s) FAILED" << std::endl; + } + std::cout << "========================================" << std::endl; + + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_lr_scheduler_validation.cc b/test/lr_scheduler/test_lr_scheduler_validation.cc new file mode 100644 index 00000000..5f7b99a8 --- /dev/null +++ b/test/lr_scheduler/test_lr_scheduler_validation.cc @@ -0,0 +1,140 @@ +#include +#include +#include +#include +#include +#include + +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { + +bool ExpectDeath(const std::function &fn) { + pid_t pid = fork(); + if (pid == -1) { + return false; + } + + if (pid == 0) { + fn(); + _exit(0); + } + + int status = 0; + if (waitpid(pid, &status, 0) == -1) { + return false; + } + + return !WIFEXITED(status) || WEXITSTATUS(status) != 0; +} + +void TestStepLRRejectsNonPositiveStepSize() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*step_size=*/0, /*gamma=*/0.1f); + (void)sched; + })); +} + +void TestLinearLRRejectsNonPositiveTotalIters() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*start_factor=*/0.5f, /*end_factor=*/1.0f, + /*total_iters=*/0); + (void)sched; + })); +} + +void TestLambdaLRRejectsNullLambda() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create(opt, /*lr_lambda=*/LambdaLR::LambdaFunc{}); + (void)sched; + })); +} + +void TestSequentialLRRejectsMismatchedOptimizer() { + ASSERT_TRUE(ExpectDeath([] { + auto opt1 = MakeDummyOptimizer(0.1f); + auto opt2 = MakeDummyOptimizer(0.1f); + + auto s1 = LRScheduler::Create(opt1, /*start_factor=*/0.5f, /*end_factor=*/1.0f, + /*total_iters=*/2); + auto s2 = LRScheduler::Create(opt2, /*step_size=*/2, /*gamma=*/0.5f); + + auto sched + = LRScheduler::Create(opt1, /*schedulers=*/std::vector>{s1, s2}, + /*milestones=*/std::vector{1}); + (void)sched; + })); +} + +void TestSequentialLRRejectsNullChild() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched + = LRScheduler::Create(opt, /*schedulers=*/std::vector>{nullptr}, + /*milestones=*/std::vector{}); + (void)sched; + })); +} + +void TestChainedSchedulerRejectsEmptyChildren() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched + = LRScheduler::Create(opt, /*schedulers=*/std::vector>{}); + (void)sched; + })); +} + +void TestChainedSchedulerRejectsMismatchedOptimizer() { + ASSERT_TRUE(ExpectDeath([] { + auto opt1 = MakeDummyOptimizer(0.1f); + auto opt2 = MakeDummyOptimizer(0.1f); + + auto s1 = LRScheduler::Create(opt1, /*step_size=*/2, /*gamma=*/0.5f); + auto s2 = LRScheduler::Create(opt2, /*factor=*/0.5f, /*total_iters=*/2); + + auto sched = LRScheduler::Create( + opt1, /*schedulers=*/std::vector>{s1, s2}); + (void)sched; + })); +} + +void TestChainedSchedulerRejectsNullChild() { + ASSERT_TRUE(ExpectDeath([] { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{nullptr}); + (void)sched; + })); +} + +} // namespace + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + std::cout << "=== LR Scheduler Validation Tests ===" << std::endl; + TestStepLRRejectsNonPositiveStepSize(); + TestLinearLRRejectsNonPositiveTotalIters(); + TestLambdaLRRejectsNullLambda(); + TestSequentialLRRejectsMismatchedOptimizer(); + TestSequentialLRRejectsNullChild(); + TestChainedSchedulerRejectsEmptyChildren(); + TestChainedSchedulerRejectsMismatchedOptimizer(); + TestChainedSchedulerRejectsNullChild(); + + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_sequential_lr.cc b/test/lr_scheduler/test_sequential_lr.cc new file mode 100644 index 00000000..08d58dbb --- /dev/null +++ b/test/lr_scheduler/test_sequential_lr.cc @@ -0,0 +1,148 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; +namespace { +constexpr float kBaseLR = 0.1f; +} // namespace + +void TestLinearThenConstant() { + std::cout << "[TC1] TestLinearThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto constant = LRScheduler::Create(opt, /*factor=*/1.0f, /*total_iters=*/100); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{linear, constant}, + /*milestones=*/std::vector{3}); + + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.0f, kEps); + + sched->Step(); // global=1, warmup step=1, lr=0.1*(1/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f / 3.0f, 1e-5f); + + sched->Step(); // global=2, warmup step=2, lr=0.1*(2/3) + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.2f / 3.0f, 1e-5f); + + sched->Step(); // global=3, constant step=0, lr=0.1*1.0=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); + + sched->Step(); // global=4, constant step=1, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), kBaseLR, kEps); +} + +void TestLinearThenStepLR() { + std::cout << "[TC2] TestLinearThenStepLR" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{3}); + + sched->Step(); // global=1 + sched->Step(); // global=2 + + sched->Step(); // global=3, StepLR step=0, lr=0.1 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.1f, kEps); + + sched->Step(); // global=4, StepLR step=1 + sched->Step(); // global=5, StepLR step=2 + sched->Step(); // global=6, StepLR step=3, 3//3=1, lr=0.1*0.5=0.05 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.05f, kEps); +} + +void TestLinearThenStepThenConstant() { + std::cout << "[TC3] TestLinearThenStepThenConstant" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + auto constant = LRScheduler::Create(opt, /*factor=*/0.5f, /*total_iters=*/2); + + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{linear, step_lr, constant}, + /*milestones=*/std::vector{3, 6}); + const std::vector expected = {0.033333f, 0.066667f, 0.1f, 0.1f, 0.1f, 0.05f, 0.05f, 0.1f, 0.1f, 0.1f}; + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-5f); + } +} + +void TestStateRoundTrip() { + std::cout << "[TC4] TestStateRoundTrip" << std::endl; + auto opt = MakeDummyOptimizer(kBaseLR); + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + auto sched = LRScheduler::Create( + opt, /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{3}); + for (int i = 0; i < 5; ++i) { sched->Step(); } + StateDict saved = sched->State(); + + auto opt2 = MakeDummyOptimizer(kBaseLR); + auto linear2 = LRScheduler::Create(opt2, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr2 = LRScheduler::Create(opt2, /*step_size=*/3, /*gamma=*/0.5f); + auto sched2 = LRScheduler::Create( + opt2, /*schedulers=*/std::vector>{linear2, step_lr2}, + /*milestones=*/std::vector{3}); + sched2->LoadState(saved); + + ASSERT_TRUE(sched2->LastStep() == sched->LastStep()); + ASSERT_FLOAT_NEAR(sched2->GetLR(), sched->GetLR(), kEps); +} + +void TestResumeConsistency() { + std::cout << "[TC5] TestResumeConsistency" << std::endl; + constexpr int kN = 10, kK = 4; + + auto make_sched = [](std::shared_ptr opt) { + auto linear = LRScheduler::Create(opt, /*start_factor=*/1e-8f, /*end_factor=*/1.0f, + /*total_iters=*/3); + auto step_lr = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.5f); + return LRScheduler::Create( + opt, /*schedulers=*/std::vector>{linear, step_lr}, + /*milestones=*/std::vector{3}); + }; + auto opt_ref = MakeDummyOptimizer(kBaseLR); + auto sched_ref = make_sched(opt_ref); + for (int i = 0; i < kN; ++i) { sched_ref->Step(); } + + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto sched_a = make_sched(opt_a); + for (int i = 0; i < kK; ++i) { sched_a->Step(); } + StateDict ckpt = sched_a->State(); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto sched_b = make_sched(opt_b); + sched_b->LoadState(ckpt); + for (int i = 0; i < kN - kK; ++i) { sched_b->Step(); } + + ASSERT_FLOAT_NEAR(sched_b->GetLR(), sched_ref->GetLR(), kEps); + ASSERT_TRUE(sched_b->LastStep() == sched_ref->LastStep()); +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== SequentialLR Tests ===" << std::endl; + TestLinearThenConstant(); + TestLinearThenStepLR(); + TestLinearThenStepThenConstant(); + TestStateRoundTrip(); + TestResumeConsistency(); + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_step_lr.cc b/test/lr_scheduler/test_step_lr.cc new file mode 100644 index 00000000..da0a4c22 --- /dev/null +++ b/test/lr_scheduler/test_step_lr.cc @@ -0,0 +1,86 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; +using namespace infini_train::lr_schedulers; + +namespace { +constexpr float kBaseLR = 0.1f; +} + +void TestWithinFirstPeriod() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (int i = 0; i < 2; ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); // last_step 1,2 → 0.1^0=1 → lr=0.1 + } +} + +void TestFirstDecay() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (int i = 0; i < 3; ++i) { sched->Step(); } + // last_step=3, 3//3=1 → 0.1^1 = 0.1 → lr=0.01 + ASSERT_FLOAT_EQ(sched->GetLR(), 0.01f); +} + +void TestMultipleDecays() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (int i = 0; i < 6; ++i) { sched->Step(); } + // last_step=6, 6//3=2 → 0.1^2 = 0.01 → lr=0.001 + ASSERT_FLOAT_NEAR(sched->GetLR(), 0.001f, 1e-7f); +} + +void TestPyTorchAlignment() { + const std::vector expected = {0.1f, 0.1f, 0.01f, 0.01f, 0.01f, 0.001f, 0.001f}; + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/0.1f); + for (size_t i = 0; i < expected.size(); ++i) { + sched->Step(); + ASSERT_FLOAT_NEAR(sched->GetLR(), expected[i], 1e-7f); + } +} + +void TestGammaOne() { + auto opt = MakeDummyOptimizer(kBaseLR); + auto sched = LRScheduler::Create(opt, /*step_size=*/3, /*gamma=*/1.0f); + for (int i = 0; i < 20; ++i) { + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), kBaseLR); + } +} + +void TestChainableAndClosedFormConsistency() { + auto opt_a = MakeDummyOptimizer(kBaseLR); + auto chainable = LRScheduler::Create(opt_a, /*step_size=*/3, /*gamma=*/0.1f); + + auto opt_b = MakeDummyOptimizer(kBaseLR); + auto closed_form = LRScheduler::Create(opt_b, /*step_size=*/3, /*gamma=*/0.1f); + + for (int epoch = 1; epoch <= 12; ++epoch) { + chainable->Step(); + closed_form->Step(epoch); + ASSERT_FLOAT_NEAR(chainable->GetLR(), closed_form->GetLR(), 1e-7f); + } +} + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + std::cout << "=== Step Tests ===" << std::endl; + TestWithinFirstPeriod(); + TestFirstDecay(); + TestMultipleDecays(); + TestPyTorchAlignment(); + TestGammaOne(); + TestChainableAndClosedFormConsistency(); + + std::cout << "========================" << std::endl; + if (g_fail_count == 0) { + std::cout << "All Tests PASSED" << std::endl; + } else { + std::cout << g_fail_count << " test(s) FAILED" << std::endl; + } + return g_fail_count > 0 ? 1 : 0; +} diff --git a/test/lr_scheduler/test_training_lr_scheduler.cc b/test/lr_scheduler/test_training_lr_scheduler.cc new file mode 100644 index 00000000..f3ce4dfb --- /dev/null +++ b/test/lr_scheduler/test_training_lr_scheduler.cc @@ -0,0 +1,97 @@ +#include "infini_train/include/lr_scheduler.h" +#include "test/lr_scheduler/test_helpers.h" + +using namespace infini_train; + +namespace { + +void TestConstantLR() { + auto opt = MakeDummyOptimizer(0.1f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "constant", + .lr = 0.1f, + .min_lr = 0.0f, + .lr_decay_iters = 10, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + + ASSERT_FLOAT_EQ(sched->GetLR(), 0.1f); + for (int i = 0; i < 5; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), 0.1f); +} + +void TestLinearWarmupAndDecay() { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "linear", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 6, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + + ASSERT_FLOAT_EQ(sched->GetLR(), 0.0f); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.5f); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 1.0f); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.775f); + for (int i = 0; i < 3; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), 0.1f); +} + +void TestCosineDecay() { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "cosine", + .lr = 1.0f, + .min_lr = 0.0f, + .lr_decay_iters = 4, + .lr_warmup_iters = 0, + .lr_warmup_init = 0.0f, + }); + + ASSERT_FLOAT_EQ(sched->GetLR(), 1.0f); + sched->Step(); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.5f); + sched->Step(); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 0.0f); +} + +void TestInverseSquareRootDecay() { + auto opt = MakeDummyOptimizer(1.0f); + auto sched = CreateLRScheduler(opt, { + .lr_decay_style = "inverse-square-root", + .lr = 1.0f, + .min_lr = 0.1f, + .lr_decay_iters = 10, + .lr_warmup_iters = 2, + .lr_warmup_init = 0.0f, + }); + + sched->Step(); + sched->Step(); + ASSERT_FLOAT_EQ(sched->GetLR(), 1.0f); + for (int i = 0; i < 6; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), 0.5f); + for (int i = 0; i < 92; ++i) { sched->Step(); } + ASSERT_FLOAT_EQ(sched->GetLR(), 0.1f); +} + +} // namespace + +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + TestConstantLR(); + TestLinearWarmupAndDecay(); + TestCosineDecay(); + TestInverseSquareRootDecay(); + + return g_fail_count > 0 ? 1 : 0; +}