-
Notifications
You must be signed in to change notification settings - Fork 16
Loosen types for Reactant #162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
@tpapp can this get a quick review? |
tpapp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
questions, see comments
| dimension(::ScalarTransform) = 1 | ||
|
|
||
| function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int) | ||
| function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit surprised that you need this, since transform_with is called internally and index is an integer.
Can you explain what the actual type is that you need here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes in Reactant you can get a TracedNumber{Integer} if I loop through this and the array x is a AbstractTraced array.
| end | ||
|
|
||
| function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int) | ||
| function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| struct Identity <: ScalarTransform end | ||
|
|
||
| transform(::Identity, x::Real) = x | ||
| transform(::Identity, x::Number) = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle I am OK with widening Real here and in other places, the intent is to exclude Complex. It unfortunate that Base does not have an intermediate type for this purpose, but we can use Number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya I agree it is very sad.
|
There is one final thing. For some of these operations, we will induce scalar indexing. I am not really sure what the best way to deal with that is, because I can't add |
|
For example function TV.transform_with(::TV.LogJac, t::TV.ScalarTransform, x::Reactant.AnyTracedRVector, index)
return TV.transform_and_logjac(t, @allowscalar x[index])..., index + 1
endrequires the |
|
I could wrap this in a Reactant extension, but it would result in a lot of code duplication. |
|
I am not familiar with the internals of Reactant, so this may be a silly suggestion, but that index could be any other type (eg a wrapper) if that would help. The entry point is |
|
But I am open to any other reasonable solution too. Eg have a |
Thanks for the great package. I use it for a bunch of research. Right now, I am working to adapt my code to use
Reactantso I can take advantage of various accelerators. One of Reactant's quirks is that it requires pretty generous type annotations. The base structures areTracedRNumbers <: NumberandTracedRArray <: AbstractArray. If I want to useTransformVariableswithReactant. I sadly need to widen some of the type annotations.This is a draft PR that uses one method to make it compatible with Reactant, but I am very open to suggestions!
There are still some issues with this PR, including the scalar indexing, which will require some additional thought.