From 05aedee59eb9efdb36aa440f697609ef8f04f454 Mon Sep 17 00:00:00 2001 From: Pavel Korotkov Date: Wed, 8 Jul 2020 14:16:48 +0300 Subject: [PATCH] Add encrypt/decrypt logic --- neofs/layer/auth.go | 104 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 82 insertions(+), 22 deletions(-) diff --git a/neofs/layer/auth.go b/neofs/layer/auth.go index f4776cf8..69702900 100644 --- a/neofs/layer/auth.go +++ b/neofs/layer/auth.go @@ -3,55 +3,115 @@ package layer import ( "crypto/rand" "crypto/rsa" + "crypto/sha256" "github.com/klauspost/compress/zstd" "github.com/nspcc-dev/neofs-api-go/service" "github.com/pkg/errors" ) -type KeyPair struct { +const ( + GatewayKeySize = 2048 +) + +type keyPair struct { PrivateKey *rsa.PrivateKey PublicKey *rsa.PublicKey } type AuthCenter struct { - gatewayKeys KeyPair + gatewayKeys keyPair } func NewAuthCenter() (*AuthCenter, error) { - var kp KeyPair - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + var ( + err error + privateKey *rsa.PrivateKey + ) + privateKey, err = pullGatewayPrivateKey() if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to pull gateway private key from trusted enclave") } - kp.PrivateKey = privateKey - kp.PublicKey = &privateKey.PublicKey - ac := &AuthCenter{ - gatewayKeys: kp, + if privateKey == nil { + if privateKey, err = rsa.GenerateKey(rand.Reader, GatewayKeySize); err != nil { + return nil, errors.Wrap(err, "failed to generate gateway private key") + } + if err = pushGatewayPrivateKey(privateKey); err != nil { + return nil, errors.Wrap(err, "failed to push gateway private key to trusted enclave") + } } + ac := &AuthCenter{gatewayKeys: keyPair{ + PrivateKey: privateKey, + PublicKey: &privateKey.PublicKey, + }} return ac, nil } -func (ac *AuthCenter) PackBearerToken(bt service.BearerToken) ([]byte, error) { - // TODO - panic("unimplemented method") +func (ac *AuthCenter) PackBearerToken(bearerToken *service.BearerTokenMsg) ([]byte, error) { + data, err := bearerToken.Marshal() + if err != nil { + return nil, errors.Wrap(err, "failed to marshal bearer token") + } + encryptedKeyID, err := ac.encrypt(compress(data)) + if err != nil { + return nil, errors.Wrap(err, "") + } + return append(sha256Hash(data), encryptedKeyID...), nil } -func (ac *AuthCenter) UnpackBearerToken(packedCredentials []byte) (service.BearerToken, error) { - zstdDecoder, _ := zstd.NewReader(nil) - // secretHash := packedCredentials[:32] - _ = packedCredentials[:32] - compressedKeyID := packedCredentials[32:] - // Get an encrypted key. - var encryptedKeyID []byte - if _, err := zstdDecoder.DecodeAll(compressedKeyID, encryptedKeyID); err != nil { +func (ac *AuthCenter) UnpackBearerToken(packedBearerToken []byte) (*service.BearerTokenMsg, error) { + compressedKeyID := packedBearerToken[32:] + encryptedKeyID, err := decompress(compressedKeyID) + if err != nil { return nil, errors.Wrap(err, "failed to decompress key ID") } - // TODO: Decrypt the key ID. - var keyID []byte + keyID, err := ac.decrypt(encryptedKeyID) + if err != nil { + return nil, errors.Wrap(err, "failed to decrypt key ID") + } bearerToken := new(service.BearerTokenMsg) if err := bearerToken.Unmarshal(keyID); err != nil { return nil, errors.Wrap(err, "failed to unmarshal embedded bearer token") } return bearerToken, nil } + +func pullGatewayPrivateKey() (*rsa.PrivateKey, error) { + // TODO: Pull the private key from a persistent and trusted enclave. + return nil, nil +} + +func pushGatewayPrivateKey(key *rsa.PrivateKey) error { + // TODO: Push the private key to a persistent and trusted enclave. + return nil +} + +func (ac *AuthCenter) encrypt(data []byte) ([]byte, error) { + return rsa.EncryptOAEP(sha256.New(), rand.Reader, ac.gatewayKeys.PublicKey, data, []byte{}) +} + +func (ac *AuthCenter) decrypt(data []byte) ([]byte, error) { + return rsa.DecryptOAEP(sha256.New(), rand.Reader, ac.gatewayKeys.PrivateKey, data, []byte{}) +} + +func compress(data []byte) []byte { + var compressedData []byte + zstdEncoder, _ := zstd.NewWriter(nil) + zstdEncoder.EncodeAll(data, compressedData) + return compressedData +} + +func decompress(data []byte) ([]byte, error) { + var decompressedData []byte + zstdDecoder, _ := zstd.NewReader(nil) + if _, err := zstdDecoder.DecodeAll(data, decompressedData); err != nil { + return nil, err + } + return decompressedData, nil +} + +func sha256Hash(data []byte) []byte { + hash := sha256.New() + hash.Write(data) + return hash.Sum(nil) +}