From c6901af5aa24189db823e0eb868b80a8a440f5e0 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 11 Feb 2015 17:41:11 +0100 Subject: [PATCH] Add streaming encrypt functions --- key.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++ key_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 182 insertions(+), 4 deletions(-) diff --git a/key.go b/key.go index 0b8dff2ce..8162f21ea 100644 --- a/key.go +++ b/key.go @@ -1,6 +1,7 @@ package restic import ( + "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" @@ -9,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "hash" "io" "os" "os/user" @@ -315,6 +317,95 @@ func (k *Key) Encrypt(ciphertext, plaintext []byte) (int, error) { return k.encrypt(k.master, ciphertext, plaintext) } +type HashReader struct { + r io.Reader + h hash.Hash + sum []byte + closed bool +} + +func NewHashReader(r io.Reader, h hash.Hash) *HashReader { + return &HashReader{ + h: h, + r: io.TeeReader(r, h), + sum: make([]byte, 0, h.Size()), + } +} + +func (h *HashReader) Read(p []byte) (n int, err error) { + if !h.closed { + n, err = h.r.Read(p) + + if err == io.EOF { + h.closed = true + h.sum = h.h.Sum(h.sum) + } else if err != nil { + return + } + } + + if h.closed { + // output hash + r := len(p) - n + + if r > 0 { + c := copy(p[n:], h.sum) + h.sum = h.sum[c:] + + n += c + err = nil + } + + if len(h.sum) == 0 { + err = io.EOF + } + } + + return +} + +// encryptFrom encrypts and signs data read from rd with ks. The returned +// io.Reader reads IV || Ciphertext || HMAC. For the hash function, SHA256 is +// used. +func (k *Key) encryptFrom(ks *keys, rd io.Reader) io.Reader { + // create IV + iv := make([]byte, ivSize) + + _, err := io.ReadFull(rand.Reader, iv) + if err != nil { + panic(fmt.Sprintf("unable to generate new random iv: %v", err)) + } + + c, err := aes.NewCipher(ks.Encrypt) + if err != nil { + panic(fmt.Sprintf("unable to create cipher: %v", err)) + } + + ivReader := bytes.NewReader(iv) + + encryptReader := cipher.StreamReader{ + R: rd, + S: cipher.NewCTR(c, iv), + } + + return NewHashReader(io.MultiReader(ivReader, encryptReader), + hmac.New(sha256.New, ks.Sign)) +} + +// EncryptFrom encrypts and signs data read from rd with the master key. The +// returned io.Reader reads IV || Ciphertext || HMAC. For the hash function, +// SHA256 is used. +func (k *Key) EncryptFrom(rd io.Reader) io.Reader { + return k.encryptFrom(k.master, rd) +} + +// EncryptFrom encrypts and signs data read from rd with the user key. The +// returned io.Reader reads IV || Ciphertext || HMAC. For the hash function, +// SHA256 is used. +func (k *Key) EncryptUserFrom(rd io.Reader) io.Reader { + return k.encryptFrom(k.user, rd) +} + // Decrypt verifes and decrypts the ciphertext. Ciphertext must be in the form // IV || Ciphertext || HMAC. func (k *Key) decrypt(ks *keys, ciphertext []byte) ([]byte, error) { diff --git a/key_test.go b/key_test.go index f5c9005f5..f7ae5d9e7 100644 --- a/key_test.go +++ b/key_test.go @@ -1,6 +1,8 @@ package restic_test import ( + "bytes" + "crypto/sha256" "flag" "io" "io/ioutil" @@ -61,10 +63,7 @@ func TestEncryptDecrypt(t *testing.T) { for _, size := range tests { data := make([]byte, size) - f, err := os.Open("/dev/urandom") - ok(t, err) - - _, err = io.ReadFull(f, data) + _, err := io.ReadFull(randomReader(42, size), data) ok(t, err) ciphertext := restic.GetChunkBuf("TestEncryptDecrypt") @@ -128,6 +127,24 @@ func TestLargeEncrypt(t *testing.T) { } } +func BenchmarkEncryptReader(b *testing.B) { + size := 8 << 20 // 8MiB + rd := randomReader(23, size) + + be := setupBackend(b) + defer teardownBackend(b, be) + k := setupKey(b, be, testPassword) + + b.ResetTimer() + b.SetBytes(int64(size)) + + for i := 0; i < b.N; i++ { + rd.Seek(0, 0) + _, err := io.Copy(ioutil.Discard, k.EncryptFrom(rd)) + ok(b, err) + } +} + func BenchmarkEncrypt(b *testing.B) { size := 8 << 20 // 8MiB data := make([]byte, size) @@ -168,3 +185,73 @@ func BenchmarkDecrypt(b *testing.B) { } restic.FreeChunkBuf("BenchmarkDecrypt", ciphertext) } + +func TestHashReader(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + if *testLargeCrypto { + tests = append(tests, 7<<20+123) + } + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(randomReader(42, size), data) + ok(t, err) + + expectedHash := sha256.Sum256(data) + + rd := restic.NewHashReader(bytes.NewReader(data), sha256.New()) + + target := bytes.NewBuffer(nil) + n, err := io.Copy(target, rd) + ok(t, err) + + assert(t, n == int64(size)+int64(len(expectedHash)), + "HashReader: invalid number of bytes read: got %d, expected %d", + n, size+len(expectedHash)) + + r := target.Bytes() + resultingHash := r[len(r)-len(expectedHash):] + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashReader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + + // try to read again, must return io.EOF + n2, err := rd.Read(make([]byte, 100)) + assert(t, n2 == 0, "HashReader returned %d additional bytes", n) + assert(t, err == io.EOF, "HashReader returned %v instead of EOF", err) + } +} + +func TestEncryptStreamReader(t *testing.T) { + s := setupBackend(t) + defer teardownBackend(t, s) + k := setupKey(t, s, testPassword) + + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + if *testLargeCrypto { + tests = append(tests, 7<<20+123) + } + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(randomReader(42, size), data) + ok(t, err) + + erd := k.EncryptFrom(bytes.NewReader(data)) + + ciphertext, err := ioutil.ReadAll(erd) + ok(t, err) + + l := len(data) + restic.CiphertextExtension + assert(t, len(ciphertext) == l, + "wrong ciphertext length: expected %d, got %d", + l, len(ciphertext)) + + // decrypt with default function + plaintext, err := k.Decrypt(ciphertext) + ok(t, err) + assert(t, bytes.Equal(data, plaintext), + "wrong plaintext after decryption: expected %02x, got %02x", + data, plaintext) + } +}