@@ -9,13 +9,14 @@ using ConcreteStructs,
99 OneHotArrays,
1010 Reactant,
1111 Enzyme,
12- BytePairEncoding
12+ BytePairEncoding,
13+ NNlib
1314using Comonicon: @main
1415
15- if ! haskey (DataDeps. registry, " nanogpt_shakespeare_input " )
16+ if ! haskey (DataDeps. registry, " shakespeare_char " )
1617 register (
1718 DataDep (
18- " nanogpt_shakespeare_input " ,
19+ " shakespeare_char " ,
1920 " Shakespeare Input Text for training NanoGPT" ,
2021 " https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt" ,
2122 " 59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca" ,
2728
2829@concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
2930 mha
31+
32+ function CausalSelfAttention (args... ; kwargs... )
33+ mha = MultiHeadAttention (args... ; kwargs... )
34+ return new {typeof(mha)} (mha)
35+ end
3036end
3137
3238function (attn:: CausalSelfAttention )(x:: AbstractArray{T,3} , ps, st) where {T}
33- return attn. mha ((x, x, x, NNlib. make_causal_mask (x)), ps, st)
39+ (y, α), stₙ = attn. mha ((x, x, x, NNlib. make_causal_mask (x)), ps, st)
40+ return y, stₙ
3441end
3542
36- @concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
43+ @concrete struct GPT2Block <: AbstractLuxWrapperLayer{:block}
3744 block
3845end
3946
40- function GPTBlock (; n_embed, n_heads, dropout_rate, use_bias )
41- return GPTBlock (
47+ function GPT2Block (; embed_dim, num_heads, hidden_dim, dropout_rate )
48+ return GPT2Block (
4249 Chain (
4350 SkipConnection (
4451 Chain (
45- LayerNorm ((n_embed, 1 )),
46- CausalSelfAttention (; n_embed, n_heads, dropout_rate, use_bias),
52+ LayerNorm (embed_dim; dims= nothing ),
53+ CausalSelfAttention (
54+ embed_dim;
55+ nheads= num_heads,
56+ attention_dropout_probability= dropout_rate,
57+ dense_kwargs= (; init_weight= glorot_uniform, init_bias= zeros32),
58+ ),
4759 ),
4860 + ,
4961 ),
5062 SkipConnection (
5163 Chain (
52- LayerNorm ((n_embed, 1 )),
53- Dense (n_embed => 4 * n_embed, gelu; use_bias),
54- Dense (4 * n_embed => n_embed; use_bias),
64+ LayerNorm (embed_dim; dims= nothing ),
65+ Dense (
66+ embed_dim => hidden_dim,
67+ gelu;
68+ init_weight= glorot_uniform,
69+ init_bias= zeros32,
70+ ),
71+ Dense (
72+ hidden_dim => embed_dim;
73+ init_weight= glorot_uniform,
74+ init_bias= zeros32,
75+ ),
5576 Dropout (dropout_rate),
5677 ),
5778 + ,
@@ -60,51 +81,62 @@ function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
6081 )
6182end
6283
63- struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
64- embedding:: E
65-
66- function PositionalEmbedding (args... ; kwargs... )
67- embed = Embedding (args... ; kwargs... )
68- return new {typeof(embed)} (embed)
69- end
70- end
71-
72- (pe:: PositionalEmbedding )(x, ps, st) = pe. embedding (1 : size (x, 1 ), ps, st)
73-
74- @concrete struct GPT <: AbstractLuxWrapperLayer{:layer}
75- layer
84+ @concrete struct GPT2 <: AbstractLuxContainerLayer{(:tok_emb, :pos_emb, :gpt_blocks)}
85+ tok_emb
86+ pos_emb
87+ gpt_blocks
7688end
7789
78- function GPT (; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
79- return GPT (
90+ function GPT2 (;
91+ n_vocab, embed_dim, num_heads, hidden_dim, dropout_rate, block_size, n_layers
92+ )
93+ return GPT2 (
94+ Embedding (n_vocab => embed_dim),
95+ Embedding (block_size => embed_dim),
8096 Chain (
81- Parallel (
82- + , Embedding (n_vocab => n_embed), PositionalEmbedding (block_size => n_embed)
83- ),
8497 Dropout (dropout_rate),
8598 Chain (
8699 ntuple (
87- Returns (GPTBlock (; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+ Returns (GPT2Block (; embed_dim, num_heads, dropout_rate, hidden_dim)),
101+ n_layers,
88102 )... ,
89103 ),
90- LayerNorm ((n_embed, 1 )),
91- Dense (n_embed => n_vocab; use_bias),
104+ LayerNorm (embed_dim; dims= nothing ),
92105 ),
93106 )
94107end
95108
96- #=
109+ function (model:: GPT2 )(x, ps, st)
110+ token_embeddings, st_tok_emb = model. tok_emb (x, ps. tok_emb, st. tok_emb)
111+ pos_embeddings, st_pos_emb = model. pos_emb (1 : size (x, 1 ), ps. pos_emb, st. pos_emb)
112+ embedding_output = token_embeddings .+ pos_embeddings
113+
114+ query, st_gpt_blocks = model. gpt_blocks (embedding_output, ps. gpt_blocks, st. gpt_blocks)
115+ _, seq_len, batch_size = size (query)
116+ outputs = reshape (
117+ ps. tok_emb. weight' * reshape (query, :, seq_len * batch_size), :, seq_len, batch_size
118+ )
119+
120+ return outputs, (; tok_emb= st_tok_emb, pos_emb= st_pos_emb, gpt_blocks= st_gpt_blocks)
121+ end
97122
98123dev = reactant_device (; force= true )
99124rng = Random. default_rng ()
100125
101- model = GPT(;
102- n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
103- use_bias=true
126+ model = GPT2 (;
127+ n_vocab= 50304 ,
128+ embed_dim= 768 ,
129+ hidden_dim= 3072 ,
130+ block_size= 1024 ,
131+ n_layers= 3 ,
132+ dropout_rate= 0.0 ,
133+ num_heads= 12 ,
104134)
105- ps, st = Lux.setup(rng, model) |> dev
135+ ps, st = Lux. setup (rng, model) |> dev;
136+
137+ x = rand (1 : 50304 , 1024 , 32 ) |> dev;
106138
107- =#
139+ @code_hlo model (x, ps, st)
108140
109141# Use the model to generate some text.
110142function generate_text (model, ps, st, seed; alphabet, output_length, sequence_length)
@@ -180,7 +212,7 @@ function get_nanogpt_data(; sequence_length, test_split)
180212end
181213
182214@main function main (;
183- n_embed :: Int = 64 ,
215+ embed_dim :: Int = 64 ,
184216 n_hidden:: Int = 256 ,
185217 n_heads:: Int = 4 ,
186218 qk_dim:: Int = 16 ,
238270
239271 model_config = (;
240272 n_vocab= length (alphabet),
241- n_embed ,
273+ embed_dim ,
242274 sequence_length,
243275 n_hidden,
244276 n_layers,
0 commit comments