Skip to content

Commit 2dc6742

Browse files
committed
fix: more cleanup
1 parent aa30a59 commit 2dc6742

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

examples/NanoGPT/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71"
23
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
34
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
45
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"

examples/NanoGPT/main.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Taken from https://github.com/FluxML/model-zoo/pull/410
22
using ConcreteStructs, MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib,
3-
DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme
3+
DataDeps, StatsBase, OneHotArrays, JLD2, Reactant, Enzyme, BytePairEncoding
44
using Comonicon: @main
55

66
if !haskey(DataDeps.registry, "nanogpt")
@@ -51,7 +51,7 @@ end
5151
block
5252
end
5353

54-
function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use_bias)
54+
function GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)
5555
return GPTBlock(Chain(
5656
SkipConnection(
5757
Chain(
@@ -63,8 +63,8 @@ function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate, use
6363
SkipConnection(
6464
Chain(
6565
LayerNorm((n_embed, 1)),
66-
Dense(n_embed => n_hidden, gelu),
67-
Dense(n_hidden => n_embed),
66+
Dense(n_embed => 4 * n_embed, gelu; use_bias),
67+
Dense(4 * n_embed => n_embed; use_bias),
6868
Dropout(dropout_rate)
6969
),
7070
+
@@ -87,25 +87,35 @@ end
8787
layer
8888
end
8989

90-
function GPT(;
91-
n_vocab, n_embed, sequence_length, n_hidden, n_layers, dropout_rate,
92-
n_heads, qk_dim, v_dim
93-
)
90+
function GPT(; n_vocab, n_embed, block_size, n_layers, dropout_rate, n_heads, use_bias)
9491
return GPT(Chain(
9592
Parallel(
9693
+,
9794
Embedding(n_vocab => n_embed),
98-
PositionalEmbedding(sequence_length => n_embed)
95+
PositionalEmbedding(block_size => n_embed)
9996
),
10097
Dropout(dropout_rate),
101-
Chain(ntuple(n_layers) do i
102-
return GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
103-
end...),
98+
Chain(ntuple(
99+
Returns(GPTBlock(; n_embed, n_heads, dropout_rate, use_bias)), n_layers
100+
)...),
104101
LayerNorm((n_embed, 1)),
105-
Dense(n_embed => n_vocab)
102+
Dense(n_embed => n_vocab; use_bias)
106103
))
107104
end
108105

106+
#=
107+
108+
dev = reactant_device(; force=true)
109+
rng = Random.default_rng()
110+
111+
model = GPT(;
112+
n_vocab=50304, n_embed=768, block_size=1024, n_layers=12, dropout_rate=0.0, n_heads=12,
113+
use_bias=true
114+
)
115+
ps, st = Lux.setup(rng, model) |> dev
116+
117+
=#
118+
109119
# Use the model to generate some text.
110120
function generate_text(
111121
model, ps, st, seed; alphabet, output_length, sequence_length
@@ -152,13 +162,14 @@ function get_nanogpt_data(; sequence_length, test_split)
152162
data_file = joinpath(datadep"nanogpt", "shakespeare_input.txt")
153163
text = String(read(data_file))
154164

155-
# For aesthetic reasons, replace newlines with strings. This is not necessary, but makes
156-
# strings print nicer.
157-
text = replace(text, r"\r?\n" => " ")
165+
idx = ceil(Int, length(text) * (1 - test_split))
166+
train_text = text[1:idx]
167+
test_text = text[(idx + 1):end]
168+
169+
tokenizer = BytePairEncoding.load_gpt2()
158170

159-
## an array of all unique characters
160-
alphabet = [unique(text)..., '_']
161-
stop = alphabet[end]
171+
train_tokens = tokenizer(train_text)
172+
test_tokens = tokenizer(test_text)
162173

163174
B = (length(text) - 1) ÷ sequence_length
164175
# We must collect() before indexing, because String indexing does strange things with multi-byte

0 commit comments

Comments
 (0)