diff --git a/Project.toml b/Project.toml index 5a4dbc8b97..ecf75ab49c 100644 --- a/Project.toml +++ b/Project.toml @@ -97,7 +97,7 @@ Functors = "0.5" GPUArraysCore = "0.2" LinearAlgebra = "1.10" LossFunctions = "0.11.1, 1" -LuxCore = "1.4" +LuxCore = "1.4" # XXX: bump to 1.5 before merge LuxLib = "1.11.0" MLDataDevices = "1.10.0" MLUtils = "0.4.4" diff --git a/ext/LuxReactantExt/tracing.jl b/ext/LuxReactantExt/tracing.jl index 1c5b4bb7a2..2fbac4d706 100644 --- a/ext/LuxReactantExt/tracing.jl +++ b/ext/LuxReactantExt/tracing.jl @@ -16,10 +16,12 @@ function Reactant.make_tracer( seen, @nospecialize(model::StatefulLuxLayer), @nospecialize(path), mode; kwargs... ) return StatefulLuxLayer( - model.model, - Reactant.make_tracer(seen, model.ps, (path..., :ps), mode; kwargs...), - Reactant.make_tracer(seen, model.st, (path..., :st), mode; kwargs...), - Reactant.make_tracer(seen, model.st_any, (path..., :st_any), mode; kwargs...), - model.fixed_state_type, + getfield(model, :model), + Reactant.make_tracer(seen, getfield(model, :ps), (path..., :ps), mode; kwargs...), + Reactant.make_tracer(seen, getfield(model, :st), (path..., :st), mode; kwargs...), + Reactant.make_tracer( + seen, getfield(model, :st_any), (path..., :st_any), mode; kwargs... + ), + getfield(model, :fixed_state_type), ) end diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index f60121ab64..72b996ef87 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.4.1" +version = "1.5.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl index c431299523..eb905de1c4 100644 --- a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -41,10 +41,22 @@ function Functors.functor( ) recon = let ft = x.fixed_state_type nt -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer( - nt.model, nt.ps, nt.st, nt.st_any, ft + getfield(nt, :model), + getfield(nt, :ps), + getfield(nt, :st), + getfield(nt, :st_any), + ft, ) end - return (; x.model, x.ps, x.st, x.st_any), recon + return ( + (; + model=getfield(x, :model), + ps=getfield(x, :ps), + st=getfield(x, :st), + st_any=getfield(x, :st_any), + ), + recon, + ) end end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl index e3ba51f40c..9d0026b932 100644 --- a/lib/LuxCore/src/LuxCore.jl +++ b/lib/LuxCore/src/LuxCore.jl @@ -151,9 +151,7 @@ this include: type stability. By default this is "disable"d. For more information, see the [documentation](https://github.com/MilesCranmer/DispatchDoctor.jl). """ -@stable default_mode = "disable" function apply(model::AbstractLuxLayer, x, ps, st) - return model(x, ps, st) -end +function apply end """ stateless_apply(model, x, ps) @@ -162,9 +160,7 @@ Calls `apply` and only returns the first argument. This function requires that ` an empty state of `NamedTuple()`. Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state. """ -function stateless_apply(model::AbstractLuxLayer, x, ps) - return first(apply(model, x, ps, Internal.get_empty_state(model))) -end +function stateless_apply end """ display_name(layer::AbstractLuxLayer) @@ -265,10 +261,6 @@ function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} return statelength(getfield(l, layer)) end -function (l::AbstractLuxWrapperLayer{layer})(x, ps, st) where {layer} - return apply(getfield(l, layer), x, ps, st) -end - # Test Mode """ testmode(st::NamedTuple) @@ -357,6 +349,7 @@ preserves_state_type(l::Tuple) = all(preserves_state_type, l) preserves_state_type(l::NamedTuple) = all(preserves_state_type, values(l)) include("stateful.jl") +include("apply.jl") module Internal diff --git a/lib/LuxCore/src/apply.jl b/lib/LuxCore/src/apply.jl new file mode 100644 index 0000000000..ff24c77712 --- /dev/null +++ b/lib/LuxCore/src/apply.jl @@ -0,0 +1,24 @@ +@stable default_mode = "disable" function apply(model::AbstractLuxLayer, x, ps, st) + return model(x, ps, st) +end + +function stateless_apply(model::AbstractLuxLayer, x, ps) + return first(apply(model, x, ps, Internal.get_empty_state(model))) +end + +# New Interface that circumvents having to manage the state manually +function (model::AbstractLuxLayer)(x, ps, st) + xs = x isa Tuple ? x : (x,) + smodel = StatefulLuxLayerImpl.NamedTupleStatefulLuxLayer(model, ps, st) + output = apply(typeof(model), smodel, xs...) + return output, StatefulLuxLayerImpl.get_states_as_namedtuple(smodel) +end + +# fallback for wrapped layers +function apply(::Type{<:AbstractLuxWrapperLayer}, model, x) + return apply(only(getfield(model, :smodels)), x) +end + +function apply(::Type{<:AbstractLuxWrapperLayer}, model, xs...) + return apply(only(getfield(model, :smodels)), xs) +end diff --git a/lib/LuxCore/src/stateful.jl b/lib/LuxCore/src/stateful.jl index 6468af19c9..e0adf962fc 100644 --- a/lib/LuxCore/src/stateful.jl +++ b/lib/LuxCore/src/stateful.jl @@ -1,6 +1,10 @@ module StatefulLuxLayerImpl -using ..LuxCore: AbstractLuxLayer, preserves_state_type +using ..LuxCore: + AbstractLuxLayer, + AbstractLuxContainerLayer, + AbstractLuxWrapperLayer, + preserves_state_type const StaticBool = Union{Val{true},Val{false}} @@ -27,6 +31,28 @@ mutable struct StatefulLuxLayer{ST,M<:AbstractLuxLayer,psType,stType} end end +function Base.getproperty(l::StatefulLuxLayer, s::Symbol) + s === :st && return get_state(l) + s === :st_any && throw( + ArgumentError( + "No property `st_any` for `StatefulLuxLayer`. To access the `st_any` field use \ + `getfield(l, :st_any)` instead.", + ), + ) + return getfield(l, s) +end + +function Base.setproperty!(l::StatefulLuxLayer, s::Symbol, v) + s === :st && return set_state!(l, v) + s === :st_any && throw( + ArgumentError( + "No property `st_any` for `StatefulLuxLayer`. To set the `st_any` field use \ + `setfield!(l, :st_any, v)` instead.", + ), + ) + return setfield!(l, s, v) +end + function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} return StatefulLuxLayer(model, ps, st, st_any, static(ST)) end @@ -43,8 +69,8 @@ function StatefulLuxLayer(model::AbstractLuxLayer, ps, st) return StatefulLuxLayer{preserves_state_type(model)}(model, ps, st) end -get_state(l::StatefulLuxLayer{Val{true}}) = l.st -get_state(l::StatefulLuxLayer{Val{false}}) = l.st_any +get_state(l::StatefulLuxLayer{Val{true}}) = getfield(l, :st) +get_state(l::StatefulLuxLayer{Val{false}}) = getfield(l, :st_any) # Needed for compact macro implementation get_state(st::AbstractArray{<:Number}) = st get_state(st::Union{AbstractArray,Tuple,NamedTuple}) = map(get_state, st) @@ -53,7 +79,7 @@ get_state(st) = st function set_state!( s::StatefulLuxLayer{Val{true},<:Any,<:Any,stType}, st::stType ) where {stType} - return s.st = st + return setfield!(s, :st, st) end function set_state!( ::StatefulLuxLayer{Val{true},<:Any,<:Any,stType}, ::stType2 @@ -66,9 +92,203 @@ function set_state!( defined for all the layers in the model.") ) end -set_state!(s::StatefulLuxLayer{Val{false}}, st) = (s.st_any = st) +set_state!(s::StatefulLuxLayer{Val{false}}, st) = setfield!(s, :st_any, st) + +# This is an internal implementation detail for bypassing manual state management +struct DualFieldNamedTuple{F1,F2,DM,D<:Tuple} + is_data_mutable::DM + data::D +end + +function get_backing_data(l::DualFieldNamedTuple{F1,F2,Val{false}}) where {F1,F2} + return getfield(l, :data) +end +function get_backing_data(l::DualFieldNamedTuple{F1,F2,Val{true}}) where {F1,F2} + return getindex.(getfield(l, :data)) +end + +@generated function prefixed_fieldnames( + ::NamedTuple{fields}, ::Val{prefix} +) where {fields,prefix} + new_fields = Tuple(QuoteNode.(Symbol.(Ref(prefix), fields))) + return :($(Tuple(new_fields)...),) +end + +function DualFieldNamedTuple( + data::NamedTuple{fields}, + ::Val{prefix}, + is_data_mutable::StaticBool=Val(false), + is_data_type_fixed::StaticBool=Val(true), +) where {fields,prefix} + @assert length(fields) == length(data) + + if dynamic(is_data_mutable) + if dynamic(is_data_type_fixed) + backing = Ref.(values(data)) + else + backing = Ref{Any}.(values(data)) + end + BT = typeof(backing) + else + @assert dynamic(is_data_type_fixed) + backing = values(data) + BT = typeof(backing) + end + + return DualFieldNamedTuple{ + fields,prefixed_fieldnames(data, Val(prefix)),typeof(is_data_mutable),BT + }( + is_data_mutable, backing + ) +end + +function DualFieldNamedTuple(data, args...) + fnames = fieldnames(typeof(data)) + return DualFieldNamedTuple(NamedTuple{fnames}(getfield.(Ref(data), fnames)), args...) +end + +first_property_names(::DualFieldNamedTuple{F1}) where {F1} = F1 + +Base.propertynames(::DualFieldNamedTuple{F1,F2}) where {F1,F2} = (F1..., F2...) + +function findfieldidx(::DualFieldNamedTuple{F1,F2}, s::Symbol) where {F1,F2} + idx = findfirst(==(s), F1) + idx !== nothing && return idx + return findfirst(==(s), F2) +end + +maybe_unwrap_ref(::Val{false}, x) = x +maybe_unwrap_ref(::Val{true}, x::Ref) = x[] + +function Base.getproperty(l::DualFieldNamedTuple{F1,F2}, s::Symbol) where {F1,F2} + idx = findfieldidx(l, s) + idx !== nothing && return maybe_unwrap_ref( + getfield(l, :is_data_mutable), getfield(getfield(l, :data), idx) + ) + throw(ArgumentError("No property $s for `DualFieldNamedTuple`")) +end + +function Base.setproperty!( + ::DualFieldNamedTuple{F1,F2,Val{false}}, s::Symbol, v +) where {F1,F2} + throw(ArgumentError("Cannot set property $s for `DualFieldNamedTuple` since the data \ + was constructed as immutable.")) +end + +function Base.setproperty!( + l::DualFieldNamedTuple{F1,F2,Val{true}}, s::Symbol, v +) where {F1,F2} + idx = findfieldidx(l, s) + idx === nothing && throw(ArgumentError("No property $s for `DualFieldNamedTuple`")) + getfield(l, :data)[idx][] = v + return l +end + +struct NamedTupleStatefulLuxLayer{layers,NT<:NamedTuple,M,NTPS,NTST} + smodels::NT + model::M + ps_extra::NTPS + st_extra::NTST +end + +function get_states_as_namedtuple(l::NamedTupleStatefulLuxLayer{layers}) where {layers} + smodels = getfield(l, :smodels) + return NamedTuple{(layers..., first_property_names(getfield(l, :st_extra))...)}(( + (getfield(smodels, layer).st for layer in layers)..., + get_backing_data(getfield(l, :st_extra))..., + ),) +end + +function NamedTupleStatefulLuxLayer{fields}( + smodels, model::AbstractLuxLayer, ps::NamedTuple, st::NamedTuple +) where {fields} + model_dfnt = DualFieldNamedTuple(model, Val(:model_)) + ps_extra = DualFieldNamedTuple(ps, Val(:ps_)) + st_extra = DualFieldNamedTuple( + st, Val(:st_), Val(true), Val(preserves_state_type(model)) + ) + return NamedTupleStatefulLuxLayer{ + fields,typeof(smodels),typeof(model_dfnt),typeof(ps_extra),typeof(st_extra) + }( + smodels, model_dfnt, ps_extra, st_extra + ) +end + +function NamedTupleStatefulLuxLayer(model::AbstractLuxLayer, ps, st) + return NamedTupleStatefulLuxLayer{()}((;), model, ps, st) +end + +@generated function NamedTupleStatefulLuxLayer( + model::AbstractLuxWrapperLayer{layer}, ps, st +) where {layer} + layers = (layer,) + smodel_expr = [:(StatefulLuxLayer(model.$(layer), ps, st))] + return quote + return NamedTupleStatefulLuxLayer{$(layers)}( + NamedTuple{$(layers)}(($(Tuple(smodel_expr)...),)), + model, + NamedTuple(), + NamedTuple(), + ) + end +end + +@generated function NamedTupleStatefulLuxLayer( + model::AbstractLuxContainerLayer{layers}, ps::PS, st::ST +) where {layers,PS,ST} + ps_extra_fields = Tuple(setdiff(fieldnames(PS), layers)) + st_extra_fields = Tuple(setdiff(fieldnames(ST), layers)) + + smodel_expr = [ + :(StatefulLuxLayer(model.$(layer), ps.$(layer), st.$(layer))) for layer in layers + ] + + ps_get_expr = [:(getfield(ps, $(QuoteNode(f)))) for f in ps_extra_fields] + st_get_expr = [:(getfield(st, $(QuoteNode(f)))) for f in st_extra_fields] + + return quote + smodels = NamedTuple{$(layers)}(($(Tuple(smodel_expr)...),)) + ps_extra = NamedTuple{$(ps_extra_fields)}(($(Tuple(ps_get_expr)...),)) + st_extra = NamedTuple{$(st_extra_fields)}(($(Tuple(st_get_expr)...),)) + return NamedTupleStatefulLuxLayer{$(layers)}(smodels, model, ps_extra, st_extra) + end +end + +function Base.propertynames(l::NamedTupleStatefulLuxLayer{layers}) where {layers} + return ( + layers..., + propertynames(getfield(l, :ps_extra))..., + propertynames(getfield(l, :st_extra))..., + propertynames(getfield(l, :model))..., + ) +end + +function Base.getproperty(l::NamedTupleStatefulLuxLayer{layers}, s::Symbol) where {layers} + s in layers && return getfield(getfield(l, :smodels), s) + hasproperty(getfield(l, :model), s) && return getproperty(getfield(l, :model), s) + hasproperty(getfield(l, :ps_extra), s) && return getproperty(getfield(l, :ps_extra), s) + hasproperty(getfield(l, :st_extra), s) && return getproperty(getfield(l, :st_extra), s) + throw(ArgumentError("No property $(s) for `NamedTupleStatefulLuxLayer`")) +end + +function Base.setproperty!( + l::NamedTupleStatefulLuxLayer{layers}, s::Symbol, v +) where {layers} + s in layers && return throw(ArgumentError("Cannot set property $(s) for \ + `NamedTupleStatefulLuxLayer` since it is a \ + container layer.")) + hasproperty(getfield(l, :model), s) && + throw(ArgumentError("Cannot set property $(s) for `NamedTupleStatefulLuxLayer` \ + since it is a model field.")) + hasproperty(getfield(l, :ps_extra), s) && + throw(ArgumentError("Cannot set property $(s) for `NamedTupleStatefulLuxLayer` \ + since it is a parameter.")) + hasproperty(getfield(l, :st_extra), s) && + return setproperty!(getfield(l, :st_extra), s, v) + throw(ArgumentError("No property $(s) for NamedTupleStatefulLuxLayer")) +end -export StatefulLuxLayer +export StatefulLuxLayer, NamedTupleStatefulLuxLayer end @@ -155,7 +375,7 @@ preserves_state_type(m::StatefulLuxLayer) = StatefulLuxLayerImpl.dynamic(m.fixed function (m::StatefulLuxLayer)(x, p=m.ps) @assert p !== nothing "Model parameters are not set in constructor. Pass in `ps` \ explicitly." - y, st = apply(m.model, x, p, StatefulLuxLayerImpl.get_state(m)) + y, st = apply(m.model, x, p, m.st) StatefulLuxLayerImpl.set_state!(m, st) return y end diff --git a/src/Lux.jl b/src/Lux.jl index 0b672bd6ba..0df34c012f 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -41,7 +41,8 @@ import LuxCore: setup, apply, replicate, - preserves_state_type + preserves_state_type, + display_name @reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers using NNlib: diff --git a/src/extended_ops.jl b/src/extended_ops.jl index 7ceb588f23..3204fcafb6 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -266,7 +266,7 @@ for (op, field) in ( :track_stats => :track_stats, :train_state => :train_state, ) - @eval function $(Symbol(:has_, op))(l::AbstractLuxLayer) + @eval function $(Symbol(:has_, op))(l) res = known(safe_getproperty(l, Val($(Meta.quot(field))))) return ifelse(res === nothing, false, res) end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 05b18d93f9..5aeed1d6bc 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -162,29 +162,28 @@ function MultiHeadAttention( ) end -(mha::MultiHeadAttention)(x, ps, st::NamedTuple) = apply_multiheadattention(mha, ps, st, x) -function (mha::MultiHeadAttention)(x::Tuple, ps, st::NamedTuple) - return apply_multiheadattention(mha, ps, st, x...) +function apply(T::Type{<:MultiHeadAttention}, mha, qkv::AbstractArray) + return apply(T, mha, qkv, qkv, qkv, nothing) end -function apply_multiheadattention(mha::MultiHeadAttention, ps, st, qkv) - return apply_multiheadattention(mha, ps, st, qkv, qkv, qkv, nothing) +function apply(T::Type{<:MultiHeadAttention}, mha, q::AbstractArray, kv::AbstractArray) + return apply(T, mha, q, kv, kv, nothing) end -function apply_multiheadattention(mha::MultiHeadAttention, ps, st, q, kv) - return apply_multiheadattention(mha, ps, st, q, kv, kv, nothing) -end - -function apply_multiheadattention(mha::MultiHeadAttention, ps, st, q, k, v, mask=nothing) - q, k, v = match_eltype(mha, ps, st, q, k, v) +function apply( + ::Type{<:MultiHeadAttention}, + mha, + q::AbstractArray, + k::AbstractArray, + v::AbstractArray, + mask=nothing, +) + # XXX: restore `match_eltype` support + # q, k, v = match_eltype(mha, ps, st, q, k, v) - q, q_st = mha.q_proj(q, ps.q_proj, st.q_proj) - k, k_st = mha.k_proj(k, ps.k_proj, st.k_proj) - v, v_st = mha.v_proj(v, ps.v_proj, st.v_proj) - - dropout = StatefulLuxLayer( - mha.attention_dropout, ps.attention_dropout, st.attention_dropout - ) + q = mha.q_proj(q) + k = mha.k_proj(k) + v = mha.v_proj(v) x, α = scaled_dot_product_attention( reshape(q, size(q, 1) ÷ mha.nheads, mha.nheads, size(q)[2:end]...), @@ -192,22 +191,10 @@ function apply_multiheadattention(mha::MultiHeadAttention, ps, st, q, k, v, mask reshape(v, size(v, 1) ÷ mha.nheads, mha.nheads, size(v)[2:end]...); head_dim=1, token_dim=3, - fdrop=dropout, mask, + fdrop=mha.attention_dropout, mha.is_causal, ) - x = reshape(x, size(x, 1) * mha.nheads, size(x)[3:end]...) - - y, out_st = mha.out_proj(x, ps.out_proj, st.out_proj) - - return ( - (y, α), - (; - q_proj=q_st, - k_proj=k_st, - v_proj=v_st, - attention_dropout=dropout.st, - out_proj=out_st, - ), - ) + + return mha.out_proj(reshape(x, size(x, 1) * mha.nheads, size(x)[3:end]...)), α end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c6b3bad7c6..c79eed776e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -48,16 +48,16 @@ struct ReshapeLayer{N} <: AbstractLuxLayer dims::NTuple{N,Int} end -outputsize(r::ReshapeLayer, _, ::AbstractRNG) = r.dims - -function (r::ReshapeLayer)(x::AbstractArray, _, st::NamedTuple) - return reshape(x, r.dims..., size(x, ndims(x))), st -end - function Base.show(io::IO, r::ReshapeLayer) return print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))") end +outputsize(r::ReshapeLayer, _, ::AbstractRNG) = r.dims + +function apply(::Type{<:ReshapeLayer}, layer, x::AbstractArray) + return reshape(x, layer.dims..., size(x, ndims(x))) +end + """ ReverseSequence(dim = nothing) @@ -100,20 +100,18 @@ end ReverseSequence(dim) = ReverseSequence(static(dim)) ReverseSequence(; dim=nothing) = ReverseSequence(static(dim)) -function (r::ReverseSequence{Nothing})(x::AbstractArray, _, st::NamedTuple) - return safe_reverse(x; dims=max(ndims(x) - 1, 1)), st +function apply(::Type{ReverseSequence{Nothing}}, layer, x::AbstractArray) + return safe_reverse(x; dims=max(ndims(x) - 1, 1)) end -function (r::ReverseSequence{StaticInt{1}})(x::AbstractVector, _, st::NamedTuple) - return safe_reverse(x), st -end +apply(::Type{ReverseSequence{StaticInt{1}}}, layer, x::AbstractVector) = safe_reverse(x) -function (r::ReverseSequence{StaticInt{N}})(::AbstractVector, _, st::NamedTuple) where {N} +function apply(::Type{ReverseSequence{StaticInt{N}}}, layer, ::AbstractVector) where {N} throw(ArgumentError("Cannot specify a dimension ($(N) != 1) for AbstractVector")) end -function (r::ReverseSequence{StaticInt{N}})(x::AbstractArray, _, st::NamedTuple) where {N} - return safe_reverse(x; dims=N), st +function apply(::Type{ReverseSequence{StaticInt{N}}}, layer, x::AbstractArray) where {N} + return safe_reverse(x; dims=N) end """ @@ -160,13 +158,13 @@ end FlattenLayer(N) = FlattenLayer(static(N)) FlattenLayer(; N=nothing) = FlattenLayer(static(N)) -function (::FlattenLayer{Nothing})(x::AbstractArray{T,N}, _, st::NamedTuple) where {T,N} - return reshape(x, :, size(x, N)), st +function apply(::Type{FlattenLayer{Nothing}}, layer, x::AbstractArray{T,N}) where {T,N} + return reshape(x, :, size(x, N)) end -function (f::FlattenLayer)(x::AbstractArray{T,N}, _, st::NamedTuple) where {T,N} - @argcheck f.N < N - return reshape(x, :, size(x)[(f.N + 1):end]...), st +function apply(::Type{FlattenLayer}, layer, x::AbstractArray{T,N}) where {T,N} + @argcheck layer.N < N + return reshape(x, :, size(x)[(layer.N + 1):end]...) end """ @@ -196,17 +194,17 @@ views. index <: Union{StaticInt,AbstractVector} end +function Base.show(io::IO, s::SelectDim) + return print(io, "SelectDim(dim = ", s.dim, ", index = ", s.index, ")") +end + SelectDim(dim::Integer, index::Integer) = SelectDim(static(dim), static(index)) SelectDim(dim::Integer, index::AbstractVector) = SelectDim(static(dim), index) -function (s::SelectDim{D,<:StaticInt})(x, _, st::NamedTuple) where {D} - return selectdim(x, known(s.dim), known(s.index)), st -end -(s::SelectDim)(x, _, st::NamedTuple) = selectdim(x, known(s.dim), s.index), st - -function Base.show(io::IO, s::SelectDim) - return print(io, "SelectDim(dim = ", s.dim, ", index = ", s.index, ")") +function apply(::Type{SelectDim{D,<:StaticInt}}, layer, x) where {D} + return selectdim(x, known(layer.dim), known(layer.index)) end +apply(::Type{<:SelectDim}, layer, x) = selectdim(x, known(layer.dim), layer.index) """ NoOpLayer() @@ -232,7 +230,8 @@ julia> y, st_new = model(x, ps, st) """ struct NoOpLayer <: AbstractLuxLayer end -(noop::NoOpLayer)(x, _, st::NamedTuple) = x, st +apply(::Type{NoOpLayer}, layer, x) = x +apply(::Type{NoOpLayer}, layer, x, xs...) = (x, xs...) """ WrappedFunction(f) @@ -259,10 +258,11 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be func <: Function end -(wf::WrappedFunction)(x, ps, st::NamedTuple{}) = wf.func(x), st - Base.show(io::IO, w::WrappedFunction) = print(io, "WrappedFunction(", w.func, ")") +apply(::Type{WrappedFunction{F}}, layer, x) where {F} = layer.func(x) +apply(::Type{WrappedFunction{F}}, layer, x, xs...) where {F} = layer.func((x, xs...)) + """ Dense(in_dims => out_dims, activation=identity; init_weight=nothing, init_bias=nothing, use_bias=True()) @@ -356,14 +356,14 @@ function outputsize(d::Dense, x::AbstractArray, ::AbstractRNG) return (d.out_dims, size(x)[2:(end - 1)]...) end -function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(d, ps, st, x) - bias = safe_getproperty(ps, Val(:bias)) - σ = NNlib.fast_act(d.activation, x) - z = matrix_to_array( - fused_dense_bias_activation(σ, ps.weight, make_abstract_matrix(y), bias), y +function apply(::Type{<:Dense}, layer, x::AbstractArray) + # XXX: restore `match_eltype` support + # y = match_eltype(d, ps, st, x) + bias = safe_getproperty(layer, Val(:bias)) + σ = NNlib.fast_act(layer.activation, x) + return matrix_to_array( + fused_dense_bias_activation(σ, layer.weight, make_abstract_matrix(x), bias), x ) - return z, st end """ @@ -443,15 +443,18 @@ statelength(d::Scale) = 0 outputsize(d::Scale, _, ::AbstractRNG) = d.dims -function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(d, ps, st, x) - σ = NNlib.fast_act(d.activation, y) - return @.(σ(y .* ps.weight)), st +function apply(::Type{Scale{False}}, layer, x::AbstractArray) + # XXX: restore `match_eltype` support + # y = match_eltype(d, ps, st, x) + σ = NNlib.fast_act(layer.activation, x) + return @.(σ(x .* layer.weight)) end -function (d::Scale{True})(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(d, ps, st, x) - σ = NNlib.fast_act(d.activation, y) - return @.(σ(y * ps.weight + ps.bias)), st + +function apply(::Type{Scale{True}}, layer, x::AbstractArray) + # XXX: restore `match_eltype` support + # y = match_eltype(d, ps, st, x) + σ = NNlib.fast_act(layer.activation, x) + return @.(σ(x * layer.weight + layer.bias)) end """ @@ -562,33 +565,26 @@ statelength(b::Bilinear) = 0 outputsize(b::Bilinear, _, ::AbstractRNG) = (b.out_dims,) -function (b::Bilinear)( - (x, y)::Tuple{<:AbstractVecOrMat,<:AbstractVecOrMat}, ps, st::NamedTuple -) - s₁, s₂, s₃ = size(ps.weight) +function apply(::Type{<:Bilinear}, layer, x::AbstractVecOrMat, y::AbstractVecOrMat) + s₁, s₂, s₃ = size(layer.weight) @argcheck s₂ == size(x, 1) && s₃ == size(y, 1) @argcheck size(x, 2) == size(y, 2) - Wy = reshape(reshape(ps.weight, (:, s₃)) * y, (s₁, s₂, :)) + Wy = reshape(reshape(layer.weight, (:, s₃)) * y, (s₁, s₂, :)) Wyx = reshape(batched_matmul(Wy, reshape(x, (s₂, 1, :))), (s₁, :)) - σ = NNlib.fast_act(b.activation, Wyx) - return bias_activation!!(σ, Wyx, safe_getproperty(ps, Val(:bias))), st + σ = NNlib.fast_act(layer.activation, Wyx) + return bias_activation!!(σ, Wyx, safe_getproperty(layer, Val(:bias))) end -function (b::Bilinear)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st::NamedTuple) +function apply(T::Type{<:Bilinear}, layer, x::AbstractArray, y::AbstractArray) @argcheck size(x)[2:end] == size(y)[2:end] - - s₁, s₂, s₃ = size(ps.weight) - x′ = reshape(x, s₂, :) - y′ = reshape(y, s₃, :) - - z, stₙ = b((x′, y′), ps, st) - - return reshape(z, s₁, size(x)[2:end]...), stₙ + s₁, s₂, s₃ = size(layer.weight) + z = apply(T, layer, reshape(x, s₂, :), reshape(y′, s₃, :)) + return reshape(z, s₁, size(x)[2:end]...) end -(b::Bilinear)(x::AbstractArray, ps, st::NamedTuple) = b((x, x), ps, st) +apply(T::Type{<:Bilinear}, layer, x::AbstractArray) = apply(T, layer, x, x) """ AlternatePrecision{T}(layer) @@ -617,9 +613,12 @@ end AlternatePrecision(::Type{T}, layer) where {T} = AlternatePrecision{T}(layer) -LuxCore.display_name(::AlternatePrecision{T}) where {T} = "AlternatePrecision{$T}" +display_name(::AlternatePrecision{T}) where {T} = "AlternatePrecision{$T}" + +function apply(::Type{<:AlternatePrecision{T}}, model, x::AbstractArray{T}) where {T} + return model.layer(x) +end -function (model::AlternatePrecision{T})(x::AbstractArray{T2}, ps, st) where {T,T2} - y, stₙ = model.layer(T.(x), ps, st) - return T2.(y), stₙ +function apply(::Type{<:AlternatePrecision{T}}, model, x::AbstractArray{T2}) where {T,T2} + return T2.(model.layer(T.(x))) end diff --git a/src/layers/containers.jl b/src/layers/containers.jl index c500b995db..01feee2fd5 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -1,3 +1,7 @@ +# NOTE: layers in this file can't use the low-boilerplate apply API since these layers +# often don't conform to the `AbstractLuxWrapperLayer` / `AbstractLuxContainerLayer` +# interface. Instead they define all the interface functions themselves. + """ SkipConnection(layers, connection; name=nothing) SkipConnection(; layers, connection, name=nothing) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index c58e7acd1f..cac88a1ade 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -250,15 +250,17 @@ function parameterlength(c::Conv) return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end -function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) +function apply(::Type{<:Conv}, c, x::AbstractArray) + # XXX: restore `match_eltype` support + # y = match_eltype(c, ps, st, x) + cdims = construct_crosscor_convdims( c.cross_correlation, - DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups), + DenseConvDims(x, c.weight; c.stride, padding=c.pad, c.dilation, c.groups), + ) + return fused_conv_bias_activation( + NNlib.fast_act(c.activation, x), c.weight, x, safe_getproperty(c, Val(:bias)), cdims ) - bias = safe_getproperty(ps, Val(:bias)) - σ = NNlib.fast_act(c.activation, y) - return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end function Base.show(io::IO, l::Conv) @@ -424,17 +426,20 @@ function parameterlength(c::ConvTranspose) return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs end -function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) +function apply(::Type{<:ConvTranspose}, c, x::AbstractArray) + # XXX: restore `match_eltype` support + # y = match_eltype(c, ps, st, x) cdims = construct_crosscor_convdims( c.cross_correlation, conv_transpose_dims( - y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups, c.outpad + x, c.weight; c.stride, padding=c.pad, c.dilation, c.groups, c.outpad ), ) - bias = safe_getproperty(ps, Val(:bias)) - σ = NNlib.fast_act(c.activation, y) - return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st + return bias_activation!!( + NNlib.fast_act(c.activation, x), + conv_transpose(x, c.weight, cdims), + safe_getproperty(c, Val(:bias)), + ) end function Base.show(io::IO, l::ConvTranspose) @@ -528,11 +533,11 @@ end Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) -function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st +function apply(::Type{<:Upsample}, m, x::AbstractArray) + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners) end -function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st +function apply(::Type{<:Upsample{Nothing}}, m, x::AbstractArray) + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners) end for interp in (:bilinear, :trilinear) @@ -541,12 +546,12 @@ for interp in (:bilinear, :trilinear) function lux_upsample_scale_dispatch( ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners ) - return $(nnlib_interp_func)(x, scale) + return $(nnlib_interp_func)(x, scale; align_corners) end function lux_upsample_size_dispatch( ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners ) - return $(nnlib_interp_func)(x; size) + return $(nnlib_interp_func)(x; size, align_corners) end end end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index d4c74e518c..ac7d1d4299 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -50,9 +50,10 @@ function AlphaDropout(p::T) where {T<:Real} return AlphaDropout(p, α, γ, β) end -function (d::AlphaDropout)(x, _, st::NamedTuple) - y, rng = alpha_dropout(st.rng, x, d.p, st.training, d.alpha, d.scale, d.bias) - return y, (; rng, st.training) +function apply(::Type{<:AlphaDropout}, d, x::AbstractArray) + y, rng = alpha_dropout(d.rng, x, d.p, d.training, d.alpha, d.scale, d.bias) + d.rng = rng + return y end Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")") @@ -106,9 +107,10 @@ function Dropout(p; dims=:) return Dropout(p, 1 / (1 - p), dims) end -function (d::Dropout)(x, _, st::NamedTuple) - y, _, rng = dropout(st.rng, x, d.p, st.training, d.q, d.dims) - return y, (; rng, st.training) +function apply(::Type{<:Dropout}, d, x::AbstractArray) + y, _, rng = dropout(d.rng, x, d.p, d.training, d.q, d.dims) + d.rng = rng + return y end function Base.show(io::IO, d::Dropout) @@ -177,6 +179,8 @@ function VariationalHiddenDropout(p; dims=:) return VariationalHiddenDropout(p, 1 / (1 - p), dims) end +# Note that we don't use `apply` here. While we support non-fixed state types, that +# api is inherently type-unstable. function (d::VariationalHiddenDropout)(x, _, st::NamedTuple) maskₒ = st.mask === nothing ? x : st.mask y, mask, rng = dropout(st.rng, x, maskₒ, d.p, st.training, st.update_mask, d.q, d.dims) diff --git a/src/layers/embedding.jl b/src/layers/embedding.jl index dc99eedcba..33daaf866d 100644 --- a/src/layers/embedding.jl +++ b/src/layers/embedding.jl @@ -57,32 +57,29 @@ end outputsize(e::Embedding, _, ::AbstractRNG) = (e.out_dims,) -function (e::Embedding)(x::Union{Number,AbstractVector}, ps, st::NamedTuple) +function apply(::Type{<:Embedding}, e, x::Union{Number,AbstractVector}) @argcheck Utils.eltype(x) <: Integer - return ps.weight[:, x], st + return ps.weight[:, x] end -function (e::Embedding)(x::AbstractArray, ps, st::NamedTuple) +function apply(T::Type{<:Embedding}, e, x::AbstractArray) @argcheck Utils.eltype(x) <: Integer - y, stₙ = e(Utils.vec(x), ps, st) - return reshape(y, :, size(x)...), stₙ + y = apply(T, e, Utils.vec(x)) + return reshape(y, :, size(x)...) end -function (e::Embedding)(x::NTuple{N,T}, ps, st::NamedTuple) where {N,T} +function apply(::Type{<:Embedding}, e, x::T...) where {T} @argcheck Utils.eltype(T) <: Integer - return ps.weight[:, x...], st + return ps.weight[:, x...] end -function (e::Embedding)(x::NTuple{N,<:AbstractVector{T}}, ps, st::NamedTuple) where {N,T} + +function apply(::Type{<:Embedding}, e, x::AbstractVector{T}...) where {T} @argcheck Utils.eltype(T) <: Integer @argcheck allequal(size, x) DimensionMismatch("Input vectors must have the same shape") - return NNlib.gather(ps.weight, x...), st + return NNlib.gather(ps.weight, x...) end -function (e::Embedding)(x::NTuple{N,<:AbstractArray{T}}, ps, st::NamedTuple) where {N,T} - @argcheck Utils.eltype(T) <: Integer +function apply(T::Type{<:Embedding}, e, x::AbstractArray...) @argcheck allequal(size, x) DimensionMismatch("Input arrays must have the same shape") - y, stₙ = e(vec.(x), ps, st) - return reshape(y, :, size(first(x))...), stₙ -end -function (e::Embedding)(::Tuple{}, _, ::NamedTuple) - throw(ArgumentError("Input tuple must contain at least one element")) + y = apply(T, e, vec.(x)...) + return reshape(y, :, size(first(x))...) end @doc doc""" @@ -145,10 +142,11 @@ function initialstates(::AbstractRNG, spe::SinusoidalPositionalEmbedding{T}) whe return (; sigmas) end -function (spe::SinusoidalPositionalEmbedding)(x::AbstractArray, ps, st::NamedTuple) - y = reshape(match_eltype(spe, ps, st, x), 1, size(x)...) .* st.sigmas - z = vcat(sin.(y), cos.(y)) .* spe.scale - return z, st +function apply(::Type{<:SinusoidalPositionalEmbedding}, spe, x::AbstractArray) + # XXX: restore `match_eltype` support + # x′ = match_eltype(spe, x) + y = reshape(x, 1, size(x)...) .* spe.sigmas + return vcat(sin.(y), cos.(y)) .* spe.scale end """ @@ -220,16 +218,14 @@ function initialstates(::AbstractRNG, rope::RotaryPositionalEmbedding) ) end -function (rope::RotaryPositionalEmbedding)( - x::AbstractArray{T,4}, ps, st::NamedTuple -) where {T} - y = apply_rotary_embedding(x, st.cos_cache, st.sin_cache; seq_dim=3) - return y, st +function apply(::Type{<:RotaryPositionalEmbedding}, rope, x::AbstractArray{T,4}) where {T} + return apply_rotary_embedding(x, rope.cos_cache, rope.sin_cache; seq_dim=3) end -function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedTuple) - y = apply_rotary_embedding(x, input_pos, st.cos_cache, st.sin_cache; seq_dim=3) - return y, st +function apply( + ::Type{<:RotaryPositionalEmbedding}, rope, x::AbstractArray{T,4}, input_pos +) where {T} + return apply_rotary_embedding(x, input_pos, rope.cos_cache, rope.sin_cache; seq_dim=3) end ## Functional variants since Qwen3 like models tend to share the same rotary embedding diff --git a/src/layers/extension.jl b/src/layers/extension.jl index faeccc7e99..03cd2c002c 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -76,14 +76,11 @@ function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer) return PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer) end -function (sc::SimpleChainsLayer)(x, ps, st) - y = match_eltype(sc, ps, st, x) - return ( - to_array( - sc.to_array, - apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x)), - ), - st, +function apply(::Type{<:SimpleChainsLayer}, sc, x) + # XXX: restore `match_eltype` support + # x′ = match_eltype(sc, x) + return to_array( + sc.to_array, apply_simple_chain(sc.layer, x, sc.params, MLDataDevices.get_device(x)) ) end diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 433b34fb61..714e6f137e 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -134,41 +134,40 @@ end parameterlength(l::BatchNorm) = ifelse(has_affine(l), l.chs * 2, 0) statelength(l::BatchNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 -function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) +function apply(::Type{<:BatchNorm}, bn, x::AbstractArray) CRC.ignore_derivatives() do - if st.training isa Val{true} + if bn.training isa Val{true} @argcheck size(x, ndims(x)) != 1 "Batch size for BatchNorm cannot be 1 during training" end end - x′ = match_eltype(BN, ps, st, x) - σ = NNlib.fast_act(BN.activation, x′) + # XXX: restore `match_eltype` support + # x′ = match_eltype(bn, ps, st, x) + x′ = x + σ = NNlib.fast_act(bn.activation, x′) y, stats = batchnorm( x′, - safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), - safe_getproperty(st, Val(:running_mean)), - safe_getproperty(st, Val(:running_var)), - st.training, + safe_getproperty(bn, Val(:scale)), + safe_getproperty(bn, Val(:bias)), + safe_getproperty(bn, Val(:running_mean)), + safe_getproperty(bn, Val(:running_var)), + bn.training, σ, - convert(unwrapped_eltype(x′), BN.momentum), - convert(unwrapped_eltype(x′), BN.epsilon), + convert(unwrapped_eltype(x′), bn.momentum), + convert(unwrapped_eltype(x′), bn.epsilon), ) - return y, update_batchnorm_state(BN, st, stats) + update_batchnorm_state!(bn, stats) + return y end -function update_batchnorm_state(BN::BatchNorm, st::NamedTuple, stats) - has_track_stats(BN) && return merge( - st, - (; - running_mean=Utils.vec(stats.running_mean), - running_var=Utils.vec(stats.running_var), - ), - ) - return st +function update_batchnorm_state!(bn, stats) + has_track_stats(bn) || return nothing + bn.running_mean = Utils.vec(stats.running_mean) + bn.running_var = Utils.vec(stats.running_var) + return nothing end -CRC.@non_differentiable update_batchnorm_state(::Any...) +CRC.@non_differentiable update_batchnorm_state!(::Any...) function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") @@ -280,18 +279,20 @@ end parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0 -function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) - x′ = match_eltype(GN, ps, st, x) - σ = NNlib.fast_act(GN.activation, x′) +function apply(::Type{<:GroupNorm}, gn, x::AbstractArray) + # XXX: restore `match_eltype` support + # x′ = match_eltype(GN, ps, st, x) + x′ = x + σ = NNlib.fast_act(gn.activation, x′) y = groupnorm( x′, - safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), - GN.groups, + safe_getproperty(gn, Val(:scale)), + safe_getproperty(gn, Val(:bias)), + gn.groups, σ, - convert(unwrapped_eltype(x′), GN.epsilon), + convert(unwrapped_eltype(x′), gn.epsilon), ) - return y, st + return y end function Base.show(io::IO, l::GroupNorm) @@ -434,35 +435,34 @@ end parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0) statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 -function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) - x′ = match_eltype(IN, ps, st, x) - σ = NNlib.fast_act(IN.activation, x′) +function apply(::Type{<:InstanceNorm}, in, x::AbstractArray) + # XXX: restore `match_eltype` support + # x′ = match_eltype(IN, ps, st, x) + x′ = x + σ = NNlib.fast_act(in.activation, x′) y, stats = instancenorm( x′, - safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), - safe_getproperty(st, Val(:running_mean)), - safe_getproperty(st, Val(:running_var)), - st.training, + safe_getproperty(in, Val(:scale)), + safe_getproperty(in, Val(:bias)), + safe_getproperty(in, Val(:running_mean)), + safe_getproperty(in, Val(:running_var)), + in.training, σ, - convert(unwrapped_eltype(x′), IN.momentum), - convert(unwrapped_eltype(x′), IN.epsilon), + convert(unwrapped_eltype(x′), in.momentum), + convert(unwrapped_eltype(x′), in.epsilon), ) - return y, update_instancenorm_state(IN, st, stats) + update_instancenorm_state!(in, stats) + return y end -function update_instancenorm_state(IN::InstanceNorm, st::NamedTuple, stats) - has_track_stats(IN) && return merge( - st, - (; - running_mean=Utils.vec(stats.running_mean), - running_var=Utils.vec(stats.running_var), - ), - ) - return st +function update_instancenorm_state!(in::InstanceNorm, stats) + has_track_stats(in) || return nothing + in.running_mean = Utils.vec(stats.running_mean) + in.running_var = Utils.vec(stats.running_var) + return nothing end -CRC.@non_differentiable update_instancenorm_state(::Any...) +CRC.@non_differentiable update_instancenorm_state!(::Any...) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(l.chs)") @@ -554,18 +554,20 @@ function initialparameters(rng::AbstractRNG, ln::LayerNorm) return (;) end -function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) - x′ = match_eltype(l, ps, st, x) +function apply(::Type{<:LayerNorm}, l, x::AbstractArray) + # XXX: restore `match_eltype` support + # x′ = match_eltype(l, ps, st, x) + x′ = x σ = NNlib.fast_act(l.activation, x′) y = layernorm( x′, - safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), + safe_getproperty(l, Val(:scale)), + safe_getproperty(l, Val(:bias)), σ, l.dims, convert(unwrapped_eltype(x′), l.epsilon), ) - return y, st + return y end function Base.show(io::IO, l::LayerNorm) @@ -781,17 +783,14 @@ parameterlength(l::RMSNorm) = has_affine(l) ? prod(l.normalized_shape) : 0 # specialization on `NT` is important here, else we won't be able to infer the # correct eltype of the output. -function (rms::RMSNorm)(x::AbstractArray{T}, ps, st::NamedTuple) where {T} +function apply(::Type{<:RMSNorm}, rms, x::AbstractArray{T}) where {T} # Don't use `match_eltype` here, since often times the eltypes are intentionally # different. ϵ = T(rms.epsilon) mean_sq = mean(abs2, x; dims=1:length(rms.normalized_shape)) if has_affine(rms) - norm_x = @. (x * LuxOps.rsqrt(mean_sq + ϵ)) * ps.scale - else - norm_x = @. x * LuxOps.rsqrt(mean_sq + ϵ) + return @. (x * LuxOps.rsqrt(mean_sq + ϵ)) * ps.scale end - - return norm_x, st + return @. x * LuxOps.rsqrt(mean_sq + ϵ) end diff --git a/test/helpers/stateful_tests.jl b/test/helpers/stateful_tests.jl index b15a736e1d..039179c6c4 100644 --- a/test/helpers/stateful_tests.jl +++ b/test/helpers/stateful_tests.jl @@ -63,15 +63,15 @@ @test smodel.st.training == 2 smodel = StatefulLuxLayer{false}(model, ps, st) - @test smodel.st_any.training isa Val{true} + @test smodel.st.training isa Val{true} smodel = LuxCore.testmode(smodel) - @test smodel.st_any.training isa Val{false} + @test smodel.st.training isa Val{false} smodel = LuxCore.trainmode(smodel) - @test smodel.st_any.training isa Val{true} + @test smodel.st.training isa Val{true} smodel = LuxCore.update_state(smodel, :training, 2) - @test smodel.st_any.training == 2 + @test smodel.st.training == 2 end end diff --git a/test/reactant/tracing_tests.jl b/test/reactant/tracing_tests.jl index 8c77b41edb..7813f29a81 100644 --- a/test/reactant/tracing_tests.jl +++ b/test/reactant/tracing_tests.jl @@ -9,14 +9,14 @@ @test get_device_type(smodel_ra.ps) <: ReactantDevice @test get_device_type(smodel_ra.st) <: ReactantDevice - @test smodel_ra.st_any === nothing + @test getfield(smodel_ra, :st_any) === nothing @test smodel_ra.fixed_state_type == smodel.fixed_state_type smodel = StatefulLuxLayer{false}(model, ps, st) smodel_ra = Reactant.to_rarray(smodel) @test get_device_type(smodel_ra.ps) <: ReactantDevice - @test get_device_type(smodel_ra.st_any) <: ReactantDevice - @test smodel_ra.st === nothing + @test get_device_type(getfield(smodel_ra, :st_ra)) <: ReactantDevice + @test getfield(smodel_ra, :st) === nothing @test smodel_ra.fixed_state_type == smodel.fixed_state_type end