Skip to content

Conversation

@coreyjadams
Copy link
Collaborator

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@coreyjadams coreyjadams requested a review from ktangsali as a code owner January 8, 2026 19:46
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR updates the Transolver model to comply with PhysicsNeMo model implementation standards by adding comprehensive documentation, type annotations, and validation logic.

Major changes:

  • Added complete NumPy-style docstrings with proper sections (Parameters, Forward, Outputs, Examples) across all model classes and functions
  • Added jaxtyping type annotations for all tensor arguments following MOD-006
  • Added input validation with torch.compiler.is_compiling() guards in main forward methods following MOD-005
  • Added high-level comments explaining complex tensor operations following MOD-003k
  • Updated pyproject.toml to ignore F722 (allows jaxtyping syntax)
  • Changed docstring prefixes from """ to r""" for LaTeX compatibility following MOD-003b

Critical issues found:

  • MLP and Transolver_block classes inherit from nn.Module instead of physicsnemo.Module, violating MOD-001. These classes should be updated to inherit from physicsnemo.Module to ensure access to serialization, versioning, and registry features.
  • Missing input validation in MLP.forward() and Transolver_block.forward() methods (MOD-005 requirement)

Positive aspects:

  • Excellent documentation quality with clear examples and cross-references
  • Proper use of LaTeX math notation for tensor shapes
  • Good high-level comments in complex tensor operations
  • Consistent formatting and structure across all files

Important Files Changed

File Analysis

Filename Score Overview
physicsnemo/models/transolver/transolver.py 3/5 Added comprehensive docstrings with jaxtyping annotations, input validation, and high-level comments. Found critical issue: MLP and Transolver_block inherit from nn.Module instead of physicsnemo.Module (violates MOD-001).
physicsnemo/models/transolver/Physics_Attention.py 4/5 Added comprehensive docstrings with proper sections, jaxtyping annotations, and input validation. All classes correctly inherit from nn.Module (appropriate for reusable layers per MOD-000a).
physicsnemo/models/transolver/Embedding.py 4/5 Added complete docstrings with proper NumPy-style sections, jaxtyping annotations, LaTeX math notation for tensor shapes, and Examples sections. All classes correctly inherit from nn.Module.
pyproject.toml 5/5 Added F722 to ruff ignore list (allows jaxtyping syntax) and removed trailing whitespace. Changes are appropriate for supporting jaxtyping annotations.

Comment on lines +333 to +335
def forward(
self, fx: Float[torch.Tensor, "batch tokens hidden"]
) -> Float[torch.Tensor, "batch tokens out"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing validation for Transolver_block.forward() - per MOD-005, forward methods must validate tensor shapes

add validation wrapped in if not torch.compiler.is_compiling(): to check input shape matches expected (B, N, hidden_dim)

Comment on lines +93 to +97
def forward(
self,
coordinates: Float[torch.Tensor, "batch seq"],
device: torch.device,
) -> Float[torch.Tensor, "batch seq dim"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing validation for RotaryEmbedding.forward() - per MOD-005, forward methods should validate tensor shapes

consider adding validation to check that coordinates has expected 2D shape (B, N)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but shape validation for every single layer might be an overkill. Shape validation as detailed in MOD-005 is more critical for larger models comprising multiple layers.

Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, just a few minor things about formatting and such.
One bigger question though:

  • For Physics_Attention.py, could it be moved to physicsnemo/nn, ither with other attention layers in physicsnemo/nn/attention_layers.py, or just in its own module?
  • Same question for Embeddings.py

Comment on lines +40 to +43
References
----------
- `Transolver paper <https://arxiv.org/pdf/2402.02366>`_
- `Transolver++ paper <https://arxiv.org/pdf/2502.02414>`_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this "References" section will not be recognized and rendered properly in the online docs. Also, IMO this big docstring with examples and so on, should not be __init__.py but either in the class docstring, or in the docs/ directory (at leats it's what other models do)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this file (and a few others) named with upper case (e.g. Embedding.py, and Physics_Attention.py looks even worse)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Historical reasons, the files came in like this externally. Now is a good time to fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have some embeddings in physicsnemo/nn, so would this makes sense to move the content of this module to there?

return torch.cat((freqs, freqs), dim=-1) # (B, N, D)


def rotate_half(x: Float[torch.Tensor, "... dim"]) -> Float[torch.Tensor, "... dim"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question about this rotate_half, and a few other functions defined in this module: are they meant to be exposed or are they just some internal helper functions for this module? If the former case, would it make sense to make them semi-private?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly think they are not much used. I will make them semi-private, good idea.

Comment on lines +93 to +97
def forward(
self,
coordinates: Float[torch.Tensor, "batch seq"],
device: torch.device,
) -> Float[torch.Tensor, "batch seq dim"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but shape validation for every single layer might be an overkill. Shape validation as detailed in MOD-005 is more critical for larger models comprising multiple layers.

@@ -71,75 +75,195 @@


class MLP(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another MLP? Can it be replaced with FullyConnected or one that we already have? (maybe add a TE option in the existing FullyConnected?)

return x


class Transolver_block(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be TransolverBlock. The mixture of Pascal case and snake case is not allowed

@@ -334,53 +515,56 @@ def __init__(
self.__name__ = "Transolver"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what is this for?

Comment on lines +152 to +154
t: Float[torch.Tensor, "... dim"],
freqs: Float[torch.Tensor, "... dim"],
) -> Float[torch.Tensor, "... dim"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure the ellipsis ... is 100% correct in the jaxtyping there? The dimensions should be the same between all 3 tensors, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants