diff --git a/act/artifactcache/handler.go b/act/artifactcache/handler.go index 252564f4..2da6febd 100644 --- a/act/artifactcache/handler.go +++ b/act/artifactcache/handler.go @@ -20,6 +20,7 @@ import ( "github.com/timshannon/bolthold" "go.etcd.io/bbolt" + "github.com/nektos/act/pkg/cacheproxy" "github.com/nektos/act/pkg/common" ) @@ -34,6 +35,7 @@ type Handler struct { listener net.Listener server *http.Server logger logrus.FieldLogger + secret string gcing atomic.Bool gcAt time.Time @@ -41,8 +43,10 @@ type Handler struct { outboundIP string } -func StartHandler(dir, outboundIP string, port uint16, logger logrus.FieldLogger) (*Handler, error) { - h := &Handler{} +func StartHandler(dir, outboundIP string, port uint16, secret string, logger logrus.FieldLogger) (*Handler, error) { + h := &Handler{ + secret: secret, + } if logger == nil { discard := logrus.New() @@ -155,7 +159,14 @@ func (h *Handler) openDB() (*bolthold.Store, error) { } // GET /_apis/artifactcache/cache -func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *Handler) find(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + rundata := runDataFromHeaders(r) + repo, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } + keys := strings.Split(r.URL.Query().Get("keys"), ",") // cache keys are case insensitive for i, key := range keys { @@ -170,7 +181,7 @@ func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Para } defer db.Close() - cache, err := findCache(db, keys, version) + cache, err := findCache(db, repo, keys, version) if err != nil { h.responseJSON(w, r, 500, err) return @@ -188,15 +199,23 @@ func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Para h.responseJSON(w, r, 204) return } + archiveLocation := fmt.Sprintf("%s/%s%s/artifacts/%d", r.Header.Get("Forgejo-Cache-Host"), r.Header.Get("Forgejo-Cache-RunId"), urlBase, cache.ID) h.responseJSON(w, r, 200, map[string]any{ "result": "hit", - "archiveLocation": fmt.Sprintf("%s%s/artifacts/%d", h.ExternalURL(), urlBase, cache.ID), + "archiveLocation": archiveLocation, "cacheKey": cache.Key, }) } // POST /_apis/artifactcache/caches -func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { +func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + rundata := runDataFromHeaders(r) + repo, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } + api := &Request{} if err := json.NewDecoder(r.Body).Decode(api); err != nil { h.responseJSON(w, r, 400, err) @@ -216,6 +235,7 @@ func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.P now := time.Now().Unix() cache.CreatedAt = now cache.UsedAt = now + cache.Repo = repo if err := insertCache(db, cache); err != nil { h.responseJSON(w, r, 500, err) return @@ -227,6 +247,13 @@ func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.P // PATCH /_apis/artifactcache/caches/:id func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + rundata := runDataFromHeaders(r) + repo, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } + id, err := strconv.ParseUint(params.ByName("id"), 10, 64) if err != nil { h.responseJSON(w, r, 400, err) @@ -249,11 +276,17 @@ func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprout return } + // Should not happen + if cache.Repo != repo { + h.responseJSON(w, r, 500, ErrValidation) + return + } + if cache.Complete { h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) return } - db.Close() + defer db.Close() start, _, err := parseContentRange(r.Header.Get("Content-Range")) if err != nil { h.responseJSON(w, r, 400, err) @@ -262,12 +295,19 @@ func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprout if err := h.storage.Write(cache.ID, start, r.Body); err != nil { h.responseJSON(w, r, 500, err) } - h.useCache(id) + h.useCache(db, cache) h.responseJSON(w, r, 200) } // POST /_apis/artifactcache/caches/:id func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + rundata := runDataFromHeaders(r) + repo, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } + id, err := strconv.ParseUint(params.ByName("id"), 10, 64) if err != nil { h.responseJSON(w, r, 400, err) @@ -290,6 +330,12 @@ func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprout return } + // Should not happen + if cache.Repo != repo { + h.responseJSON(w, r, 500, ErrValidation) + return + } + if cache.Complete { h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) return @@ -323,17 +369,53 @@ func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprout // GET /_apis/artifactcache/artifacts/:id func (h *Handler) get(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + rundata := runDataFromHeaders(r) + repo, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } + id, err := strconv.ParseUint(params.ByName("id"), 10, 64) if err != nil { h.responseJSON(w, r, 400, err) return } - h.useCache(id) + + cache := &Cache{} + db, err := h.openDB() + if err != nil { + h.responseJSON(w, r, 500, err) + return + } + defer db.Close() + if err := db.Get(id, cache); err != nil { + if errors.Is(err, bolthold.ErrNotFound) { + h.responseJSON(w, r, 404, fmt.Errorf("cache %d: not reserved", id)) + return + } + h.responseJSON(w, r, 500, err) + return + } + + // Should not happen + if cache.Repo != repo { + h.responseJSON(w, r, 500, ErrValidation) + return + } + + h.useCache(db, cache) h.storage.Serve(w, r, id) } // POST /_apis/artifactcache/clean func (h *Handler) clean(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + rundata := runDataFromHeaders(r) + _, err := h.validateMac(rundata) + if err != nil { + h.responseJSON(w, r, 403, err) + return + } // TODO: don't support force deleting cache entries // see: https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries @@ -349,12 +431,13 @@ func (h *Handler) middleware(handler httprouter.Handle) httprouter.Handle { } // if not found, return (nil, nil) instead of an error. -func findCache(db *bolthold.Store, keys []string, version string) (*Cache, error) { +func findCache(db *bolthold.Store, repo string, keys []string, version string) (*Cache, error) { cache := &Cache{} for _, prefix := range keys { // if a key in the list matches exactly, don't return partial matches if err := db.FindOne(cache, - bolthold.Where("Key").Eq(prefix). + bolthold.Where("Repo").Eq(repo). + And("Key").Eq(prefix). And("Version").Eq(version). And("Complete").Eq(true). SortBy("CreatedAt").Reverse()); err == nil || !errors.Is(err, bolthold.ErrNotFound) { @@ -369,7 +452,8 @@ func findCache(db *bolthold.Store, keys []string, version string) (*Cache, error continue } if err := db.FindOne(cache, - bolthold.Where("Key").RegExp(re). + bolthold.Where("Repo").Eq(repo). + And("Key").RegExp(re). And("Version").Eq(version). And("Complete").Eq(true). SortBy("CreatedAt").Reverse()); err != nil { @@ -394,16 +478,7 @@ func insertCache(db *bolthold.Store, cache *Cache) error { return nil } -func (h *Handler) useCache(id uint64) { - db, err := h.openDB() - if err != nil { - return - } - defer db.Close() - cache := &Cache{} - if err := db.Get(id, cache); err != nil { - return - } +func (h *Handler) useCache(db *bolthold.Store, cache *Cache) { cache.UsedAt = time.Now().Unix() _ = db.Update(cache.ID, cache) } @@ -554,3 +629,12 @@ func parseContentRange(s string) (uint64, uint64, error) { } return start, stop, nil } + +func runDataFromHeaders(r *http.Request) cacheproxy.RunData { + return cacheproxy.RunData{ + RepositoryFullName: r.Header.Get("Forgejo-Cache-Repo"), + RunNumber: r.Header.Get("Forgejo-Cache-RunNumber"), + Timestamp: r.Header.Get("Forgejo-Cache-Timestamp"), + RepositoryMAC: r.Header.Get("Forgejo-Cache-MAC"), + } +} diff --git a/act/artifactcache/handler_test.go b/act/artifactcache/handler_test.go index 252c4209..ebfdad77 100644 --- a/act/artifactcache/handler_test.go +++ b/act/artifactcache/handler_test.go @@ -18,11 +18,35 @@ import ( "go.etcd.io/bbolt" ) +const cache_repo = "testuser/repo" +const cache_runnum = "1" +const cache_timestamp = "0" +const cache_mac = "c13854dd1ac599d1d61680cd93c26b77ba0ee10f374a3408bcaea82f38ca1865" + +var handlerExternalUrl string + +type AuthHeaderTransport struct { + T http.RoundTripper +} + +func (t *AuthHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Forgejo-Cache-Repo", cache_repo) + req.Header.Set("Forgejo-Cache-RunNumber", cache_runnum) + req.Header.Set("Forgejo-Cache-Timestamp", cache_timestamp) + req.Header.Set("Forgejo-Cache-MAC", cache_mac) + req.Header.Set("Forgejo-Cache-Host", handlerExternalUrl) + return t.T.RoundTrip(req) +} + +var httpClientTransport = AuthHeaderTransport{http.DefaultTransport} +var httpClient = http.Client{Transport: &httpClientTransport} + func TestHandler(t *testing.T) { dir := filepath.Join(t.TempDir(), "artifactcache") - handler, err := StartHandler(dir, "", 0, nil) + handler, err := StartHandler(dir, "", 0, "secret", nil) require.NoError(t, err) + handlerExternalUrl = handler.ExternalURL() base := fmt.Sprintf("%s%s", handler.ExternalURL(), urlBase) defer func() { @@ -41,7 +65,7 @@ func TestHandler(t *testing.T) { require.NoError(t, handler.Close()) assert.Nil(t, handler.server) assert.Nil(t, handler.listener) - _, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 1), "", nil) + _, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, 1), "", nil) assert.Error(t, err) }) }() @@ -49,7 +73,7 @@ func TestHandler(t *testing.T) { t.Run("get not exist", func(t *testing.T) { key := strings.ToLower(t.Name()) version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) require.NoError(t, err) require.Equal(t, 204, resp.StatusCode) }) @@ -64,7 +88,7 @@ func TestHandler(t *testing.T) { }) t.Run("clean", func(t *testing.T) { - resp, err := http.Post(fmt.Sprintf("%s/clean", base), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/clean", base), "", nil) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) }) @@ -72,7 +96,7 @@ func TestHandler(t *testing.T) { t.Run("reserve with bad request", func(t *testing.T) { body := []byte(`invalid json`) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) }) @@ -90,7 +114,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -104,7 +128,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -121,7 +145,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) }) @@ -132,7 +156,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) }) @@ -151,7 +175,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -167,12 +191,12 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } @@ -182,7 +206,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -202,7 +226,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -218,7 +242,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes xx-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -226,7 +250,7 @@ func TestHandler(t *testing.T) { t.Run("commit with bad id", func(t *testing.T) { { - resp, err := http.Post(fmt.Sprintf("%s/caches/invalid_id", base), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/invalid_id", base), "", nil) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -234,7 +258,7 @@ func TestHandler(t *testing.T) { t.Run("commit with not exist id", func(t *testing.T) { { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 100), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, 100), "", nil) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -254,7 +278,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -270,17 +294,17 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) } @@ -300,7 +324,7 @@ func TestHandler(t *testing.T) { Size: 100, }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -316,35 +340,62 @@ func TestHandler(t *testing.T) { require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-59/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 500, resp.StatusCode) } }) t.Run("get with bad id", func(t *testing.T) { - resp, err := http.Get(fmt.Sprintf("%s/artifacts/invalid_id", base)) + resp, err := httpClient.Get(fmt.Sprintf("%s/artifacts/invalid_id", base)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) }) t.Run("get with not exist id", func(t *testing.T) { - resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) + resp, err := httpClient.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) }) t.Run("get with not exist id", func(t *testing.T) { - resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) + resp, err := httpClient.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) }) + t.Run("get with bad MAC", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46b4ee758284e26bb3045ad11d9d20" + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + + uploadCacheNormally(t, base, key, version, content) + + // Perform the request with the custom `httpClient` which will send correct MAC data + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + // Perform the same request with incorrect MAC data + req, err := http.NewRequest("GET", fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version), nil) + require.NoError(t, err) + req.Header.Set("Forgejo-Cache-Repo", cache_repo) + req.Header.Set("Forgejo-Cache-RunNumber", cache_runnum) + req.Header.Set("Forgejo-Cache-Timestamp", cache_timestamp) + req.Header.Set("Forgejo-Cache-MAC", "33f0e850ba0bdfd2f3e66ff79c1f8004b8226114e3b2e65c229222bb59df0f9d") // ! This is not the correct MAC + req.Header.Set("Forgejo-Cache-Host", handlerExternalUrl) + resp, err = http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 403, resp.StatusCode) + }) + t.Run("get with multiple keys", func(t *testing.T) { version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" key := strings.ToLower(t.Name()) @@ -371,7 +422,7 @@ func TestHandler(t *testing.T) { key + "_a", }, ",") - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) @@ -391,7 +442,7 @@ func TestHandler(t *testing.T) { assert.Equal(t, "hit", got.Result) assert.Equal(t, keys[except], got.CacheKey) - contentResp, err := http.Get(got.ArchiveLocation) + contentResp, err := httpClient.Get(got.ArchiveLocation) require.NoError(t, err) require.Equal(t, 200, contentResp.StatusCode) content, err := io.ReadAll(contentResp.Body) @@ -409,7 +460,7 @@ func TestHandler(t *testing.T) { { reqKey := key + "_aBc" - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKey, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKey, version)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) got := struct { @@ -448,7 +499,7 @@ func TestHandler(t *testing.T) { key + "_a_b", }, ",") - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) @@ -466,7 +517,7 @@ func TestHandler(t *testing.T) { require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) assert.Equal(t, keys[expect], got.CacheKey) - contentResp, err := http.Get(got.ArchiveLocation) + contentResp, err := httpClient.Get(got.ArchiveLocation) require.NoError(t, err) require.Equal(t, 200, contentResp.StatusCode) content, err := io.ReadAll(contentResp.Body) @@ -500,7 +551,7 @@ func TestHandler(t *testing.T) { key + "_a_b", }, ",") - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) @@ -519,7 +570,7 @@ func TestHandler(t *testing.T) { require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) assert.Equal(t, keys[expect], got.CacheKey) - contentResp, err := http.Get(got.ArchiveLocation) + contentResp, err := httpClient.Get(got.ArchiveLocation) require.NoError(t, err) require.Equal(t, 200, contentResp.StatusCode) content, err := io.ReadAll(contentResp.Body) @@ -537,7 +588,7 @@ func uploadCacheNormally(t *testing.T, base, key, version string, content []byte Size: int64(len(content)), }) require.NoError(t, err) - resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -553,18 +604,18 @@ func uploadCacheNormally(t *testing.T, base, key, version string, content []byte require.NoError(t, err) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Range", "bytes 0-99/*") - resp, err := http.DefaultClient.Do(req) + resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } { - resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) } var archiveLocation string { - resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) got := struct { @@ -578,7 +629,7 @@ func uploadCacheNormally(t *testing.T, base, key, version string, content []byte archiveLocation = got.ArchiveLocation } { - resp, err := http.Get(archiveLocation) //nolint:gosec + resp, err := httpClient.Get(archiveLocation) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) got, err := io.ReadAll(resp.Body) @@ -589,7 +640,7 @@ func uploadCacheNormally(t *testing.T, base, key, version string, content []byte func TestHandler_gcCache(t *testing.T) { dir := filepath.Join(t.TempDir(), "artifactcache") - handler, err := StartHandler(dir, "", 0, nil) + handler, err := StartHandler(dir, "", 0, "", nil) require.NoError(t, err) defer func() { diff --git a/act/artifactcache/mac.go b/act/artifactcache/mac.go new file mode 100644 index 00000000..16f46c26 --- /dev/null +++ b/act/artifactcache/mac.go @@ -0,0 +1,53 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package artifactcache + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "errors" + "strconv" + "time" + + "github.com/nektos/act/pkg/cacheproxy" +) + +var ( + ErrValidation = errors.New("validation error") +) + +func (h *Handler) validateMac(rundata cacheproxy.RunData) (string, error) { + // TODO: allow configurable max age + if !validateAge(rundata.Timestamp) { + return "", ErrValidation + } + + expectedMAC := computeMac(h.secret, rundata.RepositoryFullName, rundata.RunNumber, rundata.Timestamp) + if hmac.Equal([]byte(expectedMAC), []byte(rundata.RepositoryMAC)) { + return rundata.RepositoryFullName, nil + } + return "", ErrValidation +} + +func validateAge(ts string) bool { + tsInt, err := strconv.ParseInt(ts, 10, 64) + if err != nil { + return false + } + if tsInt > time.Now().Unix() { + return false + } + return true +} + +func computeMac(secret, repo, run, ts string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(repo)) + mac.Write([]byte(">")) + mac.Write([]byte(run)) + mac.Write([]byte(">")) + mac.Write([]byte(ts)) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/act/artifactcache/mac_test.go b/act/artifactcache/mac_test.go new file mode 100644 index 00000000..b59280c7 --- /dev/null +++ b/act/artifactcache/mac_test.go @@ -0,0 +1,80 @@ +package artifactcache + +import ( + "strconv" + "testing" + "time" + + "github.com/nektos/act/pkg/cacheproxy" + "github.com/stretchr/testify/require" +) + +func TestMac(t *testing.T) { + handler := &Handler{ + secret: "secret for testing", + } + + t.Run("validate correct mac", func(t *testing.T) { + name := "org/reponame" + run := "1" + ts := strconv.FormatInt(time.Now().Unix(), 10) + + mac := computeMac(handler.secret, name, run, ts) + rundata := cacheproxy.RunData{ + RepositoryFullName: name, + RunNumber: run, + Timestamp: ts, + RepositoryMAC: mac, + } + + repoName, err := handler.validateMac(rundata) + require.NoError(t, err) + require.Equal(t, name, repoName) + }) + + t.Run("validate incorrect timestamp", func(t *testing.T) { + name := "org/reponame" + run := "1" + ts := "9223372036854775807" // This should last us for a while... + + mac := computeMac(handler.secret, name, run, ts) + rundata := cacheproxy.RunData{ + RepositoryFullName: name, + RunNumber: run, + Timestamp: ts, + RepositoryMAC: mac, + } + + _, err := handler.validateMac(rundata) + require.Error(t, err) + }) + + t.Run("validate incorrect mac", func(t *testing.T) { + name := "org/reponame" + run := "1" + ts := strconv.FormatInt(time.Now().Unix(), 10) + + rundata := cacheproxy.RunData{ + RepositoryFullName: name, + RunNumber: run, + Timestamp: ts, + RepositoryMAC: "this is not the right mac :D", + } + + repoName, err := handler.validateMac(rundata) + require.Error(t, err) + require.Equal(t, "", repoName) + }) + + t.Run("compute correct mac", func(t *testing.T) { + secret := "this is my cool secret string :3" + name := "org/reponame" + run := "42" + ts := "1337" + + mac := computeMac(secret, name, run, ts) + expectedMac := "f666f06f917acb7186e152195b2a8c8d36d068ce683454be0878806e08e04f2b" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + + require.Equal(t, mac, expectedMac) + }) +} diff --git a/act/artifactcache/model.go b/act/artifactcache/model.go index 57812b31..1c0f855d 100644 --- a/act/artifactcache/model.go +++ b/act/artifactcache/model.go @@ -25,6 +25,7 @@ func (c *Request) ToCache() *Cache { type Cache struct { ID uint64 `json:"id" boltholdKey:"ID"` + Repo string `json:"repo" boltholdIndex:"Repo"` Key string `json:"key" boltholdIndex:"Key"` Version string `json:"version" boltholdIndex:"Version"` Size int64 `json:"cacheSize"` diff --git a/act/cacheproxy/handler.go b/act/cacheproxy/handler.go new file mode 100644 index 00000000..bd047926 --- /dev/null +++ b/act/cacheproxy/handler.go @@ -0,0 +1,227 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package cacheproxy + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "regexp" + "strconv" + "sync" + "time" + + "github.com/julienschmidt/httprouter" + "github.com/sirupsen/logrus" + + "github.com/nektos/act/pkg/common" +) + +const ( + urlBase = "/_apis/artifactcache" +) + +var ( + urlRegex = regexp.MustCompile(`/(\w+)(/_apis/artifactcache/.+)`) +) + +type Handler struct { + router *httprouter.Router + listener net.Listener + server *http.Server + logger logrus.FieldLogger + + outboundIP string + + cacheServerHost string + + cacheSecret string + + runs sync.Map +} + +type RunData struct { + RepositoryFullName string + RunNumber string + Timestamp string + RepositoryMAC string +} + +func (h *Handler) CreateRunData(fullName string, runNumber string, timestamp string) RunData { + mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp) + return RunData{ + RepositoryFullName: fullName, + RunNumber: runNumber, + Timestamp: timestamp, + RepositoryMAC: mac, + } +} + +func StartHandler(targetHost string, outboundIP string, port uint16, cacheSecret string, logger logrus.FieldLogger) (*Handler, error) { + h := &Handler{} + + if logger == nil { + discard := logrus.New() + discard.Out = io.Discard + logger = discard + } + logger = logger.WithField("module", "artifactcache") + h.logger = logger + + h.cacheSecret = cacheSecret + + if outboundIP != "" { + h.outboundIP = outboundIP + } else if ip := common.GetOutboundIP(); ip == nil { + return nil, fmt.Errorf("unable to determine outbound IP address") + } else { + h.outboundIP = ip.String() + } + + h.cacheServerHost = targetHost + + proxy, err := h.newReverseProxy(targetHost) + if err != nil { + return nil, fmt.Errorf("unable to set up proxy to target host") + } + + router := httprouter.New() + router.HandlerFunc("GET", "/:runId"+urlBase+"/cache", proxyRequestHandler(proxy)) + router.HandlerFunc("POST", "/:runId"+urlBase+"/caches", proxyRequestHandler(proxy)) + router.HandlerFunc("PATCH", "/:runId"+urlBase+"/caches/:id", proxyRequestHandler(proxy)) + router.HandlerFunc("POST", "/:runId"+urlBase+"/caches/:id", proxyRequestHandler(proxy)) + router.HandlerFunc("GET", "/:runId"+urlBase+"/artifacts/:id", proxyRequestHandler(proxy)) + router.HandlerFunc("POST", "/:runId"+urlBase+"/clean", proxyRequestHandler(proxy)) + + h.router = router + + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) // listen on all interfaces + if err != nil { + return nil, err + } + server := &http.Server{ + ReadHeaderTimeout: 2 * time.Second, + Handler: router, + } + go func() { + if err := server.Serve(listener); err != nil && errors.Is(err, net.ErrClosed) { + logger.Errorf("http serve: %v", err) + } + }() + h.listener = listener + h.server = server + + return h, nil +} + +func proxyRequestHandler(proxy *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) { + return proxy.ServeHTTP +} + +func (h *Handler) newReverseProxy(targetHost string) (*httputil.ReverseProxy, error) { + targetURL, err := url.Parse(targetHost) + if err != nil { + return nil, err + } + + proxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + matches := urlRegex.FindStringSubmatch(r.In.URL.Path) + id := matches[1] + data, ok := h.runs.Load(id) + if !ok { + // The ID doesn't exist. + h.logger.Warn(fmt.Sprintf("Tried starting a cache proxy with id %s, which does not exist.", id)) + return + } + var runData = data.(RunData) + uri := matches[2] + + r.SetURL(targetURL) + r.Out.URL.Path = uri + + r.Out.Header.Set("Forgejo-Cache-Repo", runData.RepositoryFullName) + r.Out.Header.Set("Forgejo-Cache-RunNumber", runData.RunNumber) + r.Out.Header.Set("Forgejo-Cache-RunId", id) + r.Out.Header.Set("Forgejo-Cache-Timestamp", runData.Timestamp) + r.Out.Header.Set("Forgejo-Cache-MAC", runData.RepositoryMAC) + r.Out.Header.Set("Forgejo-Cache-Host", h.ExternalURL()) + }, + } + return proxy, nil +} + +func (h *Handler) ExternalURL() string { + // TODO: make the external url configurable if necessary + return fmt.Sprintf("http://%s", net.JoinHostPort(h.outboundIP, strconv.Itoa(h.listener.Addr().(*net.TCPAddr).Port))) +} + +// Informs the proxy of a workflow run that can make cache requests. +// The RunData contains the information about the repository. +// The function returns the 32-bit random key which the run will use to identify itself. +func (h *Handler) AddRun(data RunData) (string, error) { + for retries := 0; retries < 3; retries++ { + keyBytes := make([]byte, 4) + _, err := rand.Read(keyBytes) + if err != nil { + return "", errors.New("Could not generate the run id") + } + key := hex.EncodeToString(keyBytes) + + _, loaded := h.runs.LoadOrStore(key, data) + if !loaded { + // The key was unique and added successfully + return key, nil + } + } + return "", errors.New("Repeated collisions in generating run id") +} + +func (h *Handler) RemoveRun(runID string) error { + _, existed := h.runs.LoadAndDelete(runID) + if !existed { + return errors.New("The run id was not known to the proxy") + } + return nil +} + +func (h *Handler) Close() error { + if h == nil { + return nil + } + var retErr error + if h.server != nil { + err := h.server.Close() + if err != nil { + retErr = err + } + h.server = nil + } + if h.listener != nil { + err := h.listener.Close() + if !errors.Is(err, net.ErrClosed) { + retErr = err + } + h.listener = nil + } + return retErr +} + +func computeMac(secret, repo, run, ts string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write([]byte(repo)) + mac.Write([]byte(">")) + mac.Write([]byte(run)) + mac.Write([]byte(">")) + mac.Write([]byte(ts)) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/cmd/input.go b/cmd/input.go index 36af6d86..e5aff64e 100644 --- a/cmd/input.go +++ b/cmd/input.go @@ -59,6 +59,7 @@ type Input struct { logPrefixJobID bool networkName string useNewActionCache bool + secret string } func (i *Input) resolve(path string) string { diff --git a/cmd/root.go b/cmd/root.go index 349e6ac4..c5d4ce0a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -632,7 +632,7 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str var cacheHandler *artifactcache.Handler if !input.noCacheServer && envs[cacheURLKey] == "" { var err error - cacheHandler, err = artifactcache.StartHandler(input.cacheServerPath, input.cacheServerAddr, input.cacheServerPort, common.Logger(ctx)) + cacheHandler, err = artifactcache.StartHandler(input.cacheServerPath, input.cacheServerAddr, input.cacheServerPort, input.secret, common.Logger(ctx)) if err != nil { return err }