Skip to content

Commit 0650b61

Browse files
committed
feat: forward pass is now working
1 parent af7f081 commit 0650b61

File tree

1 file changed

+73
-41
lines changed

1 file changed

+73
-41
lines changed

examples/NanoGPT/main.jl

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@ using ConcreteStructs,
99
OneHotArrays,
1010
Reactant,
1111
Enzyme,
12-
BytePairEncoding
12+
BytePairEncoding,
13+
NNlib
1314
using Comonicon: @main
1415

15-
if !haskey(DataDeps.registry, "nanogpt_shakespeare_input")
16+
if !haskey(DataDeps.registry, "shakespeare_char")
1617
register(
1718
DataDep(
18-
"nanogpt_shakespeare_input",
19+
"shakespeare_char",
1920
"Shakespeare Input Text for training NanoGPT",
2021
"https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
2122
"59a0ad62833b2e15ec811c548618876359e902717431236e52699a0e2bc253ca",
@@ -27,31 +28,51 @@ end
2728

2829
@concrete struct CausalSelfAttention <: AbstractLuxWrapperLayer{:mha}
2930
mha
31+
32+
function CausalSelfAttention(args...; kwargs...)
33+
mha = MultiHeadAttention(args...; kwargs...)
34+
return new{typeof(mha)}(mha)
35+
end
3036
end
3137

3238
function (attn::CausalSelfAttention)(x::AbstractArray{T,3}, ps, st) where {T}
33-
return attn.mha((x, x, x, NNlib.make_causal_mask(x)), ps, st)
39+
(y, α), stₙ = attn.mha((x, x, x, NNlib.make_causal_mask(x)), ps, st)
40+
return y, stₙ
3441
end
3542

36-
@concrete struct GPTBlock <: AbstractLuxWrapperLayer{:block}
43+
@concrete struct GPT2Block <: AbstractLuxWrapperLayer{:block}
3744
block
3845
end
3946

40-
function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
41-
return GPTBlock(
47+
function GPT2Block(; embed_dim, num_heads, hidden_dim, dropout_rate)
48+
return GPT2Block(
4249
Chain(
4350
SkipConnection(
4451
Chain(
45-
LayerNorm((n_embed, 1)),
46-
CausalSelfAttention(; n_embed, n_heads, dropout_rate, use_bias),
52+
LayerNorm(embed_dim; dims=nothing),
53+
CausalSelfAttention(
54+
embed_dim;
55+
nheads=num_heads,
56+
attention_dropout_probability=dropout_rate,
57+
dense_kwargs=(; init_weight=glorot_uniform, init_bias=zeros32),
58+
),
4759
),
4860
+,
4961
),
5062
SkipConnection(
5163
Chain(
52-
LayerNorm((n_embed, 1)),
53-
Dense(n_embed => 4 * n_embed, gelu; use_bias),
54-
Dense(4 * n_embed => n_embed; use_bias),
64+
LayerNorm(embed_dim; dims=nothing),
65+
Dense(
66+
embed_dim => hidden_dim,
67+
gelu;
68+
init_weight=glorot_uniform,
69+
init_bias=zeros32,
70+
),
71+
Dense(
72+
hidden_dim => embed_dim;
73+
init_weight=glorot_uniform,
74+
init_bias=zeros32,
75+
),
5576
Dropout(dropout_rate),
5677
),
5778
+,
@@ -60,51 +81,62 @@ function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
6081
)
6182
end
6283

63-
struct PositionalEmbedding{E} <: AbstractLuxWrapperLayer{:embedding}
64-
embedding::E
65-
66-
function PositionalEmbedding(args...; kwargs...)
67-
embed = Embedding(args...; kwargs...)
68-
return new{typeof(embed)}(embed)
69-
end
70-
end
71-
72-
(pe::PositionalEmbedding)(x, ps, st) = pe.embedding(1:size(x, 1), ps, st)
73-
74-
@concrete struct GPT <: AbstractLuxWrapperLayer{:layer}
75-
layer
84+
@concrete struct GPT2 <: AbstractLuxContainerLayer{(:tok_emb, :pos_emb, :gpt_blocks)}
85+
tok_emb
86+
pos_emb
87+
gpt_blocks
7688
end
7789

78-
function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
79-
return GPT(
90+
function GPT2(;
91+
n_vocab, embed_dim, num_heads, hidden_dim, dropout_rate, block_size, n_layers
92+
)
93+
return GPT2(
94+
Embedding(n_vocab => embed_dim),
95+
Embedding(block_size => embed_dim),
8096
Chain(
81-
Parallel(
82-
+, Embedding(n_vocab => n_embed), PositionalEmbedding(block_size => n_embed)
83-
),
8497
Dropout(dropout_rate),
8598
Chain(
8699
ntuple(
87-
Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+
Returns(GPT2Block(; embed_dim, num_heads, dropout_rate, hidden_dim)),
101+
n_layers,
88102
)...,
89103
),
90-
LayerNorm((n_embed, 1)),
91-
Dense(n_embed => n_vocab; use_bias),
104+
LayerNorm(embed_dim; dims=nothing),
92105
),
93106
)
94107
end
95108

96-
#=
109+
function (model::GPT2)(x, ps, st)
110+
token_embeddings, st_tok_emb = model.tok_emb(x, ps.tok_emb, st.tok_emb)
111+
pos_embeddings, st_pos_emb = model.pos_emb(1:size(x, 1), ps.pos_emb, st.pos_emb)
112+
embedding_output = token_embeddings .+ pos_embeddings
113+
114+
query, st_gpt_blocks = model.gpt_blocks(embedding_output, ps.gpt_blocks, st.gpt_blocks)
115+
_, seq_len, batch_size = size(query)
116+
outputs = reshape(
117+
ps.tok_emb.weight' * reshape(query, :, seq_len * batch_size), :, seq_len, batch_size
118+
)
119+
120+
return outputs, (; tok_emb=st_tok_emb, pos_emb=st_pos_emb, gpt_blocks=st_gpt_blocks)
121+
end
97122

98123
dev = reactant_device(; force=true)
99124
rng = Random.default_rng()
100125

101-
model = GPT(;
102-
n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
103-
use_bias=true
126+
model = GPT2(;
127+
n_vocab=50304,
128+
embed_dim=768,
129+
hidden_dim=3072,
130+
block_size=1024,
131+
n_layers=3,
132+
dropout_rate=0.0,
133+
num_heads=12,
104134
)
105-
ps, st = Lux.setup(rng, model) |> dev
135+
ps, st = Lux.setup(rng, model) |> dev;
136+
137+
x = rand(1:50304, 1024, 32) |> dev;
106138

107-
=#
139+
@code_hlo model(x, ps, st)
108140

109141
# Use the model to generate some text.
110142
function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
@@ -180,7 +212,7 @@ function get_nanogpt_data(; sequence_length, test_split)
180212
end
181213

182214
@main function main(;
183-
n_embed::Int=64,
215+
embed_dim::Int=64,
184216
n_hidden::Int=256,
185217
n_heads::Int=4,
186218
qk_dim::Int=16,
@@ -238,7 +270,7 @@ end
238270

239271
model_config = (;
240272
n_vocab=length(alphabet),
241-
n_embed,
273+
embed_dim,
242274
sequence_length,
243275
n_hidden,
244276
n_layers,

0 commit comments

Comments
 (0)