11# Taken from https://github.com/FluxML/model-zoo/pull/410
22using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3- DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme
3+ DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
44using Comonicon: @main
55
66if ! haskey (DataDeps. registry, " nanogpt" )
5151 block
5252end
5353
54- function GPTBlock (; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use_bias)
54+ function GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)
5555 return GPTBlock (Chain (
5656 SkipConnection (
5757 Chain (
@@ -63,8 +63,8 @@ function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use
6363 SkipConnection (
6464 Chain (
6565 LayerNorm ((n_embed, 1 )),
66- Dense (n_embed => n_hidden , gelu),
67- Dense (n_hidden => n_embed),
66+ Dense (n_embed => 4 * n_embed , gelu; use_bias ),
67+ Dense (4 * n_embed => n_embed; use_bias ),
6868 Dropout (dropout_rate)
6969 ),
7070 +
8787 layer
8888end
8989
90- function GPT (;
91- n_vocab, n_embed, sequence_length, n_hidden, n_layers, dropout_rate,
92- n_heads, qk_dim, v_dim
93- )
90+ function GPT (; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
9491 return GPT (Chain (
9592 Parallel (
9693 + ,
9794 Embedding (n_vocab => n_embed),
98- PositionalEmbedding (sequence_length => n_embed)
95+ PositionalEmbedding (block_size => n_embed)
9996 ),
10097 Dropout (dropout_rate),
101- Chain (ntuple (n_layers) do i
102- return GPTBlock (; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
103- end ... ),
98+ Chain (ntuple (
99+ Returns ( GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+ ) ... ),
104101 LayerNorm ((n_embed, 1 )),
105- Dense (n_embed => n_vocab)
102+ Dense (n_embed => n_vocab; use_bias )
106103 ))
107104end
108105
106+ #=
107+
108+ dev = reactant_device(; force=true)
109+ rng = Random.default_rng()
110+
111+ model = GPT(;
112+ n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
113+ use_bias=true
114+ )
115+ ps, st = Lux.setup(rng, model) |> dev
116+
117+ =#
118+
109119# Use the model to generate some text.
110120function generate_text (
111121 model, ps, st, seed; alphabet, output_length, sequence_length
@@ -152,13 +162,14 @@ function get_nanogpt_data(; sequence_length, test_split)
152162 data_file = joinpath (datadep " nanogpt" , " shakespeare_input.txt" )
153163 text = String (read (data_file))
154164
155- # For aesthetic reasons, replace newlines with strings. This is not necessary, but makes
156- # strings print nicer.
157- text = replace (text, r" \r ?\n " => " " )
165+ idx = ceil (Int, length (text) * (1 - test_split))
166+ train_text = text[1 : idx]
167+ test_text = text[(idx + 1 ): end ]
168+
169+ tokenizer = BytePairEncoding. load_gpt2 ()
158170
159- # # an array of all unique characters
160- alphabet = [unique (text)... , ' _' ]
161- stop = alphabet[end ]
171+ train_tokens = tokenizer (train_text)
172+ test_tokens = tokenizer (test_text)
162173
163174 B = (length (text) - 1 ) ÷ sequence_length
164175 # We must collect() before indexing, because String indexing does strange things with multi-byte
0 commit comments