Skip to content

Conversation

@ALjone
Copy link
Contributor

@ALjone ALjone commented Nov 8, 2024

Per now, if using auto-reset, the env is reset every single step

if self.auto_reset:
        obs_re, state_re = self.reset_env(key_reset, params)
        # Use lax.cond to efficiently choose between obs_re and obs_st
        obs = jax.lax.cond(
            done,
            lambda: obs_re,
            lambda: obs_st
        )
        state = jax.lax.cond(
            done,
            lambda: state_re,
            lambda: state_st
        )

This is fairly expensive, and can be avoided by using cond to only call the reset function when needed, which saves around 504 calls to the reset function per game:

if self.auto_reset:
        # Reset the env only if done to avoid generating new state every step
        obs, state = jax.lax.cond(
            done,
            lambda: self.reset_env(key_reset, params),
            lambda: (obs_st, state_st),
        )

I'm not a Jax expert, but as far as I can tell, the above example should work.

Doing this, I observe an increase in steps per second of more than 30%

@netlify
Copy link

netlify bot commented Nov 8, 2024

Deploy Preview for lux-eye-s3 canceled.

Name Link
🔨 Latest commit 2368823
🔍 Latest deploy log https://app.netlify.com/sites/lux-eye-s3/deploys/672e43c4bbbc480008feee87

@StoneT2000
Copy link
Member

I am surprised jax can't optimize this part out. I'll take a look and benchmark as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants