1- # Taken from https://github.com/FluxML/model-zoo/pull/410
2- using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3- DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
1+ using ConcreteStructs,
2+ MLUtils,
3+ Lux,
4+ Random,
5+ Optimisers,
6+ Printf,
7+ Statistics,
8+ DataDeps,
9+ OneHotArrays,
10+ Reactant,
11+ Enzyme,
12+ BytePairEncoding
413using Comonicon: @main
514
6- if ! haskey (DataDeps. registry, " nanogpt" )
7- register (DataDep (
8- " nanogpt" ,
9- " Shakespeare Input Text for training NanoGPT" ,
10- " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
11- " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca"
12- ))
15+ if ! haskey (DataDeps. registry, " nanogpt_shakespeare_input" )
16+ register (
17+ DataDep (
18+ " nanogpt_shakespeare_input" ,
19+ " Shakespeare Input Text for training NanoGPT" ,
20+ " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
21+ " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca" ,
22+ ),
23+ )
1324end
1425
1526# Setup the model definition
16- @concrete struct CausalSelfAttention < :
17- AbstractLuxContainerLayer{(:causal_attn , :proj , :attn_drop )}
18- causal_attn
19- proj
20- attn_drop
21- n_embed:: Int
22- n_heads:: Int
23- end
2427
25- function CausalSelfAttention (; n_embed, n_heads, dropout_rate, use_bias)
26- causal_attn = Dense (n_embed, 3 * n_embed; use_bias)
27- proj = Chain (
28- Dense (n_embed, n_embed; use_bias),
29- Dropout (dropout_rate)
30- )
31- attn_drop = Dropout (dropout_rate)
32- return CausalSelfAttention (causal_attn, proj, attn_drop, n_embed, n_heads)
28+ @concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
29+ mha
3330end
3431
35- function (attn:: CausalSelfAttention )(x:: AbstractArray{T, 3} , ps, st) where {T}
36- qkv, qkv_st = attn. causal_attn (x, ps. causal_attn, st. causal_attn)
37- q, k, v = (
38- selectdim (qkv, 1 , 1 : (attn. n_heads)),
39- selectdim (qkv, 1 , (attn. n_heads + 1 ): (2 * attn. n_heads)),
40- selectdim (qkv, 1 , (2 * attn. n_heads + 1 ): (3 * attn. n_heads))
41- )
42- dp = StatefulLuxLayer {true} (attn. attn_drop, ps. attn_drop, st. attn_drop)
43- mha, _ = NNlib. dot_product_attention (
44- q, k, v, nothing ; mask= NNlib. make_causal_mask (x), fdrop= dp, nheads= attn. n_heads
45- )
46- proj, proj_st = attn. proj (mha, ps. proj, st. proj)
47- return proj, (; causal_attn= qkv_st, proj= proj_st, attn_drop= dp. attn_drop)
32+ function (attn:: CausalSelfAttention )(x:: AbstractArray{T,3} , ps, st) where {T}
33+ return attn. mha ((x, x, x, NNlib. make_causal_mask (x)), ps, st)
4834end
4935
5036@concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
5137 block
5238end
5339
5440function GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)
55- return GPTBlock (Chain (
56- SkipConnection (
57- Chain (
58- LayerNorm ((n_embed, 1 )),
59- CausalSelfAttention (; n_embed, n_heads, dropout_rate, use_bias)
41+ return GPTBlock (
42+ Chain (
43+ SkipConnection (
44+ Chain (
45+ LayerNorm ((n_embed, 1 )),
46+ CausalSelfAttention (; n_embed, n_heads, dropout_rate, use_bias),
47+ ),
48+ + ,
6049 ),
61- +
62- ),
63- SkipConnection (
64- Chain (
65- LayerNorm (( n_embed, 1 ) ),
66- Dense (n_embed => 4 * n_embed, gelu; use_bias ),
67- Dense ( 4 * n_embed => n_embed; use_bias ),
68- Dropout (dropout_rate)
50+ SkipConnection (
51+ Chain (
52+ LayerNorm ((n_embed, 1 )),
53+ Dense (n_embed => 4 * n_embed, gelu; use_bias),
54+ Dense ( 4 * n_embed => n_embed; use_bias ),
55+ Dropout (dropout_rate ),
56+ ),
57+ + ,
6958 ),
70- +
71- )
72- ))
59+ ),
60+ )
7361end
7462
7563struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
8876end
8977
9078function GPT (; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
91- return GPT (Chain (
92- Parallel (
93- + ,
94- Embedding (n_vocab => n_embed),
95- PositionalEmbedding (block_size => n_embed)
79+ return GPT (
80+ Chain (
81+ Parallel (
82+ + , Embedding (n_vocab => n_embed), PositionalEmbedding (block_size => n_embed)
83+ ),
84+ Dropout (dropout_rate),
85+ Chain (
86+ ntuple (
87+ Returns (GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)), n_layers
88+ )... ,
89+ ),
90+ LayerNorm ((n_embed, 1 )),
91+ Dense (n_embed => n_vocab; use_bias),
9692 ),
97- Dropout (dropout_rate),
98- Chain (ntuple (
99- Returns (GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100- )... ),
101- LayerNorm ((n_embed, 1 )),
102- Dense (n_embed => n_vocab; use_bias)
103- ))
93+ )
10494end
10595
10696#=
@@ -117,9 +107,7 @@ ps, st = Lux.setup(rng, model) |> dev
117107=#
118108
119109# Use the model to generate some text.
120- function generate_text (
121- model, ps, st, seed; alphabet, output_length, sequence_length
122- )
110+ function generate_text (model, ps, st, seed; alphabet, output_length, sequence_length)
123111 dev = get_device ((ps, st))
124112 @assert ! (dev isa ReactantDevice) " Currently we don't support running inference of \
125113 dynamically sized tensors."
@@ -192,14 +180,23 @@ function get_nanogpt_data(; sequence_length, test_split)
192180end
193181
194182@main function main (;
195- n_embed:: Int = 64 , n_hidden:: Int = 256 , n_heads:: Int = 4 , qk_dim:: Int = 16 ,
196- v_dim:: Int = 16 , n_layers:: Int = 6 , sequence_length:: Int = 64 , batchsize:: Int = 128 ,
197- dropout_rate:: Float32 = 0.0f0 , test_split:: Float64 = 0.1 , lr:: Float64 = 1e-2 ,
198- epochs:: Int = 100 ,
199- # Only inference options
200- inference:: Bool = false , model_path:: String = " " ,
201- seed:: Union{String, Vector{String}} = [" _" , " The" , " Julia" , " Lux.jl" ],
202- output_length:: Int = 1024
183+ n_embed:: Int = 64 ,
184+ n_hidden:: Int = 256 ,
185+ n_heads:: Int = 4 ,
186+ qk_dim:: Int = 16 ,
187+ v_dim:: Int = 16 ,
188+ n_layers:: Int = 6 ,
189+ sequence_length:: Int = 64 ,
190+ batchsize:: Int = 128 ,
191+ dropout_rate:: Float32 = 0.0f0 ,
192+ test_split:: Float64 = 0.1 ,
193+ lr:: Float64 = 1e-2 ,
194+ epochs:: Int = 100 ,
195+ # Only inference options
196+ inference:: Bool = false ,
197+ model_path:: String = " " ,
198+ seed:: Union{String,Vector{String}} = [" _" , " The" , " Julia" , " Lux.jl" ],
199+ output_length:: Int = 1024 ,
203200)
204201 rng = Random. default_rng ()
205202 Random. seed! (rng, 1234 )
@@ -220,16 +217,14 @@ end
220217 alphabet = JLD2. load (model_path, " alphabet" )
221218 sequence_length = model_config. sequence_length
222219
223- texts = generate_text (
224- model, ps, st, seed; alphabet, output_length, sequence_length
225- )
220+ texts = generate_text (model, ps, st, seed; alphabet, output_length, sequence_length)
226221
227222 for (i, (text, s)) in enumerate (zip (texts, seed))
228223 @printf " [Info] Seed [%d]: %s\n " i s
229224 @printf " [Generated Text] %s\n\n " text
230225 end
231226
232- return
227+ return nothing
233228 end
234229
235230 alphabet, trainX, trainY, testX, testY = get_nanogpt_data (; sequence_length, test_split)
@@ -238,13 +233,19 @@ end
238233 @printf " [Info] Training size: %d sequences.\n " size (trainX, 2 )
239234 @printf " [Info] Testing size: %d sequences.\n\n " size (testX, 2 )
240235
241- train_loader = DataLoader (
242- (trainX, trainY); batchsize, shuffle= true , parallel= true
243- ) |> dev
236+ train_loader =
237+ DataLoader ((trainX, trainY); batchsize, shuffle= true , parallel= true ) |> dev
244238
245239 model_config = (;
246- n_vocab= length (alphabet), n_embed, sequence_length, n_hidden,
247- n_layers, dropout_rate, n_heads, qk_dim, v_dim
240+ n_vocab= length (alphabet),
241+ n_embed,
242+ sequence_length,
243+ n_hidden,
244+ n_layers,
245+ dropout_rate,
246+ n_heads,
247+ qk_dim,
248+ v_dim,
248249 )
249250 model = GPT (; model_config... )
250251 ps, st = Lux. setup (rng, model) |> dev
290291
291292 # Generate some text here...
292293 texts = generate_text (
293- model, ps |> cdev, st |> cdev, seed;
294- alphabet, output_length, sequence_length
294+ model, ps |> cdev, st |> cdev, seed; alphabet, output_length, sequence_length
295295 )
296296 for (i, (text, s)) in enumerate (zip (texts, seed))
297297 @printf " [Info] Seed [%d]: %s\n " i s
307307 parameters= train_state. parameters |> cdev,
308308 states= train_state. states |> cdev,
309309 alphabet= alphabet,
310- model_config= model_config
310+ model_config= model_config,
311311 )
312312 end
313313 end
0 commit comments