Skip to content

Commit f5936a6

Browse files
committed
DecodeConfig
1 parent 8cdf3a6 commit f5936a6

File tree

6 files changed

+171
-80
lines changed

6 files changed

+171
-80
lines changed

common/hugio/writers.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ package hugio
1515

1616
import (
1717
"io"
18-
19-
"github.com/gohugoio/hugo/common/hdebug"
2018
)
2119

2220
// As implemented by strings.Builder.
@@ -127,18 +125,3 @@ func (c PipeReadWriteCloser) Close() (err error) {
127125
func (c PipeReadWriteCloser) WriteString(s string) (int, error) {
128126
return c.PipeWriter.Write([]byte(s))
129127
}
130-
131-
// CounterWriter is an io.Writer that counts the number of bytes written to it.
132-
// and prints a message every 1024 bytes.
133-
type CounterWriter struct {
134-
count int64
135-
}
136-
137-
func (cw *CounterWriter) Write(p []byte) (n int, err error) {
138-
n = len(p)
139-
cw.count += int64(n)
140-
if cw.count%1024 == 0 {
141-
hdebug.Printf("%d bytes written", cw.count)
142-
}
143-
return n, nil
144-
}

common/maps/map.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ func (m *Map[K, T]) Set(key K, value T) {
7474
}
7575

7676
// WithWriteLock executes the given function with a write lock on the map.
77-
func (m *Map[K, T]) WithWriteLock(f func(m map[K]T)) {
77+
func (m *Map[K, T]) WithWriteLock(f func(m map[K]T) error) error {
7878
m.mu.Lock()
7979
defer m.mu.Unlock()
80-
f(m.m)
80+
return f(m.m)
8181
}
8282

8383
// SetIfAbsent sets the given key to the given value if the key does not already exist in the map.

common/maps/map_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ func TestMap(t *testing.T) {
6262
c.Assert(found, qt.Equals, true)
6363
c.Assert(v, qt.Equals, 300)
6464

65-
m.WithWriteLock(func(m map[string]int) {
65+
m.WithWriteLock(func(m map[string]int) error {
6666
m["f"] = 500
67+
return nil
6768
})
6869
v, found = m.Lookup("f")
6970
c.Assert(found, qt.Equals, true)

hugolib/content_map_page_assembler.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,9 @@ func (a *allPagesAssembler) doCreatePages(prefix string, depth int) error {
282282
default:
283283
// Skip this page.
284284
a.droppedPages.WithWriteLock(
285-
func(m map[*Site][]string) {
285+
func(m map[*Site][]string) error {
286286
m[site] = append(m[site], s)
287+
return nil
287288
},
288289
)
289290

@@ -635,15 +636,16 @@ func (a *allPagesAssembler) doCreatePages(prefix string, depth int) error {
635636
continue
636637
}
637638
t := term{view: viewName, term: v}
638-
a.seenTerms.WithWriteLock(func(m map[term]sitesmatrix.Vectors) {
639+
a.seenTerms.WithWriteLock(func(m map[term]sitesmatrix.Vectors) error {
639640
vectors, found := m[t]
640641
if !found {
641642
m[t] = sitesmatrix.Vectors{
642643
vec: struct{}{},
643644
}
644-
return
645+
return nil
645646
}
646647
vectors[vec] = struct{}{}
648+
return nil
647649
})
648650
}
649651
}

internal/warpc/warpc.go

Lines changed: 104 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ import (
3333
"github.com/bep/textandbinaryreader"
3434

3535
"github.com/gohugoio/hugo/common/hdebug"
36+
"github.com/gohugoio/hugo/common/hstrings"
3637
"github.com/gohugoio/hugo/common/hugio"
38+
"github.com/gohugoio/hugo/common/maps"
3739
"golang.org/x/sync/errgroup"
3840

3941
"github.com/tetratelabs/wazero"
@@ -42,6 +44,11 @@ import (
4244
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
4345
)
4446

47+
const (
48+
MessageKindJSON string = "json"
49+
MessageKindBlob string = "blob"
50+
)
51+
4552
const currentVersion = 1
4653

4754
//go:embed wasm/quickjs.wasm
@@ -59,13 +66,51 @@ type Header struct {
5966
// Command is the command to execute.
6067
Command string `json:"command"`
6168

69+
// RequestKinds is a list of kinds in this RPC request,
70+
// e.g. {"json", "blob"}, or {"json"}.
71+
RequestKinds []string `json:"requestKinds"`
72+
// ResponseKinds is a list of kinds expected in the response,
73+
// e.g. {"json", "blob"}, or {"json"}.
74+
ResponseKinds []string `json:"responseKinds"`
75+
6276
// Set in the response if there was an error.
6377
Err string `json:"err"`
6478

6579
// Warnings is a list of warnings that may be returned in the response.
6680
Warnings []string `json:"warnings,omitempty"`
6781
}
6882

83+
func (m *Header) init() error {
84+
if m.ID == 0 {
85+
return errors.New("ID must not be 0 (note that this must be unique within the current request set time window)")
86+
}
87+
if m.Version == 0 {
88+
m.Version = currentVersion
89+
}
90+
if len(m.RequestKinds) == 0 {
91+
m.RequestKinds = []string{string(MessageKindJSON)}
92+
}
93+
if len(m.ResponseKinds) == 0 {
94+
m.ResponseKinds = []string{string(MessageKindJSON)}
95+
}
96+
if m.Version != currentVersion {
97+
return fmt.Errorf("unsupported version: %d", m.Version)
98+
}
99+
for range 2 {
100+
if len(m.RequestKinds) > 2 {
101+
return fmt.Errorf("invalid number of request kinds: %d", len(m.RequestKinds))
102+
}
103+
if len(m.ResponseKinds) > 2 {
104+
return fmt.Errorf("invalid number of response kinds: %d", len(m.ResponseKinds))
105+
}
106+
m.RequestKinds = hstrings.UniqueStringsReuse(m.RequestKinds)
107+
m.ResponseKinds = hstrings.UniqueStringsReuse(m.ResponseKinds)
108+
109+
}
110+
111+
return nil
112+
}
113+
69114
type Message[T any] struct {
70115
Header Header `json:"header"`
71116
Data T `json:"data"`
@@ -75,6 +120,10 @@ func (m Message[T]) GetID() uint32 {
75120
return m.Header.ID
76121
}
77122

123+
func (m *Message[T]) init() error {
124+
return m.Header.init()
125+
}
126+
78127
type SourceProvider interface {
79128
GetSource() io.Reader
80129
GetSourceLength() uint32
@@ -160,9 +209,6 @@ func putTimer(t *time.Timer) {
160209
// Execute sends a request to the dispatcher and waits for the response.
161210
func (p *dispatcherPool[Q, R]) Execute(ctx context.Context, q Message[Q]) (Message[R], error) {
162211
d := p.getDispatcher()
163-
if q.GetID() == 0 {
164-
return d.zeroR, errors.New("ID must not be 0 (note that this must be unique within the current request set time window)")
165-
}
166212

167213
call, err := d.newCall(q)
168214
if err != nil {
@@ -195,20 +241,22 @@ func (p *dispatcherPool[Q, R]) Execute(ctx context.Context, q Message[Q]) (Messa
195241
if err == nil && resp.Header.Err != "" {
196242
err = errors.New(resp.Header.Err)
197243
}
244+
198245
return resp, err
199246
}
200247

201248
func (d *dispatcher[Q, R]) newCall(q Message[Q]) (*call[Q, R], error) {
202-
responseCountdown := &atomic.Int32{}
203-
responseCountdown.Add(1) // Default is JSON response only.
204-
if _, ok := any(d.zeroQ.Data).(DestinationProvider); ok {
205-
// Expecting JSON followed by binary blob.
206-
responseCountdown.Add(1)
249+
if err := q.init(); err != nil {
250+
return nil, err
251+
}
252+
responseKinds := maps.NewMap[string, bool]()
253+
for _, rk := range q.Header.ResponseKinds {
254+
responseKinds.Set(rk, true)
207255
}
208256
call := &call[Q, R]{
209-
donec: make(chan *call[Q, R], 1),
210-
request: q,
211-
responseCountdown: responseCountdown,
257+
donec: make(chan *call[Q, R], 1),
258+
request: q,
259+
responseKinds: responseKinds,
212260
}
213261

214262
if d.shutdown || d.closing {
@@ -240,7 +288,7 @@ func (d *dispatcher[Q, R]) send(call *call[Q, R]) error {
240288
}
241289
if sp, ok := any(call.request.Data).(SourceProvider); ok {
242290
if sp.GetSourceLength() == 0 {
243-
panic("source length must be greater than 0")
291+
panic(fmt.Sprintf("source length must be greater than 0, header: %+v", call.request.Header))
244292
}
245293

246294
hdebug.Printf(" == == === Blob header %d %d", call.request.GetID(), sp.GetSourceLength())
@@ -272,7 +320,7 @@ func (d *dispatcher[Q, R]) inputBlobs() {
272320

273321
call := d.pendingCall(id)
274322

275-
hdebug.Printf("START === === === Read blob header id: %d len: %d countdown: %d", id, length, call.responseCountdown.Load())
323+
hdebug.Printf("START === === === Read blob header id: %d len: %d ", id, length)
276324

277325
if err := call.handleBlob(lr); err != nil {
278326
inputErr = err
@@ -282,18 +330,27 @@ func (d *dispatcher[Q, R]) inputBlobs() {
282330
inputErr = fmt.Errorf("blob %d: expected to read %d more bytes", id, lr.N)
283331
break
284332
}
285-
call.responseCountdown.Add(-1)
286-
if call.responseCountdown.Load() <= 0 {
287-
d.mu.Lock()
288-
delete(d.pending, id)
289-
d.mu.Unlock()
290-
call.done()
333+
if err := call.responseKinds.WithWriteLock(
334+
func(m map[string]bool) error {
335+
if _, ok := m[MessageKindBlob]; !ok {
336+
return fmt.Errorf("unexpected blob response for %q call ID %d", call.request.Header.Command, id)
337+
}
338+
delete(m, MessageKindBlob)
339+
// Message exchange is complete.
340+
d.mu.Lock()
341+
delete(d.pending, id)
342+
d.mu.Unlock()
343+
call.done()
344+
return nil
345+
}); err != nil {
346+
inputErr = err
347+
break
291348
}
292-
hdebug.Printf("END === === === Read blob header id: %d len: %d countdown: %d", id, length, call.responseCountdown.Load())
349+
hdebug.Printf("END === === === Read blob header id: %d len: %d", id, length)
293350
}
294351

295-
if inputErr != nil && inputErr != io.EOF && inputErr != io.ErrClosedPipe {
296-
fmt.Printf("ERR %s", inputErr) // TODO1
352+
if inputErr != nil {
353+
// panic(inputErr) // TODO1 fix me, consolidate with JSON error handling.
297354
}
298355
}
299356

@@ -309,17 +366,27 @@ func (d *dispatcher[Q, R]) inputJSON() {
309366

310367
call := d.pendingCall(r.GetID())
311368

312-
hdebug.Printf("END === === === get JSON id: %d countdown: %d", r.GetID(), call.responseCountdown.Load())
313-
314-
call.responseCountdown.Add(-1)
315-
call.response = r
316-
if call.responseCountdown.Load() <= 0 || r.Header.Err != "" {
317-
d.mu.Lock()
318-
delete(d.pending, r.GetID())
319-
d.mu.Unlock()
320-
call.done() // TODO1 check that this can be called multiple times safely.
369+
if err := call.responseKinds.WithWriteLock(
370+
func(m map[string]bool) error {
371+
call.response = r
372+
if _, ok := m[MessageKindJSON]; !ok {
373+
return fmt.Errorf("unexpected JSON response for call ID %d", r.GetID())
374+
}
375+
delete(m, MessageKindJSON)
376+
if len(m) == 0 {
377+
// Message exchange is complete.
378+
d.mu.Lock()
379+
delete(d.pending, r.GetID())
380+
d.mu.Unlock()
381+
call.done()
382+
}
383+
return nil
384+
}); err != nil {
385+
inputErr = err
386+
break
321387
}
322-
hdebug.Printf("END === === === get JSON id: %d countdown: %d", r.GetID(), call.responseCountdown.Load())
388+
389+
hdebug.Printf("END === === === get JSON id: %d", r.GetID())
323390

324391
}
325392

@@ -355,11 +422,11 @@ func (d *dispatcher[Q, R]) pendingCall(id uint32) *call[Q, R] {
355422
}
356423

357424
type call[Q, R any] struct {
358-
request Message[Q]
359-
response Message[R]
360-
responseCountdown *atomic.Int32
361-
err error
362-
donec chan *call[Q, R]
425+
request Message[Q]
426+
response Message[R]
427+
responseKinds *maps.Map[string, bool]
428+
err error
429+
donec chan *call[Q, R]
363430
}
364431

365432
func (c *call[Q, R]) handleBlob(r io.Reader) error {

0 commit comments

Comments
 (0)