Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import {
SdkClientType,
SdkEnumType,
SdkHttpOperation,
SdkModelType,
SdkType,
SdkUnionType,
UsageFlags,
} from "@azure-tools/typespec-client-generator-core";
import { CSharpEmitterContext } from "../sdk-context.js";
Expand Down Expand Up @@ -152,11 +155,90 @@ function fixNamingConflicts(models: InputModelType[], constants: InputLiteralTyp
}
}

function indexModelsUsedInUnions(sdkContext: CSharpEmitterContext): Record<string, Set<string>> {
const unionUsageMap: Record<string, Set<string>> = {};
for (const u of sdkContext.sdkPackage.unions.filter((u) => u.kind === "union")) {
const modelsInUnion = new Set<string>(
u.variantTypes.filter((m) => m.kind === "model").map((m) => m.name),
);
for (const modelName of modelsInUnion) {
if (!unionUsageMap[modelName]) {
unionUsageMap[modelName] = new Set<string>();
}
unionUsageMap[modelName].add(u.name);
}
}
return unionUsageMap;
}

function duplicateModelsUsedInUnions(sdkContext: CSharpEmitterContext) {
const modelUnionUsageMap = indexModelsUsedInUnions(sdkContext);
const entries = Object.entries(modelUnionUsageMap).filter(
([, unionNames]) => unionNames.size > 1,
);
if (entries.length === 0) {
return;
}
const modelsLookup = Object.fromEntries(
Array.from(sdkContext.__typeCache.types.keys())
.filter((k) => k.kind === "model")
.map((k) => {
return [k.name, k];
}),
);
const unionsLookup = Object.fromEntries(
sdkContext.sdkPackage.unions
.filter((u) => u.kind === "union")
.map((u) => {
return [u.name, u];
}),
);
for (const [modelName, unionNames] of entries) {
duplicateModelInUnion(modelName, unionNames, sdkContext, modelsLookup, unionsLookup);
}
}

function duplicateModelInUnion(
modelName: string,
unionNames: Set<string>,
sdkContext: CSharpEmitterContext,
modelsLookup: {
[k: string]: SdkModelType;
},
unionsLookup: {
[k: string]: SdkUnionType<SdkType>;
},
) {
const modelType = modelsLookup[modelName];
for (const unionName of unionNames) {
const unionType = unionsLookup[unionName];
// Create a duplicate of the model
const duplicatedModel = {
...modelType,
name: `${modelType.name}For${unionType.name[0].toUpperCase()}${unionType.name.slice(1)}`,
};
// Update the union to use the duplicated model
unionType.variantTypes = unionType.variantTypes.map((vt) => {
if (vt.kind === "model" && vt.name === modelType.name) {
return duplicatedModel;
}
return vt;
});
}
// remove the original model from the sdk context type cache
sdkContext.__typeCache.types.delete(modelType);
delete modelsLookup[modelName];
}

function navigateModels(sdkContext: CSharpEmitterContext) {
for (const m of sdkContext.sdkPackage.models) {
fromSdkType(sdkContext, m);
}
for (const e of sdkContext.sdkPackage.enums) {
fromSdkType(sdkContext, e);
}
duplicateModelsUsedInUnions(sdkContext);
for (const u of sdkContext.sdkPackage.unions) {
fromSdkType(sdkContext, u);
}
}
186 changes: 185 additions & 1 deletion packages/http-client-csharp/emitter/src/lib/type-converter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
DecoratorInfo,
SdkArrayType,
SdkBuiltInKinds,
SdkBuiltInType,
SdkConstantType,
SdkDateTimeType,
Expand Down Expand Up @@ -385,12 +386,180 @@ function fromSdkBuiltInType(
};
}

function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): InputUnionType {
function discriminatorPropertyFromUnion(
sdkContext: CSharpEmitterContext,
union: SdkUnionType,
variantTypes: InputType[],
): InputModelProperty | undefined {
if (!union.discriminatedOptions) {
return undefined;
}

const discriminatorPropertyName = union.discriminatedOptions.discriminatorPropertyName;
const discriminatorProperties = variantTypes
.map((variant) => {
if (variant.kind === "model") {
const discProp = variant.properties.find((p) => p.name === discriminatorPropertyName);
if (discProp) {
return discProp;
}
}
return undefined;
})
.filter((p) => p !== undefined);

if (discriminatorProperties.length === 0) {
return undefined;
}

// Declare an enum for all the constant values
const discriminatorEnumType: InputEnumType = {
kind: "enum",
name: `${union.name}${discriminatorPropertyName[0].toUpperCase()}${discriminatorPropertyName.slice(1)}`,
valueType: fromSdkBuiltInType(sdkContext, {
kind: "string",
name: "string",
crossLanguageDefinitionId: "TypeSpec.string",
decorators: [],
}),
values: [],
namespace: union.namespace,
crossLanguageDefinitionId: "",
access: undefined,
usage: UsageFlags.None,
decorators: [],
isFixed: false,
isFlags: false,
};

const enumValues: InputEnumValueType[] = discriminatorProperties.map((prop) => {
if (prop.type.kind === "constant") {
return {
kind: "enumvalue",
name: prop.type.value === null ? "Null" : prop.type.value.toString(),
value: typeof prop.type.value === "boolean" ? (prop.type.value ? 1 : 0) : prop.type.value,
enumType: discriminatorEnumType,
valueType: prop.type.valueType,
} as InputEnumValueType;
}
throw new Error(
`Discriminator property ${discriminatorPropertyName} in union ${union.name} is not a constant type.`,
);
// TODO handle numeric constants
// TODO handle default variants
// TODO handle string values
// TODO handle open ended enums
});

discriminatorEnumType.values.push(...enumValues);

sdkContext.__typeCache.updateSdkTypeReferences(
{
kind: "enum",
name: discriminatorEnumType.name,
valueType: fromSdkBuiltInType(sdkContext, {
kind: "string",
name: "string",
crossLanguageDefinitionId: "TypeSpec.string",
decorators: [],
}) as SdkBuiltInType<SdkBuiltInKinds>,
values: [],
namespace: union.namespace,
crossLanguageDefinitionId: "",
access: "public",
usage: UsageFlags.None,
decorators: [],
isFixed: false,
isFlags: false,
isGeneratedName: true,
isUnionAsEnum: false,
apiVersions: [],
},
discriminatorEnumType,
);

return {
kind: "property",
name: discriminatorPropertyName,
serializedName: discriminatorPropertyName,
type: discriminatorEnumType,
optional: false,
readOnly: false,
decorators: [],
flatten: false,
discriminator: true,
isHttpMetadata: false,
isApiVersion: false,
crossLanguageDefinitionId: "",
serializationOptions: {
json: { name: discriminatorPropertyName },
},
};
}

function removeDiscriminatorPropertiesFromVariants(
variantTypes: InputType[],
discriminatorPropertyName: string,
) {
for (const variant of variantTypes) {
if (variant.kind === "model") {
const discriminatorProperty = variant.properties.find(
(p) => p.name === discriminatorPropertyName,
);
variant.properties = variant.properties.filter((p) => p.name !== discriminatorPropertyName);
variant.discriminatorValue =
discriminatorProperty?.type.kind === "constant"
? discriminatorProperty.type.value?.toString()
: undefined;
}
}
}

function fromUnionType(
sdkContext: CSharpEmitterContext,
union: SdkUnionType,
): InputUnionType | InputModelType {
const variantTypes: InputType[] = [];
for (const value of union.variantTypes) {
const variantType = fromSdkType(sdkContext, value);
variantTypes.push(variantType);
}
if (isDiscriminatedUnion(union)) {
const discriminatorProperty = discriminatorPropertyFromUnion(sdkContext, union, variantTypes);
const properties = discriminatorProperty ? [discriminatorProperty] : [];
if (discriminatorProperty) {
removeDiscriminatorPropertiesFromVariants(variantTypes, discriminatorProperty.name);
}
const baseType: InputModelType = {
kind: "model",
name: union.name,
namespace: union.namespace,
discriminatorProperty: discriminatorProperty,
crossLanguageDefinitionId: union.crossLanguageDefinitionId,
access: union.access,
usage: union.usage,
properties: properties,
serializationOptions: {},
summary: union.summary,
doc: union.doc,
deprecation: union.deprecation,
decorators: union.decorators,
external: fromSdkExternalTypeInfo(union),
} as InputModelType;
const discriminatedSubtypes: Record<string, InputModelType> = {};
variantTypes.forEach((variant) => {
if (variant.kind === "model") {
variant.baseModel = baseType;
if (variant.discriminatorValue !== undefined) {
discriminatedSubtypes[variant.discriminatorValue] = variant;
}
}
});
if (Object.keys(discriminatedSubtypes).length > 0) {
baseType.discriminatedSubtypes = discriminatedSubtypes;
}
return baseType;
}

return {
kind: "union",
Expand All @@ -402,6 +571,21 @@ function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): I
};
}

function isDiscriminatedUnion(sdkType: SdkUnionType): boolean {
if (
!sdkType.discriminatedOptions ||
sdkType.discriminatedOptions.envelope === "object" ||
(sdkType.discriminatedOptions.envelopePropertyName !== undefined &&
sdkType.discriminatedOptions.envelopePropertyName !== "")
) {
return false;
}

return sdkType.variantTypes.every((variant) => {
return variant.kind === "model" && !variant.baseModel;
});
}

function fromSdkConstantType(
sdkContext: CSharpEmitterContext,
constantType: SdkConstantType,
Expand Down
Loading
Loading