Fix a race condition in pull through cache population by removing the functionality

of readers joining current downloads.  Concurrent requests for the same blob
will not block, but only the first instance will be comitted locally.

Signed-off-by: Richard Scothern <richard.scothern@gmail.com>
This commit is contained in:
Richard Scothern 2015-09-18 16:11:35 -07:00
parent 1039e2dc26
commit 36fa22c821
3 changed files with 260 additions and 175 deletions

View file

@ -22,15 +22,10 @@ type proxyBlobStore struct {
scheduler *scheduler.TTLExpirationScheduler scheduler *scheduler.TTLExpirationScheduler
} }
var _ distribution.BlobStore = proxyBlobStore{} var _ distribution.BlobStore = &proxyBlobStore{}
type inflightBlob struct {
refCount int
bw distribution.BlobWriter
}
// inflight tracks currently downloading blobs // inflight tracks currently downloading blobs
var inflight = make(map[digest.Digest]*inflightBlob) var inflight = make(map[digest.Digest]struct{})
// mu protects inflight // mu protects inflight
var mu sync.Mutex var mu sync.Mutex
@ -42,140 +37,113 @@ func setResponseHeaders(w http.ResponseWriter, length int64, mediaType string, d
w.Header().Set("Etag", digest.String()) w.Header().Set("Etag", digest.String())
} }
func (pbs proxyBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { func (pbs *proxyBlobStore) copyContent(ctx context.Context, dgst digest.Digest, writer io.Writer) (distribution.Descriptor, error) {
desc, err := pbs.localStore.Stat(ctx, dgst) desc, err := pbs.remoteStore.Stat(ctx, dgst)
if err != nil && err != distribution.ErrBlobUnknown {
return err
}
if err == nil {
proxyMetrics.BlobPush(uint64(desc.Size))
return pbs.localStore.ServeBlob(ctx, w, r, dgst)
}
desc, err = pbs.remoteStore.Stat(ctx, dgst)
if err != nil { if err != nil {
return err return distribution.Descriptor{}, err
}
if w, ok := writer.(http.ResponseWriter); ok {
setResponseHeaders(w, desc.Size, desc.MediaType, dgst)
} }
remoteReader, err := pbs.remoteStore.Open(ctx, dgst) remoteReader, err := pbs.remoteStore.Open(ctx, dgst)
if err != nil { if err != nil {
return err return distribution.Descriptor{}, err
} }
bw, isNew, cleanup, err := getOrCreateBlobWriter(ctx, pbs.localStore, desc) _, err = io.CopyN(writer, remoteReader, desc.Size)
if err != nil { if err != nil {
return err return distribution.Descriptor{}, err
} }
defer cleanup()
if isNew { proxyMetrics.BlobPush(uint64(desc.Size))
go func() {
err := streamToStorage(ctx, remoteReader, desc, bw) return desc, nil
}
func (pbs *proxyBlobStore) serveLocal(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) (bool, error) {
localDesc, err := pbs.localStore.Stat(ctx, dgst)
if err != nil { if err != nil {
context.GetLogger(ctx).Error(err) // Stat can report a zero sized file here if it's checked between creation
// and population. Return nil error, and continue
return false, nil
} }
proxyMetrics.BlobPull(uint64(desc.Size)) if err == nil {
proxyMetrics.BlobPush(uint64(localDesc.Size))
return true, pbs.localStore.ServeBlob(ctx, w, r, dgst)
}
return false, nil
}
func (pbs *proxyBlobStore) storeLocal(ctx context.Context, dgst digest.Digest) error {
defer func() {
mu.Lock()
delete(inflight, dgst)
mu.Unlock()
}() }()
err := streamToClient(ctx, w, desc, bw)
var desc distribution.Descriptor
var err error
var bw distribution.BlobWriter
bw, err = pbs.localStore.Create(ctx)
if err != nil { if err != nil {
return err return err
} }
proxyMetrics.BlobPush(uint64(desc.Size)) desc, err = pbs.copyContent(ctx, dgst, bw)
pbs.scheduler.AddBlob(dgst.String(), blobTTL)
return nil
}
err = streamToClient(ctx, w, desc, bw)
if err != nil { if err != nil {
return err return err
} }
proxyMetrics.BlobPush(uint64(desc.Size))
_, err = bw.Commit(ctx, desc)
if err != nil {
return err
}
return nil return nil
} }
type cleanupFunc func() func (pbs *proxyBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
served, err := pbs.serveLocal(ctx, w, r, dgst)
// getOrCreateBlobWriter will track which blobs are currently being downloaded and enable client requesting
// the same blob concurrently to read from the existing stream.
func getOrCreateBlobWriter(ctx context.Context, blobs distribution.BlobService, desc distribution.Descriptor) (distribution.BlobWriter, bool, cleanupFunc, error) {
mu.Lock()
defer mu.Unlock()
dgst := desc.Digest
cleanup := func() {
mu.Lock()
defer mu.Unlock()
inflight[dgst].refCount--
if inflight[dgst].refCount == 0 {
defer delete(inflight, dgst)
_, err := inflight[dgst].bw.Commit(ctx, desc)
if err != nil { if err != nil {
// There is a narrow race here where Commit can be called while this blob's TTL is expiring context.GetLogger(ctx).Errorf("Error serving blob from local storage: %s", err.Error())
// and its being removed from storage. In that case, the client stream will continue return err
// uninterruped and the blob will be pulled through on the next request, so just log it
context.GetLogger(ctx).Errorf("Error committing blob: %q", err)
} }
} if served {
return nil
} }
var bw distribution.BlobWriter mu.Lock()
_, ok := inflight[dgst] _, ok := inflight[dgst]
if ok { if ok {
bw = inflight[dgst].bw mu.Unlock()
inflight[dgst].refCount++ _, err := pbs.copyContent(ctx, dgst, w)
return bw, false, cleanup, nil return err
} }
inflight[dgst] = struct{}{}
mu.Unlock()
var err error go func(dgst digest.Digest) {
bw, err = blobs.Create(ctx) if err := pbs.storeLocal(ctx, dgst); err != nil {
if err != nil { context.GetLogger(ctx).Errorf("Error committing to storage: %s", err.Error())
return nil, false, nil, err
} }
pbs.scheduler.AddBlob(dgst.String(), repositoryTTL)
}(dgst)
inflight[dgst] = &inflightBlob{refCount: 1, bw: bw} _, err = pbs.copyContent(ctx, dgst, w)
return bw, true, cleanup, nil
}
func streamToStorage(ctx context.Context, remoteReader distribution.ReadSeekCloser, desc distribution.Descriptor, bw distribution.BlobWriter) error {
_, err := io.CopyN(bw, remoteReader, desc.Size)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func streamToClient(ctx context.Context, w http.ResponseWriter, desc distribution.Descriptor, bw distribution.BlobWriter) error { func (pbs *proxyBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
setResponseHeaders(w, desc.Size, desc.MediaType, desc.Digest)
reader, err := bw.Reader()
if err != nil {
return err
}
defer reader.Close()
teeReader := io.TeeReader(reader, w)
buf := make([]byte, 32768, 32786)
var soFar int64
for {
rd, err := teeReader.Read(buf)
if err == nil || err == io.EOF {
soFar += int64(rd)
if soFar < desc.Size {
// buffer underflow, keep trying
continue
}
return nil
}
return err
}
}
func (pbs proxyBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
desc, err := pbs.localStore.Stat(ctx, dgst) desc, err := pbs.localStore.Stat(ctx, dgst)
if err == nil { if err == nil {
return desc, err return desc, err
@ -189,26 +157,26 @@ func (pbs proxyBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distrib
} }
// Unsupported functions // Unsupported functions
func (pbs proxyBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) { func (pbs *proxyBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
return distribution.Descriptor{}, distribution.ErrUnsupported return distribution.Descriptor{}, distribution.ErrUnsupported
} }
func (pbs proxyBlobStore) Create(ctx context.Context) (distribution.BlobWriter, error) { func (pbs *proxyBlobStore) Create(ctx context.Context) (distribution.BlobWriter, error) {
return nil, distribution.ErrUnsupported return nil, distribution.ErrUnsupported
} }
func (pbs proxyBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) { func (pbs *proxyBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
return nil, distribution.ErrUnsupported return nil, distribution.ErrUnsupported
} }
func (pbs proxyBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) { func (pbs *proxyBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
return nil, distribution.ErrUnsupported return nil, distribution.ErrUnsupported
} }
func (pbs proxyBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) { func (pbs *proxyBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
return nil, distribution.ErrUnsupported return nil, distribution.ErrUnsupported
} }
func (pbs proxyBlobStore) Delete(ctx context.Context, dgst digest.Digest) error { func (pbs *proxyBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
return distribution.ErrUnsupported return distribution.ErrUnsupported
} }

View file

@ -1,10 +1,13 @@
package proxy package proxy
import ( import (
"fmt" "io/ioutil"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time"
"github.com/docker/distribution" "github.com/docker/distribution"
"github.com/docker/distribution/context" "github.com/docker/distribution/context"
@ -12,75 +15,119 @@ import (
"github.com/docker/distribution/registry/proxy/scheduler" "github.com/docker/distribution/registry/proxy/scheduler"
"github.com/docker/distribution/registry/storage" "github.com/docker/distribution/registry/storage"
"github.com/docker/distribution/registry/storage/cache/memory" "github.com/docker/distribution/registry/storage/cache/memory"
"github.com/docker/distribution/registry/storage/driver/filesystem"
"github.com/docker/distribution/registry/storage/driver/inmemory" "github.com/docker/distribution/registry/storage/driver/inmemory"
) )
var sbsMu sync.Mutex
type statsBlobStore struct { type statsBlobStore struct {
stats map[string]int stats map[string]int
blobs distribution.BlobStore blobs distribution.BlobStore
} }
func (sbs statsBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) { func (sbs statsBlobStore) Put(ctx context.Context, mediaType string, p []byte) (distribution.Descriptor, error) {
sbsMu.Lock()
sbs.stats["put"]++ sbs.stats["put"]++
sbsMu.Unlock()
return sbs.blobs.Put(ctx, mediaType, p) return sbs.blobs.Put(ctx, mediaType, p)
} }
func (sbs statsBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) { func (sbs statsBlobStore) Get(ctx context.Context, dgst digest.Digest) ([]byte, error) {
sbsMu.Lock()
sbs.stats["get"]++ sbs.stats["get"]++
sbsMu.Unlock()
return sbs.blobs.Get(ctx, dgst) return sbs.blobs.Get(ctx, dgst)
} }
func (sbs statsBlobStore) Create(ctx context.Context) (distribution.BlobWriter, error) { func (sbs statsBlobStore) Create(ctx context.Context) (distribution.BlobWriter, error) {
sbsMu.Lock()
sbs.stats["create"]++ sbs.stats["create"]++
sbsMu.Unlock()
return sbs.blobs.Create(ctx) return sbs.blobs.Create(ctx)
} }
func (sbs statsBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) { func (sbs statsBlobStore) Resume(ctx context.Context, id string) (distribution.BlobWriter, error) {
sbsMu.Lock()
sbs.stats["resume"]++ sbs.stats["resume"]++
sbsMu.Unlock()
return sbs.blobs.Resume(ctx, id) return sbs.blobs.Resume(ctx, id)
} }
func (sbs statsBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) { func (sbs statsBlobStore) Open(ctx context.Context, dgst digest.Digest) (distribution.ReadSeekCloser, error) {
sbsMu.Lock()
sbs.stats["open"]++ sbs.stats["open"]++
sbsMu.Unlock()
return sbs.blobs.Open(ctx, dgst) return sbs.blobs.Open(ctx, dgst)
} }
func (sbs statsBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error { func (sbs statsBlobStore) ServeBlob(ctx context.Context, w http.ResponseWriter, r *http.Request, dgst digest.Digest) error {
sbsMu.Lock()
sbs.stats["serveblob"]++ sbs.stats["serveblob"]++
sbsMu.Unlock()
return sbs.blobs.ServeBlob(ctx, w, r, dgst) return sbs.blobs.ServeBlob(ctx, w, r, dgst)
} }
func (sbs statsBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) { func (sbs statsBlobStore) Stat(ctx context.Context, dgst digest.Digest) (distribution.Descriptor, error) {
sbsMu.Lock()
sbs.stats["stat"]++ sbs.stats["stat"]++
sbsMu.Unlock()
return sbs.blobs.Stat(ctx, dgst) return sbs.blobs.Stat(ctx, dgst)
} }
func (sbs statsBlobStore) Delete(ctx context.Context, dgst digest.Digest) error { func (sbs statsBlobStore) Delete(ctx context.Context, dgst digest.Digest) error {
sbsMu.Lock()
sbs.stats["delete"]++ sbs.stats["delete"]++
sbsMu.Unlock()
return sbs.blobs.Delete(ctx, dgst) return sbs.blobs.Delete(ctx, dgst)
} }
type testEnv struct { type testEnv struct {
numUnique int
inRemote []distribution.Descriptor inRemote []distribution.Descriptor
store proxyBlobStore store proxyBlobStore
ctx context.Context ctx context.Context
} }
func (te testEnv) LocalStats() *map[string]int { func (te *testEnv) LocalStats() *map[string]int {
sbsMu.Lock()
ls := te.store.localStore.(statsBlobStore).stats ls := te.store.localStore.(statsBlobStore).stats
sbsMu.Unlock()
return &ls return &ls
} }
func (te testEnv) RemoteStats() *map[string]int { func (te *testEnv) RemoteStats() *map[string]int {
sbsMu.Lock()
rs := te.store.remoteStore.(statsBlobStore).stats rs := te.store.remoteStore.(statsBlobStore).stats
sbsMu.Unlock()
return &rs return &rs
} }
// Populate remote store and record the digests // Populate remote store and record the digests
func makeTestEnv(t *testing.T, name string) testEnv { func makeTestEnv(t *testing.T, name string) *testEnv {
ctx := context.Background() ctx := context.Background()
localRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption) truthDir, err := ioutil.TempDir("", "truth")
if err != nil {
t.Fatalf("unable to create tempdir: %s", err)
}
cacheDir, err := ioutil.TempDir("", "cache")
if err != nil {
t.Fatalf("unable to create tempdir: %s", err)
}
// todo: create a tempfile area here
localRegistry, err := storage.NewRegistry(ctx, filesystem.New(truthDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()), storage.EnableRedirect, storage.DisableDigestResumption)
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -89,7 +136,7 @@ func makeTestEnv(t *testing.T, name string) testEnv {
t.Fatalf("unexpected error getting repo: %v", err) t.Fatalf("unexpected error getting repo: %v", err)
} }
truthRegistry, err := storage.NewRegistry(ctx, inmemory.New(), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider())) truthRegistry, err := storage.NewRegistry(ctx, filesystem.New(cacheDir), storage.BlobDescriptorCacheProvider(memory.NewInMemoryBlobDescriptorCacheProvider()))
if err != nil { if err != nil {
t.Fatalf("error creating registry: %v", err) t.Fatalf("error creating registry: %v", err)
} }
@ -116,33 +163,59 @@ func makeTestEnv(t *testing.T, name string) testEnv {
scheduler: s, scheduler: s,
} }
te := testEnv{ te := &testEnv{
store: proxyBlobStore, store: proxyBlobStore,
ctx: ctx, ctx: ctx,
} }
return te return te
} }
func populate(t *testing.T, te *testEnv, blobCount int) { func makeBlob(size int) []byte {
var inRemote []distribution.Descriptor blob := make([]byte, size, size)
for i := 0; i < blobCount; i++ { for i := 0; i < size; i++ {
bytes := []byte(fmt.Sprintf("blob%d", i)) blob[i] = byte('A' + rand.Int()%48)
}
return blob
}
func init() {
rand.Seed(42)
}
func perm(m []distribution.Descriptor) []distribution.Descriptor {
for i := 0; i < len(m); i++ {
j := rand.Intn(i + 1)
tmp := m[i]
m[i] = m[j]
m[j] = tmp
}
return m
}
func populate(t *testing.T, te *testEnv, blobCount, size, numUnique int) {
var inRemote []distribution.Descriptor
for i := 0; i < numUnique; i++ {
bytes := makeBlob(size)
for j := 0; j < blobCount/numUnique; j++ {
desc, err := te.store.remoteStore.Put(te.ctx, "", bytes) desc, err := te.store.remoteStore.Put(te.ctx, "", bytes)
if err != nil { if err != nil {
t.Errorf("Put in store") t.Fatalf("Put in store")
} }
inRemote = append(inRemote, desc) inRemote = append(inRemote, desc)
} }
}
te.inRemote = inRemote te.inRemote = inRemote
te.numUnique = numUnique
} }
func TestProxyStoreStat(t *testing.T) { func TestProxyStoreStat(t *testing.T) {
te := makeTestEnv(t, "foo/bar") te := makeTestEnv(t, "foo/bar")
remoteBlobCount := 1 remoteBlobCount := 1
populate(t, &te, remoteBlobCount) populate(t, te, remoteBlobCount, 10, 1)
localStats := te.LocalStats() localStats := te.LocalStats()
remoteStats := te.RemoteStats() remoteStats := te.RemoteStats()
@ -164,43 +237,91 @@ func TestProxyStoreStat(t *testing.T) {
} }
} }
func TestProxyStoreServe(t *testing.T) { func TestProxyStoreServeHighConcurrency(t *testing.T) {
te := makeTestEnv(t, "foo/bar") te := makeTestEnv(t, "foo/bar")
remoteBlobCount := 1 blobSize := 200
populate(t, &te, remoteBlobCount) blobCount := 10
numUnique := 1
populate(t, te, blobCount, blobSize, numUnique)
numClients := 16
testProxyStoreServe(t, te, numClients)
}
func TestProxyStoreServeMany(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
blobSize := 200
blobCount := 10
numUnique := 4
populate(t, te, blobCount, blobSize, numUnique)
numClients := 4
testProxyStoreServe(t, te, numClients)
}
// todo(richardscothern): blobCount must be smaller than num clients
func TestProxyStoreServeBig(t *testing.T) {
te := makeTestEnv(t, "foo/bar")
blobSize := 2 << 20
blobCount := 4
numUnique := 2
populate(t, te, blobCount, blobSize, numUnique)
numClients := 4
testProxyStoreServe(t, te, numClients)
}
// testProxyStoreServe will create clients to consume all blobs
// populated in the truth store
func testProxyStoreServe(t *testing.T, te *testEnv, numClients int) {
localStats := te.LocalStats() localStats := te.LocalStats()
remoteStats := te.RemoteStats() remoteStats := te.RemoteStats()
var wg sync.WaitGroup
for i := 0; i < numClients; i++ {
// Serveblob - pulls through blobs // Serveblob - pulls through blobs
for _, dr := range te.inRemote { wg.Add(1)
go func() {
defer wg.Done()
for _, remoteBlob := range te.inRemote {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "", nil) r, err := http.NewRequest("GET", "", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = te.store.ServeBlob(te.ctx, w, r, dr.Digest) err = te.store.ServeBlob(te.ctx, w, r, remoteBlob.Digest)
if err != nil { if err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
dl, err := digest.FromBytes(w.Body.Bytes()) bodyBytes := w.Body.Bytes()
localDigest, err := digest.FromBytes(bodyBytes)
if err != nil { if err != nil {
t.Fatalf("Error making digest from blob") t.Fatalf("Error making digest from blob")
} }
if dl != dr.Digest { if localDigest != remoteBlob.Digest {
t.Errorf("Mismatching blob fetch from proxy") t.Fatalf("Mismatching blob fetch from proxy")
} }
} }
}()
}
if (*localStats)["stat"] != remoteBlobCount && (*localStats)["create"] != remoteBlobCount { wg.Wait()
t.Fatalf("unexpected local stats")
} remoteBlobCount := len(te.inRemote)
if (*remoteStats)["stat"] != remoteBlobCount && (*remoteStats)["open"] != remoteBlobCount { if (*localStats)["stat"] != remoteBlobCount*numClients && (*localStats)["create"] != te.numUnique {
t.Fatalf("unexpected local stats") t.Fatal("Expected: stat:", remoteBlobCount*numClients, "create:", remoteBlobCount)
} }
// Wait for any async storage goroutines to finish
time.Sleep(3 * time.Second)
remoteStatCount := (*remoteStats)["stat"]
remoteOpenCount := (*remoteStats)["open"]
// Serveblob - blobs come from local // Serveblob - blobs come from local
for _, dr := range te.inRemote { for _, dr := range te.inRemote {
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -223,15 +344,11 @@ func TestProxyStoreServe(t *testing.T) {
} }
} }
// Stat to find local, but no new blobs were created localStats = te.LocalStats()
if (*localStats)["stat"] != remoteBlobCount*2 && (*localStats)["create"] != remoteBlobCount*2 { remoteStats = te.RemoteStats()
t.Fatalf("unexpected local stats")
}
// Remote unchanged // Ensure remote unchanged
if (*remoteStats)["stat"] != remoteBlobCount && (*remoteStats)["open"] != remoteBlobCount { if (*remoteStats)["stat"] != remoteStatCount && (*remoteStats)["open"] != remoteOpenCount {
fmt.Printf("\tlocal=%#v, \n\tremote=%#v\n", localStats, remoteStats) t.Fatalf("unexpected remote stats: %#v", remoteStats)
t.Fatalf("unexpected local stats")
} }
} }

View file

@ -94,7 +94,7 @@ func (pr *proxyingRegistry) Repository(ctx context.Context, name string) (distri
} }
return &proxiedRepository{ return &proxiedRepository{
blobStore: proxyBlobStore{ blobStore: &proxyBlobStore{
localStore: localRepo.Blobs(ctx), localStore: localRepo.Blobs(ctx),
remoteStore: remoteRepo.Blobs(ctx), remoteStore: remoteRepo.Blobs(ctx),
scheduler: pr.scheduler, scheduler: pr.scheduler,