From 46107ba1d99e544db77378d90edac4141178d8a0 Mon Sep 17 00:00:00 2001 From: Akhil Rane Date: Wed, 25 May 2022 02:48:56 -0400 Subject: [PATCH] Allow google oauth client to consume creds for workload identity New short lived credentials for workload identity have different type called 'external_account'. Current code does not support this type. Signed-off-by: Akhil Rane --- registry/storage/driver/gcs/gcs.go | 47 ++++++++++-- registry/storage/driver/gcs/gcs_test.go | 98 ++++++++++++++++++++----- 2 files changed, 120 insertions(+), 25 deletions(-) diff --git a/registry/storage/driver/gcs/gcs.go b/registry/storage/driver/gcs/gcs.go index 86dc87f14..6499b109d 100644 --- a/registry/storage/driver/gcs/gcs.go +++ b/registry/storage/driver/gcs/gcs.go @@ -148,17 +148,17 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri } var ts oauth2.TokenSource - jwtConf := new(jwt.Config) + var creds *google.Credentials if keyfile, ok := parameters["keyfile"]; ok { jsonKey, err := ioutil.ReadFile(fmt.Sprint(keyfile)) if err != nil { return nil, err } - jwtConf, err = google.JWTConfigFromJSON(jsonKey, storage.ScopeFullControl) + creds, err = google.CredentialsFromJSON(context.Background(), jsonKey, storage.ScopeFullControl) if err != nil { return nil, err } - ts = jwtConf.TokenSource(context.Background()) + ts = creds.TokenSource } else if credentials, ok := parameters["credentials"]; ok { credentialMap, ok := credentials.(map[interface{}]interface{}) if !ok { @@ -179,11 +179,11 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri return nil, fmt.Errorf("Failed to marshal gcs credentials to json") } - jwtConf, err = google.JWTConfigFromJSON(data, storage.ScopeFullControl) + creds, err = google.CredentialsFromJSON(context.Background(), data, storage.ScopeFullControl) if err != nil { return nil, err } - ts = jwtConf.TokenSource(context.Background()) + ts = creds.TokenSource } else { var err error ts, err = google.DefaultTokenSource(context.Background(), storage.ScopeFullControl) @@ -200,8 +200,8 @@ func FromParameters(parameters map[string]interface{}) (storagedriver.StorageDri params := driverParameters{ bucket: fmt.Sprint(bucket), rootDirectory: fmt.Sprint(rootDirectory), - email: jwtConf.Email, - privateKey: jwtConf.PrivateKey, + email: getEmailFromCredentialsJSON(creds.JSON), + privateKey: getPrivateKeyFromCredentialsJSON(creds.JSON), client: oauth2.NewClient(context.Background(), ts), chunkSize: chunkSize, maxConcurrency: maxConcurrency, @@ -928,3 +928,36 @@ func (d *driver) pathToDirKey(path string) string { func (d *driver) keyToPath(key string) string { return "/" + strings.Trim(strings.TrimPrefix(key, d.rootDirectory), "/") } + +func getEmailFromCredentialsJSON(JSON []byte) string { + var credsFile struct { + Email string `json:"client_email"` + } + err := json.Unmarshal(JSON, &credsFile) + if err == nil && credsFile.Email != "" { + return credsFile.Email + } + return "" +} + +func getPrivateKeyFromCredentialsJSON(JSON []byte) []byte { + var credsFile struct { + PrivateKey string `json:"private_key"` + } + err := json.Unmarshal(JSON, &credsFile) + if err == nil && credsFile.PrivateKey != "" { + return []byte(credsFile.PrivateKey) + } + return nil +} + +func getTypeFromCredentialsJSON(JSON []byte) string { + var credsFile struct { + Type string `json:"type"` + } + err := json.Unmarshal(JSON, &credsFile) + if err == nil && credsFile.Type != "" { + return credsFile.Type + } + return "" +} diff --git a/registry/storage/driver/gcs/gcs_test.go b/registry/storage/driver/gcs/gcs_test.go index e58216be0..1a7e26f86 100644 --- a/registry/storage/driver/gcs/gcs_test.go +++ b/registry/storage/driver/gcs/gcs_test.go @@ -3,6 +3,7 @@ package gcs import ( + "context" "fmt" "io/ioutil" "os" @@ -23,14 +24,15 @@ func Test(t *testing.T) { check.TestingT(t) } var gcsDriverConstructor func(rootDirectory string) (storagedriver.StorageDriver, error) var skipGCS func() string +var credentialsType string func init() { bucket := os.Getenv("REGISTRY_STORAGE_GCS_BUCKET") - credentials := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") + credentials_path := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") // Skip GCS storage driver tests if environment variable parameters are not provided skipGCS = func() string { - if bucket == "" || credentials == "" { + if bucket == "" || credentials_path == "" { return "The following environment variables must be set to enable these tests: REGISTRY_STORAGE_GCS_BUCKET, GOOGLE_APPLICATION_CREDENTIALS" } return "" @@ -40,6 +42,11 @@ func init() { return } + credentialsJSON, err := ioutil.ReadFile(fmt.Sprint(credentials_path)) + if err != nil { + panic(err) + } + root, err := ioutil.TempDir("", "driver-") if err != nil { panic(err) @@ -49,24 +56,15 @@ func init() { var email string var privateKey []byte - ts, err = google.DefaultTokenSource(dcontext.Background(), storage.ScopeFullControl) + creds, err := google.CredentialsFromJSON(context.Background(), []byte(credentialsJSON), storage.ScopeFullControl) if err != nil { - // Assume that the file contents are within the environment variable since it exists - // but does not contain a valid file path - jwtConfig, err := google.JWTConfigFromJSON([]byte(credentials), storage.ScopeFullControl) - if err != nil { - panic(fmt.Sprintf("Error reading JWT config : %s", err)) - } - email = jwtConfig.Email - privateKey = []byte(jwtConfig.PrivateKey) - if len(privateKey) == 0 { - panic("Error reading JWT config : missing private_key property") - } - if email == "" { - panic("Error reading JWT config : missing client_email property") - } - ts = jwtConfig.TokenSource(dcontext.Background()) + panic(fmt.Sprintf("Error reading credentials json file : %s", err)) } + credentialsType = getTypeFromCredentialsJSON(creds.JSON) + email = getEmailFromCredentialsJSON(creds.JSON) + privateKey = getPrivateKeyFromCredentialsJSON(creds.JSON) + + ts = creds.TokenSource gcsDriverConstructor = func(rootDirectory string) (storagedriver.StorageDriver, error) { parameters := driverParameters{ @@ -76,6 +74,7 @@ func init() { privateKey: privateKey, client: oauth2.NewClient(dcontext.Background(), ts), chunkSize: defaultChunkSize, + maxConcurrency: defaultMaxConcurrency, } return New(parameters) @@ -309,3 +308,66 @@ func TestMoveDirectory(t *testing.T) { t.Fatalf("Moving directory /parent/dir /parent/other should have return a non-nil error\n") } } + +// Test getting signed URL for a stored object +func TestURLFor(t *testing.T) { + if skipGCS() != "" { + t.Skip(skipGCS()) + } + + validRoot, err := ioutil.TempDir("", "driver-") + if err != nil { + t.Fatalf("unexpected error creating temporary directory: %v", err) + } + defer os.Remove(validRoot) + + driver, err := gcsDriverConstructor(validRoot) + if err != nil { + t.Fatalf("unexpected error creating rooted driver: %v", err) + } + + filename := "/test" + ctx := dcontext.Background() + + contents := make([]byte, defaultChunkSize) + writer, err := driver.Writer(ctx, filename, false) + defer driver.Delete(ctx, filename) + if err != nil { + t.Fatalf("driver.Writer: unexpected error: %v", err) + } + _, err = writer.Write(contents) + if err != nil { + t.Fatalf("writer.Write: unexpected error: %v", err) + } + err = writer.Commit() + if err != nil { + t.Fatalf("writer.Commit: unexpected error: %v", err) + } + err = writer.Close() + if err != nil { + t.Fatalf("writer.Close: unexpected error: %v", err) + } + if writer.Size() != int64(len(contents)) { + t.Fatalf("writer.Size: %d != %d", writer.Size(), len(contents)) + } + + options := make(map[string]interface{}) + + // fetch and verify if signed URL for the stored object is not empty + signedURL, err := driver.URLFor(ctx, filename, options) + + if credentialsType == "service_account" { + if err != nil { + t.Fatalf("driver.URLFor: unexpected error: %v", err) + } + if len(signedURL) == 0 { + t.Fatalf("signed URL is empty") + } + } else if credentialsType == "external_account" { + if err == nil { + t.Fatalf("driver.URLFor: expected error: %v", storagedriver.ErrUnsupportedMethod{}) + } + } else { + t.Fatalf("driver.URLFor: unexpected credentials type") + } +}