diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl index e7cd6da75f..81ac636207 100644 --- a/examples/CIFAR10/conv_mixer.jl +++ b/examples/CIFAR10/conv_mixer.jl @@ -15,6 +15,7 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) [ Chain( SkipConnection( + +, Chain( Conv( (kernel_size, kernel_size), @@ -24,8 +25,7 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) pad=SamePad(), ), BatchNorm(dim), - ), - +, + ) ), Conv((1, 1), dim => dim, relu), BatchNorm(dim), diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index ddf8c7e778..2f28ca6097 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -67,7 +67,7 @@ function Lux.convert_flux_model(l::Flux.SkipConnection; kwargs...) else Lux.convert_flux_model(l.connection; kwargs...) end - return Lux.SkipConnection(Lux.convert_flux_model(l.layers; kwargs...), connection) + return Lux.SkipConnection(connection, Lux.convert_flux_model(l.layers; kwargs...)) end function Lux.convert_flux_model(l::Flux.Bilinear; preserve_ps_st::Bool=false, kwargs...) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index c500b995db..ace7fd058f 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -1,5 +1,5 @@ """ - SkipConnection(layers, connection; name=nothing) + SkipConnection(connection, layers; name=nothing) SkipConnection(; layers, connection, name=nothing) Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers @@ -7,7 +7,7 @@ and a shortcut connection linking the block's input to the output through a user 2-argument callable. The first argument to the callable will be propagated through the given `layer` while the second is the unchanged, "skipped" input. -The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. +The simplest "ResNet"-type connection is just `SkipConnection(+, layer)`. ## Arguments @@ -51,7 +51,7 @@ end PrettyPrinting.printable_children(l::SkipConnection) = (; l.connection, l.layers) -function SkipConnection(layers, connection; name::NAME_TYPE=nothing) +function SkipConnection(connection, layers; name::NAME_TYPE=nothing) return SkipConnection(; layers, connection, name) end diff --git a/test/downstream/Flux/flux_integrationtest.jl b/test/downstream/Flux/flux_integrationtest.jl index 3d9dd6e7a4..a435ff9a67 100644 --- a/test/downstream/Flux/flux_integrationtest.jl +++ b/test/downstream/Flux/flux_integrationtest.jl @@ -44,7 +44,7 @@ toluxforce = FromFluxAdaptor(; force_preserve=true, preserve_ps_st=true) end @testset "Skip Connection" begin - model = dev(Flux.SkipConnection(Flux.Dense(2 => 2), +)) + model = dev(Flux.SkipConnection(+, Flux.Dense(2 => 2))) x = aType(rand(Float32, 2, 1)) model_lux = toluxpsst(model) diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index 719ea2d828..8163593d56 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -57,7 +57,7 @@ const MODELS_LIST = Any[ ), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), - (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (SkipConnection(vcat, Dense(2 => 2)), randn(Float32, 2, 3)), (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index b73c13523f..6fda8e2bcc 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -289,7 +289,7 @@ end display(d) b = Bilinear((2, 2) => 3) display(b) - layer = SkipConnection(d, b) + layer = SkipConnection(b, d) display(layer) ps, st = dev(Lux.setup(rng, layer)) x = aType(randn(rng, Float32, 2, 1)) @@ -304,7 +304,7 @@ end display(d) b = Bilinear((2, 2) => 3; use_bias=false) display(b) - layer = SkipConnection(d, b) + layer = SkipConnection(b, d) display(layer) ps, st = dev(Lux.setup(rng, layer)) x = aType(randn(rng, Float32, 2, 1)) @@ -319,7 +319,7 @@ end display(d) b = Bilinear((3, 2) => 5) display(b) - layer = SkipConnection(d, b) + layer = SkipConnection(b, d) display(layer) ps, st = dev(Lux.setup(rng, layer)) x = aType(randn(rng, Float32, 2, 7, 11)) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index 22aa5b7cf6..6fed0cb7ad 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -3,7 +3,7 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "zero sum" begin - layer = SkipConnection(WrappedFunction(Broadcast.BroadcastFunction(zero)), .+) + layer = SkipConnection(+, WrappedFunction(Broadcast.BroadcastFunction(zero))) display(layer) ps, st = dev(Lux.setup(rng, layer)) x = aType(randn(rng, Float32, 10, 10, 10, 10)) @@ -14,7 +14,7 @@ end @testset "concat size" begin - layer = SkipConnection(Dense(10, 10), hcat) + layer = SkipConnection(hcat, Dense(10, 10)) display(layer) ps, st = dev(Lux.setup(rng, layer)) x = aType(randn(rng, Float32, 10, 2))