Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion brax/training/acting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
"""Brax training acting functions."""

import time
from typing import Callable, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple

from brax import envs
from brax.training.types import Metrics
from brax.training.types import Policy
from brax.training.types import PolicyParams
from brax.training.types import PRNGKey
from brax.training.types import Transition
from jax.experimental import io_callback
import jax
import jax.numpy as jnp
import numpy as np

State = envs.State
Expand Down Expand Up @@ -58,6 +60,8 @@ def generate_unroll(
key: PRNGKey,
unroll_length: int,
extra_fields: Sequence[str] = (),
render_fn: Optional[Callable[[State], None]] = None,
should_render: jax.Array = jnp.array(False, dtype=bool),
) -> Tuple[State, Transition]:
"""Collect trajectories of given unroll_length."""

Expand All @@ -68,6 +72,14 @@ def f(carry, unused_t):
nstate, transition = actor_step(
env, state, policy, current_key, extra_fields=extra_fields
)

if render_fn is not None:

def render(state: State):
io_callback(render_fn, None, state)

jax.lax.cond(should_render, render, lambda s: None, nstate)

return (nstate, next_key), transition

(final_state, _), data = jax.lax.scan(
Expand Down Expand Up @@ -115,6 +127,7 @@ def generate_eval_unroll(
eval_policy_fn(policy_params),
key,
unroll_length=episode_length // action_repeat,
should_render=jnp.array(False, dtype=bool), # No rendering during eval
)[0]

self._generate_eval_unroll = jax.jit(generate_eval_unroll)
Expand Down
32 changes: 25 additions & 7 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def train(
Returns:
Tuple of (make_policy function, network params, metrics)
"""
# If the environment exposes a `render_fn`, use it for real-time rendering during training.
render_fn = None
if hasattr(environment, 'render_fn'):
render_fn = environment.render_fn

assert batch_size * num_minibatches % num_envs == 0
_validate_madrona_args(
madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env
Expand Down Expand Up @@ -483,7 +488,7 @@ def convert_data(x: jnp.ndarray):
return (optimizer_state, params, key), metrics

def training_step(
carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t
carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t, should_render: jax.Array,
) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:
training_state, state, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)
Expand All @@ -504,6 +509,8 @@ def f(carry, unused_t):
current_key,
unroll_length,
extra_fields=('truncation', 'episode_metrics', 'episode_done'),
render_fn=render_fn,
should_render=should_render,
)
return (next_state, next_key), data

Expand Down Expand Up @@ -552,10 +559,13 @@ def f(carry, unused_t):
return (new_training_state, state, new_key), metrics

def training_epoch(
training_state: TrainingState, state: envs.State, key: PRNGKey
training_state: TrainingState, state: envs.State, key: PRNGKey, should_render: jax.Array,
) -> Tuple[TrainingState, envs.State, Metrics]:
training_step_partial = functools.partial(
training_step, should_render=should_render
)
(training_state, state, _), loss_metrics = jax.lax.scan(
training_step,
training_step_partial,
(training_state, state, key),
(),
length=num_training_steps_per_epoch,
Expand All @@ -567,12 +577,12 @@ def training_epoch(

# Note that this is NOT a pure jittable method.
def training_epoch_with_timing(
training_state: TrainingState, env_state: envs.State, key: PRNGKey
training_state: TrainingState, env_state: envs.State, key: PRNGKey, should_render: jax.Array,
) -> Tuple[TrainingState, envs.State, Metrics]:
nonlocal training_walltime
t = time.time()
training_state, env_state = _strip_weak_type((training_state, env_state))
result = training_epoch(training_state, env_state, key)
result = training_epoch(training_state, env_state, key, should_render)
training_state, env_state, metrics = _strip_weak_type(result)

metrics = jax.tree_util.tree_map(jnp.mean, metrics)
Expand Down Expand Up @@ -695,11 +705,19 @@ def training_epoch_with_timing(
logging.info('starting iteration %s %s', it, time.time() - xt)

for _ in range(max(num_resets_per_eval, 1)):
# optimization
should_render_py = False
if hasattr(environment, 'should_render'):
should_render_py = bool(environment.should_render)

should_render_jax = jnp.array(should_render_py, dtype=bool)
should_render_replicated = jax.device_put_replicated(
should_render_jax, jax.local_devices()[:local_devices_to_use]
)

epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state, env_state, training_metrics) = (
training_epoch_with_timing(training_state, env_state, epoch_keys)
training_epoch_with_timing(training_state, env_state, epoch_keys, should_render_replicated)
)
current_step = int(_unpmap(training_state.env_steps))

Expand Down