diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 976b21c86..80c1715f8 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -180,10 +180,12 @@ def __init__( self.public = public + self._using_trackio = False if use_wandb is True: try: import trackio as wandb + self._using_trackio = True logger.warning("Trackio was found available in your environment, using it instead of wandb") self.wandb_project = os.environ.get("WANDB_PROJECT", None) self.space_id = os.environ.get("WANDB_SPACE_ID", None) @@ -207,6 +209,7 @@ def __init__( resume="allow", **wandb_kwargs, ) + self._wandb_module = wandb @property def results(self): @@ -292,13 +295,64 @@ def save(self) -> None: def push_to_wandb(self, results_dict: dict, details_datasets: dict) -> None: # reformat the results key to replace ':' with '/' - results_dict = {k.replace(":", "/"): v for k, v in results_dict["results"].items()} + flat_results = {k.replace(":", "/"): v for k, v in results_dict["results"].items()} self.wandb_run.log( - {**results_dict}, + {**flat_results}, ) + + # When the backend is Trackio, log per-sample details as trackio.Trace + # so each sample is inspectable on the dashboard. (For wandb users this + # path stays a no-op; details are still pushed to the Hub via push_to_hub.) + if self._using_trackio: + self._log_details_as_traces(details_datasets) + self.wandb_run.finish() + def _log_details_as_traces(self, details_datasets: dict) -> None: + """Log each sample of each task as a ``trackio.Trace``. + + Pulls (prompt, completion, metrics, gold) out of the per-sample detail + dataset and emits one Trace per row, with the task name as a top-level + log key so samples are grouped by task on the Trackio dashboard. + """ + trackio_module = self._wandb_module + step = 0 + for task_name, dataset in details_datasets.items(): + for row in dataset: + model_response = row.get("model_response") or {} + doc = row.get("doc") or {} + metric = row.get("metric") or {} + + prompt = model_response.get("input") + if isinstance(prompt, list): + prompt = "\n".join(str(p) for p in prompt) + generations = model_response.get("text") or [] + completion = generations[0] if generations else "" + + messages = [] + instruction = doc.get("instruction") + if instruction: + messages.append({"role": "system", "content": str(instruction)}) + messages.append({"role": "user", "content": str(prompt or doc.get("query", ""))}) + if completion: + messages.append({"role": "assistant", "content": str(completion)}) + + metadata = {"task": task_name, **{k: v for k, v in metric.items() if v is not None}} + gold_index = doc.get("gold_index") + choices = doc.get("choices") + if gold_index is not None: + metadata["gold_index"] = gold_index + if isinstance(choices, list) and isinstance(gold_index, int) and 0 <= gold_index < len(choices): + metadata["gold"] = choices[gold_index] + + try: + trace = trackio_module.Trace(messages=messages, metadata=metadata) + self.wandb_run.log({f"{task_name.replace(':', '/')}/sample": trace}, step=step) + except Exception as e: + logger.warning(f"Failed to log Trackio trace for {task_name} sample {step}: {e}") + step += 1 + def save_results(self, date_id: str, results_dict: dict): if self.results_path_template is not None: org_model_parts = self.general_config_logger.model_name.split("/")