Skip to content

fix: restore actor weights after loading OPD teacher checkpoint#1903

Open
canlin03 wants to merge 1 commit into
THUDM:mainfrom
canlin03:fix/opd-teacher-model-state
Open

fix: restore actor weights after loading OPD teacher checkpoint#1903
canlin03 wants to merge 1 commit into
THUDM:mainfrom
canlin03:fix/opd-teacher-model-state

Conversation

@canlin03
Copy link
Copy Markdown

@canlin03 canlin03 commented May 12, 2026

Problem

When using Megatron-based OPD (--opd-type megatron) without --offload-train, load_other_checkpoint("teacher") switches self.model to the teacher weights. The generic model-state recovery path only runs when offload_train is enabled, so it is skipped here — leaving self.model in teacher state at the end of __init__.

As a result, step-0 rollouts and evaluations are silently run with the teacher model instead of the student (actor).

Reproduction

Full training script
#!/bin/bash

source "/ossfs/workspace/slime/slime/scripts/models/qwen3-4B.sh"

CKPT_ARGS=(
   --hf-checkpoint /ossfs/workspace/Qwen3-4B
   --ref-load /ossfs/workspace/Qwen3-4B_torch_dist
   --load /root/Qwen3-4B_slime_0510_new_data/
   --save /root/Qwen3-4B_slime_0510_new_data/
   --save-interval 20
)

ROLLOUT_ARGS=(
   --prompt-data DeepMath-103K/train_filtered_level6.jsonl
   --input-key prompt
   --apply-chat-template
   --apply-chat-template-kwargs "{\"enable_thinking\":false}"
   --rollout-shuffle
   --num-rollout 40
   --rm-type math
   --label-key label
   --rollout-batch-size 512
   --n-samples-per-prompt 1
   --rollout-max-response-len 16384
   --rollout-temperature 1
   --global-batch-size 512
   --balance-data
)

EVAL_ARGS=(
   --eval-interval 5
   --eval-prompt-data aime-2024/aime-2024.jsonl
   --n-samples-per-eval-prompt 2
   --eval-max-response-len 16384
   --eval-top-p 1
)

PERF_ARGS=(
   --tensor-model-parallel-size 4
   --sequence-parallel
   --pipeline-model-parallel-size 1
   --context-parallel-size 1
   --expert-model-parallel-size 1
   --expert-tensor-parallel-size 1
   --recompute-granularity full
   --recompute-method uniform
   --recompute-num-layers 1
   --use-dynamic-batch-size
   --max-tokens-per-gpu 16384
)

GRPO_ARGS=(
   --advantage-estimator grpo
   --use-opd
   --opd-type megatron
   --opd-kl-coef 1.0
   --opd-teacher-load Qwen3-4B-Non-Thinking-RL-Math-Step500_torch_dist
   --entropy-coef 0.00
)

OPTIMIZER_ARGS=(
   --optimizer adam
   --lr 3e-6
   --lr-decay-style constant
   --weight-decay 0.1
   --adam-beta1 0.9
   --adam-beta2 0.98
)

SGLANG_ARGS=(
   --rollout-num-gpus-per-engine 1
   --sglang-mem-fraction-static 0.7
)

MISC_ARGS=(
   --attention-dropout 0.0
   --hidden-dropout 0.0
   --accumulate-allreduce-grads-in-fp32
   --attention-softmax-in-fp32
   --attention-backend flash
)

ray start --head --node-ip-address ${MASTER_ADDR:-"127.0.0.1"} --num-gpus 8 --disable-usage-stats

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json='{"env_vars": {"PYTHONPATH": "/root/Megatron-LM/", "CUDA_DEVICE_MAX_CONNECTIONS": "1"}}' \
   -- python3 train.py \
   --actor-num-nodes 1 \
   --actor-num-gpus-per-node 4 \
   --rollout-num-gpus 4 \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]} \
   ${ROLLOUT_ARGS[@]} \
   ${OPTIMIZER_ARGS[@]} \
   ${GRPO_ARGS[@]} \
   ${PERF_ARGS[@]} \
   ${EVAL_ARGS[@]} \
   ${SGLANG_ARGS[@]} \
   ${MISC_ARGS[@]}

Key flags — the bug triggers when --offload-train is NOT set:

--use-opd
--opd-type megatron
--opd-teacher-load <stronger_teacher_checkpoint>
# --offload-train is NOT set

Observed (before fix): Step-0 eval accuracy on AIME-2024 is abnormally high — matching the teacher model's performance rather than the untrained student. In our run the teacher was a math RL-trained Qwen3-4B checkpoint (Step500); the student was the base Qwen3-4B. sglang was silently serving teacher weights from the very first step.

Expected (after fix): Step-0 eval accuracy drops to ~0.25, consistent with the base student model before any RL training.

If the maintainers need the teacher checkpoint or the dataset to reproduce this, feel free to reach out and I am happy to provide them.

Fix

Explicitly call _switch_model("actor") after loading the teacher checkpoint when offload_train is not set, ensuring self.model is always restored to actor state before update_weights() pushes weights to sglang.

if with_opd_teacher:
    self.load_other_checkpoint("teacher", args.opd_teacher_load)
    if not self.args.offload_train:
        self._switch_model("actor")  # restore student weights before update_weights()

When offload_train is disabled, the generic model-state recovery path
is skipped after load_other_checkpoint("teacher"). This left self.model
in teacher state, causing step-0 rollouts and evals to run with the
teacher model instead of the student (actor).

Fix: explicitly call _switch_model("actor") after the teacher checkpoint
is loaded when offload_train is not set.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant