// This file is https://github.com/aws/aws-sdk-go-v2/blob/a2b751d1ba71f59175a41f9cae5f159f1044360f/internal/v4a/credentials_test.go package v4a import ( "context" "fmt" "testing" "github.com/aws/aws-sdk-go-v2/aws" ) type rotatingCredsProvider struct { count int fail chan struct{} } func (r *rotatingCredsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { select { case <-r.fail: return aws.Credentials{}, fmt.Errorf("rotatingCredsProvider error") default: } credentials := aws.Credentials{ AccessKeyID: fmt.Sprintf("ACCESS_KEY_ID_%d", r.count), SecretAccessKey: fmt.Sprintf("SECRET_ACCESS_KEY_%d", r.count), SessionToken: fmt.Sprintf("SESSION_TOKEN_%d", r.count), } return credentials, nil } func TestSymmetricCredentialAdaptor(t *testing.T) { provider := &rotatingCredsProvider{ count: 0, fail: make(chan struct{}), } adaptor := &SymmetricCredentialAdaptor{SymmetricProvider: provider} if symCreds, err := adaptor.Retrieve(context.Background()); err != nil { t.Fatalf("expect no error, got %v", err) } else if !symCreds.HasKeys() { t.Fatalf("expect symmetric credentials to have keys") } if load := adaptor.asymmetric.Load(); load != nil { t.Errorf("expect asymmetric credentials to be nil") } if asymCreds, err := adaptor.RetrievePrivateKey(context.Background()); err != nil { t.Fatalf("expect no error, got %v", err) } else if !asymCreds.HasKeys() { t.Fatalf("expect asymmetric credentials to have keys") } if _, err := adaptor.Retrieve(context.Background()); err != nil { t.Fatalf("expect no error, got %v", err) } if load := adaptor.asymmetric.Load(); load.(*Credentials) == nil { t.Errorf("expect asymmetric credentials to be not nil") } provider.count++ if _, err := adaptor.Retrieve(context.Background()); err != nil { t.Fatalf("expect no error, got %v", err) } if load := adaptor.asymmetric.Load(); load.(*Credentials) != nil { t.Errorf("expect asymmetric credentials to be nil") } close(provider.fail) // All requests to the original provider will now fail from this point-on. _, err := adaptor.Retrieve(context.Background()) if err == nil { t.Error("expect error, got nil") } }