Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion sim/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,20 @@ void my_init(Env *env, Dict *kwargs) {
}
}

void my_log(Log *log, Dict *out) {
void my_log(Log *log, Dict *out, float n) {
// static_vec_aggregate_logs divides every Log field by total agent count n,
// so log->total_* are per-agent rates here. We want the fleet ratio
// sum_distance / sum_infractions; the n cancels in that ratio so we
// could compute it directly from the rates. We multiply by n anyway
// to recover the raw fleet totals: this makes the fmaxf(1.0f, ...)
// clamp meaningful (it floors the denominator at "1 infraction across
// the whole fleet"), so a zero-infraction window reports total fleet
// distance instead of distance / epsilon = absurd value.
float total_distance_travelled = log->total_distance_travelled * n;
float total_infractions = log->total_infractions * n;
float avg_distance_per_infraction =
total_distance_travelled / fmaxf(1.0f, total_infractions);

dict_set(out, "score", log->score);
dict_set(out, "episode_return", log->episode_return);
dict_set(out, "episode_length", log->episode_length);
Expand All @@ -191,6 +204,7 @@ void my_log(Log *log, Dict *out) {
dict_set(out, "num_goals_reached", log->num_goals_reached);
dict_set(out, "avg_speed_per_agent", log->avg_speed_per_agent);
dict_set(out, "dnf_rate", log->dnf_rate);
dict_set(out, "avg_distance_per_infraction", avg_distance_per_infraction);
dict_set(out, "n", log->n);
}

Expand Down
6 changes: 3 additions & 3 deletions src/vecenv.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ extern const char* cudaGetErrorString(cudaError_t);

// Forward declare env-provided functions (defined in binding.c after this include)
void my_init(Env* env, Dict* kwargs);
void my_log(Log* log, Dict* out);
void my_log(Log* log, Dict* out, float n);
void my_env_constants(void* env, Dict* out);

struct StaticThreading {
Expand Down Expand Up @@ -643,7 +643,7 @@ void static_vec_log(StaticVec* vec, Dict* out) {
for (int i = 0; i < vec->size; i++) {
memset(&envs[i].log, 0, sizeof(Log));
}
my_log(&aggregate, out);
my_log(&aggregate, out, n);
dict_set(out, "n", n);
}

Expand All @@ -653,7 +653,7 @@ void static_vec_eval_log(StaticVec* vec, Dict* out) {
if (n == 0) {
return;
}
my_log(&aggregate, out);
my_log(&aggregate, out, n);
dict_set(out, "n", n);
}

Expand Down
Loading