diff --git a/tests/test_bulk_mode_ids.py b/tests/test_bulk_mode_ids.py index 78d436e..79ea66f 100644 --- a/tests/test_bulk_mode_ids.py +++ b/tests/test_bulk_mode_ids.py @@ -51,3 +51,25 @@ def test_invalid_bulk_id_errors(): _, state = watcher.analyze_traps(randomized_model=randomized_model, trap_state=state, return_artifacts=True, return_bulk_ids=True, plot=False) with pytest.raises(ValueError): watcher.remove_modes(mode_ids_by_layer={999:[1]}, mode_type='bulk', randomized_model=randomized_model, trap_state=state, plot=False) + + +def test_bulk_only_has_mp_edges_and_nonempty(): + model = OneLayer() + watcher = ww.WeightWatcher(model=model) + randomized_model, state = watcher.randomize_model(model=model, rng=123, return_state=True) + bulk_df, out_state = watcher.analyze_traps( + randomized_model=randomized_model, + trap_state=state, + return_artifacts=True, + return_bulk_ids=True, + bulk_only=True, + max_bulk_modes_per_layer=5, + bulk_sampling_seed=123, + plot=False, + ) + assert len(bulk_df) > 0 + assert set(bulk_df["mode_type"]) == {"bulk"} + for _lid, layer_state in out_state["layers"].items(): + assert np.isfinite(layer_state["mp_bulk_min"]) + assert np.isfinite(layer_state["mp_bulk_max"]) + assert len(layer_state["bulk_svd_indices"]) > 0 diff --git a/weightwatcher/trap_analysis.py b/weightwatcher/trap_analysis.py index 08bc1cb..be31743 100644 --- a/weightwatcher/trap_analysis.py +++ b/weightwatcher/trap_analysis.py @@ -57,16 +57,73 @@ def _sample_bulk_modes(svd_indices, eigenvalues, max_modes=None, seed=None, stra return sorted(int(x) for x in out) raise ValueError("bulk_sampling_strategy must be one of: all, uniform, stratified") - -def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_only=False, trap_only=False, max_bulk_modes_per_layer=None, bulk_sampling_seed=None, bulk_sampling_strategy='all'): +def _extract_mp_bulk_edges(layer_result=None, details=None, bulk_stats=None): + candidates_min = ["mp_bulk_min", "bulk_min", "lambda_min", "mp_lambda_min", "lambda_minus", "xmin"] + candidates_max = ["mp_bulk_max", "bulk_max", "lambda_max", "mp_lambda_max", "lambda_plus", "xmax"] + + def read(obj, keys): + if obj is None: + return None + if isinstance(obj, dict): + for k in keys: + if k in obj and obj[k] is not None: + return obj[k] + else: + for k in keys: + if hasattr(obj, k): + v = getattr(obj, k) + if v is not None: + return v + return None + + mp_min = read(bulk_stats, candidates_min) + if mp_min is None: + mp_min = read(details, candidates_min) + if mp_min is None: + mp_min = read(layer_result, candidates_min) + mp_max = read(bulk_stats, candidates_max) + if mp_max is None: + mp_max = read(details, candidates_max) + if mp_max is None: + mp_max = read(layer_result, candidates_max) + mp_min = 0.0 if mp_min is None else float(mp_min) + mp_max = np.nan if mp_max is None else float(mp_max) + return mp_min, mp_max + + +def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_only=False, trap_only=False, max_bulk_modes_per_layer=None, bulk_sampling_seed=None, bulk_sampling_strategy='all', allow_bulk_without_mp_edges=False): trap_svd = [int(i) for i in layer_state.get('trap_mode_indices_0based', [])] S = np.asarray(layer_state.get('S_perm', []), dtype=float) evals = S*S - mp_max=float(layer_state.get('bulk_stats',{}).get('mp_bulk_max', np.nan)) - mp_min=float(layer_state.get('bulk_stats',{}).get('mp_bulk_min', 0.0)) + bulk_stats = layer_state.get('bulk_stats', {}) or {} + mp_min, mp_max = _extract_mp_bulk_edges(layer_result=layer_state, details=layer_rows[0] if layer_rows else None, bulk_stats=bulk_stats) + if not np.isfinite(mp_max): + if allow_bulk_without_mp_edges: + wwcore.logger.warning("Missing MP bulk upper edge for layer_id=%s; falling back to non-trap finite modes.", layer_state.get("layer_id")) + else: + raise ValueError("Cannot build bulk IDs because MP bulk upper edge is missing. Pass allow_bulk_without_mp_edges=True to fall back.") + bulk_stats["mp_bulk_min"] = float(mp_min) + bulk_stats["mp_bulk_max"] = float(mp_max) if np.isfinite(mp_max) else np.nan + layer_state["bulk_stats"] = bulk_stats + layer_state["mp_bulk_min"] = float(mp_min) + layer_state["mp_bulk_max"] = float(mp_max) if np.isfinite(mp_max) else np.nan + layer_state["evals"] = np.asarray(evals, dtype=float) + layer_state["singular_values"] = np.asarray(S, dtype=float) trap_set=set(trap_svd) - inside=[i for i,e in enumerate(evals) if np.isfinite(mp_max) and e>=mp_min and e<=mp_max] + if np.isfinite(mp_max): + inside=[i for i,e in enumerate(evals) if np.isfinite(e) and e>=mp_min and e<=mp_max] + else: + inside=[i for i,e in enumerate(evals) if np.isfinite(e) and i not in trap_set] bulk=[i for i in inside if i not in trap_set] + if return_bulk_ids and len(bulk) == 0: + raise ValueError( + f"No eligible bulk modes found for layer_id={layer_state.get('layer_id')} " + f"name={layer_state.get('name')} len_evals={len(evals)} mp_bulk_min={mp_min} " + f"mp_bulk_max={mp_max} min_eval={np.nanmin(evals) if len(evals) else np.nan} " + f"max_eval={np.nanmax(evals) if len(evals) else np.nan} n_traps={len(trap_svd)}. " + "This is an invalid state for bulk ID generation and indicates a broken MP fit " + "or a corrupted trap/bulk classification path." + ) bulk=_sample_bulk_modes(bulk, evals, max_bulk_modes_per_layer, bulk_sampling_seed, bulk_sampling_strategy) layer_state['trap_svd_indices']=trap_svd layer_state['bulk_svd_indices']=bulk @@ -84,7 +141,7 @@ def _build_trap_bulk_rows(layer_state, layer_rows, return_bulk_ids=False, bulk_o if return_bulk_ids: for bi,svd_i in enumerate(bulk, start=1): ev=float(evals[svd_i]) - bulk_rows.append({'layer_id':int(layer_state['layer_id']),'name':layer_state.get('name'),'longname':layer_state.get('longname'),'mode_type':'bulk','ablation_type':'bulk','mode_id':bi,'trap_id':np.nan,'trap_index':np.nan,'bulk_id':bi,'bulk_index':bi,'is_trap':False,'is_bulk':True,'svd_mode_index':svd_i,'mode_index':svd_i,'singular_value':float(S[svd_i]),'eigenvalue':ev,'eval_perm':ev,'mp_lambda_min':mp_min,'mp_lambda_max':mp_max,'is_inside_mp_bulk':True,'is_above_mp_edge':False,'is_below_mp_edge':False}) + bulk_rows.append({'layer_id':int(layer_state['layer_id']),'name':layer_state.get('name'),'longname':layer_state.get('longname'),'mode_type':'bulk','ablation_type':'bulk','mode_id':bi,'trap_id':np.nan,'trap_index':np.nan,'bulk_id':bi,'bulk_index':bi,'is_trap':False,'is_bulk':True,'svd_mode_index':svd_i,'mode_index':svd_i,'singular_value':float(S[svd_i]),'eigenvalue':ev,'eval_perm':ev,'mp_lambda_min':mp_min,'mp_lambda_max':mp_max,'mp_bulk_min':mp_min,'mp_bulk_max':mp_max,'bulk_quantile':float(bi)/float(max(1,len(bulk))),'is_inside_mp_bulk':True,'is_above_mp_edge':False,'is_below_mp_edge':False}) if bulk_only: return bulk_rows if trap_only: return layer_rows return layer_rows + bulk_rows @@ -130,6 +187,7 @@ def analyze_traps( max_bulk_modes_per_layer=None, bulk_sampling_seed=None, bulk_sampling_strategy="all", + allow_bulk_without_mp_edges=False, ): """Externalized implementation for WeightWatcher.analyze_traps().""" if layers is None: @@ -229,6 +287,11 @@ def analyze_traps( layer_out = watcher.apply_analyze_traps(ww_layer, params=layer_params) if return_artifacts: layer_rows, layer_state = layer_out + if isinstance(layer_state, dict): + if layer_state.get("mp_bulk_max") is None and hasattr(ww_layer, "bulk_max"): + layer_state["mp_bulk_max"] = float(ww_layer.bulk_max) if ww_layer.bulk_max is not None else None + if layer_state.get("mp_bulk_min") is None and hasattr(ww_layer, "bulk_min"): + layer_state["mp_bulk_min"] = float(ww_layer.bulk_min) if ww_layer.bulk_min is not None else 0.0 if trap_state is None: trap_state = {"already_randomized": bool(already_randomized), "permuted_ids": {}, "layers": {}} trap_state.setdefault("layers", {})[int(ww_layer.layer_id)] = layer_state @@ -236,7 +299,7 @@ def analyze_traps( else: layer_rows = layer_out if layer_rows or return_bulk_ids: - layer_rows = _build_trap_bulk_rows(layer_state if return_artifacts else {"layer_id": int(ww_layer.layer_id), "name": ww_layer.name, "longname": ww_layer.longname, "S_perm": np.array([]), "trap_mode_indices_0based": []}, layer_rows or [], return_bulk_ids=return_bulk_ids, bulk_only=bulk_only, trap_only=trap_only, max_bulk_modes_per_layer=max_bulk_modes_per_layer, bulk_sampling_seed=bulk_sampling_seed, bulk_sampling_strategy=bulk_sampling_strategy) + layer_rows = _build_trap_bulk_rows(layer_state if return_artifacts else {"layer_id": int(ww_layer.layer_id), "name": ww_layer.name, "longname": ww_layer.longname, "S_perm": np.array([]), "trap_mode_indices_0based": []}, layer_rows or [], return_bulk_ids=return_bulk_ids, bulk_only=bulk_only, trap_only=trap_only, max_bulk_modes_per_layer=max_bulk_modes_per_layer, bulk_sampling_seed=bulk_sampling_seed, bulk_sampling_strategy=bulk_sampling_strategy, allow_bulk_without_mp_edges=allow_bulk_without_mp_edges) if params.get(wwcore.PLOT, False): trap_infos = [] for row in layer_rows: @@ -282,11 +345,13 @@ def analyze_traps( if len(details) > 0: if "perm_mode_index" in details.columns: - details["perm_mode_index_0based"] = details["perm_mode_index"].astype(int) - details["perm_mode_index"] = details["perm_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api) + mask = details["perm_mode_index"].notna() + details.loc[mask, "perm_mode_index_0based"] = details.loc[mask, "perm_mode_index"].astype(int) + details.loc[mask, "perm_mode_index"] = details.loc[mask, "perm_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api) if "trap_mode_index" in details.columns: - details["trap_mode_index_0based"] = details["trap_mode_index"].astype(int) - details["trap_mode_index"] = details["trap_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api) + mask = details["trap_mode_index"].notna() + details.loc[mask, "trap_mode_index_0based"] = details.loc[mask, "trap_mode_index"].astype(int) + details.loc[mask, "trap_mode_index"] = details.loc[mask, "trap_mode_index_0based"].apply(remove_traps_ops._internal_trap_index_to_api) lead_cols = ["layer_id", "name"] details = details[lead_cols + [c for c in details.columns if c not in lead_cols]] diff --git a/weightwatcher/weightwatcher.py b/weightwatcher/weightwatcher.py index c1370e8..0c7fb5d 100644 --- a/weightwatcher/weightwatcher.py +++ b/weightwatcher/weightwatcher.py @@ -3770,7 +3770,8 @@ def analyze_traps(self, model=None, layers=[], trap_only=False, max_bulk_modes_per_layer=None, bulk_sampling_seed=None, - bulk_sampling_strategy="all"): + bulk_sampling_strategy="all", + allow_bulk_without_mp_edges=False): """Analyze randomized correlation traps and return one row per trap. This method follows the randomized/permuted trap workflow: @@ -3838,6 +3839,7 @@ def analyze_traps(self, model=None, layers=[], max_bulk_modes_per_layer=max_bulk_modes_per_layer, bulk_sampling_seed=bulk_sampling_seed, bulk_sampling_strategy=bulk_sampling_strategy, + allow_bulk_without_mp_edges=allow_bulk_without_mp_edges, ) def analyze_bulk_modes(self, bulk_ids_by_layer=None, layers=None, randomized_model=None, trap_state=None, **kwargs):