diff --git a/docs/source/metric-list.mdx b/docs/source/metric-list.mdx index 06d3dd069..514148832 100644 --- a/docs/source/metric-list.mdx +++ b/docs/source/metric-list.mdx @@ -26,6 +26,7 @@ These metrics need the model to generate an output. They are therefore slower. - normalization on string pre-comparision on whitespace, articles, capitalization, .... - comparing the full string, or only subsets (prefix, suffix, ...) - `maj_at_k`: Model majority vote. Samples k generations from the model and assumes the most frequent is the actual prediction. + - `bayes_at_n`: Corpus-level Bayes@N for repeated generations. It reports `bayes@n`, the posterior mean, and `bayes@n_sigma`, the posterior standard deviation; multi-category outcomes require category weights. - `f1_score`: Average F1 score in terms of word overlap between the model output and gold (normalisation optional). - `f1_score_macro`: Corpus level macro F1 score. - `f1_score_macro`: Corpus level micro F1 score. @@ -54,7 +55,7 @@ These metrics need the model to generate an output. They are therefore slower. - `edit_distance`: Average Levenshtein edit distance between model generation and reference, - `edit_similarity`: Average Levenshtein edit similarity (normalized by the length of longer sequence) between model generation and reference. - Math: - - Both `exact_match` and `maj_at_k` can be used to evaluate mathematics tasks with math specific normalization to remove and filter latex. + - `exact_match`, `maj_at_k`, and `bayes_at_n` can be used to evaluate mathematics tasks with math specific normalization to remove and filter latex. ## LLM-as-Judge - `llm_judge_gpt3p5`: Can be used for any generative task, the model will be scored by a GPT3.5 model using the OpenAI API. diff --git a/src/lighteval/metrics/bayes_at_n.py b/src/lighteval/metrics/bayes_at_n.py new file mode 100644 index 000000000..fe3fc0715 --- /dev/null +++ b/src/lighteval/metrics/bayes_at_n.py @@ -0,0 +1,160 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# DEALINGS IN THE SOFTWARE. + +"""Bayes@N posterior moments for repeated categorical outcomes.""" + +from collections.abc import Sequence + +import numpy as np + + +def _as_2d_int_matrix(values: Sequence[Sequence[int]] | np.ndarray, name: str) -> np.ndarray: + try: + matrix = np.asarray(values) + except ValueError as exc: + raise ValueError(f"{name} must be a rectangular 1D or 2D array.") from exc + + if matrix.ndim == 1: + matrix = matrix.reshape(1, -1) + elif matrix.ndim != 2: + raise ValueError(f"{name} must be a 1D or 2D array.") + + if matrix.shape[0] == 0: + raise ValueError(f"{name} must contain at least one row.") + + if matrix.dtype == np.dtype("bool"): + return matrix.astype(int) + if np.issubdtype(matrix.dtype, np.integer): + return matrix.astype(int, copy=False) + if not np.issubdtype(matrix.dtype, np.number): + raise ValueError(f"{name} entries must be integer category ids.") + + float_matrix = matrix.astype(float) + if not np.all(np.isfinite(float_matrix)): + raise ValueError(f"{name} entries must be finite integer category ids.") + if not np.all(float_matrix == np.floor(float_matrix)): + raise ValueError(f"{name} entries must be integer category ids.") + return float_matrix.astype(int) + + +def _as_weights(weights: Sequence[float] | np.ndarray | None, R: np.ndarray) -> np.ndarray: + if weights is None: + unique_values = np.unique(R) + if np.all(np.isin(unique_values, [0, 1])): + return np.array([0.0, 1.0]) + + unique_str = ", ".join(str(value) for value in unique_values) + raise ValueError( + f"R contains non-binary category ids ({unique_str}); pass weights to score multi-category outcomes." + ) + + weight_array = np.asarray(weights, dtype=float) + if weight_array.ndim != 1: + raise ValueError("weights must be a 1D array.") + if weight_array.size == 0: + raise ValueError("weights must contain at least one value.") + if not np.all(np.isfinite(weight_array)): + raise ValueError("weights must contain only finite values.") + return weight_array + + +def _validate_matrix_range(matrix: np.ndarray, low: int, high: int, name: str) -> None: + if matrix.size == 0: + return + if matrix.min() < low or matrix.max() > high: + raise ValueError(f"{name} entries must be integers in [{low}, {high}].") + + +def _row_bincount(matrix: np.ndarray, length: int) -> np.ndarray: + if matrix.shape[1] == 0: + return np.zeros((matrix.shape[0], length), dtype=int) + + counts = np.zeros((matrix.shape[0], length), dtype=int) + rows = np.repeat(np.arange(matrix.shape[0]), matrix.shape[1]) + np.add.at(counts, (rows, matrix.ravel()), 1) + return counts + + +def _as_prior_matrix( + prior: Sequence[Sequence[int]] | np.ndarray | None, + num_rows: int, +) -> np.ndarray: + if prior is None: + return np.zeros((num_rows, 0), dtype=int) + + prior_matrix = _as_2d_int_matrix(prior, "prior") + if prior_matrix.ndim == 1: + prior_matrix = prior_matrix.reshape(1, -1) + if prior_matrix.shape[0] != num_rows: + if prior_matrix.size % num_rows != 0: + raise ValueError("prior must have the same number of rows as R.") + prior_matrix = prior_matrix.reshape(num_rows, -1) + return prior_matrix + + +def bayes_at_n( + R: Sequence[Sequence[int]] | np.ndarray, + weights: Sequence[float] | np.ndarray | None = None, + prior: Sequence[Sequence[int]] | np.ndarray | None = None, +) -> tuple[float, float]: + """Return the Bayes@N posterior mean and standard deviation. + + Args: + R: ``M x N`` matrix of integer category ids. A 1D array is treated as + one row. + weights: Category score weights. If omitted, ``R`` must be binary and + weights ``[0.0, 1.0]`` are used. + prior: Optional ``M x D`` matrix of row-aligned prior observations. + + Returns: + ``(mu, sigma)``, where ``mu`` is the posterior mean and ``sigma`` is the + posterior standard deviation. + """ + outcome_matrix = _as_2d_int_matrix(R, "R") + if outcome_matrix.shape[1] == 0: + raise ValueError("R must contain at least one outcome per row.") + + weight_array = _as_weights(weights, outcome_matrix) + num_rows, num_samples = outcome_matrix.shape + max_category = weight_array.size - 1 + prior_matrix = _as_prior_matrix(prior, num_rows) + + _validate_matrix_range(outcome_matrix, 0, max_category, "R") + _validate_matrix_range(prior_matrix, 0, max_category, "prior") + + prior_samples = prior_matrix.shape[1] + total_count = 1 + max_category + prior_samples + num_samples + + outcome_counts = _row_bincount(outcome_matrix, max_category + 1) + prior_counts = _row_bincount(prior_matrix, max_category + 1) + 1 + posterior_counts = outcome_counts + prior_counts + + delta_weights = weight_array - weight_array[0] + mu = weight_array[0] + (posterior_counts @ delta_weights).sum() / (num_rows * total_count) + + posterior_probs = posterior_counts / total_count + second_moment = (posterior_probs * (delta_weights**2)).sum(axis=1) + squared_mean = (posterior_probs @ delta_weights) ** 2 + sigma = np.sqrt((second_moment - squared_mean).sum() / (num_rows**2 * (total_count + 1))) + + return float(mu), float(sigma) + + +__all__ = ["bayes_at_n"] diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 82cfbb706..c89cccff4 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -32,6 +32,7 @@ from lighteval.metrics.harness_compatibility.drop import DropMetrics from lighteval.metrics.harness_compatibility.truthful_qa import TruthfulqaMCMetrics from lighteval.metrics.metrics_corpus import ( + BayesAtNCorpus, CorpusLevelF1Score, CorpusLevelPerplexityMetric, CorpusLevelTranslationMetric, @@ -44,6 +45,7 @@ ROUGE, AccGoldLikelihood, AvgAtN, + BayesAtN, BertScore, ExactMatches, Extractiveness, @@ -172,6 +174,43 @@ class Metrics(Enum): corpus_level_fn=np.mean, higher_is_better=True, ) + bayes_at_n = CorpusLevelMetricGrouping( + metric_name=["bayes@n", "bayes@n_sigma"], + sample_level_fn=BayesAtN(strip_strings=True), + category=SamplingMethod.GENERATIVE, + corpus_level_fn={ + "bayes@n": BayesAtNCorpus("mu"), + "bayes@n_sigma": BayesAtNCorpus("sigma"), + }, + higher_is_better={ + "bayes@n": True, + "bayes@n_sigma": False, + }, + ) + bayes_at_n_math = CorpusLevelMetricGrouping( + metric_name=["math-bayes@n", "math-bayes@n_sigma"], + sample_level_fn=BayesAtN( + name_prefix="math", + strip_strings=True, + sample_scoring_function=MultilingualExtractiveMatchMetric( + language=Language.ENGLISH, + fallback_mode="first_match", + precision=5, + gold_extraction_target=(ExprExtractionConfig(),), + pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), + aggregation_function=max, + ), + ), + category=SamplingMethod.GENERATIVE, + corpus_level_fn={ + "math-bayes@n": BayesAtNCorpus("mu"), + "math-bayes@n_sigma": BayesAtNCorpus("sigma"), + }, + higher_is_better={ + "math-bayes@n": True, + "math-bayes@n_sigma": False, + }, + ) bert_score = SampleLevelMetricGrouping( metric_name=["BERTScore-P", "BERTScore-R", "BERTScore-F"], sample_level_fn=BertScore(normalize_gold=remove_braces, normalize_pred=remove_braces_and_strip), diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 92c2c574a..ccbfa8931 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -34,6 +34,7 @@ import sacrebleu import sklearn.metrics +from lighteval.metrics.bayes_at_n import bayes_at_n from lighteval.metrics.sample_preparator import ( GenerativeCorpusMetricInput, LogprobCorpusMetricInput, @@ -62,6 +63,96 @@ def __str__(self): return f"{self.__class__.__name__}({', '.join(attr_strs)})" +def _is_repeated_full_bayes_prior(non_null_priors: list[object], first_prior: np.ndarray, num_rows: int) -> bool: + if not all(np.array_equal(np.asarray(prior), first_prior) for prior in non_null_priors): + return False + return (first_prior.ndim == 2 and first_prior.shape[0] == num_rows) or ( + first_prior.ndim == 1 and num_rows == 1 + ) + + +def _coerce_bayes_prior_row(prior: object) -> list[int]: + prior_array = np.asarray(prior) + if prior_array.ndim == 0: + raise ValueError("Bayes@N prior rows must be 1D arrays.") + if prior_array.ndim == 2: + if prior_array.shape[0] != 1: + raise ValueError("Bayes@N row-level prior payloads must contain exactly one row.") + prior_array = prior_array.reshape(-1) + elif prior_array.ndim != 1: + raise ValueError("Bayes@N row-level prior payloads must be 1D arrays.") + return prior_array.tolist() + + +def _coerce_bayes_prior(priors: list[object | None], num_rows: int) -> list[list[int]] | object | None: + non_null_priors = [prior for prior in priors if prior is not None] + if not non_null_priors: + return None + if len(non_null_priors) != len(priors): + raise ValueError("Bayes@N prior observations must be provided for every row or omitted for every row.") + + first_prior = np.asarray(non_null_priors[0]) + if _is_repeated_full_bayes_prior(non_null_priors, first_prior, num_rows): + return non_null_priors[0] + + prior_rows = [_coerce_bayes_prior_row(prior) for prior in non_null_priors] + prior_lengths = {len(row) for row in prior_rows} + if len(prior_lengths) != 1: + raise ValueError("Bayes@N prior rows must all have the same number of observations.") + return prior_rows + + +def _coerce_bayes_items(items: list[dict | list[int]]) -> tuple[list[list[int]], list[float] | None, object | None]: + if len(items) == 0: + raise ValueError("Bayes@N needs at least one row.") + + rows = [] + weights = None + priors = [] + for item in items: + if isinstance(item, dict): + if "scores" not in item: + raise ValueError("Bayes@N payloads must contain a 'scores' row.") + row = item["scores"] + item_weights = item.get("weights") + priors.append(item.get("prior")) + else: + row = item + item_weights = None + priors.append(None) + + row = list(row) + if len(row) == 0: + raise ValueError("Bayes@N rows must contain at least one score.") + rows.append(row) + + if item_weights is not None: + item_weights = np.asarray(item_weights, dtype=float) + if weights is None: + weights = item_weights + elif not np.array_equal(weights, item_weights): + raise ValueError("Bayes@N received inconsistent weights across rows.") + + row_lengths = {len(row) for row in rows} + if len(row_lengths) != 1: + raise ValueError("Bayes@N requires every row to have the same number of scores.") + + weights_list = weights.tolist() if weights is not None else None + return rows, weights_list, _coerce_bayes_prior(priors, len(rows)) + + +class BayesAtNCorpus(CorpusLevelComputation): + def __init__(self, statistic: Literal["mu", "sigma"]): + if statistic not in {"mu", "sigma"}: + raise ValueError("BayesAtNCorpus statistic must be either 'mu' or 'sigma'.") + self.statistic = statistic + + def compute_corpus(self, items: list[dict | list[int]]) -> float: + rows, weights, prior = _coerce_bayes_items(items) + mu, sigma = bayes_at_n(rows, weights=weights, prior=prior) + return mu if self.statistic == "mu" else sigma + + # General aggregations class MatthewsCorrCoef(CorpusLevelComputation): def compute_corpus(self, items: list[GenerativeCorpusMetricInput]) -> float: diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index db14b9bf6..a0d82dfdd 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -28,6 +28,7 @@ import logging import os from abc import ABC, abstractmethod +from dataclasses import replace from typing import Callable, Literal, Union import nltk @@ -1208,6 +1209,88 @@ def num_samples(self): return self.n +class BayesAtN(SamplingMetric, SampleLevelComputation): + def __init__( + self, + n: int | None = None, + weights: list[float] | None = None, + prior: list[list[int]] | None = None, + confidence: float | None = None, + name_prefix: str | None = None, + **kwargs, + ): + """Collect repeated sample scores for corpus-level Bayes@N aggregation. + + Args: + n (int | None): Number of generated samples to score. If omitted, + all available samples are used. + weights (list[float] | None): Optional score for each integer + category id. Binary scores default to ``[0.0, 1.0]``. + prior (list[list[int]] | None): Optional row-aligned prior + observations used by the corpus-level aggregator. + confidence (float | None): Reserved for future interval outputs. + name_prefix (str | None): Optional prefix for metric names. + **kwargs: Additional keyword arguments. + """ + super().__init__(**kwargs) + self.n = n + self.weights = weights + self.prior = prior + self.confidence = confidence + self.name = f"{name_prefix}-bayes@n" if name_prefix else "bayes@n" + + @property + def metric_names(self): + return [self.name, f"{self.name}_sigma"] + + def _coerce_score(self, score: float | int | bool) -> int: + if isinstance(score, (bool, np.bool_)): + return int(score) + if isinstance(score, (int, np.integer)): + category = int(score) + elif isinstance(score, (float, np.floating)): + if not np.isfinite(score) or not float(score).is_integer(): + raise ValueError( + "Bayes@N sample scores must be integer category ids. " + "Continuous scores require an explicit category mapping before aggregation." + ) + category = int(score) + else: + raise ValueError("Bayes@N sample scores must be integer category ids.") + + if category < 0: + raise ValueError("Bayes@N sample scores must be non-negative integer category ids.") + return category + + def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict: + """Collect the row of repeated outcomes for one document.""" + predictions = model_response.final_text + if self.n is None: + n = len(predictions) + logger.warning( + "n undefined in Bayes@N. We assume it's the same as the sample's number of predictions." + ) + else: + n = self.n + if len(predictions) < n: + logger.warning(f"Number of predictions is less than {self.n} for Bayes@N.") + if n <= 0: + raise ValueError("Bayes@N requires at least one prediction.") + + processed_doc = replace(doc, choices=[self.preprocess(text=choice) for choice in doc.choices]) + + row = [] + for pred in predictions[:n]: + processed_response = ModelResponse(text=[self.preprocess(text=pred)]) + row.append(self._coerce_score(self.compute_score(processed_doc, processed_response))) + + payload = {"scores": row, "weights": self.weights, "prior": self.prior} + return dict.fromkeys(self.metric_names, payload) + + def num_samples(self): + return self.n if self.n is not None else 1 + + class MajAtN(SamplingMetric, SampleLevelComputation): def __init__(self, n: int | None = None, **kwargs): """An exact match class. diff --git a/src/lighteval/pipeline.py b/src/lighteval/pipeline.py index 1f5da9c14..e3ac809e5 100644 --- a/src/lighteval/pipeline.py +++ b/src/lighteval/pipeline.py @@ -245,7 +245,12 @@ def _update_num_samples(self, tasks: list[LightevalTask]): """ for task in tasks: for metric in task.metrics: - if metric_data := self._metric_options.get(metric.metric_name, None): + metric_names = metric.metric_name if isinstance(metric.metric_name, list) else [metric.metric_name] + metric_data = next( + (self._metric_options[name] for name in metric_names if name in self._metric_options), + None, + ) + if metric_data: num_samples = metric_data.get("num_samples", None) if num_samples: task.num_samples = [num_samples] diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index e7c4e9eb6..748356fc2 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -109,6 +109,12 @@ def load_community_tasks(): DEFAULT_SUITES = CORE_SUITES + OPTIONAL_SUITES +def _metric_name_contains_at(metric_name: str | list[str]) -> bool: + if isinstance(metric_name, list): + return any("@" in name for name in metric_name) + return "@" in metric_name + + class Registry: """The Registry class is used to manage the task registry and get task classes.""" @@ -229,7 +235,7 @@ def _update_task_configs(self) -> dict[str, LightevalTaskConfig]: # noqa: C901 config.num_fewshots = few_shot config.full_name = f"{expanded_task}|{config.num_fewshots}" # If some tasks are parametrizable and in cli, we set attributes here - for metric in [m for m in config.metrics if "@" in m.metric_name]: # parametrizable metric + for metric in [m for m in config.metrics if _metric_name_contains_at(m.metric_name)]: for attribute, value in metric_params_dict.items(): setattr(metric.sample_level_fn, attribute, value) required = getattr(metric.sample_level_fn, "attribute_must_be_set", []) diff --git a/tests/unit/metrics/test_bayes_at_n.py b/tests/unit/metrics/test_bayes_at_n.py new file mode 100644 index 000000000..87d37a1b1 --- /dev/null +++ b/tests/unit/metrics/test_bayes_at_n.py @@ -0,0 +1,114 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# DEALINGS IN THE SOFTWARE. + +import numpy as np +import pytest + +from lighteval.metrics.bayes_at_n import bayes_at_n +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.metrics_corpus import BayesAtNCorpus +from lighteval.metrics.metrics_sample import BayesAtN +from lighteval.models.model_output import ModelResponse +from lighteval.tasks.requests import Doc + + +def test_bayes_at_n_multicategory_with_prior(): + R = np.array([[0, 1, 2, 2, 1], [1, 1, 0, 2, 2]]) + weights = np.array([0.0, 0.5, 1.0]) + prior = np.array([[0, 2], [1, 2]]) + + mu, sigma = bayes_at_n(R, weights=weights, prior=prior) + + assert mu == pytest.approx(0.575) + assert sigma == pytest.approx(0.084275, abs=1e-6) + + +def test_bayes_at_n_multicategory_without_prior(): + R = np.array([[0, 1, 2, 2, 1], [1, 1, 0, 2, 2]]) + weights = np.array([0.0, 0.5, 1.0]) + + mu, sigma = bayes_at_n(R, weights=weights) + + assert mu == pytest.approx(0.5625) + assert sigma == pytest.approx(0.091998, abs=1e-6) + + +def test_bayes_at_n_binary_defaults_weights(): + mu, sigma = bayes_at_n([[0, 1, 1], [1, 1, 0]]) + + assert mu == pytest.approx(0.6) + assert sigma == pytest.approx(0.1414213562373095) + + +def test_bayes_at_n_requires_weights_for_multicategory(): + with pytest.raises(ValueError, match="pass weights"): + bayes_at_n([[0, 1, 2]]) + + +def test_bayes_at_n_validates_categories_and_prior_shape(): + with pytest.raises(ValueError, match="R entries"): + bayes_at_n([[0, 2]], weights=[0.0, 1.0]) + + with pytest.raises(ValueError, match="integer category ids"): + bayes_at_n([[0.5, 1.0]]) + + with pytest.raises(ValueError, match="same number of rows"): + bayes_at_n([[0, 1], [1, 0]], prior=[[0, 1, 0]]) + + +def test_bayes_at_n_corpus_aggregator_uses_all_rows(): + items = [ + {"scores": [0, 1, 2, 2, 1], "weights": [0.0, 0.5, 1.0], "prior": [[0, 2], [1, 2]]}, + {"scores": [1, 1, 0, 2, 2], "weights": [0.0, 0.5, 1.0], "prior": [[0, 2], [1, 2]]}, + ] + + assert BayesAtNCorpus("mu").compute_corpus(items) == pytest.approx(0.575) + assert BayesAtNCorpus("sigma").compute_corpus(items) == pytest.approx(0.084275, abs=1e-6) + + +def test_bayes_at_n_sample_metric_and_registration(): + metric = Metrics.bayes_at_n(sample_params={"n": 5}) + metric_name, sigma_name = metric.metric_name + docs = [ + Doc(query="q1", choices=["A"], gold_index=0), + Doc(query="q2", choices=["A"], gold_index=0), + ] + responses = [ + ModelResponse(text=["A", "B", "A", "A", "B"]), + ModelResponse(text=["B", "A", "B", "B", "A"]), + ] + + sample_outputs = [metric.compute_sample(doc=doc, model_response=response) for doc, response in zip(docs, responses)] + expected_rows = [[1, 0, 1, 1, 0], [0, 1, 0, 0, 1]] + expected_mu, expected_sigma = bayes_at_n(expected_rows) + + assert sample_outputs[0][metric_name]["scores"] == expected_rows[0] + aggregations = metric.get_corpus_aggregations() + assert aggregations[metric_name]([output[metric_name] for output in sample_outputs]) == pytest.approx(expected_mu) + assert aggregations[sigma_name]([output[sigma_name] for output in sample_outputs]) == pytest.approx(expected_sigma) + + +def test_bayes_at_n_rejects_continuous_sample_scores(): + metric = BayesAtN(n=1, sample_scoring_function=lambda doc, response: 0.5) + doc = Doc(query="q", choices=["A"], gold_index=0) + response = ModelResponse(text=["A"]) + + with pytest.raises(ValueError, match="integer category ids"): + metric.compute(doc=doc, model_response=response)