diff --git a/sim/binding.c b/sim/binding.c index df5841ba1..83a354f3d 100644 --- a/sim/binding.c +++ b/sim/binding.c @@ -182,7 +182,20 @@ void my_init(Env *env, Dict *kwargs) { } } -void my_log(Log *log, Dict *out) { +void my_log(Log *log, Dict *out, float n) { + // static_vec_aggregate_logs divides every Log field by total agent count n, + // so log->total_* are per-agent rates here. We want the fleet ratio + // sum_distance / sum_infractions; the n cancels in that ratio so we + // could compute it directly from the rates. We multiply by n anyway + // to recover the raw fleet totals: this makes the fmaxf(1.0f, ...) + // clamp meaningful (it floors the denominator at "1 infraction across + // the whole fleet"), so a zero-infraction window reports total fleet + // distance instead of distance / epsilon = absurd value. + float total_distance_travelled = log->total_distance_travelled * n; + float total_infractions = log->total_infractions * n; + float avg_distance_per_infraction = + total_distance_travelled / fmaxf(1.0f, total_infractions); + dict_set(out, "score", log->score); dict_set(out, "episode_return", log->episode_return); dict_set(out, "episode_length", log->episode_length); @@ -191,6 +204,7 @@ void my_log(Log *log, Dict *out) { dict_set(out, "num_goals_reached", log->num_goals_reached); dict_set(out, "avg_speed_per_agent", log->avg_speed_per_agent); dict_set(out, "dnf_rate", log->dnf_rate); + dict_set(out, "avg_distance_per_infraction", avg_distance_per_infraction); dict_set(out, "n", log->n); } diff --git a/src/vecenv.h b/src/vecenv.h index 067fda802..c88fe79b2 100644 --- a/src/vecenv.h +++ b/src/vecenv.h @@ -238,7 +238,7 @@ extern const char* cudaGetErrorString(cudaError_t); // Forward declare env-provided functions (defined in binding.c after this include) void my_init(Env* env, Dict* kwargs); -void my_log(Log* log, Dict* out); +void my_log(Log* log, Dict* out, float n); void my_env_constants(void* env, Dict* out); struct StaticThreading { @@ -643,7 +643,7 @@ void static_vec_log(StaticVec* vec, Dict* out) { for (int i = 0; i < vec->size; i++) { memset(&envs[i].log, 0, sizeof(Log)); } - my_log(&aggregate, out); + my_log(&aggregate, out, n); dict_set(out, "n", n); } @@ -653,7 +653,7 @@ void static_vec_eval_log(StaticVec* vec, Dict* out) { if (n == 0) { return; } - my_log(&aggregate, out); + my_log(&aggregate, out, n); dict_set(out, "n", n); }