diff --git a/act/artifactcache/caches.go b/act/artifactcache/caches.go index 6d499c8a..b4d9bd17 100644 --- a/act/artifactcache/caches.go +++ b/act/artifactcache/caches.go @@ -19,12 +19,13 @@ import ( //go:generate mockery --inpackage --name caches type caches interface { - openDB() (*bolthold.Store, error) + getDB() *bolthold.Store validateMac(rundata RunData) (string, error) readCache(id uint64, repo string) (*Cache, error) useCache(id uint64) error setgcAt(at time.Time) gcCache() + close() serve(w http.ResponseWriter, r *http.Request, id uint64) commit(id uint64, size int64) (int64, error) @@ -38,6 +39,8 @@ type cachesImpl struct { logger logrus.FieldLogger secret string + db *bolthold.Store + gcing atomic.Bool gcAt time.Time } @@ -68,12 +71,6 @@ func newCaches(dir, secret string, logger logrus.FieldLogger) (caches, error) { } c.storage = storage - c.gcCache() - - return c, nil -} - -func (c *cachesImpl) openDB() (*bolthold.Store, error) { file := filepath.Join(c.dir, "bolt.db") db, err := bolthold.Open(file, 0o644, &bolthold.Options{ Encoder: json.Marshal, @@ -87,7 +84,22 @@ func (c *cachesImpl) openDB() (*bolthold.Store, error) { if err != nil { return nil, fmt.Errorf("Open(%s): %w", file, err) } - return db, nil + c.db = db + + c.gcCache() + + return c, nil +} + +func (c *cachesImpl) close() { + if c.db != nil { + c.db.Close() + c.db = nil + } +} + +func (c *cachesImpl) getDB() *bolthold.Store { + return c.db } var findCacheWithIsolationKeyFallback = func(db *bolthold.Store, repo string, keys []string, version, writeIsolationKey string) (*Cache, error) { @@ -156,11 +168,7 @@ func insertCache(db *bolthold.Store, cache *Cache) error { } func (c *cachesImpl) readCache(id uint64, repo string) (*Cache, error) { - db, err := c.openDB() - if err != nil { - return nil, err - } - defer db.Close() + db := c.getDB() cache := &Cache{} if err := db.Get(id, cache); err != nil { return nil, fmt.Errorf("readCache: Get(%v): %w", id, err) @@ -173,11 +181,7 @@ func (c *cachesImpl) readCache(id uint64, repo string) (*Cache, error) { } func (c *cachesImpl) useCache(id uint64) error { - db, err := c.openDB() - if err != nil { - return err - } - defer db.Close() + db := c.getDB() cache := &Cache{} if err := db.Get(id, cache); err != nil { return fmt.Errorf("useCache: Get(%v): %w", id, err) @@ -232,12 +236,7 @@ func (c *cachesImpl) gcCache() { c.gcAt = time.Now() c.logger.Debugf("gc: %v", c.gcAt.String()) - db, err := c.openDB() - if err != nil { - fatal(c.logger, err) - return - } - defer db.Close() + db := c.getDB() // Remove the caches which are not completed for a while, they are most likely to be broken. var caches []*Cache diff --git a/act/artifactcache/caches_test.go b/act/artifactcache/caches_test.go index a08a9af7..34eab331 100644 --- a/act/artifactcache/caches_test.go +++ b/act/artifactcache/caches_test.go @@ -14,6 +14,7 @@ import ( func TestCacheReadWrite(t *testing.T) { caches, err := newCaches(t.TempDir(), "secret", logrus.New()) require.NoError(t, err) + defer caches.close() t.Run("NotFound", func(t *testing.T) { found, err := caches.readCache(456, "repo") assert.Nil(t, found) @@ -33,9 +34,7 @@ func TestCacheReadWrite(t *testing.T) { cache.Repo = repo t.Run("Insert", func(t *testing.T) { - db, err := caches.openDB() - require.NoError(t, err) - defer db.Close() + db := caches.getDB() assert.NoError(t, insertCache(db, cache)) }) diff --git a/act/artifactcache/handler.go b/act/artifactcache/handler.go index 0b574397..29ff61e3 100644 --- a/act/artifactcache/handler.go +++ b/act/artifactcache/handler.go @@ -122,6 +122,10 @@ func (h *handler) Close() error { return nil } var retErr error + if h.caches != nil { + h.caches.close() + h.caches = nil + } if h.server != nil { err := h.server.Close() if err != nil { @@ -151,6 +155,9 @@ func (h *handler) getCaches() caches { } func (h *handler) setCaches(caches caches) { + if h.caches != nil { + h.caches.close() + } h.caches = caches } @@ -170,12 +177,7 @@ func (h *handler) find(w http.ResponseWriter, r *http.Request, params httprouter } version := r.URL.Query().Get("version") - db, err := h.caches.openDB() - if err != nil { - h.responseFatalJSON(w, r, err) - return - } - defer db.Close() + db := h.caches.getDB() cache, err := findCacheWithIsolationKeyFallback(db, repo, keys, version, rundata.WriteIsolationKey) if err != nil { @@ -221,12 +223,7 @@ func (h *handler) reserve(w http.ResponseWriter, r *http.Request, params httprou api.Key = strings.ToLower(api.Key) cache := api.ToCache() - db, err := h.caches.openDB() - if err != nil { - h.responseFatalJSON(w, r, err) - return - } - defer db.Close() + db := h.caches.getDB() now := time.Now().Unix() cache.CreatedAt = now @@ -335,12 +332,7 @@ func (h *handler) commit(w http.ResponseWriter, r *http.Request, params httprout // write real size back to cache, it may be different from the current value when the request doesn't specify it. cache.Size = size - db, err := h.caches.openDB() - if err != nil { - h.responseFatalJSON(w, r, err) - return - } - defer db.Close() + db := h.caches.getDB() cache.Complete = true if err := db.Update(cache.ID, cache); err != nil { diff --git a/act/artifactcache/handler_test.go b/act/artifactcache/handler_test.go index 136f1a87..91754f77 100644 --- a/act/artifactcache/handler_test.go +++ b/act/artifactcache/handler_test.go @@ -78,9 +78,7 @@ func TestHandler(t *testing.T) { defer func() { t.Run("inspect db", func(t *testing.T) { - db, err := handler.getCaches().openDB() - require.NoError(t, err) - defer db.Close() + db := handler.getCaches().getDB() require.NoError(t, db.Bolt().View(func(tx *bbolt.Tx) error { return tx.Bucket([]byte("Cache")).ForEach(func(k, v []byte) error { t.Logf("%s: %s", k, v) @@ -937,40 +935,11 @@ func TestHandlerAPIFatalErrors(t *testing.T) { handler.find(w, req, nil) }, }, - { - name: "find open", - caches: func(t *testing.T, message string) caches { - caches := newMockCaches(t) - caches.On("validateMac", RunData{}).Return(cacheRepo, nil) - caches.On("openDB", mock.Anything, mock.Anything).Return(nil, errors.New(message)) - return caches - }, - call: func(t *testing.T, handler Handler, w http.ResponseWriter) { - req, err := http.NewRequest("GET", "example.com/cache", nil) - require.NoError(t, err) - handler.find(w, req, nil) - }, - }, - { - name: "reserve", - caches: func(t *testing.T, message string) caches { - caches := newMockCaches(t) - caches.On("validateMac", RunData{}).Return(cacheRepo, nil) - caches.On("openDB", mock.Anything, mock.Anything).Return(nil, errors.New(message)) - return caches - }, - call: func(t *testing.T, handler Handler, w http.ResponseWriter) { - body, err := json.Marshal(&Request{}) - require.NoError(t, err) - req, err := http.NewRequest("POST", "example.com/caches", bytes.NewReader(body)) - require.NoError(t, err) - handler.reserve(w, req, nil) - }, - }, { name: "upload", caches: func(t *testing.T, message string) caches { caches := newMockCaches(t) + caches.On("close").Return() caches.On("validateMac", RunData{}).Return(cacheRepo, nil) caches.On("readCache", mock.Anything, mock.Anything).Return(nil, errors.New(message)) return caches @@ -988,6 +957,7 @@ func TestHandlerAPIFatalErrors(t *testing.T) { name: "commit", caches: func(t *testing.T, message string) caches { caches := newMockCaches(t) + caches.On("close").Return() caches.On("validateMac", RunData{}).Return(cacheRepo, nil) caches.On("readCache", mock.Anything, mock.Anything).Return(nil, errors.New(message)) return caches @@ -1005,6 +975,7 @@ func TestHandlerAPIFatalErrors(t *testing.T) { name: "get", caches: func(t *testing.T, message string) caches { caches := newMockCaches(t) + caches.On("close").Return() caches.On("validateMac", RunData{}).Return(cacheRepo, nil) caches.On("readCache", mock.Anything, mock.Anything).Return(nil, errors.New(message)) return caches @@ -1042,10 +1013,12 @@ func TestHandlerAPIFatalErrors(t *testing.T) { dir := filepath.Join(t.TempDir(), "artifactcache") handler, err := StartHandler(dir, "", 0, "secret", nil) require.NoError(t, err) + defer handler.Close() fatalMessage = "" - handler.setCaches(testCase.caches(t, message)) + caches := testCase.caches(t, message) // doesn't need to be closed because it will be given to handler + handler.setCaches(caches) w := httptest.NewRecorder() testCase.call(t, handler, w) @@ -1138,18 +1111,15 @@ func TestHandler_gcCache(t *testing.T) { }, } - db, err := handler.getCaches().openDB() - require.NoError(t, err) + db := handler.getCaches().getDB() for _, c := range cases { require.NoError(t, insertCache(db, c.Cache)) } - require.NoError(t, db.Close()) handler.getCaches().setgcAt(time.Time{}) // ensure gcCache will not skip handler.getCaches().gcCache() - db, err = handler.getCaches().openDB() - require.NoError(t, err) + db = handler.getCaches().getDB() for i, v := range cases { t.Run(fmt.Sprintf("%d_%s", i, v.Cache.Key), func(t *testing.T) { cache := &Cache{} @@ -1161,7 +1131,6 @@ func TestHandler_gcCache(t *testing.T) { } }) } - require.NoError(t, db.Close()) } func TestHandler_ExternalURL(t *testing.T) { diff --git a/act/artifactcache/mock_caches.go b/act/artifactcache/mock_caches.go index 9d484f80..cadf0b95 100644 --- a/act/artifactcache/mock_caches.go +++ b/act/artifactcache/mock_caches.go @@ -19,6 +19,11 @@ type mockCaches struct { mock.Mock } +// close provides a mock function with no fields +func (_m *mockCaches) close() { + _m.Called() +} + // commit provides a mock function with given fields: id, size func (_m *mockCaches) commit(id uint64, size int64) (int64, error) { ret := _m.Called(id, size) @@ -80,19 +85,15 @@ func (_m *mockCaches) gcCache() { _m.Called() } -// openDB provides a mock function with no fields -func (_m *mockCaches) openDB() (*bolthold.Store, error) { +// getDB provides a mock function with no fields +func (_m *mockCaches) getDB() *bolthold.Store { ret := _m.Called() if len(ret) == 0 { - panic("no return value specified for openDB") + panic("no return value specified for getDB") } var r0 *bolthold.Store - var r1 error - if rf, ok := ret.Get(0).(func() (*bolthold.Store, error)); ok { - return rf() - } if rf, ok := ret.Get(0).(func() *bolthold.Store); ok { r0 = rf() } else { @@ -101,13 +102,7 @@ func (_m *mockCaches) openDB() (*bolthold.Store, error) { } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } // readCache provides a mock function with given fields: id, repo