Skip to content

Commit f64885f

Browse files
committed
feat: weighted sample
1 parent 3b2e15f commit f64885f

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/NanoGPT/main.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,17 @@ sumabs2first(layer, x, ps, st) = sum(abs2, first(layer(x, ps, st)))
147147
=#
148148

149149
# Use the model to generate some text.
150-
# function weighted_sample(items::AbstractVector, weights::AbstractVector)
151150

152-
# end
151+
function weighted_sample!(rng, items::AbstractVector, weights::AbstractVector, n::Int)
152+
@assert length(items) == length(weights)
153+
154+
weights = weights ./ sum(weights)
155+
cumprobs = reshape(cumsum(weights), :, 1)
156+
random_vals = rand(rng, 1, n)
157+
158+
indices = dropdims(sum(cumprobs .< random_vals; dims=1); dims=1) .+ 1
159+
return items[indices]
160+
end
153161

154162
function generate_text(model, ps, st, seed; alphabet, output_length, sequence_length)
155163
dev = get_device((ps, st))

0 commit comments

Comments
 (0)