diff --git a/registry/storage/blobwriter.go b/registry/storage/blobwriter.go index 2bb25dcc..54c79b6f 100644 --- a/registry/storage/blobwriter.go +++ b/registry/storage/blobwriter.go @@ -33,7 +33,7 @@ type blobWriter struct { id string startedAt time.Time digester digest.Digester - written int64 // track the contiguous write + written int64 // track the write to digester fileWriter storagedriver.FileWriter driver storagedriver.StorageDriver @@ -119,7 +119,12 @@ func (bw *blobWriter) Write(p []byte) (int, error) { return 0, err } - n, err := io.MultiWriter(bw.fileWriter, bw.digester.Hash()).Write(p) + _, err := bw.fileWriter.Write(p) + if err != nil { + return 0, err + } + + n, err := bw.digester.Hash().Write(p) bw.written += int64(n) return n, err @@ -133,7 +138,11 @@ func (bw *blobWriter) ReadFrom(r io.Reader) (n int64, err error) { return 0, err } - nn, err := io.Copy(io.MultiWriter(bw.fileWriter, bw.digester.Hash()), r) + // Using a TeeReader instead of MultiWriter ensures Copy returns + // the amount written to the digester as well as ensuring that we + // write to the fileWriter first + tee := io.TeeReader(r, bw.fileWriter) + nn, err := io.Copy(bw.digester.Hash(), tee) bw.written += nn return nn, err diff --git a/registry/storage/blobwriter_resumable.go b/registry/storage/blobwriter_resumable.go index 1f3b93d4..b970e865 100644 --- a/registry/storage/blobwriter_resumable.go +++ b/registry/storage/blobwriter_resumable.go @@ -21,18 +21,13 @@ func (bw *blobWriter) resumeDigest(ctx context.Context) error { return errResumableDigestNotAvailable } - h, ok := bw.digester.Hash().(encoding.BinaryMarshaler) + h, ok := bw.digester.Hash().(encoding.BinaryUnmarshaler) if !ok { return errResumableDigestNotAvailable } - state, err := h.MarshalBinary() - if err != nil { - return err - } - offset := bw.fileWriter.Size() - if offset == int64(len(state)) { + if offset == bw.written { // State of digester is already at the requested offset. return nil } @@ -62,19 +57,14 @@ func (bw *blobWriter) resumeDigest(ctx context.Context) error { return err } - // This type assertion is safe since we already did an assertion at the beginning - if err = h.(encoding.BinaryUnmarshaler).UnmarshalBinary(storedState); err != nil { - return err - } - - state, err = h.(encoding.BinaryMarshaler).MarshalBinary() - if err != nil { + if err = h.UnmarshalBinary(storedState); err != nil { return err } + bw.written = hashStateMatch.offset } // Mind the gap. - if gapLen := offset - int64(len(state)); gapLen > 0 { + if gapLen := offset - bw.written; gapLen > 0 { return errResumableDigestNotAvailable } @@ -136,14 +126,14 @@ func (bw *blobWriter) storeHashState(ctx context.Context) error { state, err := h.MarshalBinary() if err != nil { - return fmt.Errorf("could not marshal: %v", err) + return err } uploadHashStatePath, err := pathFor(uploadHashStatePathSpec{ name: bw.blobStore.repository.Named().String(), id: bw.id, alg: bw.digester.Digest().Algorithm(), - offset: int64(len(state)), + offset: bw.written, }) if err != nil {