forked from TrueCloudLab/frostfs-s3-gw
7ab473a688
Signed-off-by: Denis Kirillov <denis@nspcc.ru>
342 lines
7.4 KiB
Go
342 lines
7.4 KiB
Go
package encryption
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
errorsStd "errors"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/minio/sio"
|
|
)
|
|
|
|
// Params contains encryption key info.
|
|
type Params struct {
|
|
customerKey []byte
|
|
}
|
|
|
|
// ObjectEncryption stores parsed object encryption headers.
|
|
type ObjectEncryption struct {
|
|
Enabled bool
|
|
Algorithm string
|
|
HMACKey string
|
|
HMACSalt string
|
|
}
|
|
|
|
type encryptedPart struct {
|
|
size uint64
|
|
encryptedSize uint64
|
|
}
|
|
|
|
// Range stores payload interval.
|
|
type Range struct {
|
|
Start uint64
|
|
End uint64
|
|
}
|
|
|
|
// Decrypter allows decrypt payload of encrypted object.
|
|
type Decrypter struct {
|
|
reader io.Reader
|
|
decReader io.Reader
|
|
parts []encryptedPart
|
|
currentPart int
|
|
encryption Params
|
|
|
|
rangeParam *Range
|
|
|
|
partDataRemain uint64
|
|
encPartRangeLen uint64
|
|
|
|
seqNumber uint64
|
|
decLen uint64
|
|
skipLen uint64
|
|
|
|
ln uint64
|
|
off uint64
|
|
}
|
|
|
|
const (
|
|
blockSize = 1 << 16 // 64KB
|
|
fullBlockSize = blockSize + 32
|
|
aes256KeySize = 32
|
|
)
|
|
|
|
// NewParams creates new params to encrypt with provided key.
|
|
func NewParams(key []byte) (*Params, error) {
|
|
if len(key) != aes256KeySize {
|
|
return nil, fmt.Errorf("invalid key size: %d", len(key))
|
|
}
|
|
var p Params
|
|
p.customerKey = make([]byte, aes256KeySize)
|
|
copy(p.customerKey, key)
|
|
return &p, nil
|
|
}
|
|
|
|
// Key returns encryption key.
|
|
func (p Params) Key() []byte {
|
|
return p.customerKey
|
|
}
|
|
|
|
// Enabled returns true if key isn't empty.
|
|
func (p Params) Enabled() bool {
|
|
return len(p.customerKey) > 0
|
|
}
|
|
|
|
// HMAC computes salted HMAC.
|
|
func (p Params) HMAC() ([]byte, []byte, error) {
|
|
mac := hmac.New(sha256.New, p.Key())
|
|
|
|
salt := make([]byte, 16)
|
|
if _, err := rand.Read(salt); err != nil {
|
|
return nil, nil, errorsStd.New("failed to init create salt")
|
|
}
|
|
|
|
mac.Write(salt)
|
|
return mac.Sum(nil), salt, nil
|
|
}
|
|
|
|
// MatchObjectEncryption checks if encryption params are valid for provided object.
|
|
func (p Params) MatchObjectEncryption(encInfo ObjectEncryption) error {
|
|
if p.Enabled() != encInfo.Enabled {
|
|
return errorsStd.New("invalid encryption view")
|
|
}
|
|
|
|
if !encInfo.Enabled {
|
|
return nil
|
|
}
|
|
|
|
hmacSalt, err := hex.DecodeString(encInfo.HMACSalt)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid hmacSalt '%s': %w", encInfo.HMACSalt, err)
|
|
}
|
|
|
|
hmacKey, err := hex.DecodeString(encInfo.HMACKey)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid hmacKey '%s': %w", encInfo.HMACKey, err)
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, p.Key())
|
|
mac.Write(hmacSalt)
|
|
expectedHmacKey := mac.Sum(nil)
|
|
if !bytes.Equal(expectedHmacKey, hmacKey) {
|
|
return errorsStd.New("mismatched hmac key")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewMultipartDecrypter creates new decrypted that can decrypt multipart object
|
|
// that contains concatenation of encrypted parts.
|
|
func NewMultipartDecrypter(p Params, decryptedObjectSize uint64, partsSizes []uint64, r *Range) (*Decrypter, error) {
|
|
parts := make([]encryptedPart, len(partsSizes))
|
|
|
|
for i, size := range partsSizes {
|
|
encPartSize, err := sio.EncryptedSize(size)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compute encrypted size: %w", err)
|
|
}
|
|
|
|
parts[i] = encryptedPart{
|
|
size: size,
|
|
encryptedSize: encPartSize,
|
|
}
|
|
}
|
|
|
|
rangeParam := r
|
|
if rangeParam == nil {
|
|
rangeParam = &Range{
|
|
End: decryptedObjectSize - 1,
|
|
}
|
|
}
|
|
|
|
return newDecrypter(p, parts, rangeParam)
|
|
}
|
|
|
|
// NewDecrypter creates decrypter for regular encrypted object.
|
|
func NewDecrypter(p Params, encryptedObjectSize uint64, r *Range) (*Decrypter, error) {
|
|
decSize, err := sio.DecryptedSize(encryptedObjectSize)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("compute decrypted size: %w", err)
|
|
}
|
|
|
|
parts := []encryptedPart{{
|
|
size: decSize,
|
|
encryptedSize: encryptedObjectSize,
|
|
}}
|
|
|
|
return newDecrypter(p, parts, r)
|
|
}
|
|
|
|
func newDecrypter(p Params, parts []encryptedPart, r *Range) (*Decrypter, error) {
|
|
if !p.Enabled() {
|
|
return nil, errorsStd.New("couldn't create decrypter with disabled encryption")
|
|
}
|
|
|
|
if r != nil && r.Start > r.End {
|
|
return nil, fmt.Errorf("invalid range: %d %d", r.Start, r.End)
|
|
}
|
|
|
|
decReader := &Decrypter{
|
|
parts: parts,
|
|
rangeParam: r,
|
|
encryption: p,
|
|
}
|
|
|
|
decReader.initRangeParams()
|
|
|
|
return decReader, nil
|
|
}
|
|
|
|
// DecryptedLength is actual (decrypted) length of data.
|
|
func (d Decrypter) DecryptedLength() uint64 {
|
|
return d.decLen
|
|
}
|
|
|
|
// EncryptedLength is size of encrypted data that should be read for successful decryption.
|
|
func (d Decrypter) EncryptedLength() uint64 {
|
|
return d.ln
|
|
}
|
|
|
|
// EncryptedOffset is offset of encrypted payload for successful decryption.
|
|
func (d Decrypter) EncryptedOffset() uint64 {
|
|
return d.off
|
|
}
|
|
|
|
func (d *Decrypter) initRangeParams() {
|
|
d.partDataRemain = d.parts[d.currentPart].size
|
|
d.encPartRangeLen = d.parts[d.currentPart].encryptedSize
|
|
if d.rangeParam == nil {
|
|
d.decLen = d.partDataRemain
|
|
d.ln = d.encPartRangeLen
|
|
return
|
|
}
|
|
|
|
start, end := d.rangeParam.Start, d.rangeParam.End
|
|
|
|
var sum, encSum uint64
|
|
var partStart int
|
|
for i, part := range d.parts {
|
|
if start < sum+part.size {
|
|
partStart = i
|
|
break
|
|
}
|
|
sum += part.size
|
|
encSum += part.encryptedSize
|
|
}
|
|
|
|
d.skipLen = (start - sum) % blockSize
|
|
d.seqNumber = (start - sum) / blockSize
|
|
encOffPart := d.seqNumber * fullBlockSize
|
|
d.off = encSum + encOffPart
|
|
d.encPartRangeLen = d.encPartRangeLen - encOffPart
|
|
d.partDataRemain = d.partDataRemain + sum - start
|
|
|
|
var partEnd int
|
|
for i, part := range d.parts[partStart:] {
|
|
index := partStart + i
|
|
if end < sum+part.size {
|
|
partEnd = index
|
|
break
|
|
}
|
|
sum += part.size
|
|
encSum += part.encryptedSize
|
|
}
|
|
|
|
payloadPartEnd := (end - sum) / blockSize
|
|
endEnc := encSum + (payloadPartEnd+1)*fullBlockSize
|
|
|
|
endPartEnc := encSum + d.parts[partEnd].encryptedSize
|
|
if endPartEnc < endEnc {
|
|
endEnc = endPartEnc
|
|
}
|
|
d.ln = endEnc - d.off
|
|
d.decLen = end - start + 1
|
|
|
|
if d.ln < d.encPartRangeLen {
|
|
d.encPartRangeLen = d.ln
|
|
}
|
|
if d.decLen < d.partDataRemain {
|
|
d.partDataRemain = d.decLen
|
|
}
|
|
}
|
|
|
|
func (d *Decrypter) updateRangeParams() {
|
|
d.partDataRemain = d.parts[d.currentPart].size
|
|
d.encPartRangeLen = d.parts[d.currentPart].encryptedSize
|
|
d.seqNumber = 0
|
|
d.skipLen = 0
|
|
}
|
|
|
|
// Read implements io.Reader.
|
|
func (d *Decrypter) Read(p []byte) (int, error) {
|
|
if uint64(len(p)) < d.partDataRemain {
|
|
n, err := d.decReader.Read(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
d.partDataRemain -= uint64(n)
|
|
return n, nil
|
|
}
|
|
|
|
n1, err := io.ReadFull(d.decReader, p[:d.partDataRemain])
|
|
if err != nil {
|
|
return n1, err
|
|
}
|
|
|
|
d.currentPart++
|
|
if d.currentPart == len(d.parts) {
|
|
return n1, io.EOF
|
|
}
|
|
|
|
d.updateRangeParams()
|
|
|
|
err = d.initNextDecReader()
|
|
if err != nil {
|
|
return n1, err
|
|
}
|
|
|
|
n2, err := d.decReader.Read(p[n1:])
|
|
if err != nil {
|
|
return n1 + n2, err
|
|
}
|
|
|
|
d.partDataRemain -= uint64(n2)
|
|
|
|
return n1 + n2, nil
|
|
}
|
|
|
|
// SetReader sets encrypted payload reader that should be decrypted.
|
|
// Must be invoked before any read.
|
|
func (d *Decrypter) SetReader(r io.Reader) error {
|
|
d.reader = r
|
|
return d.initNextDecReader()
|
|
}
|
|
|
|
func (d *Decrypter) initNextDecReader() error {
|
|
if d.reader == nil {
|
|
return errorsStd.New("reader isn't set")
|
|
}
|
|
|
|
r, err := sio.DecryptReader(io.LimitReader(d.reader, int64(d.encPartRangeLen)),
|
|
sio.Config{
|
|
MinVersion: sio.Version20,
|
|
SequenceNumber: uint32(d.seqNumber),
|
|
Key: d.encryption.Key(),
|
|
CipherSuites: []byte{sio.AES_256_GCM},
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("couldn't create decrypter: %w", err)
|
|
}
|
|
|
|
if d.skipLen > 0 {
|
|
if _, err = io.CopyN(io.Discard, r, int64(d.skipLen)); err != nil {
|
|
return fmt.Errorf("couldn't skip some bytes: %w", err)
|
|
}
|
|
}
|
|
d.decReader = r
|
|
|
|
return nil
|
|
}
|