Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Functors = "0.5"
GPUArraysCore = "0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1, 1"
LuxCore = "1.4"
LuxCore = "1.4" # XXX: bump to 1.5 before merge
LuxLib = "1.11.0"
MLDataDevices = "1.10.0"
MLUtils = "0.4.4"
Expand Down
12 changes: 7 additions & 5 deletions ext/LuxReactantExt/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ function Reactant.make_tracer(
seen, @nospecialize(model::StatefulLuxLayer), @nospecialize(path), mode; kwargs...
)
return StatefulLuxLayer(
model.model,
Reactant.make_tracer(seen, model.ps, (path..., :ps), mode; kwargs...),
Reactant.make_tracer(seen, model.st, (path..., :st), mode; kwargs...),
Reactant.make_tracer(seen, model.st_any, (path..., :st_any), mode; kwargs...),
model.fixed_state_type,
getfield(model, :model),
Reactant.make_tracer(seen, getfield(model, :ps), (path..., :ps), mode; kwargs...),
Reactant.make_tracer(seen, getfield(model, :st), (path..., :st), mode; kwargs...),
Reactant.make_tracer(
seen, getfield(model, :st_any), (path..., :st_any), mode; kwargs...
),
getfield(model, :fixed_state_type),
)
end
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.1"
version = "1.5.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
16 changes: 14 additions & 2 deletions lib/LuxCore/ext/LuxCoreFunctorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,22 @@ function Functors.functor(
)
recon = let ft = x.fixed_state_type
nt -> LuxCore.StatefulLuxLayerImpl.StatefulLuxLayer(
nt.model, nt.ps, nt.st, nt.st_any, ft
getfield(nt, :model),
getfield(nt, :ps),
getfield(nt, :st),
getfield(nt, :st_any),
ft,
)
end
return (; x.model, x.ps, x.st, x.st_any), recon
return (
(;
model=getfield(x, :model),
ps=getfield(x, :ps),
st=getfield(x, :st),
st_any=getfield(x, :st_any),
),
recon,
)
end

end
13 changes: 3 additions & 10 deletions lib/LuxCore/src/LuxCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ this include:
type stability. By default this is "disable"d. For more information, see the
[documentation](https://github.com/MilesCranmer/DispatchDoctor.jl).
"""
@stable default_mode = "disable" function apply(model::AbstractLuxLayer, x, ps, st)
return model(x, ps, st)
end
function apply end

"""
stateless_apply(model, x, ps)
Expand All @@ -162,9 +160,7 @@ Calls `apply` and only returns the first argument. This function requires that `
an empty state of `NamedTuple()`. Behavior of other kinds of models are undefined and it is
the responsibility of the user to ensure that the model has an empty state.
"""
function stateless_apply(model::AbstractLuxLayer, x, ps)
return first(apply(model, x, ps, Internal.get_empty_state(model)))
end
function stateless_apply end

"""
display_name(layer::AbstractLuxLayer)
Expand Down Expand Up @@ -265,10 +261,6 @@ function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer}
return statelength(getfield(l, layer))
end

function (l::AbstractLuxWrapperLayer{layer})(x, ps, st) where {layer}
return apply(getfield(l, layer), x, ps, st)
end

# Test Mode
"""
testmode(st::NamedTuple)
Expand Down Expand Up @@ -357,6 +349,7 @@ preserves_state_type(l::Tuple) = all(preserves_state_type, l)
preserves_state_type(l::NamedTuple) = all(preserves_state_type, values(l))

include("stateful.jl")
include("apply.jl")

module Internal

Expand Down
24 changes: 24 additions & 0 deletions lib/LuxCore/src/apply.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@stable default_mode = "disable" function apply(model::AbstractLuxLayer, x, ps, st)
return model(x, ps, st)
end

function stateless_apply(model::AbstractLuxLayer, x, ps)
return first(apply(model, x, ps, Internal.get_empty_state(model)))
end

# New Interface that circumvents having to manage the state manually
function (model::AbstractLuxLayer)(x, ps, st)
xs = x isa Tuple ? x : (x,)
smodel = StatefulLuxLayerImpl.NamedTupleStatefulLuxLayer(model, ps, st)
output = apply(typeof(model), smodel, xs...)
return output, StatefulLuxLayerImpl.get_states_as_namedtuple(smodel)
end

# fallback for wrapped layers
function apply(::Type{<:AbstractLuxWrapperLayer}, model, x)
return apply(only(getfield(model, :smodels)), x)
end

function apply(::Type{<:AbstractLuxWrapperLayer}, model, xs...)
return apply(only(getfield(model, :smodels)), xs)
end
Loading
Loading