Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ def __init__(

self.public = public

self._using_trackio = False
if use_wandb is True:
try:
import trackio as wandb

self._using_trackio = True
logger.warning("Trackio was found available in your environment, using it instead of wandb")
self.wandb_project = os.environ.get("WANDB_PROJECT", None)
self.space_id = os.environ.get("WANDB_SPACE_ID", None)
Expand All @@ -207,6 +209,7 @@ def __init__(
resume="allow",
**wandb_kwargs,
)
self._wandb_module = wandb

@property
def results(self):
Expand Down Expand Up @@ -292,13 +295,64 @@ def save(self) -> None:

def push_to_wandb(self, results_dict: dict, details_datasets: dict) -> None:
# reformat the results key to replace ':' with '/'
results_dict = {k.replace(":", "/"): v for k, v in results_dict["results"].items()}
flat_results = {k.replace(":", "/"): v for k, v in results_dict["results"].items()}

self.wandb_run.log(
{**results_dict},
{**flat_results},
)

# When the backend is Trackio, log per-sample details as trackio.Trace
# so each sample is inspectable on the dashboard. (For wandb users this
# path stays a no-op; details are still pushed to the Hub via push_to_hub.)
if self._using_trackio:
self._log_details_as_traces(details_datasets)

self.wandb_run.finish()

def _log_details_as_traces(self, details_datasets: dict) -> None:
"""Log each sample of each task as a ``trackio.Trace``.

Pulls (prompt, completion, metrics, gold) out of the per-sample detail
dataset and emits one Trace per row, with the task name as a top-level
log key so samples are grouped by task on the Trackio dashboard.
"""
trackio_module = self._wandb_module
step = 0
for task_name, dataset in details_datasets.items():
for row in dataset:
model_response = row.get("model_response") or {}
doc = row.get("doc") or {}
metric = row.get("metric") or {}

prompt = model_response.get("input")
if isinstance(prompt, list):
prompt = "\n".join(str(p) for p in prompt)
generations = model_response.get("text") or []
completion = generations[0] if generations else ""

messages = []
instruction = doc.get("instruction")
if instruction:
messages.append({"role": "system", "content": str(instruction)})
messages.append({"role": "user", "content": str(prompt or doc.get("query", ""))})
if completion:
messages.append({"role": "assistant", "content": str(completion)})

metadata = {"task": task_name, **{k: v for k, v in metric.items() if v is not None}}
gold_index = doc.get("gold_index")
choices = doc.get("choices")
if gold_index is not None:
metadata["gold_index"] = gold_index
if isinstance(choices, list) and isinstance(gold_index, int) and 0 <= gold_index < len(choices):
metadata["gold"] = choices[gold_index]

try:
trace = trackio_module.Trace(messages=messages, metadata=metadata)
self.wandb_run.log({f"{task_name.replace(':', '/')}/sample": trace}, step=step)
except Exception as e:
logger.warning(f"Failed to log Trackio trace for {task_name} sample {step}: {e}")
step += 1

def save_results(self, date_id: str, results_dict: dict):
if self.results_path_template is not None:
org_model_parts = self.general_config_logger.model_name.split("/")
Expand Down
Loading