diff --git a/crypt/cipher.go b/crypt/cipher.go index 9ce04e9e1..909712767 100644 --- a/crypt/cipher.go +++ b/crypt/cipher.go @@ -8,7 +8,6 @@ import ( "encoding/base32" "fmt" "io" - "io/ioutil" "strings" "sync" "unicode/utf8" @@ -48,6 +47,7 @@ var ( ErrorBadBase32Encoding = errors.New("bad base32 filename encoding") ErrorFileClosed = errors.New("file already closed") ErrorNotAnEncryptedFile = errors.New("not an encrypted file - no \"" + encryptedSuffix + "\" suffix") + ErrorBadSeek = errors.New("Seek beyond end of file") defaultSalt = []byte{0xA8, 0x0D, 0xF4, 0x3A, 0x8F, 0xBD, 0x03, 0x08, 0xA7, 0xCA, 0xB8, 0x3E, 0x58, 0x1F, 0x86, 0xB1} ) @@ -404,6 +404,7 @@ func (n *nonce) add(x uint64) { // encrypter encrypts an io.Reader on the fly type encrypter struct { + mu sync.Mutex in io.Reader c *cipher nonce nonce @@ -437,6 +438,9 @@ func (c *cipher) newEncrypter(in io.Reader) (*encrypter, error) { // Read as per io.Reader func (fh *encrypter) Read(p []byte) (n int, err error) { + fh.mu.Lock() + defer fh.mu.Unlock() + if fh.err != nil { return 0, fh.err } @@ -474,7 +478,9 @@ func (fh *encrypter) finish(err error) (int, error) { } fh.err = err fh.c.putBlock(fh.buf) + fh.buf = nil fh.c.putBlock(fh.readBuf) + fh.readBuf = nil return 0, err } @@ -489,6 +495,7 @@ func (c *cipher) EncryptData(in io.Reader) (io.Reader, error) { // decrypter decrypts an io.ReaderCloser on the fly type decrypter struct { + mu sync.Mutex rc io.ReadCloser nonce nonce initialNonce nonce @@ -551,37 +558,48 @@ func (c *cipher) newDecrypterSeek(open OpenAtOffset, offset int64) (fh *decrypte return fh, nil } +// read data into internal buffer - call with fh.mu held +func (fh *decrypter) fillBuffer() (err error) { + // FIXME should overlap the reads with a go-routine and 2 buffers? + readBuf := fh.readBuf + n, err := io.ReadFull(fh.rc, readBuf) + if err == io.EOF { + // ReadFull only returns n=0 and EOF + return io.EOF + } else if err == io.ErrUnexpectedEOF { + // Next read will return EOF + } else if err != nil { + return err + } + // Check header + 1 byte exists + if n <= blockHeaderSize { + return ErrorEncryptedFileBadHeader + } + // Decrypt the block using the nonce + block := fh.buf + _, ok := secretbox.Open(block[:0], readBuf[:n], fh.nonce.pointer(), &fh.c.dataKey) + if !ok { + return ErrorEncryptedBadBlock + } + fh.bufIndex = 0 + fh.bufSize = n - blockHeaderSize + fh.nonce.increment() + return nil +} + // Read as per io.Reader func (fh *decrypter) Read(p []byte) (n int, err error) { + fh.mu.Lock() + defer fh.mu.Unlock() + if fh.err != nil { return 0, fh.err } if fh.bufIndex >= fh.bufSize { - // Read data - // FIXME should overlap the reads with a go-routine and 2 buffers? - readBuf := fh.readBuf - n, err = io.ReadFull(fh.rc, readBuf) - if err == io.EOF { - // ReadFull only returns n=0 and EOF - return 0, fh.finish(io.EOF) - } else if err == io.ErrUnexpectedEOF { - // Next read will return EOF - } else if err != nil { + err = fh.fillBuffer() + if err != nil { return 0, fh.finish(err) } - // Check header + 1 byte exists - if n <= blockHeaderSize { - return 0, fh.finish(ErrorEncryptedFileBadHeader) - } - // Decrypt the block using the nonce - block := fh.buf - _, ok := secretbox.Open(block[:0], readBuf[:n], fh.nonce.pointer(), &fh.c.dataKey) - if !ok { - return 0, fh.finish(ErrorEncryptedBadBlock) - } - fh.bufIndex = 0 - fh.bufSize = n - blockHeaderSize - fh.nonce.increment() } n = copy(p, fh.buf[fh.bufIndex:fh.bufSize]) fh.bufIndex += n @@ -590,6 +608,9 @@ func (fh *decrypter) Read(p []byte) (n int, err error) { // Seek as per io.Seeker func (fh *decrypter) Seek(offset int64, whence int) (int64, error) { + fh.mu.Lock() + defer fh.mu.Unlock() + if fh.open == nil { return 0, fh.finish(errors.New("can't seek - not initialised with newDecrypterSeek")) } @@ -599,7 +620,7 @@ func (fh *decrypter) Seek(offset int64, whence int) (int64, error) { // Reset error or return it if not EOF if fh.err == io.EOF { - fh.err = nil + fh.unFinish() } else if fh.err != nil { return 0, fh.err } @@ -636,16 +657,18 @@ func (fh *decrypter) Seek(offset int64, whence int) (int64, error) { fh.rc = rc } - // Empty the buffer - fh.bufIndex = 0 - fh.bufSize = 0 - - // Discard excess bytes - _, err := io.CopyN(ioutil.Discard, fh, discard) + // Fill the buffer + err := fh.fillBuffer() if err != nil { return 0, fh.finish(err) } + // Discard bytes from the buffer + if int(discard) > fh.bufSize { + return 0, fh.finish(ErrorBadSeek) + } + fh.bufIndex = int(discard) + return offset, nil } @@ -656,12 +679,31 @@ func (fh *decrypter) finish(err error) error { } fh.err = err fh.c.putBlock(fh.buf) + fh.buf = nil fh.c.putBlock(fh.readBuf) + fh.readBuf = nil return err } +// unFinish undoes the effects of finish +func (fh *decrypter) unFinish() { + // Clear error + fh.err = nil + + // reinstate the buffers + fh.buf = fh.c.getBlock() + fh.readBuf = fh.c.getBlock() + + // Empty the buffer + fh.bufIndex = 0 + fh.bufSize = 0 +} + // Close func (fh *decrypter) Close() error { + fh.mu.Lock() + defer fh.mu.Unlock() + // Check already closed if fh.err == ErrorFileClosed { return fh.err