Skip to content

Expand set of Mooncake rules#356

Open
lkdvos wants to merge 65 commits intomainfrom
ld-mooncakerules
Open

Expand set of Mooncake rules#356
lkdvos wants to merge 65 commits intomainfrom
ld-mooncakerules

Conversation

@lkdvos
Copy link
Member

@lkdvos lkdvos commented Jan 20, 2026

Here I am porting over a bunch of our chainrules to the Mooncake ones.

In particular, I am trying to identify the core computational routines and writing the rules for these, while not blindly taking the same methods.
For example, in ChainRules we overload rules for *(::Number, ::AbstractTensorMap), in Mooncake we simply define a rule for scale!(::AbstractTensorMap, ::Number).

To do:

  • Index manipulations
  • VectorInterface
  • LinearAlgebra
  • TensorOperations
  • PlanarOperations

Requires #360 to be merged first!

@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 83.30206% with 89 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorKitMooncakeExt/tangent.jl 64.66% 47 Missing ⚠️
ext/TensorKitMooncakeExt/planaroperations.jl 0.00% 32 Missing ⚠️
ext/TensorKitMooncakeExt/indexmanipulations.jl 96.13% 7 Missing ⚠️
ext/TensorKitMooncakeExt/factorizations.jl 88.88% 3 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl 100.00% <ø> (ø)
ext/TensorKitMooncakeExt/linalg.jl 100.00% <100.00%> (+100.00%) ⬆️
ext/TensorKitMooncakeExt/tensoroperations.jl 100.00% <100.00%> (+1.92%) ⬆️
ext/TensorKitMooncakeExt/utility.jl 85.71% <100.00%> (+42.85%) ⬆️
ext/TensorKitMooncakeExt/vectorinterface.jl 100.00% <100.00%> (ø)
src/factorizations/matrixalgebrakit.jl 97.05% <100.00%> (+0.02%) ⬆️
src/fusiontrees/manipulations.jl 86.30% <100.00%> (ø)
src/tensors/diagonal.jl 92.19% <100.00%> (+0.11%) ⬆️
src/tensors/indexmanipulations.jl 76.92% <100.00%> (+3.58%) ⬆️
ext/TensorKitMooncakeExt/factorizations.jl 88.88% <88.88%> (ø)
... and 3 more

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kshyatt
Copy link
Member

kshyatt commented Jan 21, 2026

I can likely pick up some of the linalg ones if you like

@lkdvos
Copy link
Member Author

lkdvos commented Jan 21, 2026

I'll keep my progress committed and pushed, feel free to push if you have something. If not I'll just gradually keep adding some whenever I'm waiting for other tests, so also shouldn't be a huge issue.

@lkdvos lkdvos marked this pull request as ready for review January 22, 2026 13:26
@lkdvos lkdvos requested review from Jutho and kshyatt January 22, 2026 13:26
Copy link
Member

@kshyatt kshyatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment its it might be nice to put some of these pullbacks into shared files like we did with MAK and TO so that if/when we add Enzyme support, we can do so with a light touch

@lkdvos
Copy link
Member Author

lkdvos commented Jan 22, 2026

I definitely agree that it would be nicer to put this in a better form, but would you be okay with leaving that for a follow-up PR?
I tried already separating out some of the functions so it should become easier to migrate this in the future.

What is preventing me from actually pulling this through though is that I now am also altering the primal computation in some places, specifically for constructions involving alpha.
The idea is that for a computation f!(out, args..., alpha, beta) = beta * out + alpha * f(args...) the pullback with respect to alpha is simply derived from f(args...) alone, so I can change the primal computation to f!_mod(out, args..., alpha, beta) = add!(out, f(args...), alpha, beta) and store the intermediate result. I.e. at the cost of adding an additional allocation and an in-place add!, I remove having to compute f in the reverse pass, but only when dalpha is required. (See e.g. the rule for mul!).

Without actually having the Enzyme code next to it, it's a bit hard to already come up with the correct abstractions to make sure this works for both engines, and I want to avoid having to do that work twice.

Additionally, it would be nice to immediately overload the TensorOperations functions but these haven't been released yet (and additionally I would like to play a similar trick there, but haven't gotten around to that yet)

@kshyatt
Copy link
Member

kshyatt commented Jan 22, 2026

I definitely agree that it would be nicer to put this in a better form, but would you be okay with leaving that for a follow-up PR?

Yeah that sounds fine, just separating things into discrete functions is great already

@lkdvos lkdvos enabled auto-merge (squash) January 22, 2026 15:53
@lkdvos lkdvos force-pushed the ld-mooncakerules branch 2 times, most recently from 0dd1456 to bd3cc11 Compare January 23, 2026 16:14
@lkdvos lkdvos requested a review from kshyatt January 23, 2026 16:14
@lkdvos
Copy link
Member Author

lkdvos commented Jan 29, 2026

Small update here:

  • This requires another MatrixAlgebraKit release to satisfy the Mooncake 0.5 compat
  • I'm adding a custom tangent type here because it turns out my test tolerances were a bit stupid. The finite differences tests that Mooncake is performing are giving wrong answers for non-abelian symmetries in the same way that the ChainRules ones required overloading FiniteDifferences.to_vec, since the inner product would just be the one on the data, rather than the one of the actual tensors.

This last part still really confuses me (I remember it also did with the ChainRules), but I'm just going to assume we figured this out correctly last time and copy that here.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 29, 2026

Your PR no longer requires formatting changes. Thank you for your contribution!

@lkdvos
Copy link
Member Author

lkdvos commented Jan 29, 2026

@lkdvos lkdvos force-pushed the ld-mooncakerules branch 4 times, most recently from 2331978 to 93c8ff1 Compare February 2, 2026 21:21
return Mooncake._rdata(Δβ)
# TODO: this result might be easier to compute as:
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
At = TO.tensortrace(A, p, q, false, One(), backend)
Copy link
Member

@Jutho Jutho Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow a similar strategy for blas_contract! and trace_permute! as for mul!, i.e. if _needs_tangent(α), we compute the result of the trace/contraction separately, instead of directly adding it to C, and then reuse that result in pullback_Δα?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this here, but in the future something we might have to experiment with is to further look at the trade-off between memory and computation cost here.
I have a feeling that if we were to really write blas_contract as permute-permute-gemm-permute, store the intermediates from that codepath and then correctly carry out the reverse pass, this might be faster, as this should avoid the need to repermute some of the tensors in the reverse pass. (e.g., you can see now already that the combination ΔC, pΔC, false appears in both the pullback for A and B, so effectively we are permuting this object twice)

However, as this is not what we were using before, I wanted to not get into that yet in this initial implementation, as I think this does require some careful consideration. I've left a to do comment to elaborate on this though.

Comment on lines 25 to 27
Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T
Mooncake.@foldable Mooncake.tangent_type(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} =
TK.tensormaptype(S, N₁, N₂, Mooncake.tangent_type(A))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't quite get the point of the two-arg version of this function from the Mooncake manual, as it is only mentioned in the "full interface" section without any details. Why is the two-arg version just using T<:TensorMap, whereas the one-arg version tries to be smart on the tangent type of the storage type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signatures are a little different:

tangent_type(fdata_type, rdata_type) -> ttype
tangent_type(primal_type) -> ttype

so for the first, we are using the fdata_type, which already has been converted, while the second is still a primal, so the storage has to be converted to tangent type.

Comment on lines +45 to +46
Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) =
TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this include something like

Suggested change
Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) =
TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t))
function Mooncake.zero_tangent_internal(t::TensorMap{T}, c::Mooncake.MaybeCache) where {T}
Tx = Mooncake.tangent_type(T)
Tx == Mooncake.NoTangent && return Mooncake.NoTangent()
return TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t))
end

to account for e.g. the T == Int case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I guess that yields a TensorMap filled with NoTangent(), which is also what seems to happen for Vector{Int}. I got confused (similarly for tangent_type) by the examples in the "full implementation appendix" .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might actually be reasonable to detect this already in the tangent type generation case, and simply bypass everything and make that NoTangent to begin with. I can't actually make a Vector{NoTangent}, since that does not "fit" inside the TensorMap{<:Number} type restriction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this now be fixed by catching the NoTangent case in the Mooncake.tangent_type definition? I don't know the internal structure of Mooncake, but it seems like Mooncake.zero_tangent_internal(primal, cache) could be called without needing to first call tangent_type(primal), since zero_tangent_internal could be expected to produce tangents of the correct type?

Comment on lines 150 to 156
getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3))

return if FieldName === 1 || FieldName === :data
dval = tangent(t).data
Dual(val, dval)
else # cannot be invalid fieldname since already called `getfield`
Dual(val, NoFData()), getfield_pullback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pullback appearing here in frule! correct? This looks off.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed not, I am now starting to think that these things aren't actually checked in the tangent test suite, so I'll still try and explicitly add some tests.

::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:DiagOrTensorMap}, ::CoDual{Val{FieldName}}
) where {FieldName}
val = getfield(primal(t), FieldName)
getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you briefly explain what the NoPullback does. I wasn't really able to understand the Mooncake doc string at this late hour.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was just using this wrong, and have now resolved this

@Jutho
Copy link
Member

Jutho commented Feb 14, 2026

Ok, I think I now went through everything but the test files. I think there are some remaining comments that need to be addressed.

Comment on lines +40 to +46
AB = if _needs_tangent(α)
AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator)
add!(C, AB, α, β)
else
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
nothing
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem correct. In the if case, doesn't AB get assigned to the output of add!, which is C ?

Suggested change
AB = if _needs_tangent(α)
AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator)
add!(C, AB, α, β)
else
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
nothing
end
if _needs_tangent(α)
AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator)
add!(C, AB, α, β)
else
TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
AB = nothing
end

Comment on lines +153 to +159
At = if _needs_tangent(α)
At = TO.tensortrace(A, p, q, false, One(), backend)
add!(C, A, α, β)
else
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
nothing
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment:

Suggested change
At = if _needs_tangent(α)
At = TO.tensortrace(A, p, q, false, One(), backend)
add!(C, A, α, β)
else
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
nothing
end
if _needs_tangent(α)
At = TO.tensortrace(A, p, q, false, One(), backend)
add!(C, A, α, β)
else
TensorKit.trace_permute!(C, A, p, q, α, β, backend)
At = nothing
end

Comment on lines +114 to +118
function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache)
data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c)
data === t.data || copy!(t.data, data)
return p
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems identical to the function on line 102-106:

Suggested change
function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache)
data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c)
data === t.data || copy!(t.data, data)
return p
end

Comment on lines +132 to +133
Mooncake._dot_internal(::Mooncake.MaybeCache, t::TensorMap, s::TensorMap) = Float64(real(inner(t, s)))
Mooncake._dot_internal(::Mooncake.MaybeCache, t::DiagonalTensorMap, s::DiagonalTensorMap) = Float64(real(inner(t, s)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the Float64 requirement a Mooncake specific thing?

Comment on lines +167 to +170
_field_symbol(f::Symbol) = f
_field_symbol(i::Int) = i == 1 ? :x : i == 2 ? :a : throw(ArgumentError("Invalid field index '$i' for type A."))
_field_symbol(::Type{Val{F}}) where {F} = _field_symbol(F)
_field_symbol(::Val{F}) where {F} = _field_symbol(F)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used? This seems to come out of the Mooncake manual, including the type A?

Comment on lines +279 to +280
ddata′ = Mooncake.increment_rdata!!(ddata, Δt_rdata.data)
return NoRData(), NoRData(), ddata′, NoRData()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it ever happen that Δt_rdata is not a NoRData, as this is how Mooncake.rdata_type was defined?

Copy link
Member

@Jutho Jutho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some more (fewer) comments and questions. Once these are addressed, I will go over the tests in some more detail, but I expect that this will be ready. Thanks for this massive PR, and for your patience 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for AD with Mooncake

3 participants