Conversation
|
I can likely pick up some of the linalg ones if you like |
|
I'll keep my progress committed and pushed, feel free to push if you have something. If not I'll just gradually keep adding some whenever I'm waiting for other tests, so also shouldn't be a huge issue. |
130f031 to
7a6adcc
Compare
kshyatt
left a comment
There was a problem hiding this comment.
One comment its it might be nice to put some of these pullbacks into shared files like we did with MAK and TO so that if/when we add Enzyme support, we can do so with a light touch
|
I definitely agree that it would be nicer to put this in a better form, but would you be okay with leaving that for a follow-up PR? What is preventing me from actually pulling this through though is that I now am also altering the primal computation in some places, specifically for constructions involving Without actually having the Enzyme code next to it, it's a bit hard to already come up with the correct abstractions to make sure this works for both engines, and I want to avoid having to do that work twice. Additionally, it would be nice to immediately overload the TensorOperations functions but these haven't been released yet (and additionally I would like to play a similar trick there, but haven't gotten around to that yet) |
Yeah that sounds fine, just separating things into discrete functions is great already |
0dd1456 to
bd3cc11
Compare
6037b99 to
e2dc00e
Compare
|
Small update here:
This last part still really confuses me (I remember it also did with the ChainRules), but I'm just going to assume we figured this out correctly last time and copy that here. |
e2dc00e to
3bf4288
Compare
|
Your PR no longer requires formatting changes. Thank you for your contribution! |
3bf4288 to
fe3b92d
Compare
2331978 to
93c8ff1
Compare
74abd9d to
4f64074
Compare
| return Mooncake._rdata(Δβ) | ||
| # TODO: this result might be easier to compute as: | ||
| # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α | ||
| At = TO.tensortrace(A, p, q, false, One(), backend) |
There was a problem hiding this comment.
Can we follow a similar strategy for blas_contract! and trace_permute! as for mul!, i.e. if _needs_tangent(α), we compute the result of the trace/contraction separately, instead of directly adding it to C, and then reuse that result in pullback_Δα?
There was a problem hiding this comment.
I changed this here, but in the future something we might have to experiment with is to further look at the trade-off between memory and computation cost here.
I have a feeling that if we were to really write blas_contract as permute-permute-gemm-permute, store the intermediates from that codepath and then correctly carry out the reverse pass, this might be faster, as this should avoid the need to repermute some of the tensors in the reverse pass. (e.g., you can see now already that the combination ΔC, pΔC, false appears in both the pullback for A and B, so effectively we are permuting this object twice)
However, as this is not what we were using before, I wanted to not get into that yet in this initial implementation, as I think this does require some careful consideration. I've left a to do comment to elaborate on this though.
ext/TensorKitMooncakeExt/tangent.jl
Outdated
| Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T | ||
| Mooncake.@foldable Mooncake.tangent_type(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} = | ||
| TK.tensormaptype(S, N₁, N₂, Mooncake.tangent_type(A)) |
There was a problem hiding this comment.
I didn't quite get the point of the two-arg version of this function from the Mooncake manual, as it is only mentioned in the "full interface" section without any details. Why is the two-arg version just using T<:TensorMap, whereas the one-arg version tries to be smart on the tangent type of the storage type?
There was a problem hiding this comment.
The signatures are a little different:
tangent_type(fdata_type, rdata_type) -> ttype
tangent_type(primal_type) -> ttypeso for the first, we are using the fdata_type, which already has been converted, while the second is still a primal, so the storage has to be converted to tangent type.
| Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) = | ||
| TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) |
There was a problem hiding this comment.
Should this include something like
| Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) = | |
| TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) | |
| function Mooncake.zero_tangent_internal(t::TensorMap{T}, c::Mooncake.MaybeCache) where {T} | |
| Tx = Mooncake.tangent_type(T) | |
| Tx == Mooncake.NoTangent && return Mooncake.NoTangent() | |
| return TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) | |
| end |
to account for e.g. the T == Int case?
There was a problem hiding this comment.
Although I guess that yields a TensorMap filled with NoTangent(), which is also what seems to happen for Vector{Int}. I got confused (similarly for tangent_type) by the examples in the "full implementation appendix" .
There was a problem hiding this comment.
It might actually be reasonable to detect this already in the tangent type generation case, and simply bypass everything and make that NoTangent to begin with. I can't actually make a Vector{NoTangent}, since that does not "fit" inside the TensorMap{<:Number} type restriction.
There was a problem hiding this comment.
Will this now be fixed by catching the NoTangent case in the Mooncake.tangent_type definition? I don't know the internal structure of Mooncake, but it seems like Mooncake.zero_tangent_internal(primal, cache) could be called without needing to first call tangent_type(primal), since zero_tangent_internal could be expected to produce tangents of the correct type?
ext/TensorKitMooncakeExt/tangent.jl
Outdated
| getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) | ||
|
|
||
| return if FieldName === 1 || FieldName === :data | ||
| dval = tangent(t).data | ||
| Dual(val, dval) | ||
| else # cannot be invalid fieldname since already called `getfield` | ||
| Dual(val, NoFData()), getfield_pullback |
There was a problem hiding this comment.
Is this pullback appearing here in frule! correct? This looks off.
There was a problem hiding this comment.
indeed not, I am now starting to think that these things aren't actually checked in the tangent test suite, so I'll still try and explicitly add some tests.
ext/TensorKitMooncakeExt/tangent.jl
Outdated
| ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:DiagOrTensorMap}, ::CoDual{Val{FieldName}} | ||
| ) where {FieldName} | ||
| val = getfield(primal(t), FieldName) | ||
| getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) |
There was a problem hiding this comment.
Can you briefly explain what the NoPullback does. I wasn't really able to understand the Mooncake doc string at this late hour.
There was a problem hiding this comment.
I think I was just using this wrong, and have now resolved this
|
Ok, I think I now went through everything but the test files. I think there are some remaining comments that need to be addressed. |
7d44198 to
80c1333
Compare
| AB = if _needs_tangent(α) | ||
| AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) | ||
| add!(C, AB, α, β) | ||
| else | ||
| TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) | ||
| nothing | ||
| end |
There was a problem hiding this comment.
This doesn't seem correct. In the if case, doesn't AB get assigned to the output of add!, which is C ?
| AB = if _needs_tangent(α) | |
| AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) | |
| add!(C, AB, α, β) | |
| else | |
| TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) | |
| nothing | |
| end | |
| if _needs_tangent(α) | |
| AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) | |
| add!(C, AB, α, β) | |
| else | |
| TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) | |
| AB = nothing | |
| end |
| At = if _needs_tangent(α) | ||
| At = TO.tensortrace(A, p, q, false, One(), backend) | ||
| add!(C, A, α, β) | ||
| else | ||
| TensorKit.trace_permute!(C, A, p, q, α, β, backend) | ||
| nothing | ||
| end |
There was a problem hiding this comment.
Same comment:
| At = if _needs_tangent(α) | |
| At = TO.tensortrace(A, p, q, false, One(), backend) | |
| add!(C, A, α, β) | |
| else | |
| TensorKit.trace_permute!(C, A, p, q, α, β, backend) | |
| nothing | |
| end | |
| if _needs_tangent(α) | |
| At = TO.tensortrace(A, p, q, false, One(), backend) | |
| add!(C, A, α, β) | |
| else | |
| TensorKit.trace_permute!(C, A, p, q, α, β, backend) | |
| At = nothing | |
| end |
| function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache) | ||
| data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c) | ||
| data === t.data || copy!(t.data, data) | ||
| return p | ||
| end |
There was a problem hiding this comment.
This seems identical to the function on line 102-106:
| function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache) | |
| data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c) | |
| data === t.data || copy!(t.data, data) | |
| return p | |
| end |
| Mooncake._dot_internal(::Mooncake.MaybeCache, t::TensorMap, s::TensorMap) = Float64(real(inner(t, s))) | ||
| Mooncake._dot_internal(::Mooncake.MaybeCache, t::DiagonalTensorMap, s::DiagonalTensorMap) = Float64(real(inner(t, s))) |
There was a problem hiding this comment.
Is the Float64 requirement a Mooncake specific thing?
| _field_symbol(f::Symbol) = f | ||
| _field_symbol(i::Int) = i == 1 ? :x : i == 2 ? :a : throw(ArgumentError("Invalid field index '$i' for type A.")) | ||
| _field_symbol(::Type{Val{F}}) where {F} = _field_symbol(F) | ||
| _field_symbol(::Val{F}) where {F} = _field_symbol(F) |
There was a problem hiding this comment.
Is this used? This seems to come out of the Mooncake manual, including the type A?
| ddata′ = Mooncake.increment_rdata!!(ddata, Δt_rdata.data) | ||
| return NoRData(), NoRData(), ddata′, NoRData() |
There was a problem hiding this comment.
Can it ever happen that Δt_rdata is not a NoRData, as this is how Mooncake.rdata_type was defined?
Here I am porting over a bunch of our chainrules to the Mooncake ones.
In particular, I am trying to identify the core computational routines and writing the rules for these, while not blindly taking the same methods.
For example, in ChainRules we overload rules for
*(::Number, ::AbstractTensorMap), in Mooncake we simply define a rule forscale!(::AbstractTensorMap, ::Number).To do:
Requires #360 to be merged first!