Skip to content

Commit 00474d4

Browse files
authored
Composite Tools Supports DefaultResults for Skippable Steps (#3006)
Changes Schema & Configuration Added defaultResults field to WorkflowStep CRD type (map[string]runtime.RawExtension) Added corresponding fields in raw config, internal config, and composer types Updated all converters (CRD→config, YAML→config, config→composer) Validation Implemented shared validation in workflow_validation.go Validates that defaultResults[field] is specified when: Step may be skipped (has condition or onError.action: continue) Downstream steps reference that specific output field Validation is reused by both VirtualMCPServer and VirtualMCPCompositeToolDefinition webhooks Template Reference Extraction Created pkg/templates/references.go with AST-based template parsing Uses text/template/parse package for reliable reference extraction Extracts .steps.<stepID>.output.<field> references for validation Runtime Updated RecordStepSkipped to accept and store defaultResults as step output Updated handleToolStepFailure to set defaultResults when continuing on error Documentation Added "Default Results" section to virtualmcpcompositetooldefinition-guide.md Added quick reference to composite-tools-quick-reference.md Added "Pattern 5: Default Results for Skippable Steps" to advanced-workflow-patterns.md --------- Signed-off-by: Jeremy Drouillard <[email protected]>
1 parent 08735de commit 00474d4

28 files changed

+1557
-30
lines changed

cmd/thv-operator/api/v1alpha1/virtualmcpcompositetooldefinition_webhook.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,12 @@ func (r *VirtualMCPCompositeToolDefinition) validateSteps() error {
164164
}
165165

166166
// Third pass: validate dependencies don't create cycles
167-
return r.validateDependencyCycles()
167+
if err := r.validateDependencyCycles(); err != nil {
168+
return err
169+
}
170+
171+
// Fourth pass: validate defaultResults for skippable steps
172+
return validateDefaultResultsForSteps("spec.steps", r.Spec.Steps, r.Spec.Output)
168173
}
169174

170175
// validateStep validates a single workflow step

cmd/thv-operator/api/v1alpha1/virtualmcpserver_types.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,15 @@ type WorkflowStep struct {
325325
// Timeout is the maximum execution time for this step
326326
// +optional
327327
Timeout string `json:"timeout,omitempty"`
328+
329+
// DefaultResults provides fallback output values when this step is skipped
330+
// (due to condition evaluating to false) or fails (when onError.action is "continue").
331+
// Each key corresponds to an output field name referenced by downstream steps.
332+
// Required if the step may be skipped AND downstream steps reference this step's output.
333+
// +optional
334+
// +kubebuilder:pruning:PreserveUnknownFields
335+
// +kubebuilder:validation:Schemaless
336+
DefaultResults map[string]runtime.RawExtension `json:"defaultResults,omitempty"`
328337
}
329338

330339
// ErrorHandling defines error handling behavior for workflow steps

cmd/thv-operator/api/v1alpha1/virtualmcpserver_webhook.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,13 @@ func (*VirtualMCPServer) validateCompositeTool(index int, tool CompositeToolSpec
228228
toolNames[tool.Name] = true
229229

230230
// Validate steps
231-
return validateCompositeToolSteps(index, tool.Steps)
231+
if err := validateCompositeToolSteps(index, tool.Steps); err != nil {
232+
return err
233+
}
234+
235+
// Validate defaultResults for skippable steps
236+
pathPrefix := fmt.Sprintf("spec.compositeTools[%d].steps", index)
237+
return validateDefaultResultsForSteps(pathPrefix, tool.Steps, tool.Output)
232238
}
233239

234240
// validateCompositeToolSteps validates all steps in a composite tool
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
package v1alpha1
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/stacklok/toolhive/pkg/templates"
9+
)
10+
11+
// stepFieldRef represents a reference to a specific field on a step's output.
12+
type stepFieldRef struct {
13+
stepID string
14+
field string
15+
}
16+
17+
// validateDefaultResultsForSteps validates that defaultResults is specified for steps that:
18+
// 1. May be skipped (have a condition or onError.action == "continue")
19+
// 2. Are referenced by downstream steps
20+
//
21+
// This is a shared validation function used by both VirtualMCPServer and VirtualMCPCompositeToolDefinition webhooks.
22+
// The pathPrefix parameter allows customizing error message paths (e.g., "spec.steps" or "spec.compositeTools[0].steps").
23+
// nolint:gocyclo // multiple passes of the workflow are required to validate references are safe.
24+
func validateDefaultResultsForSteps(pathPrefix string, steps []WorkflowStep, output *OutputSpec) error {
25+
// 1. Compute all skippable step IDs
26+
skippableStepIDs := make(map[string]struct{})
27+
for _, step := range steps {
28+
if stepMayBeSkipped(step) {
29+
skippableStepIDs[step.ID] = struct{}{}
30+
}
31+
}
32+
33+
// If no skippable steps, nothing to validate
34+
if len(skippableStepIDs) == 0 {
35+
return nil
36+
}
37+
38+
// 2. Compute map from skippable step ID to set of fields with default values
39+
skippableStepDefaults := make(map[string]map[string]struct{})
40+
for _, step := range steps {
41+
if _, ok := skippableStepIDs[step.ID]; ok {
42+
skippableStepDefaults[step.ID] = make(map[string]struct{})
43+
for key := range step.DefaultResults {
44+
skippableStepDefaults[step.ID][key] = struct{}{}
45+
}
46+
}
47+
}
48+
49+
// 3. For each step, check if any references are to skippable steps missing defaults for that field
50+
for _, step := range steps {
51+
refs, err := extractStepFieldRefsFromStep(step)
52+
if err != nil {
53+
return fmt.Errorf("failed to extract step references from step %s: %w", step.ID, err)
54+
}
55+
56+
for _, ref := range refs {
57+
// Check if this step is skippable
58+
defaultFields, isSkippable := skippableStepDefaults[ref.stepID]
59+
if !isSkippable {
60+
continue
61+
}
62+
63+
// Check if the referenced field has a default
64+
if _, hasDefault := defaultFields[ref.field]; !hasDefault {
65+
return fmt.Errorf(
66+
"%s[%s].defaultResults[%s] is required: step %q may be skipped and field %q is referenced by step %s",
67+
pathPrefix, ref.stepID, ref.field, ref.stepID, ref.field, step.ID)
68+
}
69+
}
70+
}
71+
72+
// Check output for references to skippable steps missing defaults
73+
if output != nil {
74+
outputRefs, err := extractStepFieldRefsFromOutput(output)
75+
if err != nil {
76+
return fmt.Errorf("failed to extract step references from output: %w", err)
77+
}
78+
79+
for _, ref := range outputRefs {
80+
defaultFields, isSkippable := skippableStepDefaults[ref.stepID]
81+
if !isSkippable {
82+
continue
83+
}
84+
85+
if _, hasDefault := defaultFields[ref.field]; !hasDefault {
86+
return fmt.Errorf(
87+
"%s[%s].defaultResults[%s] is required: step %q may be skipped and field %q is referenced by output",
88+
pathPrefix, ref.stepID, ref.field, ref.stepID, ref.field)
89+
}
90+
}
91+
}
92+
93+
return nil
94+
}
95+
96+
// stepMayBeSkipped returns true if a step may be skipped during execution.
97+
// A step may be skipped if:
98+
// - It has a condition (may evaluate to false)
99+
// - It has onError.action == "continue" (may fail and be skipped)
100+
func stepMayBeSkipped(step WorkflowStep) bool {
101+
// Step has a condition that may evaluate to false
102+
if step.Condition != "" {
103+
return true
104+
}
105+
106+
// Step has continue-on-error, meaning failure results in skip
107+
if step.OnError != nil && step.OnError.Action == ErrorActionContinue {
108+
return true
109+
}
110+
111+
return false
112+
}
113+
114+
// extractStepFieldRefsFromStep extracts step field references from a step's templates.
115+
func extractStepFieldRefsFromStep(step WorkflowStep) ([]stepFieldRef, error) {
116+
var allRefs []stepFieldRef
117+
118+
// Extract from condition
119+
if step.Condition != "" {
120+
refs, err := extractStepFieldRefsFromTemplate(step.Condition)
121+
if err != nil {
122+
return nil, err
123+
}
124+
allRefs = append(allRefs, refs...)
125+
}
126+
127+
// Extract from arguments
128+
if step.Arguments != nil && len(step.Arguments.Raw) > 0 {
129+
var args map[string]any
130+
if err := json.Unmarshal(step.Arguments.Raw, &args); err == nil {
131+
for _, argValue := range args {
132+
if strValue, ok := argValue.(string); ok {
133+
refs, err := extractStepFieldRefsFromTemplate(strValue)
134+
if err != nil {
135+
return nil, err
136+
}
137+
allRefs = append(allRefs, refs...)
138+
}
139+
}
140+
}
141+
}
142+
143+
// Extract from message (elicitation steps)
144+
if step.Message != "" {
145+
refs, err := extractStepFieldRefsFromTemplate(step.Message)
146+
if err != nil {
147+
return nil, err
148+
}
149+
allRefs = append(allRefs, refs...)
150+
}
151+
152+
return uniqueStepFieldRefs(allRefs), nil
153+
}
154+
155+
// extractStepFieldRefsFromOutput extracts step field references from output templates.
156+
func extractStepFieldRefsFromOutput(output *OutputSpec) ([]stepFieldRef, error) {
157+
if output == nil {
158+
return nil, nil
159+
}
160+
161+
var allRefs []stepFieldRef
162+
163+
for _, prop := range output.Properties {
164+
if prop.Value != "" {
165+
refs, err := extractStepFieldRefsFromTemplate(prop.Value)
166+
if err != nil {
167+
return nil, err
168+
}
169+
allRefs = append(allRefs, refs...)
170+
}
171+
172+
// Recursively check nested properties
173+
if len(prop.Properties) > 0 {
174+
nestedOutput := &OutputSpec{Properties: prop.Properties}
175+
nestedRefs, err := extractStepFieldRefsFromOutput(nestedOutput)
176+
if err != nil {
177+
return nil, err
178+
}
179+
allRefs = append(allRefs, nestedRefs...)
180+
}
181+
}
182+
183+
return uniqueStepFieldRefs(allRefs), nil
184+
}
185+
186+
// extractStepFieldRefsFromTemplate extracts step output field references from a template string.
187+
// Only references to .steps.<stepID>.output.<field> are extracted.
188+
// For ".steps.step1.output.foo.bar", it returns stepFieldRef{stepID: "step1", field: "foo"}.
189+
// References to .steps.<stepID>.status or .steps.<stepID>.error are ignored.
190+
func extractStepFieldRefsFromTemplate(tmplStr string) ([]stepFieldRef, error) {
191+
refs, err := templates.ExtractReferences(tmplStr)
192+
if err != nil {
193+
return nil, err
194+
}
195+
196+
var stepRefs []stepFieldRef
197+
for _, ref := range refs {
198+
// Look for ".steps.<stepID>.output.<field>" pattern
199+
if strings.HasPrefix(ref, ".steps.") {
200+
// Split: ["", "steps", "stepID", "output", "field", ...]
201+
parts := strings.SplitN(ref, ".", 6)
202+
// Must have at least 5 parts and the 4th must be "output"
203+
if len(parts) >= 5 && parts[3] == "output" {
204+
stepRefs = append(stepRefs, stepFieldRef{
205+
stepID: parts[2],
206+
field: parts[4],
207+
})
208+
}
209+
}
210+
}
211+
212+
return uniqueStepFieldRefs(stepRefs), nil
213+
}
214+
215+
// uniqueStepFieldRefs returns a deduplicated slice of stepFieldRefs.
216+
func uniqueStepFieldRefs(refs []stepFieldRef) []stepFieldRef {
217+
seen := make(map[stepFieldRef]struct{})
218+
result := make([]stepFieldRef, 0, len(refs))
219+
for _, r := range refs {
220+
if _, ok := seen[r]; !ok {
221+
seen[r] = struct{}{}
222+
result = append(result, r)
223+
}
224+
}
225+
return result
226+
}

0 commit comments

Comments
 (0)