diff --git a/neofs/layer/auth.go b/neofs/layer/auth.go index 69702900..d6e32fed 100644 --- a/neofs/layer/auth.go +++ b/neofs/layer/auth.go @@ -1,6 +1,7 @@ package layer import ( + "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -11,39 +12,101 @@ import ( ) const ( - GatewayKeySize = 2048 + gatewayEncryptionKeySize = 2048 ) -type keyPair struct { - PrivateKey *rsa.PrivateKey - PublicKey *rsa.PublicKey +type ( + signatureKeyName byte + encryptionKeyName byte +) + +const ( + _ signatureKeyName = iota + // Indicates that the key is a NeoFS ECDSA key. + gateNeoFSECDSAKey + gateNeoFSEd25519Key +) + +const ( + _ encryptionKeyName = iota + // Indicates that the key is used to encrypt + // a bearer token to pass auth procedure. + gateUserAuthKey +) + +type ( + signatureKeyPair struct { + PrivateKey *ecdsa.PrivateKey + PublicKey *ecdsa.PublicKey + } + + encryptionKeyPair struct { + PrivateKey *rsa.PrivateKey + PublicKey *rsa.PublicKey + } +) + +type secureEnclave struct { + signatureKeys map[signatureKeyName]signatureKeyPair + encryptionKeys map[encryptionKeyName]encryptionKeyPair } -type AuthCenter struct { - gatewayKeys keyPair +func newSecureEnclave() (*secureEnclave, error) { + // TODO: Get private keys. + // TODO: Fetch NeoFS and Auth private keys from app settings. + return &secureEnclave{ + signatureKeys: map[signatureKeyName]signatureKeyPair{}, + encryptionKeys: map[encryptionKeyName]encryptionKeyPair{}, + }, nil } -func NewAuthCenter() (*AuthCenter, error) { - var ( - err error - privateKey *rsa.PrivateKey - ) - privateKey, err = pullGatewayPrivateKey() +func (se *secureEnclave) Encrypt(keyName encryptionKeyName, data []byte) ([]byte, error) { + return rsa.EncryptOAEP(sha256.New(), rand.Reader, se.encryptionKeys[keyName].PublicKey, data, []byte{}) +} + +func (se *secureEnclave) Decrypt(keyName encryptionKeyName, data []byte) ([]byte, error) { + return rsa.DecryptOAEP(sha256.New(), rand.Reader, se.encryptionKeys[keyName].PrivateKey, data, []byte{}) +} + +var globalEnclave *secureEnclave + +func init() { + var err error + globalEnclave, err = newSecureEnclave() if err != nil { - return nil, errors.Wrap(err, "failed to pull gateway private key from trusted enclave") + panic("failed to initialize secure enclave") } - if privateKey == nil { - if privateKey, err = rsa.GenerateKey(rand.Reader, GatewayKeySize); err != nil { - return nil, errors.Wrap(err, "failed to generate gateway private key") - } - if err = pushGatewayPrivateKey(privateKey); err != nil { - return nil, errors.Wrap(err, "failed to push gateway private key to trusted enclave") - } - } - ac := &AuthCenter{gatewayKeys: keyPair{ - PrivateKey: privateKey, - PublicKey: &privateKey.PublicKey, - }} +} + +// AuthCenter is a central app's authentication/authorization management unit. +type AuthCenter struct { + zstdEncoder *zstd.Encoder + zstdDecoder *zstd.Decoder +} + +// NewAuthCenter creates an instance of AuthCenter. +func NewAuthCenter() (*AuthCenter, error) { + // var ( + // err error + // privateKey *rsa.PrivateKey + // ) + // secureEnclave := &SecureEnclave{} + // privateKey, err = secureEnclave.PullGatewayEncryptionPrivateKey() + // if err != nil { + // return nil, errors.Wrap(err, "failed to pull gateway private key from trusted enclave") + // } + // if privateKey == nil { + // // TODO: Move this logic to the enclave. + // if privateKey, err = rsa.GenerateKey(rand.Reader, gatewayEncryptionKeySize); err != nil { + // return nil, errors.Wrap(err, "failed to generate gateway private key") + // } + // // if err = keysEnclave.PushGatewayEncryptionPrivateKey(privateKey); err != nil { + // // return nil, errors.Wrap(err, "failed to push gateway private key to trusted enclave") + // // } + // } + zstdEncoder, _ := zstd.NewWriter(nil) + zstdDecoder, _ := zstd.NewReader(nil) + ac := &AuthCenter{zstdEncoder: zstdEncoder, zstdDecoder: zstdDecoder} return ac, nil } @@ -52,7 +115,7 @@ func (ac *AuthCenter) PackBearerToken(bearerToken *service.BearerTokenMsg) ([]by if err != nil { return nil, errors.Wrap(err, "failed to marshal bearer token") } - encryptedKeyID, err := ac.encrypt(compress(data)) + encryptedKeyID, err := globalEnclave.Encrypt(gateUserAuthKey, ac.compress(data)) if err != nil { return nil, errors.Wrap(err, "") } @@ -61,11 +124,11 @@ func (ac *AuthCenter) PackBearerToken(bearerToken *service.BearerTokenMsg) ([]by func (ac *AuthCenter) UnpackBearerToken(packedBearerToken []byte) (*service.BearerTokenMsg, error) { compressedKeyID := packedBearerToken[32:] - encryptedKeyID, err := decompress(compressedKeyID) + encryptedKeyID, err := ac.decompress(compressedKeyID) if err != nil { return nil, errors.Wrap(err, "failed to decompress key ID") } - keyID, err := ac.decrypt(encryptedKeyID) + keyID, err := globalEnclave.Decrypt(gateUserAuthKey, encryptedKeyID) if err != nil { return nil, errors.Wrap(err, "failed to decrypt key ID") } @@ -76,35 +139,17 @@ func (ac *AuthCenter) UnpackBearerToken(packedBearerToken []byte) (*service.Bear return bearerToken, nil } -func pullGatewayPrivateKey() (*rsa.PrivateKey, error) { - // TODO: Pull the private key from a persistent and trusted enclave. - return nil, nil -} - -func pushGatewayPrivateKey(key *rsa.PrivateKey) error { - // TODO: Push the private key to a persistent and trusted enclave. - return nil -} - -func (ac *AuthCenter) encrypt(data []byte) ([]byte, error) { - return rsa.EncryptOAEP(sha256.New(), rand.Reader, ac.gatewayKeys.PublicKey, data, []byte{}) -} - -func (ac *AuthCenter) decrypt(data []byte) ([]byte, error) { - return rsa.DecryptOAEP(sha256.New(), rand.Reader, ac.gatewayKeys.PrivateKey, data, []byte{}) -} - -func compress(data []byte) []byte { +func (ac *AuthCenter) compress(data []byte) []byte { + ac.zstdEncoder.Reset(nil) var compressedData []byte - zstdEncoder, _ := zstd.NewWriter(nil) - zstdEncoder.EncodeAll(data, compressedData) + ac.zstdEncoder.EncodeAll(data, compressedData) return compressedData } -func decompress(data []byte) ([]byte, error) { +func (ac *AuthCenter) decompress(data []byte) ([]byte, error) { + ac.zstdDecoder.Reset(nil) var decompressedData []byte - zstdDecoder, _ := zstd.NewReader(nil) - if _, err := zstdDecoder.DecodeAll(data, decompressedData); err != nil { + if _, err := ac.zstdDecoder.DecodeAll(data, decompressedData); err != nil { return nil, err } return decompressedData, nil