diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go index b3df2611b..42cbf715b 100644 --- a/pkg/webhook/webhook.go +++ b/pkg/webhook/webhook.go @@ -192,8 +192,8 @@ func (m *podMutator) mutateContainers(containers []corev1.Container, clientID st if _, ok := skipContainers[containers[i].Name]; ok { continue } - // add environment variables to container if not exists - containers[i] = addEnvironmentVariables(containers[i], clientID, tenantID, m.azureAuthorityHost) + // set environment variables to container + containers[i] = setEnvironmentVariables(containers[i], clientID, tenantID, m.azureAuthorityHost) // add the volume mount if not exists containers[i] = addProjectedTokenVolumeMount(containers[i]) } @@ -357,28 +357,36 @@ func getTenantID(sa *corev1.ServiceAccount, c *config.Config) string { return c.TenantID } -// addEnvironmentVariables adds the clientID, tenantID and token file path environment variables needed for SDK -func addEnvironmentVariables(container corev1.Container, clientID, tenantID, azureAuthorityHost string) corev1.Container { - m := make(map[string]string) - for _, env := range container.Env { - m[env.Name] = env.Value +// createAppendOrUpdateForEnvVars returns a function +// that efficiently appends or updates a container's environment variables +func createAppendOrUpdateForEnvVars( + container *corev1.Container, +) func(key string, value string) { + m := make(map[string]int) + for envIndex, env := range container.Env { + m[env.Name] = envIndex + } + + return func(key string, value string) { + if ind, ok := m[key]; !ok { + container.Env = append(container.Env, corev1.EnvVar{Name: key, Value: value}) + } else { + container.Env[ind].Value = value + } } +} + +// setEnvironmentVariables sets the clientID, tenantID and token file path environment variables needed for SDK +func setEnvironmentVariables(container corev1.Container, clientID, tenantID, azureAuthorityHost string) corev1.Container { + appendOrUpdate := createAppendOrUpdateForEnvVars(&container) // add the clientID env var - if _, ok := m[AzureClientIDEnvVar]; !ok { - container.Env = append(container.Env, corev1.EnvVar{Name: AzureClientIDEnvVar, Value: clientID}) - } + appendOrUpdate(AzureClientIDEnvVar, clientID) // add the tenantID env var - if _, ok := m[AzureTenantIDEnvVar]; !ok { - container.Env = append(container.Env, corev1.EnvVar{Name: AzureTenantIDEnvVar, Value: tenantID}) - } + appendOrUpdate(AzureTenantIDEnvVar, tenantID) // add the token file env var - if _, ok := m[AzureFederatedTokenFileEnvVar]; !ok { - container.Env = append(container.Env, corev1.EnvVar{Name: AzureFederatedTokenFileEnvVar, Value: filepath.Join(TokenFileMountPath, TokenFilePathName)}) - } + appendOrUpdate(AzureFederatedTokenFileEnvVar, filepath.Join(TokenFileMountPath, TokenFilePathName)) // add the azure authority host env var - if _, ok := m[AzureAuthorityHostEnvVar]; !ok { - container.Env = append(container.Env, corev1.EnvVar{Name: AzureAuthorityHostEnvVar, Value: azureAuthorityHost}) - } + appendOrUpdate(AzureAuthorityHostEnvVar, azureAuthorityHost) return container } diff --git a/pkg/webhook/webhook_test.go b/pkg/webhook/webhook_test.go index 248d402d3..ea68c4f92 100644 --- a/pkg/webhook/webhook_test.go +++ b/pkg/webhook/webhook_test.go @@ -527,7 +527,7 @@ func TestAddEnvironmentVariables(t *testing.T) { }, }, { - name: "existing environment variables not replaced", + name: "existing environment variables should be replaced to support admission reinvocation", container: corev1.Container{ Name: "cont1", Image: "image", @@ -542,11 +542,11 @@ func TestAddEnvironmentVariables(t *testing.T) { }, { Name: AzureFederatedTokenFileEnvVar, - Value: filepath.Join(TokenFileMountPath, TokenFilePathName), + Value: "/tmp/token", }, { Name: AzureAuthorityHostEnvVar, - Value: "https://login.microsoftonline.com/", + Value: "https://localhost:8080/", }, }, }, @@ -556,11 +556,11 @@ func TestAddEnvironmentVariables(t *testing.T) { Env: []corev1.EnvVar{ { Name: AzureClientIDEnvVar, - Value: "myClientID", + Value: "clientID", }, { Name: AzureTenantIDEnvVar, - Value: "myTenantID", + Value: "tenantID", }, { Name: AzureFederatedTokenFileEnvVar, @@ -616,7 +616,7 @@ func TestAddEnvironmentVariables(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - actualContainer := addEnvironmentVariables(test.container, "clientID", "tenantID", "https://login.microsoftonline.com/") + actualContainer := setEnvironmentVariables(test.container, "clientID", "tenantID", "https://login.microsoftonline.com/") if !reflect.DeepEqual(actualContainer, test.expectedContainer) { t.Fatalf("expected: %v, got: %v", test.expectedContainer, actualContainer) }