diff --git a/lib/pool/pool.go b/lib/pool/pool.go index acd44a12c..07d6f3b91 100644 --- a/lib/pool/pool.go +++ b/lib/pool/pool.go @@ -5,21 +5,31 @@ package pool import ( "fmt" "log" - "sync/atomic" + "sync" "time" "github.com/ncw/rclone/lib/mmap" ) // Pool of internal buffers +// +// We hold buffers in cache. Every time we Get or Put we update +// minFill which is the minimum len(cache) seen. +// +// Every flushTime we remove minFill buffers from the cache as they +// were not used in the previous flushTime interval. type Pool struct { - cache chan []byte - bufferSize int - timer *time.Timer - inUse int32 - flushTime time.Duration - alloc func(int) ([]byte, error) - free func([]byte) error + mu sync.Mutex + cache [][]byte + minFill int // the minimum fill of the cache + bufferSize int + poolSize int + timer *time.Timer + inUse int + flushTime time.Duration + flushPending bool + alloc func(int) ([]byte, error) + free func([]byte) error } // New makes a buffer pool @@ -30,7 +40,8 @@ type Pool struct { // useMmap should be set to use mmap allocations func New(flushTime time.Duration, bufferSize, poolSize int, useMmap bool) *Pool { bp := &Pool{ - cache: make(chan []byte, poolSize), + cache: make([][]byte, 0, poolSize), + poolSize: poolSize, flushTime: flushTime, bufferSize: bufferSize, } @@ -45,47 +56,113 @@ func New(flushTime time.Duration, bufferSize, poolSize int, useMmap bool) *Pool return nil } } - bp.timer = time.AfterFunc(flushTime, bp.Flush) + bp.timer = time.AfterFunc(flushTime, bp.flushAged) return bp } +// get gets the last buffer in bp.cache +// +// Call with mu held +func (bp *Pool) get() []byte { + n := len(bp.cache) - 1 + buf := bp.cache[n] + bp.cache[n] = nil // clear buffer pointer from bp.cache + bp.cache = bp.cache[:n] + return buf +} + +// put puts the buffer on the end of bp.cache +// +// Call with mu held +func (bp *Pool) put(buf []byte) { + bp.cache = append(bp.cache, buf) +} + +// flush n entries from the entire buffer pool +// Call with mu held +func (bp *Pool) flush(n int) { + for i := 0; i < n; i++ { + _ = bp.get() + } + bp.minFill = len(bp.cache) +} + // Flush the entire buffer pool func (bp *Pool) Flush() { - for { - select { - case b := <-bp.cache: - bp.freeBuffer(b) - default: - return - } + bp.mu.Lock() + bp.flush(len(bp.cache)) + bp.mu.Unlock() +} + +// Remove bp.minFill buffers +func (bp *Pool) flushAged() { + bp.mu.Lock() + bp.flushPending = false + bp.flush(bp.minFill) + // If there are still items in the cache, schedule another flush + if len(bp.cache) != 0 { + bp.kickFlusher() } + bp.mu.Unlock() } -// InUse returns the approximate number of buffers in use which -// haven't been returned to the pool. +// InUse returns the number of buffers in use which haven't been +// returned to the pool func (bp *Pool) InUse() int { - return int(atomic.LoadInt32(&bp.inUse)) + bp.mu.Lock() + defer bp.mu.Unlock() + return bp.inUse } -// starts or resets the buffer flusher timer +// InPool returns the number of buffers in the pool +func (bp *Pool) InPool() int { + bp.mu.Lock() + defer bp.mu.Unlock() + return len(bp.cache) +} + +// starts or resets the buffer flusher timer - call with mu held func (bp *Pool) kickFlusher() { + if bp.flushPending { + return + } + bp.flushPending = true bp.timer.Reset(bp.flushTime) } +// Make sure minFill is correct - call with mu held +func (bp *Pool) updateMinFill() { + if len(bp.cache) < bp.minFill { + bp.minFill = len(bp.cache) + } +} + // Get a buffer from the pool or allocate one func (bp *Pool) Get() []byte { - select { - case b := <-bp.cache: - return b - default: + bp.mu.Lock() + var buf []byte + waitTime := time.Millisecond + for { + if len(bp.cache) > 0 { + buf = bp.get() + break + } else { + var err error + buf, err = bp.alloc(bp.bufferSize) + if err == nil { + break + } + log.Printf("Failed to get memory for buffer, waiting for %v: %v", waitTime, err) + bp.mu.Unlock() + time.Sleep(waitTime) + bp.mu.Lock() + waitTime *= 2 + } } - mem, err := bp.alloc(bp.bufferSize) - if err != nil { - log.Printf("Failed to get memory for buffer, waiting for a freed one: %v", err) - return <-bp.cache - } - atomic.AddInt32(&bp.inUse, 1) - return mem + bp.inUse++ + bp.updateMinFill() + bp.mu.Unlock() + return buf } // freeBuffer returns mem to the os if required @@ -93,8 +170,6 @@ func (bp *Pool) freeBuffer(mem []byte) { err := bp.free(mem) if err != nil { log.Printf("Failed to free memory: %v", err) - } else { - atomic.AddInt32(&bp.inUse, -1) } } @@ -102,17 +177,19 @@ func (bp *Pool) freeBuffer(mem []byte) { // // Note that if you try to return a buffer of the wrong size to Put it // will panic. -func (bp *Pool) Put(mem []byte) { - mem = mem[0:cap(mem)] - if len(mem) != bp.bufferSize { - panic(fmt.Sprintf("Returning buffer sized %d but expecting %d", len(mem), bp.bufferSize)) +func (bp *Pool) Put(buf []byte) { + bp.mu.Lock() + defer bp.mu.Unlock() + buf = buf[0:cap(buf)] + if len(buf) != bp.bufferSize { + panic(fmt.Sprintf("Returning buffer sized %d but expecting %d", len(buf), bp.bufferSize)) } - select { - case bp.cache <- mem: - bp.kickFlusher() - return - default: + if len(bp.cache) < bp.poolSize { + bp.put(buf) + } else { + bp.freeBuffer(buf) } - bp.freeBuffer(mem) - mem = nil + bp.inUse-- + bp.updateMinFill() + bp.kickFlusher() } diff --git a/lib/pool/pool_test.go b/lib/pool/pool_test.go index 7f9ceb4f9..c2e97281c 100644 --- a/lib/pool/pool_test.go +++ b/lib/pool/pool_test.go @@ -1,53 +1,95 @@ package pool import ( + "errors" + "fmt" + "math/rand" "testing" "time" "github.com/stretchr/testify/assert" ) -func testGetPut(t *testing.T, useMmap bool) { +// makes the allocations be unreliable +func makeUnreliable(bp *Pool) { + bp.alloc = func(size int) ([]byte, error) { + if rand.Intn(3) != 0 { + return nil, errors.New("failed to allocate memory") + } + return make([]byte, size), nil + } + bp.free = func(b []byte) error { + if rand.Intn(3) != 0 { + return errors.New("failed to free memory") + } + return nil + } +} + +func testGetPut(t *testing.T, useMmap bool, unreliable bool) { bp := New(60*time.Second, 4096, 2, useMmap) + if unreliable { + makeUnreliable(bp) + } assert.Equal(t, 0, bp.InUse()) b1 := bp.Get() assert.Equal(t, 1, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) b2 := bp.Get() assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) b3 := bp.Get() assert.Equal(t, 3, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) bp.Put(b1) - assert.Equal(t, 3, bp.InUse()) + assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 1, bp.InPool()) bp.Put(b2) - assert.Equal(t, 3, bp.InUse()) + assert.Equal(t, 1, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) bp.Put(b3) - assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) + addr := func(b []byte) string { + return fmt.Sprintf("%p", &b[0]) + } b1a := bp.Get() - assert.Equal(t, b1, b1a) - assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, addr(b2), addr(b1a)) + assert.Equal(t, 1, bp.InUse()) + assert.Equal(t, 1, bp.InPool()) b2a := bp.Get() - assert.Equal(t, b1, b2a) + assert.Equal(t, addr(b1), addr(b2a)) assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) bp.Put(b1a) bp.Put(b2a) - assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) + + assert.Panics(t, func() { + bp.Put(make([]byte, 1)) + }) bp.Flush() assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) } -func testFlusher(t *testing.T, useMmap bool) { +func testFlusher(t *testing.T, useMmap bool, unreliable bool) { bp := New(50*time.Millisecond, 4096, 2, useMmap) + if unreliable { + makeUnreliable(bp) + } b1 := bp.Get() b2 := bp.Get() @@ -55,38 +97,109 @@ func testFlusher(t *testing.T, useMmap bool) { bp.Put(b1) bp.Put(b2) bp.Put(b3) - assert.Equal(t, 2, bp.InUse()) + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 0, bp.minFill) + assert.Equal(t, true, bp.flushPending) + bp.mu.Unlock() - checkFlushHasHappened := func() { + checkFlushHasHappened := func(desired int) { var n int for i := 0; i < 10; i++ { time.Sleep(100 * time.Millisecond) - n = bp.InUse() - if n == 0 { + n = bp.InPool() + if n <= desired { break } } - assert.Equal(t, 0, n) + assert.Equal(t, desired, n) } - checkFlushHasHappened() + checkFlushHasHappened(0) + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 0, bp.minFill) + assert.Equal(t, false, bp.flushPending) + bp.mu.Unlock() + // Now do manual aging to check it is working properly + bp = New(100*time.Second, 4096, 2, useMmap) + + // Check the new one doesn't get flushed b1 = bp.Get() + b2 = bp.Get() bp.Put(b1) - assert.Equal(t, 1, bp.InUse()) + bp.Put(b2) - checkFlushHasHappened() + bp.mu.Lock() + assert.Equal(t, 0, bp.minFill) + assert.Equal(t, true, bp.flushPending) + bp.mu.Unlock() + + bp.flushAged() + + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 2, bp.minFill) + assert.Equal(t, true, bp.flushPending) + bp.mu.Unlock() + + bp.Put(bp.Get()) + + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 2, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 1, bp.minFill) + assert.Equal(t, true, bp.flushPending) + bp.mu.Unlock() + + bp.flushAged() + + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 1, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 1, bp.minFill) + assert.Equal(t, true, bp.flushPending) + bp.mu.Unlock() + + bp.flushAged() + + assert.Equal(t, 0, bp.InUse()) + assert.Equal(t, 0, bp.InPool()) + bp.mu.Lock() + assert.Equal(t, 0, bp.minFill) + assert.Equal(t, false, bp.flushPending) + bp.mu.Unlock() } func TestPool(t *testing.T) { - for _, useMmap := range []bool{false, true} { - name := "make" - if useMmap { - name = "mmap" - } - t.Run(name, func(t *testing.T) { - t.Run("GetPut", func(t *testing.T) { testGetPut(t, useMmap) }) - t.Run("Flusher", func(t *testing.T) { testFlusher(t, useMmap) }) + for _, test := range []struct { + name string + useMmap bool + unreliable bool + }{ + { + name: "make", + useMmap: false, + unreliable: false, + }, + { + name: "mmap", + useMmap: true, + unreliable: false, + }, + { + name: "canFail", + useMmap: false, + unreliable: true, + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Run("GetPut", func(t *testing.T) { testGetPut(t, test.useMmap, test.unreliable) }) + t.Run("Flusher", func(t *testing.T) { testFlusher(t, test.useMmap, test.unreliable) }) }) } }