Skip to content

Commit af7f081

Browse files
committed
refactor: use Lux primitives
1 parent ce9bc4c commit af7f081

File tree

1 file changed

+90
-90
lines changed

1 file changed

+90
-90
lines changed

examples/NanoGPT/main.jl

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,63 @@
1-
# Taken from https://github.com/FluxML/model-zoo/pull/410
2-
using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3-
DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
1+
using ConcreteStructs,
2+
MLUtils,
3+
Lux,
4+
Random,
5+
Optimisers,
6+
Printf,
7+
Statistics,
8+
DataDeps,
9+
OneHotArrays,
10+
Reactant,
11+
Enzyme,
12+
BytePairEncoding
413
using Comonicon: @main
514

6-
if !haskey(DataDeps.registry, "nanogpt")
7-
register(DataDep(
8-
"nanogpt",
9-
"Shakespeare Input Text for training NanoGPT",
10-
"https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
11-
"59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca"
12-
))
15+
if !haskey(DataDeps.registry, "nanogpt_shakespeare_input")
16+
register(
17+
DataDep(
18+
"nanogpt_shakespeare_input",
19+
"Shakespeare Input Text for training NanoGPT",
20+
"https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
21+
"59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca",
22+
),
23+
)
1324
end
1425

1526
# Setup the model definition
16-
@concrete struct CausalSelfAttention <:
17-
AbstractLuxContainerLayer{(:causal_attn, :proj, :attn_drop)}
18-
causal_attn
19-
proj
20-
attn_drop
21-
n_embed::Int
22-
n_heads::Int
23-
end
2427

25-
function CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias)
26-
causal_attn = Dense(n_embed, 3 * n_embed; use_bias)
27-
proj = Chain(
28-
Dense(n_embed, n_embed; use_bias),
29-
Dropout(dropout_rate)
30-
)
31-
attn_drop = Dropout(dropout_rate)
32-
return CausalSelfAttention(causal_attn, proj, attn_drop, n_embed, n_heads)
28+
@concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
29+
mha
3330
end
3431

35-
function (attn::CausalSelfAttention)(x::AbstractArray{T, 3}, ps, st) where {T}
36-
qkv, qkv_st = attn.causal_attn(x, ps.causal_attn, st.causal_attn)
37-
q, k, v = (
38-
selectdim(qkv, 1, 1:(attn.n_heads)),
39-
selectdim(qkv, 1, (attn.n_heads + 1):(2 * attn.n_heads)),
40-
selectdim(qkv, 1, (2 * attn.n_heads + 1):(3 * attn.n_heads))
41-
)
42-
dp = StatefulLuxLayer{true}(attn.attn_drop, ps.attn_drop, st.attn_drop)
43-
mha, _ = NNlib.dot_product_attention(
44-
q, k, v, nothing; mask=NNlib.make_causal_mask(x), fdrop=dp, nheads=attn.n_heads
45-
)
46-
proj, proj_st = attn.proj(mha, ps.proj, st.proj)
47-
return proj, (; causal_attn=qkv_st, proj=proj_st, attn_drop=dp.attn_drop)
32+
function (attn::CausalSelfAttention)(x::AbstractArray{T,3}, ps, st) where {T}
33+
return attn.mha((x, x, x, NNlib.make_causal_mask(x)), ps, st)
4834
end
4935

5036
@concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
5137
block
5238
end
5339

5440
function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
55-
return GPTBlock(Chain(
56-
SkipConnection(
57-
Chain(
58-
LayerNorm((n_embed, 1)),
59-
CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias)
41+
return GPTBlock(
42+
Chain(
43+
SkipConnection(
44+
Chain(
45+
LayerNorm((n_embed, 1)),
46+
CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias),
47+
),
48+
+,
6049
),
61-
+
62-
),
63-
SkipConnection(
64-
Chain(
65-
LayerNorm((n_embed, 1)),
66-
Dense(n_embed => 4 * n_embed, gelu; use_bias),
67-
Dense(4 * n_embed => n_embed; use_bias),
68-
Dropout(dropout_rate)
50+
SkipConnection(
51+
Chain(
52+
LayerNorm((n_embed, 1)),
53+
Dense(n_embed => 4 * n_embed, gelu; use_bias),
54+
Dense(4 * n_embed => n_embed; use_bias),
55+
Dropout(dropout_rate),
56+
),
57+
+,
6958
),
70-
+
71-
)
72-
))
59+
),
60+
)
7361
end
7462

7563
struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
@@ -88,19 +76,21 @@ end
8876
end
8977

9078
function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
91-
return GPT(Chain(
92-
Parallel(
93-
+,
94-
Embedding(n_vocab => n_embed),
95-
PositionalEmbedding(block_size => n_embed)
79+
return GPT(
80+
Chain(
81+
Parallel(
82+
+, Embedding(n_vocab => n_embed), PositionalEmbedding(block_size => n_embed)
83+
),
84+
Dropout(dropout_rate),
85+
Chain(
86+
ntuple(
87+
Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
88+
)...,
89+
),
90+
LayerNorm((n_embed, 1)),
91+
Dense(n_embed => n_vocab; use_bias),
9692
),
97-
Dropout(dropout_rate),
98-
Chain(ntuple(
99-
Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100-
)...),
101-
LayerNorm((n_embed, 1)),
102-
Dense(n_embed => n_vocab; use_bias)
103-
))
93+
)
10494
end
10595

10696
#=
@@ -117,9 +107,7 @@ ps, st = Lux.setup(rng, model) |> dev
117107
=#
118108

119109
# Use the model to generate some text.
120-
function generate_text(
121-
model, ps, st, seed; alphabet, output_length, sequence_length
122-
)
110+
function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
123111
dev = get_device((ps, st))
124112
@assert !(dev isa ReactantDevice) "Currently we don't support running inference of \
125113
dynamically sized tensors."
@@ -192,14 +180,23 @@ function get_nanogpt_data(; sequence_length, test_split)
192180
end
193181

194182
@main function main(;
195-
n_embed::Int=64, n_hidden::Int=256, n_heads::Int=4, qk_dim::Int=16,
196-
v_dim::Int=16, n_layers::Int=6, sequence_length::Int=64, batchsize::Int=128,
197-
dropout_rate::Float32=0.0f0, test_split::Float64=0.1, lr::Float64=1e-2,
198-
epochs::Int=100,
199-
# Only inference options
200-
inference::Bool=false, model_path::String="",
201-
seed::Union{String, Vector{String}}=["_", "The", "Julia", "Lux.jl"],
202-
output_length::Int=1024
183+
n_embed::Int=64,
184+
n_hidden::Int=256,
185+
n_heads::Int=4,
186+
qk_dim::Int=16,
187+
v_dim::Int=16,
188+
n_layers::Int=6,
189+
sequence_length::Int=64,
190+
batchsize::Int=128,
191+
dropout_rate::Float32=0.0f0,
192+
test_split::Float64=0.1,
193+
lr::Float64=1e-2,
194+
epochs::Int=100,
195+
# Only inference options
196+
inference::Bool=false,
197+
model_path::String="",
198+
seed::Union{String,Vector{String}}=["_", "The", "Julia", "Lux.jl"],
199+
output_length::Int=1024,
203200
)
204201
rng = Random.default_rng()
205202
Random.seed!(rng, 1234)
@@ -220,16 +217,14 @@ end
220217
alphabet = JLD2.load(model_path, "alphabet")
221218
sequence_length = model_config.sequence_length
222219

223-
texts = generate_text(
224-
model, ps, st, seed; alphabet, output_length, sequence_length
225-
)
220+
texts = generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
226221

227222
for (i, (text, s)) in enumerate(zip(texts, seed))
228223
@printf "[Info] Seed [%d]: %s\n" i s
229224
@printf "[Generated Text] %s\n\n" text
230225
end
231226

232-
return
227+
return nothing
233228
end
234229

235230
alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)
@@ -238,13 +233,19 @@ end
238233
@printf "[Info] Training size: %d sequences.\n" size(trainX, 2)
239234
@printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2)
240235

241-
train_loader = DataLoader(
242-
(trainX, trainY); batchsize, shuffle=true, parallel=true
243-
) |> dev
236+
train_loader =
237+
DataLoader((trainX, trainY); batchsize, shuffle=true, parallel=true) |> dev
244238

245239
model_config = (;
246-
n_vocab=length(alphabet), n_embed, sequence_length, n_hidden,
247-
n_layers, dropout_rate, n_heads, qk_dim, v_dim
240+
n_vocab=length(alphabet),
241+
n_embed,
242+
sequence_length,
243+
n_hidden,
244+
n_layers,
245+
dropout_rate,
246+
n_heads,
247+
qk_dim,
248+
v_dim,
248249
)
249250
model = GPT(; model_config...)
250251
ps, st = Lux.setup(rng, model) |> dev
@@ -290,8 +291,7 @@ end
290291

291292
# Generate some text here...
292293
texts = generate_text(
293-
model, ps |> cdev, st |> cdev, seed;
294-
alphabet, output_length, sequence_length
294+
model, ps |> cdev, st |> cdev, seed; alphabet, output_length, sequence_length
295295
)
296296
for (i, (text, s)) in enumerate(zip(texts, seed))
297297
@printf "[Info] Seed [%d]: %s\n" i s
@@ -307,7 +307,7 @@ end
307307
parameters=train_state.parameters |> cdev,
308308
states=train_state.states |> cdev,
309309
alphabet=alphabet,
310-
model_config=model_config
310+
model_config=model_config,
311311
)
312312
end
313313
end

0 commit comments

Comments
 (0)