diff --git a/fs/cache/cache.go b/fs/cache/cache.go index 61706b423..549833570 100644 --- a/fs/cache/cache.go +++ b/fs/cache/cache.go @@ -2,93 +2,39 @@ package cache import ( - "sync" - "time" - "github.com/rclone/rclone/fs" + "github.com/rclone/rclone/lib/cache" ) var ( - fsCacheMu sync.Mutex - fsCache = map[string]*cacheEntry{} - fsNewFs = fs.NewFs // for tests - expireRunning = false - cacheExpireDuration = 300 * time.Second // expire the cache entry when it is older than this - cacheExpireInterval = 60 * time.Second // interval to run the cache expire + c = cache.New() ) -type cacheEntry struct { - f fs.Fs // cached f - err error // nil or fs.ErrorIsFile - fsString string // remote string - lastUsed time.Time // time used for expiry +// GetFn gets a fs.Fs named fsString either from the cache or creates +// it afresh with the create function +func GetFn(fsString string, create func(fsString string) (fs.Fs, error)) (f fs.Fs, err error) { + value, err := c.Get(fsString, func(fsString string) (value interface{}, ok bool, error error) { + f, err := create(fsString) + ok = err == nil || err == fs.ErrorIsFile + return f, ok, err + }) + if err != nil { + return nil, err + } + return value.(fs.Fs), nil } // Get gets a fs.Fs named fsString either from the cache or creates it afresh func Get(fsString string) (f fs.Fs, err error) { - fsCacheMu.Lock() - entry, ok := fsCache[fsString] - if !ok { - fsCacheMu.Unlock() // Unlock in case Get is called recursively - f, err = fsNewFs(fsString) - if err != nil && err != fs.ErrorIsFile { - return f, err - } - entry = &cacheEntry{ - f: f, - fsString: fsString, - err: err, - } - fsCacheMu.Lock() - fsCache[fsString] = entry - } - defer fsCacheMu.Unlock() - entry.lastUsed = time.Now() - if !expireRunning { - time.AfterFunc(cacheExpireInterval, cacheExpire) - expireRunning = true - } - return entry.f, entry.err + return GetFn(fsString, fs.NewFs) } // Put puts an fs.Fs named fsString into the cache func Put(fsString string, f fs.Fs) { - fsCacheMu.Lock() - defer fsCacheMu.Unlock() - fsCache[fsString] = &cacheEntry{ - f: f, - fsString: fsString, - lastUsed: time.Now(), - } - if !expireRunning { - time.AfterFunc(cacheExpireInterval, cacheExpire) - expireRunning = true - } -} - -// cacheExpire expires any entries that haven't been used recently -func cacheExpire() { - fsCacheMu.Lock() - defer fsCacheMu.Unlock() - now := time.Now() - for fsString, entry := range fsCache { - if now.Sub(entry.lastUsed) > cacheExpireDuration { - delete(fsCache, fsString) - } - } - if len(fsCache) != 0 { - time.AfterFunc(cacheExpireInterval, cacheExpire) - expireRunning = true - } else { - expireRunning = false - } + c.Put(fsString, f) } // Clear removes everything from the cahce func Clear() { - fsCacheMu.Lock() - for k := range fsCache { - delete(fsCache, k) - } - fsCacheMu.Unlock() + c.Clear() } diff --git a/fs/cache/cache_test.go b/fs/cache/cache_test.go index 65de2a817..84c68e0a3 100644 --- a/fs/cache/cache_test.go +++ b/fs/cache/cache_test.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "testing" - "time" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fstest/mockfs" @@ -17,10 +16,9 @@ var ( errSentinel = errors.New("an error") ) -func mockNewFs(t *testing.T) func() { +func mockNewFs(t *testing.T) (func(), func(path string) (fs.Fs, error)) { called = 0 - oldFsNewFs := fsNewFs - fsNewFs = func(path string) (fs.Fs, error) { + create := func(path string) (fs.Fs, error) { assert.Equal(t, 0, called) called++ switch path { @@ -33,115 +31,74 @@ func mockNewFs(t *testing.T) func() { } panic(fmt.Sprintf("Unknown path %q", path)) } - return func() { - fsNewFs = oldFsNewFs - fsCacheMu.Lock() - fsCache = map[string]*cacheEntry{} - expireRunning = false - fsCacheMu.Unlock() + cleanup := func() { + c.Clear() } + return cleanup, create } func TestGet(t *testing.T) { - defer mockNewFs(t)() + cleanup, create := mockNewFs(t) + defer cleanup() - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 0, c.Entries()) - f, err := Get("/") + f, err := GetFn("/", create) require.NoError(t, err) - assert.Equal(t, 1, len(fsCache)) + assert.Equal(t, 1, c.Entries()) - f2, err := Get("/") + f2, err := GetFn("/", create) require.NoError(t, err) assert.Equal(t, f, f2) } func TestGetFile(t *testing.T) { - defer mockNewFs(t)() + cleanup, create := mockNewFs(t) + defer cleanup() - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 0, c.Entries()) - f, err := Get("/file.txt") + f, err := GetFn("/file.txt", create) require.Equal(t, fs.ErrorIsFile, err) - assert.Equal(t, 1, len(fsCache)) + assert.Equal(t, 1, c.Entries()) - f2, err := Get("/file.txt") + f2, err := GetFn("/file.txt", create) require.Equal(t, fs.ErrorIsFile, err) assert.Equal(t, f, f2) } func TestGetError(t *testing.T) { - defer mockNewFs(t)() + cleanup, create := mockNewFs(t) + defer cleanup() - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 0, c.Entries()) - f, err := Get("/error") + f, err := GetFn("/error", create) require.Equal(t, errSentinel, err) require.Equal(t, nil, f) - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 0, c.Entries()) } func TestPut(t *testing.T) { - defer mockNewFs(t)() + cleanup, create := mockNewFs(t) + defer cleanup() f := mockfs.NewFs("mock", "mock") - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 0, c.Entries()) Put("/alien", f) - assert.Equal(t, 1, len(fsCache)) + assert.Equal(t, 1, c.Entries()) - fNew, err := Get("/alien") + fNew, err := GetFn("/alien", create) require.NoError(t, err) require.Equal(t, f, fNew) - assert.Equal(t, 1, len(fsCache)) -} - -func TestCacheExpire(t *testing.T) { - defer mockNewFs(t)() - - cacheExpireInterval = time.Millisecond - assert.Equal(t, false, expireRunning) - - _, err := Get("/") - require.NoError(t, err) - - fsCacheMu.Lock() - entry := fsCache["/"] - - assert.Equal(t, 1, len(fsCache)) - fsCacheMu.Unlock() - cacheExpire() - fsCacheMu.Lock() - assert.Equal(t, 1, len(fsCache)) - entry.lastUsed = time.Now().Add(-cacheExpireDuration - 60*time.Second) - assert.Equal(t, true, expireRunning) - fsCacheMu.Unlock() - time.Sleep(10 * time.Millisecond) - fsCacheMu.Lock() - assert.Equal(t, false, expireRunning) - assert.Equal(t, 0, len(fsCache)) - fsCacheMu.Unlock() -} - -func TestClear(t *testing.T) { - defer mockNewFs(t)() - - assert.Equal(t, 0, len(fsCache)) - - _, err := Get("/") - require.NoError(t, err) - - assert.Equal(t, 1, len(fsCache)) - - Clear() - - assert.Equal(t, 0, len(fsCache)) + assert.Equal(t, 1, c.Entries()) } diff --git a/lib/cache/cache.go b/lib/cache/cache.go new file mode 100644 index 000000000..59d853af4 --- /dev/null +++ b/lib/cache/cache.go @@ -0,0 +1,134 @@ +// Package cache implements a simple cache where the entries are +// expired after a given time (5 minutes of disuse by default). +package cache + +import ( + "sync" + "time" +) + +// Cache holds values indexed by string, but expired after a given (5 +// minutes by default). +type Cache struct { + mu sync.Mutex + cache map[string]*cacheEntry + expireRunning bool + expireDuration time.Duration // expire the cache entry when it is older than this + expireInterval time.Duration // interval to run the cache expire +} + +// New creates a new cache with the default expire duration and interval +func New() *Cache { + return &Cache{ + cache: map[string]*cacheEntry{}, + expireRunning: false, + expireDuration: 300 * time.Second, + expireInterval: 60 * time.Second, + } +} + +// cacheEntry is stored in the cache +type cacheEntry struct { + value interface{} // cached item + err error // creation error + key string // key + lastUsed time.Time // time used for expiry +} + +// CreateFunc is called to create new values. If the create function +// returns an error it will be cached if ok is true, otherwise the +// error will just be returned, allowing negative caching if required. +type CreateFunc func(key string) (value interface{}, ok bool, error error) + +// used marks an entry as accessed now and kicks the expire timer off +// should be called with the lock held +func (c *Cache) used(entry *cacheEntry) { + entry.lastUsed = time.Now() + if !c.expireRunning { + time.AfterFunc(c.expireInterval, c.cacheExpire) + c.expireRunning = true + } +} + +// Get gets a value named key either from the cache or creates it +// afresh with the create function. +func (c *Cache) Get(key string, create CreateFunc) (value interface{}, err error) { + c.mu.Lock() + entry, ok := c.cache[key] + if !ok { + c.mu.Unlock() // Unlock in case Get is called recursively + value, ok, err = create(key) + if err != nil && !ok { + return value, err + } + entry = &cacheEntry{ + value: value, + key: key, + err: err, + } + c.mu.Lock() + c.cache[key] = entry + } + defer c.mu.Unlock() + c.used(entry) + return entry.value, entry.err +} + +// Put puts an value named key into the cache +func (c *Cache) Put(key string, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + entry := &cacheEntry{ + value: value, + key: key, + } + c.used(entry) + c.cache[key] = entry +} + +// GetMaybe returns the key and true if found, nil and false if not +func (c *Cache) GetMaybe(key string) (value interface{}, found bool) { + c.mu.Lock() + defer c.mu.Unlock() + entry, found := c.cache[key] + if !found { + return nil, found + } + c.used(entry) + return entry.value, found +} + +// cacheExpire expires any entries that haven't been used recently +func (c *Cache) cacheExpire() { + c.mu.Lock() + defer c.mu.Unlock() + now := time.Now() + for key, entry := range c.cache { + if now.Sub(entry.lastUsed) > c.expireDuration { + delete(c.cache, key) + } + } + if len(c.cache) != 0 { + time.AfterFunc(c.expireInterval, c.cacheExpire) + c.expireRunning = true + } else { + c.expireRunning = false + } +} + +// Clear removes everything from the cahce +func (c *Cache) Clear() { + c.mu.Lock() + for k := range c.cache { + delete(c.cache, k) + } + c.mu.Unlock() +} + +// Entries returns the number of entries in the cache +func (c *Cache) Entries() int { + c.mu.Lock() + entries := len(c.cache) + c.mu.Unlock() + return entries +} diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go new file mode 100644 index 000000000..8c098f7f9 --- /dev/null +++ b/lib/cache/cache_test.go @@ -0,0 +1,174 @@ +package cache + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + called = 0 + errSentinel = errors.New("an error") + errCached = errors.New("a cached error") +) + +func setup(t *testing.T) (*Cache, CreateFunc) { + called = 0 + create := func(path string) (interface{}, bool, error) { + assert.Equal(t, 0, called) + called++ + switch path { + case "/": + return "/", true, nil + case "/file.txt": + return "/file.txt", true, errCached + case "/error": + return nil, false, errSentinel + } + panic(fmt.Sprintf("Unknown path %q", path)) + } + c := New() + return c, create +} + +func TestGet(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, len(c.cache)) + + f, err := c.Get("/", create) + require.NoError(t, err) + + assert.Equal(t, 1, len(c.cache)) + + f2, err := c.Get("/", create) + require.NoError(t, err) + + assert.Equal(t, f, f2) +} + +func TestGetFile(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, len(c.cache)) + + f, err := c.Get("/file.txt", create) + require.Equal(t, errCached, err) + + assert.Equal(t, 1, len(c.cache)) + + f2, err := c.Get("/file.txt", create) + require.Equal(t, errCached, err) + + assert.Equal(t, f, f2) +} + +func TestGetError(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, len(c.cache)) + + f, err := c.Get("/error", create) + require.Equal(t, errSentinel, err) + require.Equal(t, nil, f) + + assert.Equal(t, 0, len(c.cache)) +} + +func TestPut(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, len(c.cache)) + + c.Put("/alien", "slime") + + assert.Equal(t, 1, len(c.cache)) + + fNew, err := c.Get("/alien", create) + require.NoError(t, err) + require.Equal(t, "slime", fNew) + + assert.Equal(t, 1, len(c.cache)) +} + +func TestCacheExpire(t *testing.T) { + c, create := setup(t) + + c.expireInterval = time.Millisecond + assert.Equal(t, false, c.expireRunning) + + _, err := c.Get("/", create) + require.NoError(t, err) + + c.mu.Lock() + entry := c.cache["/"] + + assert.Equal(t, 1, len(c.cache)) + c.mu.Unlock() + c.cacheExpire() + c.mu.Lock() + assert.Equal(t, 1, len(c.cache)) + entry.lastUsed = time.Now().Add(-c.expireDuration - 60*time.Second) + assert.Equal(t, true, c.expireRunning) + c.mu.Unlock() + time.Sleep(10 * time.Millisecond) + c.mu.Lock() + assert.Equal(t, false, c.expireRunning) + assert.Equal(t, 0, len(c.cache)) + c.mu.Unlock() +} + +func TestClear(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, len(c.cache)) + + _, err := c.Get("/", create) + require.NoError(t, err) + + assert.Equal(t, 1, len(c.cache)) + + c.Clear() + + assert.Equal(t, 0, len(c.cache)) +} + +func TestEntries(t *testing.T) { + c, create := setup(t) + + assert.Equal(t, 0, c.Entries()) + + _, err := c.Get("/", create) + require.NoError(t, err) + + assert.Equal(t, 1, c.Entries()) + + c.Clear() + + assert.Equal(t, 0, c.Entries()) +} + +func TestGetMaybe(t *testing.T) { + c, create := setup(t) + + value, found := c.GetMaybe("/") + assert.Equal(t, false, found) + assert.Nil(t, value) + + f, err := c.Get("/", create) + require.NoError(t, err) + + value, found = c.GetMaybe("/") + assert.Equal(t, true, found) + assert.Equal(t, f, value) + + c.Clear() + + value, found = c.GetMaybe("/") + assert.Equal(t, false, found) + assert.Nil(t, value) +}