Skip to content

Conversation

@abhaybd
Copy link

@abhaybd abhaybd commented Jan 21, 2026

As seen here, the jax version of the pi0 modeling code correctly uses config.action_dim for the linear layer shapes. However, the corresponding pytorch code hardcodes this as 32.

This works fine for pi0/pi05 since they use an action dim of 32, but if users want to e.g. train a paligemma model with a different action dimension the pytorch code will error. This PR fixes this bug and uses the action dimension from the config.

@jimmyt857 jimmyt857 removed their request for review January 21, 2026 23:28
@abhaybd
Copy link
Author

abhaybd commented Jan 22, 2026

Fixes #714

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant