diff --git a/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts b/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts index 94f1bfd767d..89eab9fc138 100644 --- a/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts +++ b/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts @@ -5,6 +5,9 @@ import { SdkClientType, SdkEnumType, SdkHttpOperation, + SdkModelType, + SdkType, + SdkUnionType, UsageFlags, } from "@azure-tools/typespec-client-generator-core"; import { CSharpEmitterContext } from "../sdk-context.js"; @@ -152,6 +155,81 @@ function fixNamingConflicts(models: InputModelType[], constants: InputLiteralTyp } } +function indexModelsUsedInUnions(sdkContext: CSharpEmitterContext): Record> { + const unionUsageMap: Record> = {}; + for (const u of sdkContext.sdkPackage.unions.filter((u) => u.kind === "union")) { + const modelsInUnion = new Set( + u.variantTypes.filter((m) => m.kind === "model").map((m) => m.name), + ); + for (const modelName of modelsInUnion) { + if (!unionUsageMap[modelName]) { + unionUsageMap[modelName] = new Set(); + } + unionUsageMap[modelName].add(u.name); + } + } + return unionUsageMap; +} + +function duplicateModelsUsedInUnions(sdkContext: CSharpEmitterContext) { + const modelUnionUsageMap = indexModelsUsedInUnions(sdkContext); + const entries = Object.entries(modelUnionUsageMap).filter( + ([, unionNames]) => unionNames.size > 1, + ); + if (entries.length === 0) { + return; + } + const modelsLookup = Object.fromEntries( + Array.from(sdkContext.__typeCache.types.keys()) + .filter((k) => k.kind === "model") + .map((k) => { + return [k.name, k]; + }), + ); + const unionsLookup = Object.fromEntries( + sdkContext.sdkPackage.unions + .filter((u) => u.kind === "union") + .map((u) => { + return [u.name, u]; + }), + ); + for (const [modelName, unionNames] of entries) { + duplicateModelInUnion(modelName, unionNames, sdkContext, modelsLookup, unionsLookup); + } +} + +function duplicateModelInUnion( + modelName: string, + unionNames: Set, + sdkContext: CSharpEmitterContext, + modelsLookup: { + [k: string]: SdkModelType; + }, + unionsLookup: { + [k: string]: SdkUnionType; + }, +) { + const modelType = modelsLookup[modelName]; + for (const unionName of unionNames) { + const unionType = unionsLookup[unionName]; + // Create a duplicate of the model + const duplicatedModel = { + ...modelType, + name: `${modelType.name}For${unionType.name[0].toUpperCase()}${unionType.name.slice(1)}`, + }; + // Update the union to use the duplicated model + unionType.variantTypes = unionType.variantTypes.map((vt) => { + if (vt.kind === "model" && vt.name === modelType.name) { + return duplicatedModel; + } + return vt; + }); + } + // remove the original model from the sdk context type cache + sdkContext.__typeCache.types.delete(modelType); + delete modelsLookup[modelName]; +} + function navigateModels(sdkContext: CSharpEmitterContext) { for (const m of sdkContext.sdkPackage.models) { fromSdkType(sdkContext, m); @@ -159,4 +237,8 @@ function navigateModels(sdkContext: CSharpEmitterContext) { for (const e of sdkContext.sdkPackage.enums) { fromSdkType(sdkContext, e); } + duplicateModelsUsedInUnions(sdkContext); + for (const u of sdkContext.sdkPackage.unions) { + fromSdkType(sdkContext, u); + } } diff --git a/packages/http-client-csharp/emitter/src/lib/type-converter.ts b/packages/http-client-csharp/emitter/src/lib/type-converter.ts index f3a008ed668..57f25300429 100644 --- a/packages/http-client-csharp/emitter/src/lib/type-converter.ts +++ b/packages/http-client-csharp/emitter/src/lib/type-converter.ts @@ -4,6 +4,7 @@ import { DecoratorInfo, SdkArrayType, + SdkBuiltInKinds, SdkBuiltInType, SdkConstantType, SdkDateTimeType, @@ -385,12 +386,180 @@ function fromSdkBuiltInType( }; } -function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): InputUnionType { +function discriminatorPropertyFromUnion( + sdkContext: CSharpEmitterContext, + union: SdkUnionType, + variantTypes: InputType[], +): InputModelProperty | undefined { + if (!union.discriminatedOptions) { + return undefined; + } + + const discriminatorPropertyName = union.discriminatedOptions.discriminatorPropertyName; + const discriminatorProperties = variantTypes + .map((variant) => { + if (variant.kind === "model") { + const discProp = variant.properties.find((p) => p.name === discriminatorPropertyName); + if (discProp) { + return discProp; + } + } + return undefined; + }) + .filter((p) => p !== undefined); + + if (discriminatorProperties.length === 0) { + return undefined; + } + + // Declare an enum for all the constant values + const discriminatorEnumType: InputEnumType = { + kind: "enum", + name: `${union.name}${discriminatorPropertyName[0].toUpperCase()}${discriminatorPropertyName.slice(1)}`, + valueType: fromSdkBuiltInType(sdkContext, { + kind: "string", + name: "string", + crossLanguageDefinitionId: "TypeSpec.string", + decorators: [], + }), + values: [], + namespace: union.namespace, + crossLanguageDefinitionId: "", + access: undefined, + usage: UsageFlags.None, + decorators: [], + isFixed: false, + isFlags: false, + }; + + const enumValues: InputEnumValueType[] = discriminatorProperties.map((prop) => { + if (prop.type.kind === "constant") { + return { + kind: "enumvalue", + name: prop.type.value === null ? "Null" : prop.type.value.toString(), + value: typeof prop.type.value === "boolean" ? (prop.type.value ? 1 : 0) : prop.type.value, + enumType: discriminatorEnumType, + valueType: prop.type.valueType, + } as InputEnumValueType; + } + throw new Error( + `Discriminator property ${discriminatorPropertyName} in union ${union.name} is not a constant type.`, + ); + // TODO handle numeric constants + // TODO handle default variants + // TODO handle string values + // TODO handle open ended enums + }); + + discriminatorEnumType.values.push(...enumValues); + + sdkContext.__typeCache.updateSdkTypeReferences( + { + kind: "enum", + name: discriminatorEnumType.name, + valueType: fromSdkBuiltInType(sdkContext, { + kind: "string", + name: "string", + crossLanguageDefinitionId: "TypeSpec.string", + decorators: [], + }) as SdkBuiltInType, + values: [], + namespace: union.namespace, + crossLanguageDefinitionId: "", + access: "public", + usage: UsageFlags.None, + decorators: [], + isFixed: false, + isFlags: false, + isGeneratedName: true, + isUnionAsEnum: false, + apiVersions: [], + }, + discriminatorEnumType, + ); + + return { + kind: "property", + name: discriminatorPropertyName, + serializedName: discriminatorPropertyName, + type: discriminatorEnumType, + optional: false, + readOnly: false, + decorators: [], + flatten: false, + discriminator: true, + isHttpMetadata: false, + isApiVersion: false, + crossLanguageDefinitionId: "", + serializationOptions: { + json: { name: discriminatorPropertyName }, + }, + }; +} + +function removeDiscriminatorPropertiesFromVariants( + variantTypes: InputType[], + discriminatorPropertyName: string, +) { + for (const variant of variantTypes) { + if (variant.kind === "model") { + const discriminatorProperty = variant.properties.find( + (p) => p.name === discriminatorPropertyName, + ); + variant.properties = variant.properties.filter((p) => p.name !== discriminatorPropertyName); + variant.discriminatorValue = + discriminatorProperty?.type.kind === "constant" + ? discriminatorProperty.type.value?.toString() + : undefined; + } + } +} + +function fromUnionType( + sdkContext: CSharpEmitterContext, + union: SdkUnionType, +): InputUnionType | InputModelType { const variantTypes: InputType[] = []; for (const value of union.variantTypes) { const variantType = fromSdkType(sdkContext, value); variantTypes.push(variantType); } + if (isDiscriminatedUnion(union)) { + const discriminatorProperty = discriminatorPropertyFromUnion(sdkContext, union, variantTypes); + const properties = discriminatorProperty ? [discriminatorProperty] : []; + if (discriminatorProperty) { + removeDiscriminatorPropertiesFromVariants(variantTypes, discriminatorProperty.name); + } + const baseType: InputModelType = { + kind: "model", + name: union.name, + namespace: union.namespace, + discriminatorProperty: discriminatorProperty, + crossLanguageDefinitionId: union.crossLanguageDefinitionId, + access: union.access, + usage: union.usage, + properties: properties, + serializationOptions: {}, + summary: union.summary, + doc: union.doc, + deprecation: union.deprecation, + decorators: union.decorators, + external: fromSdkExternalTypeInfo(union), + } as InputModelType; + const discriminatedSubtypes: Record = {}; + variantTypes.forEach((variant) => { + if (variant.kind === "model") { + variant.baseModel = baseType; + if (variant.discriminatorValue !== undefined) { + discriminatedSubtypes[variant.discriminatorValue] = variant; + } + } + }); + if (Object.keys(discriminatedSubtypes).length > 0) { + baseType.discriminatedSubtypes = discriminatedSubtypes; + } + return baseType; + } return { kind: "union", @@ -402,6 +571,21 @@ function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): I }; } +function isDiscriminatedUnion(sdkType: SdkUnionType): boolean { + if ( + !sdkType.discriminatedOptions || + sdkType.discriminatedOptions.envelope === "object" || + (sdkType.discriminatedOptions.envelopePropertyName !== undefined && + sdkType.discriminatedOptions.envelopePropertyName !== "") + ) { + return false; + } + + return sdkType.variantTypes.every((variant) => { + return variant.kind === "model" && !variant.baseModel; + }); +} + function fromSdkConstantType( sdkContext: CSharpEmitterContext, constantType: SdkConstantType, diff --git a/packages/http-client-csharp/emitter/test/Unit/client-model-builder.test.ts b/packages/http-client-csharp/emitter/test/Unit/client-model-builder.test.ts index 611fafc9ec9..82412cf5db7 100644 --- a/packages/http-client-csharp/emitter/test/Unit/client-model-builder.test.ts +++ b/packages/http-client-csharp/emitter/test/Unit/client-model-builder.test.ts @@ -441,3 +441,95 @@ describe("parseApiVersions", () => { ok(barClient.apiVersions.includes("bv2"), "Bar client should include bv2"); }); }); + +describe("union usage", () => { + let runner: TestHost; + + beforeEach(async () => { + runner = await createEmitterTestHost(); + }); + it("should duplicate models used in multiple unions", async () => { + const program = await typeSpecCompile( + ` + model ModelShared { + id: string; + type: "type1"; + } + + model ModelUniqueA { + name: string; + type: "type2"; + } + + model ModelUniqueB { + value: int32; + type: "type2"; + } + + @discriminated(#{ discriminatorPropertyName: "type", envelope: "none" }) + union UnionA { + "type1": ModelShared; + "type2": ModelUniqueA; + } + + @discriminated(#{ discriminatorPropertyName: "type", envelope: "none" }) + union UnionB { + "type1": ModelShared; + "type2": ModelUniqueB; + } + + op testA(@body input: UnionA): UnionB; + `, + runner, + { IsTCGCNeeded: true }, + ); + //TODO what happens if the model being duplicated is also being used elsewhere, e.g., as operation parameter or response type? + const context = createEmitterContext(program); + const sdkContext = await createCSharpSdkContext(context); + const root = createModel(sdkContext); + + const modelUniqueA = root.models.find((m) => m.name === "ModelUniqueA"); + ok(modelUniqueA, "ModelUniqueA should exist"); + + const modelUniqueB = root.models.find((m) => m.name === "ModelUniqueB"); + ok(modelUniqueB, "ModelUniqueB should exist"); + + const modelSharedOriginal = root.models.find((m) => m.name === "ModelShared"); + ok(!modelSharedOriginal, "Original ModelShared should NOT exist"); + + const modelSharedForA = root.models.find((m) => m.name === "ModelSharedForUnionA"); + ok(modelSharedForA, "ModelSharedForUnionA should exist"); + + const modelSharedForB = root.models.find((m) => m.name === "ModelSharedForUnionB"); + ok(modelSharedForB, "ModelSharedForUnionB should exist"); + + strictEqual( + modelSharedForA.properties.length, + modelSharedForB.properties.length, + "Both duplicated ModelShared should have same number of properties", + ); + strictEqual( + modelSharedForA.properties[0].name, + modelSharedForB.properties[0].name, + "Both duplicated ModelShared should have same property names", + ); + + const unionA = root.models.find((m) => m.name === "UnionA"); + ok(unionA, "UnionA should exist"); + strictEqual(unionA.kind, "model"); + ok(unionA.discriminatedSubtypes, "UnionA should have discriminatedSubtypes"); + const unionASubTypes = new Set(Object.values(unionA.discriminatedSubtypes).map((t) => t.name)); + ok(unionASubTypes.has("ModelSharedForUnionA"), "UnionA should reference ModelSharedForUnionA"); + ok(unionASubTypes.has("ModelUniqueA"), "UnionA should reference ModelUniqueA"); + ok(!unionASubTypes.has("ModelShared"), "UnionA should NOT reference original ModelShared"); + + const unionB = root.models.find((m) => m.name === "UnionB"); + ok(unionB, "UnionB should exist"); + strictEqual(unionB.kind, "model"); + ok(unionB.discriminatedSubtypes, "UnionB should have discriminatedSubtypes"); + const unionBSubTypes = new Set(Object.values(unionB.discriminatedSubtypes).map((t) => t.name)); + ok(unionBSubTypes.has("ModelSharedForUnionB"), "UnionB should reference ModelSharedForUnionB"); + ok(unionBSubTypes.has("ModelUniqueB"), "UnionB should reference ModelUniqueB"); + ok(!unionBSubTypes.has("ModelShared"), "UnionB should NOT reference original ModelShared"); + }); +}); diff --git a/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts b/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts index 081708c903f..d9446eba72f 100644 --- a/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts +++ b/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts @@ -180,3 +180,212 @@ describe("External types", () => { strictEqual((jsonElementProp.type as any).external.minVersion, "8.0.0"); }); }); + +describe("Union types to model hierarchies", () => { + let runner: TestHost; + + beforeEach(async () => { + runner = await createEmitterTestHost(); + }); + const supportedCases = [ + { + name: "request bodies", + opDefinition: `op test(@body input: MyUnion): void;`, + }, + { + name: "response bodies", + opDefinition: `op test(): MyUnion;`, + }, + { + name: "properties", + opDefinition: ` + model ContainerModel { + unionProp: MyUnion; + } + op test(): ContainerModel; + `, + }, + ]; + supportedCases.forEach(({ name, opDefinition }) => + it(`should convert ${name} union with members to model hierarchy`, async () => { + const program = await typeSpecCompile( + ` + model Alpha { + alphaProp: string; + type: "alpha"; + } + model Beta { + betaProp: int32; + type: "beta"; + } + @discriminated(#{ discriminatorPropertyName: "type", envelope: "none" }) + union MyUnion { + "alpha": Alpha, + "beta": Beta + } + ${opDefinition} + `, + runner, + { IsTCGCNeeded: true }, + ); + const context = createEmitterContext(program); + const sdkContext = await createCSharpSdkContext(context); + const root = createModel(sdkContext); + + const alphaModel = root.models.find((m) => m.name === "Alpha"); + ok(alphaModel, "Alpha should exist"); + + const betaModel = root.models.find((m) => m.name === "Beta"); + ok(betaModel, "Beta should exist"); + + const myUnion = root.models.find((m) => m.name === "MyUnion"); + ok(myUnion, "MyUnion should exist"); + + const enumDefinition = root.enums.find((e) => e.name === "MyUnionType"); + ok(enumDefinition, "Discriminator enum MyUnionType should exist"); + + // Validate that MyUnion is a model + strictEqual(myUnion.kind, "model", "MyUnion should be converted to a model"); + + // Validate that Alpha and Beta inherit from MyUnion + strictEqual(alphaModel.baseModel, myUnion, "Alpha should inherit from MyUnion"); + strictEqual(betaModel.baseModel, myUnion, "Beta should inherit from MyUnion"); + + // Validate the base model has the discriminator property + const discriminatorProperty = myUnion.properties.find((p) => p.name === "type"); + ok(discriminatorProperty, "MyUnion should have a discriminator property 'type'"); + strictEqual( + discriminatorProperty.type.kind, + "enum", + "Discriminator property 'type' should be of type string", + ); + + strictEqual(discriminatorProperty.kind, "property"); + strictEqual(discriminatorProperty.name, "type"); + strictEqual(discriminatorProperty.serializedName, "type"); + strictEqual(discriminatorProperty.type.kind, "enum"); + strictEqual(discriminatorProperty.optional, false); + strictEqual(discriminatorProperty.readOnly, false); + strictEqual(discriminatorProperty.discriminator, true); + + strictEqual( + discriminatorProperty, + myUnion.discriminatorProperty, + "Discriminator property should be set on MyUnion", + ); + + // Validate that the discriminator property has the correct enum values + const enumValues = new Set(discriminatorProperty.type.values.map((v) => v.name)); + strictEqual(enumValues.has("alpha"), true, "Discriminator enum should include 'alpha'"); + strictEqual(enumValues.has("beta"), true, "Discriminator enum should include 'beta'"); + + // Validate that Alpha and Beta have a discriminatorValue + strictEqual( + alphaModel.discriminatorValue, + "alpha", + "Alpha should have discriminatorValue 'alpha'", + ); + strictEqual( + betaModel.discriminatorValue, + "beta", + "Beta should have discriminatorValue 'beta'", + ); + + // Validate that Alpha and Beta DO NOT have the discriminator property + const alphaDiscriminatorProp = alphaModel.properties.find((p) => p.name === "type"); + strictEqual( + alphaDiscriminatorProp, + undefined, + "Alpha should not have the discriminator property 'type'", + ); + + const betaDiscriminatorProp = betaModel.properties.find((p) => p.name === "type"); + strictEqual( + betaDiscriminatorProp, + undefined, + "Beta should not have the discriminator property 'type'", + ); + + // Validate the operation request body is of the new model type + if (opDefinition.includes("@body")) { + const testOp = root.clients[0].methods.find((op) => op.name === "test"); + ok(testOp, "Operation 'test' should exist"); + const bodyParam = testOp.parameters.find((p) => p.name === "input"); + ok(bodyParam, "Body parameter 'input' should exist"); + strictEqual( + bodyParam.type, + myUnion, + "Body parameter 'input' type should be the converted MyUnion model", + ); + } + + // Validate the operation response body is of the new model type + if (opDefinition.includes("): MyUnion;")) { + const testOp = root.clients[0].methods.find((op) => op.name === "test"); + ok(testOp, "Operation 'test' should exist"); + strictEqual( + testOp.response.type, + myUnion, + "Operation return type should be the converted MyUnion model", + ); + } + + // Validate the property type is of the new model type + if (opDefinition.includes("model ContainerModel")) { + const containerModel = root.models.find((m) => m.name === "ContainerModel"); + ok(containerModel, "ContainerModel should exist"); + const unionProp = containerModel.properties.find((p) => p.name === "unionProp"); + ok(unionProp, "Property 'unionProp' should exist"); + strictEqual( + unionProp.type, + myUnion, + "Property 'unionProp' type should be the converted MyUnion model", + ); + } + }), + ); + + const unsupportedCases = [ + { + name: "envelopped", + unionDefinition: `@discriminated(#{ discriminatorPropertyName: "type", envelopePropertyName: "data" }) + union MyUnion { + "alpha": Alpha, + "beta": Beta + }`, + }, + ]; + unsupportedCases.forEach(({ name, unionDefinition }) => + it(`should NOT convert unsupported ${name} union with members to model hierarchy`, async () => { + const program = await typeSpecCompile( + ` + model Alpha { + alphaProp: string; + type: "alpha"; + } + model Beta { + betaProp: int32; + type: "beta"; + } + ${unionDefinition} + + op test(@body input: MyUnion): void; + `, + runner, + { IsTCGCNeeded: true }, + ); + const context = createEmitterContext(program); + const sdkContext = await createCSharpSdkContext(context); + const root = createModel(sdkContext); + + const alphaModel = root.models.find((m) => m.name === "Alpha"); + ok(alphaModel, "Alpha should exist"); + + const betaModel = root.models.find((m) => m.name === "Beta"); + ok(betaModel, "Beta should exist"); + + const myUnion = root.models.find((m) => m.name === "MyUnion"); + ok(!myUnion, "MyUnion should NOT exist"); + }), + ); +}); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs index b7741d64749..276bf6cb970 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs @@ -885,7 +885,7 @@ p.Property is null if (type is { IsFrameworkType: false, IsEnum: true }) { - if (_inputModel.BaseModel.DiscriminatorProperty!.Type is InputEnumType inputEnumType) + if (_inputModel.BaseModel.DiscriminatorProperty?.Type is InputEnumType inputEnumType) { var discriminatorProvider = CodeModelGenerator.Instance.TypeFactory.CreateEnum(enumType: inputEnumType);