-
Notifications
You must be signed in to change notification settings - Fork 564
Update transolver to comply with model standards #1316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR updates the Transolver model to comply with PhysicsNeMo model implementation standards by adding comprehensive documentation, type annotations, and validation logic.
Major changes:
- Added complete NumPy-style docstrings with proper sections (Parameters, Forward, Outputs, Examples) across all model classes and functions
- Added jaxtyping type annotations for all tensor arguments following MOD-006
- Added input validation with
torch.compiler.is_compiling()guards in main forward methods following MOD-005 - Added high-level comments explaining complex tensor operations following MOD-003k
- Updated
pyproject.tomlto ignore F722 (allows jaxtyping syntax) - Changed docstring prefixes from
"""tor"""for LaTeX compatibility following MOD-003b
Critical issues found:
MLPandTransolver_blockclasses inherit fromnn.Moduleinstead ofphysicsnemo.Module, violating MOD-001. These classes should be updated to inherit fromphysicsnemo.Moduleto ensure access to serialization, versioning, and registry features.- Missing input validation in
MLP.forward()andTransolver_block.forward()methods (MOD-005 requirement)
Positive aspects:
- Excellent documentation quality with clear examples and cross-references
- Proper use of LaTeX math notation for tensor shapes
- Good high-level comments in complex tensor operations
- Consistent formatting and structure across all files
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| physicsnemo/models/transolver/transolver.py | 3/5 | Added comprehensive docstrings with jaxtyping annotations, input validation, and high-level comments. Found critical issue: MLP and Transolver_block inherit from nn.Module instead of physicsnemo.Module (violates MOD-001). |
| physicsnemo/models/transolver/Physics_Attention.py | 4/5 | Added comprehensive docstrings with proper sections, jaxtyping annotations, and input validation. All classes correctly inherit from nn.Module (appropriate for reusable layers per MOD-000a). |
| physicsnemo/models/transolver/Embedding.py | 4/5 | Added complete docstrings with proper NumPy-style sections, jaxtyping annotations, LaTeX math notation for tensor shapes, and Examples sections. All classes correctly inherit from nn.Module. |
| pyproject.toml | 5/5 | Added F722 to ruff ignore list (allows jaxtyping syntax) and removed trailing whitespace. Changes are appropriate for supporting jaxtyping annotations. |
| def forward( | ||
| self, fx: Float[torch.Tensor, "batch tokens hidden"] | ||
| ) -> Float[torch.Tensor, "batch tokens out"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for Transolver_block.forward() - per MOD-005, forward methods must validate tensor shapes
add validation wrapped in if not torch.compiler.is_compiling(): to check input shape matches expected (B, N, hidden_dim)
| def forward( | ||
| self, | ||
| coordinates: Float[torch.Tensor, "batch seq"], | ||
| device: torch.device, | ||
| ) -> Float[torch.Tensor, "batch seq dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing validation for RotaryEmbedding.forward() - per MOD-005, forward methods should validate tensor shapes
consider adding validation to check that coordinates has expected 2D shape (B, N)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but shape validation for every single layer might be an overkill. Shape validation as detailed in MOD-005 is more critical for larger models comprising multiple layers.
CharlelieLrt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good, just a few minor things about formatting and such.
One bigger question though:
- For
Physics_Attention.py, could it be moved tophysicsnemo/nn, ither with other attention layers inphysicsnemo/nn/attention_layers.py, or just in its own module? - Same question for
Embeddings.py
| References | ||
| ---------- | ||
| - `Transolver paper <https://arxiv.org/pdf/2402.02366>`_ | ||
| - `Transolver++ paper <https://arxiv.org/pdf/2502.02414>`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK this "References" section will not be recognized and rendered properly in the online docs. Also, IMO this big docstring with examples and so on, should not be __init__.py but either in the class docstring, or in the docs/ directory (at leats it's what other models do)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this file (and a few others) named with upper case (e.g. Embedding.py, and Physics_Attention.py looks even worse)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Historical reasons, the files came in like this externally. Now is a good time to fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have some embeddings in physicsnemo/nn, so would this makes sense to move the content of this module to there?
| return torch.cat((freqs, freqs), dim=-1) # (B, N, D) | ||
|
|
||
|
|
||
| def rotate_half(x: Float[torch.Tensor, "... dim"]) -> Float[torch.Tensor, "... dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question about this rotate_half, and a few other functions defined in this module: are they meant to be exposed or are they just some internal helper functions for this module? If the former case, would it make sense to make them semi-private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I honestly think they are not much used. I will make them semi-private, good idea.
| def forward( | ||
| self, | ||
| coordinates: Float[torch.Tensor, "batch seq"], | ||
| device: torch.device, | ||
| ) -> Float[torch.Tensor, "batch seq dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true, but shape validation for every single layer might be an overkill. Shape validation as detailed in MOD-005 is more critical for larger models comprising multiple layers.
| @@ -71,75 +75,195 @@ | |||
|
|
|||
|
|
|||
| class MLP(nn.Module): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another MLP? Can it be replaced with FullyConnected or one that we already have? (maybe add a TE option in the existing FullyConnected?)
| return x | ||
|
|
||
|
|
||
| class Transolver_block(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be TransolverBlock. The mixture of Pascal case and snake case is not allowed
| @@ -334,53 +515,56 @@ def __init__( | |||
| self.__name__ = "Transolver" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what is this for?
| t: Float[torch.Tensor, "... dim"], | ||
| freqs: Float[torch.Tensor, "... dim"], | ||
| ) -> Float[torch.Tensor, "... dim"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure the ellipsis ... is 100% correct in the jaxtyping there? The dimensions should be the same between all 3 tensors, right?
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.