From da7ef7c2a19cb4e94a9323eefe54c7a723a0aaaa Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Fri, 15 Aug 2025 20:26:35 -0600 Subject: [PATCH 1/5] fix: PRs cache artifacts separate from other runs --- act/artifactcache/handler.go | 29 +++++++++++++++++++++++++++-- act/artifactcache/model.go | 17 +++++++++-------- act/cacheproxy/handler.go | 7 ++++++- internal/app/run/runner.go | 17 ++++++++++++++++- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/act/artifactcache/handler.go b/act/artifactcache/handler.go index 0f78be85..8b8572ee 100644 --- a/act/artifactcache/handler.go +++ b/act/artifactcache/handler.go @@ -181,11 +181,19 @@ func (h *Handler) find(w http.ResponseWriter, r *http.Request, params httprouter } defer db.Close() - cache, err := findCache(db, repo, keys, version) + cache, err := findCache(db, repo, keys, version, rundata.WriteIsolationKey) if err != nil { h.responseJSON(w, r, 500, err) return } + // If read was scoped to WriteIsolationKey and didn't find anything, we can fallback to the non-isolated cache read + if cache == nil && rundata.WriteIsolationKey != "" { + cache, err = findCache(db, repo, keys, version, "") + if err != nil { + h.responseJSON(w, r, 500, err) + return + } + } if cache == nil { h.responseJSON(w, r, 204) return @@ -236,6 +244,7 @@ func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, params httprou cache.CreatedAt = now cache.UsedAt = now cache.Repo = repo + cache.WriteIsolationKey = rundata.WriteIsolationKey if err := insertCache(db, cache); err != nil { h.responseJSON(w, r, 500, err) return @@ -275,6 +284,10 @@ func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprout h.responseJSON(w, r, 500, fmt.Errorf("cache repo is not valid")) return } + if cache.WriteIsolationKey != rundata.WriteIsolationKey { + h.responseJSON(w, r, 403, fmt.Errorf("cache authorized for write isolation %q, but attempting to operate on %q", rundata.WriteIsolationKey, cache.WriteIsolationKey)) + return + } if cache.Complete { h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) @@ -326,6 +339,10 @@ func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprout h.responseJSON(w, r, 500, fmt.Errorf("cache repo is not valid")) return } + if cache.WriteIsolationKey != rundata.WriteIsolationKey { + h.responseJSON(w, r, 403, fmt.Errorf("cache authorized for write isolation %q, but attempting to operate on %q", rundata.WriteIsolationKey, cache.WriteIsolationKey)) + return + } if cache.Complete { h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) @@ -386,6 +403,11 @@ func (h *Handler) get(w http.ResponseWriter, r *http.Request, params httprouter. h.responseJSON(w, r, 500, fmt.Errorf("cache repo is not valid")) return } + // reads permitted against caches w/ the same isolation key, or no isolation key + if cache.WriteIsolationKey != rundata.WriteIsolationKey && cache.WriteIsolationKey != "" { + h.responseJSON(w, r, 403, fmt.Errorf("cache authorized for write isolation %q, but attempting to operate on %q", rundata.WriteIsolationKey, cache.WriteIsolationKey)) + return + } if err := h.useCache(id); err != nil { h.responseJSON(w, r, 500, fmt.Errorf("cache useCache: %w", err)) @@ -417,7 +439,7 @@ func (h *Handler) middleware(handler httprouter.Handle) httprouter.Handle { } // if not found, return (nil, nil) instead of an error. -func findCache(db *bolthold.Store, repo string, keys []string, version string) (*Cache, error) { +func findCache(db *bolthold.Store, repo string, keys []string, version, writeIsolationKey string) (*Cache, error) { cache := &Cache{} for _, prefix := range keys { // if a key in the list matches exactly, don't return partial matches @@ -425,6 +447,7 @@ func findCache(db *bolthold.Store, repo string, keys []string, version string) ( bolthold.Where("Repo").Eq(repo).Index("Repo"). And("Key").Eq(prefix). And("Version").Eq(version). + And("WriteIsolationKey").Eq(writeIsolationKey). And("Complete").Eq(true). SortBy("CreatedAt").Reverse()); err == nil || !errors.Is(err, bolthold.ErrNotFound) { if err != nil { @@ -441,6 +464,7 @@ func findCache(db *bolthold.Store, repo string, keys []string, version string) ( bolthold.Where("Repo").Eq(repo).Index("Repo"). And("Key").RegExp(re). And("Version").Eq(version). + And("WriteIsolationKey").Eq(writeIsolationKey). And("Complete").Eq(true). SortBy("CreatedAt").Reverse()); err != nil { if errors.Is(err, bolthold.ErrNotFound) { @@ -644,5 +668,6 @@ func runDataFromHeaders(r *http.Request) cacheproxy.RunData { RunNumber: r.Header.Get("Forgejo-Cache-RunNumber"), Timestamp: r.Header.Get("Forgejo-Cache-Timestamp"), RepositoryMAC: r.Header.Get("Forgejo-Cache-MAC"), + WriteIsolationKey: r.Header.Get("Forgejo-Cache-WriteIsolationKey"), } } diff --git a/act/artifactcache/model.go b/act/artifactcache/model.go index b27fd8ed..cb1af860 100644 --- a/act/artifactcache/model.go +++ b/act/artifactcache/model.go @@ -24,12 +24,13 @@ func (c *Request) ToCache() *Cache { } type Cache struct { - ID uint64 `json:"id" boltholdKey:"ID"` - Repo string `json:"repo" boltholdIndex:"Repo"` - Key string `json:"key"` - Version string `json:"version"` - Size int64 `json:"cacheSize"` - Complete bool `json:"complete"` - UsedAt int64 `json:"usedAt"` - CreatedAt int64 `json:"createdAt"` + ID uint64 `json:"id" boltholdKey:"ID"` + Repo string `json:"repo" boltholdIndex:"Repo"` + Key string `json:"key"` + Version string `json:"version"` + Size int64 `json:"cacheSize"` + Complete bool `json:"complete"` + UsedAt int64 `json:"usedAt"` + CreatedAt int64 `json:"createdAt"` + WriteIsolationKey string `json:"writeIsolationKey"` } diff --git a/act/cacheproxy/handler.go b/act/cacheproxy/handler.go index 56b083a6..a2d76527 100644 --- a/act/cacheproxy/handler.go +++ b/act/cacheproxy/handler.go @@ -51,15 +51,17 @@ type RunData struct { RunNumber string Timestamp string RepositoryMAC string + WriteIsolationKey string } -func (h *Handler) CreateRunData(fullName, runNumber, timestamp string) RunData { +func (h *Handler) CreateRunData(fullName, runNumber, timestamp, writeIsolationKey string) RunData { mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp) return RunData{ RepositoryFullName: fullName, RunNumber: runNumber, Timestamp: timestamp, RepositoryMAC: mac, + WriteIsolationKey: writeIsolationKey, } } @@ -152,6 +154,9 @@ func (h *Handler) newReverseProxy(targetHost string) (*httputil.ReverseProxy, er r.Out.Header.Set("Forgejo-Cache-Timestamp", runData.Timestamp) r.Out.Header.Set("Forgejo-Cache-MAC", runData.RepositoryMAC) r.Out.Header.Set("Forgejo-Cache-Host", h.ExternalURL()) + if runData.WriteIsolationKey != "" { + r.Out.Header.Set("Forgejo-Cache-WriteIsolationKey", runData.WriteIsolationKey) + } }, } return proxy, nil diff --git a/internal/app/run/runner.go b/internal/app/run/runner.go index b2fcf4fc..86cddd1a 100644 --- a/internal/app/run/runner.go +++ b/internal/app/run/runner.go @@ -265,8 +265,23 @@ func (r *Runner) run(ctx context.Context, task *runnerv1.Task, reporter *report. // Register the run with the cacheproxy and modify the CACHE_URL if r.cacheProxy != nil { + writeIsolationKey := "" + + // When performing an action on an event from a PR, provide a "write isolation key" to the cache. The generated + // ACTIONS_CACHE_URL will be able to read the cache, and write to a cache, but its writes will be isolated to + // future runs of the PR's workflows and won't be shared with other pull requests or actions. This is a security + // measure to prevent a malicious pull request from poisoning the cache with secret-stealing code which would + // later be executed on another action. + if taskContext["event_name"].GetStringValue() == "pull_request" || taskContext["event_name"].GetStringValue() == "pull_request_target" { + // Ensure that `Ref` has the expected format so that we don't end up with a useless write isolation key + if !strings.HasPrefix(preset.Ref, "refs/pull/") { + return fmt.Errorf("write isolation key: expected preset.Ref to be refs/pull/..., but was %q", preset.Ref) + } + writeIsolationKey = preset.Ref + } + timestamp := strconv.FormatInt(time.Now().Unix(), 10) - cacheRunData := r.cacheProxy.CreateRunData(preset.Repository, preset.RunID, timestamp) + cacheRunData := r.cacheProxy.CreateRunData(preset.Repository, preset.RunID, timestamp, writeIsolationKey) cacheRunID, err := r.cacheProxy.AddRun(cacheRunData) if err == nil { defer func() { _ = r.cacheProxy.RemoveRun(cacheRunID) }() From 6c35ea4fd92288e142158e5955c66ac444f6e679 Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Sat, 16 Aug 2025 19:26:30 -0600 Subject: [PATCH 2/5] add unit tests for all changes in artifactcache --- act/artifactcache/handler_test.go | 236 ++++++++++++++++++++++++++++-- 1 file changed, 227 insertions(+), 9 deletions(-) diff --git a/act/artifactcache/handler_test.go b/act/artifactcache/handler_test.go index b8ff6769..d078b90e 100644 --- a/act/artifactcache/handler_test.go +++ b/act/artifactcache/handler_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "code.forgejo.org/forgejo/runner/v9/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/timshannon/bolthold" @@ -28,7 +29,8 @@ const ( var handlerExternalURL string type AuthHeaderTransport struct { - T http.RoundTripper + T http.RoundTripper + WriteIsolationKey string } func (t *AuthHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -37,11 +39,14 @@ func (t *AuthHeaderTransport) RoundTrip(req *http.Request) (*http.Response, erro req.Header.Set("Forgejo-Cache-Timestamp", cacheTimestamp) req.Header.Set("Forgejo-Cache-MAC", cacheMac) req.Header.Set("Forgejo-Cache-Host", handlerExternalURL) + if t.WriteIsolationKey != "" { + req.Header.Set("Forgejo-Cache-WriteIsolationKey", t.WriteIsolationKey) + } return t.T.RoundTrip(req) } var ( - httpClientTransport = AuthHeaderTransport{http.DefaultTransport} + httpClientTransport = AuthHeaderTransport{T: http.DefaultTransport} httpClient = http.Client{Transport: &httpClientTransport} ) @@ -88,7 +93,7 @@ func TestHandler(t *testing.T) { content := make([]byte, 100) _, err := rand.Read(content) require.NoError(t, err) - uploadCacheNormally(t, base, key, version, content) + uploadCacheNormally(t, base, key, version, "", content) }) t.Run("clean", func(t *testing.T) { @@ -380,7 +385,7 @@ func TestHandler(t *testing.T) { _, err := rand.Read(content) require.NoError(t, err) - uploadCacheNormally(t, base, key, version, content) + uploadCacheNormally(t, base, key, version, "", content) // Perform the request with the custom `httpClient` which will send correct MAC data resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) @@ -416,7 +421,7 @@ func TestHandler(t *testing.T) { for i := range contents { _, err := rand.Read(contents[i]) require.NoError(t, err) - uploadCacheNormally(t, base, keys[i], version, contents[i]) + uploadCacheNormally(t, base, keys[i], version, "", contents[i]) time.Sleep(time.Second) // ensure CreatedAt of caches are different } @@ -454,13 +459,91 @@ func TestHandler(t *testing.T) { assert.Equal(t, contents[except], content) }) + t.Run("find can't match without WriteIsolationKey match", func(t *testing.T) { + defer func() { httpClientTransport.WriteIsolationKey = "" }() + + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + key := strings.ToLower(t.Name()) + + uploadCacheNormally(t, base, key, version, "TestWriteKey", make([]byte, 64)) + + httpClientTransport.WriteIsolationKey = "AnotherTestWriteKey" + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + + httpClientTransport.WriteIsolationKey = "" + resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + + httpClientTransport.WriteIsolationKey = "TestWriteKey" + resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + }) + + t.Run("find prefers WriteIsolationKey match", func(t *testing.T) { + defer func() { httpClientTransport.WriteIsolationKey = "" }() + + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" + key := strings.ToLower(t.Name()) + + // Between two values with the same `key`... + uploadCacheNormally(t, base, key, version, "TestWriteKey", make([]byte, 64)) + uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) + + // We should read the value with the matching WriteIsolationKey from the cache... + httpClientTransport.WriteIsolationKey = "TestWriteKey" + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + // Which we finally check matches the correct WriteIsolationKey's content here. + assert.Equal(t, make([]byte, 64), content) + }) + + t.Run("find falls back if matching WriteIsolationKey not available", func(t *testing.T) { + defer func() { httpClientTransport.WriteIsolationKey = "" }() + + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" + key := strings.ToLower(t.Name()) + + uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) + + httpClientTransport.WriteIsolationKey = "TestWriteKey" + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + assert.Equal(t, make([]byte, 128), content) + }) + t.Run("case insensitive", func(t *testing.T) { version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" key := strings.ToLower(t.Name()) content := make([]byte, 100) _, err := rand.Read(content) require.NoError(t, err) - uploadCacheNormally(t, base, key+"_ABC", version, content) + uploadCacheNormally(t, base, key+"_ABC", version, "", content) { reqKey := key + "_aBc" @@ -494,7 +577,7 @@ func TestHandler(t *testing.T) { for i := range contents { _, err := rand.Read(contents[i]) require.NoError(t, err) - uploadCacheNormally(t, base, keys[i], version, contents[i]) + uploadCacheNormally(t, base, keys[i], version, "", contents[i]) time.Sleep(time.Second) // ensure CreatedAt of caches are different } @@ -545,7 +628,7 @@ func TestHandler(t *testing.T) { for i := range contents { _, err := rand.Read(contents[i]) require.NoError(t, err) - uploadCacheNormally(t, base, keys[i], version, contents[i]) + uploadCacheNormally(t, base, keys[i], version, "", contents[i]) time.Sleep(time.Second) // ensure CreatedAt of caches are different } @@ -581,9 +664,144 @@ func TestHandler(t *testing.T) { require.NoError(t, err) assert.Equal(t, contents[expect], content) }) + + t.Run("upload across WriteIsolationKey", func(t *testing.T) { + defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + content := make([]byte, 256) + + var id uint64 + // reserve + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: int64(len(content)), + }) + require.NoError(t, err) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + // upload, but with the incorrect write isolation key relative to the cache obj created + { + httpClientTransport.WriteIsolationKey = "WrongKey" + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := httpClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 403, resp.StatusCode) + } + }) + + t.Run("commit across WriteIsolationKey", func(t *testing.T) { + defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + content := make([]byte, 256) + + var id uint64 + // reserve + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: int64(len(content)), + }) + require.NoError(t, err) + resp, err := httpClient.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + // upload + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := httpClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + // commit, but with the incorrect write isolation key relative to the cache obj created + { + httpClientTransport.WriteIsolationKey = "WrongKey" + resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 403, resp.StatusCode) + } + }) + + t.Run("get across WriteIsolationKey", func(t *testing.T) { + defer func() { httpClientTransport.WriteIsolationKey = "" }() + + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" + key := strings.ToLower(t.Name()) + uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) + keyIsolated := strings.ToLower(t.Name()) + "_isolated" + uploadCacheNormally(t, base, keyIsolated, version, "CorrectKey", make([]byte, 128)) + + // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be OK for `key` since it + // was written with WriteIsolationKey "" meaning it is available for non-isolated access + { + httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" + + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + httpClientTransport.WriteIsolationKey = "CorrectKey" // reset for next find + } + + // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be 403 for `keyIsolated` + // because it was written with a different WriteIsolationKey. + { + httpClientTransport.WriteIsolationKey = "CorrectKey" // for test purposes make the `find` successful... + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, keyIsolated, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + + httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" // but then access w/ the wrong key for `get` + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 403, contentResp.StatusCode) + } + }) } -func uploadCacheNormally(t *testing.T, base, key, version string, content []byte) { +func uploadCacheNormally(t *testing.T, base, key, version, writeIsolationKey string, content []byte) { + defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, writeIsolationKey)() + var id uint64 { body, err := json.Marshal(&Request{ From 4bd93294d4fea8102ccf2b9e4e0c3a84fdabff7f Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Tue, 19 Aug 2025 11:18:32 -0600 Subject: [PATCH 3/5] add WriteIsolationKey to MAC --- act/artifactcache/handler_test.go | 179 ++++++++++++++++++------------ act/artifactcache/mac.go | 6 +- act/artifactcache/mac_test.go | 13 ++- act/cacheproxy/handler.go | 6 +- 4 files changed, 122 insertions(+), 82 deletions(-) diff --git a/act/artifactcache/handler_test.go b/act/artifactcache/handler_test.go index d078b90e..1d813d56 100644 --- a/act/artifactcache/handler_test.go +++ b/act/artifactcache/handler_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "code.forgejo.org/forgejo/runner/v9/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/timshannon/bolthold" @@ -23,21 +22,26 @@ const ( cacheRepo = "testuser/repo" cacheRunnum = "1" cacheTimestamp = "0" - cacheMac = "c13854dd1ac599d1d61680cd93c26b77ba0ee10f374a3408bcaea82f38ca1865" + cacheMac = "bc2e9167f9e310baebcead390937264e4c0b21d2fdd49f5b9470d54406099360" ) var handlerExternalURL string type AuthHeaderTransport struct { - T http.RoundTripper - WriteIsolationKey string + T http.RoundTripper + WriteIsolationKey string + OverrideDefaultMac string } func (t *AuthHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Set("Forgejo-Cache-Repo", cacheRepo) req.Header.Set("Forgejo-Cache-RunNumber", cacheRunnum) req.Header.Set("Forgejo-Cache-Timestamp", cacheTimestamp) - req.Header.Set("Forgejo-Cache-MAC", cacheMac) + if t.OverrideDefaultMac != "" { + req.Header.Set("Forgejo-Cache-MAC", t.OverrideDefaultMac) + } else { + req.Header.Set("Forgejo-Cache-MAC", cacheMac) + } req.Header.Set("Forgejo-Cache-Host", handlerExternalURL) if t.WriteIsolationKey != "" { req.Header.Set("Forgejo-Cache-WriteIsolationKey", t.WriteIsolationKey) @@ -467,25 +471,28 @@ func TestHandler(t *testing.T) { uploadCacheNormally(t, base, key, version, "TestWriteKey", make([]byte, 64)) - httpClientTransport.WriteIsolationKey = "AnotherTestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 204, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("AnotherTestWriteKey")() + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + }() - httpClientTransport.WriteIsolationKey = "" - resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 204, resp.StatusCode) + { + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + } - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + }() }) t.Run("find prefers WriteIsolationKey match", func(t *testing.T) { - defer func() { httpClientTransport.WriteIsolationKey = "" }() - version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" key := strings.ToLower(t.Name()) @@ -494,47 +501,51 @@ func TestHandler(t *testing.T) { uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) // We should read the value with the matching WriteIsolationKey from the cache... - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() - got := struct { - ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 200, contentResp.StatusCode) - content, err := io.ReadAll(contentResp.Body) - require.NoError(t, err) - // Which we finally check matches the correct WriteIsolationKey's content here. - assert.Equal(t, make([]byte, 64), content) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + // Which we finally check matches the correct WriteIsolationKey's content here. + assert.Equal(t, make([]byte, 64), content) + }() }) t.Run("find falls back if matching WriteIsolationKey not available", func(t *testing.T) { - defer func() { httpClientTransport.WriteIsolationKey = "" }() - version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" key := strings.ToLower(t.Name()) uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() - got := struct { - ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 200, contentResp.StatusCode) - content, err := io.ReadAll(contentResp.Body) - require.NoError(t, err) - assert.Equal(t, make([]byte, 128), content) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + assert.Equal(t, make([]byte, 128), content) + }() }) t.Run("case insensitive", func(t *testing.T) { @@ -666,7 +677,7 @@ func TestHandler(t *testing.T) { }) t.Run("upload across WriteIsolationKey", func(t *testing.T) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + defer overrideWriteIsolationKey("CorrectKey")() key := strings.ToLower(t.Name()) version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" @@ -692,8 +703,8 @@ func TestHandler(t *testing.T) { id = got.CacheID } // upload, but with the incorrect write isolation key relative to the cache obj created - { - httpClientTransport.WriteIsolationKey = "WrongKey" + func() { + defer overrideWriteIsolationKey("WrongKey")() req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) require.NoError(t, err) @@ -702,11 +713,11 @@ func TestHandler(t *testing.T) { resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 403, resp.StatusCode) - } + }() }) t.Run("commit across WriteIsolationKey", func(t *testing.T) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + defer overrideWriteIsolationKey("CorrectKey")() key := strings.ToLower(t.Name()) version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" @@ -743,12 +754,12 @@ func TestHandler(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) } // commit, but with the incorrect write isolation key relative to the cache obj created - { - httpClientTransport.WriteIsolationKey = "WrongKey" + func() { + defer overrideWriteIsolationKey("WrongKey")() resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 403, resp.StatusCode) - } + }() }) t.Run("get across WriteIsolationKey", func(t *testing.T) { @@ -762,8 +773,8 @@ func TestHandler(t *testing.T) { // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be OK for `key` since it // was written with WriteIsolationKey "" meaning it is available for non-isolated access - { - httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" + func() { + defer overrideWriteIsolationKey("WhoopsWrongKey")() resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) require.NoError(t, err) @@ -777,30 +788,52 @@ func TestHandler(t *testing.T) { require.NoError(t, err) require.Equal(t, 200, contentResp.StatusCode) httpClientTransport.WriteIsolationKey = "CorrectKey" // reset for next find - } + }() // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be 403 for `keyIsolated` // because it was written with a different WriteIsolationKey. { - httpClientTransport.WriteIsolationKey = "CorrectKey" // for test purposes make the `find` successful... - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, keyIsolated, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) - got := struct { + got := func() struct { ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + } { + defer overrideWriteIsolationKey("CorrectKey")() // for test purposes make the `find` successful... + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, keyIsolated, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + return got + }() - httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" // but then access w/ the wrong key for `get` - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 403, contentResp.StatusCode) + func() { + defer overrideWriteIsolationKey("WhoopsWrongKey")() // but then access w/ the wrong key for `get` + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 403, contentResp.StatusCode) + }() } }) } +func overrideWriteIsolationKey(writeIsolationKey string) func() { + originalWriteIsolationKey := httpClientTransport.WriteIsolationKey + originalMac := httpClientTransport.OverrideDefaultMac + + httpClientTransport.WriteIsolationKey = writeIsolationKey + httpClientTransport.OverrideDefaultMac = computeMac("secret", cacheRepo, cacheRunnum, cacheTimestamp, httpClientTransport.WriteIsolationKey) + + return func() { + httpClientTransport.WriteIsolationKey = originalWriteIsolationKey + httpClientTransport.OverrideDefaultMac = originalMac + } +} + func uploadCacheNormally(t *testing.T, base, key, version, writeIsolationKey string, content []byte) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, writeIsolationKey)() + if writeIsolationKey != "" { + defer overrideWriteIsolationKey(writeIsolationKey)() + } var id uint64 { diff --git a/act/artifactcache/mac.go b/act/artifactcache/mac.go index 6ed20b01..72e5fa2e 100644 --- a/act/artifactcache/mac.go +++ b/act/artifactcache/mac.go @@ -22,7 +22,7 @@ func (h *Handler) validateMac(rundata cacheproxy.RunData) (string, error) { return "", ErrValidation } - expectedMAC := computeMac(h.secret, rundata.RepositoryFullName, rundata.RunNumber, rundata.Timestamp) + expectedMAC := computeMac(h.secret, rundata.RepositoryFullName, rundata.RunNumber, rundata.Timestamp, rundata.WriteIsolationKey) if hmac.Equal([]byte(expectedMAC), []byte(rundata.RepositoryMAC)) { return rundata.RepositoryFullName, nil } @@ -40,12 +40,14 @@ func validateAge(ts string) bool { return true } -func computeMac(secret, repo, run, ts string) string { +func computeMac(secret, repo, run, ts, writeIsolationKey string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(repo)) mac.Write([]byte(">")) mac.Write([]byte(run)) mac.Write([]byte(">")) mac.Write([]byte(ts)) + mac.Write([]byte(">")) + mac.Write([]byte(writeIsolationKey)) return hex.EncodeToString(mac.Sum(nil)) } diff --git a/act/artifactcache/mac_test.go b/act/artifactcache/mac_test.go index a4b801a9..0e1f3be6 100644 --- a/act/artifactcache/mac_test.go +++ b/act/artifactcache/mac_test.go @@ -19,7 +19,7 @@ func TestMac(t *testing.T) { run := "1" ts := strconv.FormatInt(time.Now().Unix(), 10) - mac := computeMac(handler.secret, name, run, ts) + mac := computeMac(handler.secret, name, run, ts, "") rundata := cacheproxy.RunData{ RepositoryFullName: name, RunNumber: run, @@ -37,7 +37,7 @@ func TestMac(t *testing.T) { run := "1" ts := "9223372036854775807" // This should last us for a while... - mac := computeMac(handler.secret, name, run, ts) + mac := computeMac(handler.secret, name, run, ts, "") rundata := cacheproxy.RunData{ RepositoryFullName: name, RunNumber: run, @@ -72,9 +72,12 @@ func TestMac(t *testing.T) { run := "42" ts := "1337" - mac := computeMac(secret, name, run, ts) - expectedMac := "f666f06f917acb7186e152195b2a8c8d36d068ce683454be0878806e08e04f2b" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + mac := computeMac(secret, name, run, ts, "") + expectedMac := "4754474b21329e8beadd2b4054aa4be803965d66e710fa1fee091334ed804f29" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + require.Equal(t, expectedMac, mac) - require.Equal(t, mac, expectedMac) + mac = computeMac(secret, name, run, ts, "refs/pull/12/head") + expectedMac = "9ca8f4cb5e1b083ee8cd215215bc00f379b28511d3ef7930bf054767de34766d" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + require.Equal(t, expectedMac, mac) }) } diff --git a/act/cacheproxy/handler.go b/act/cacheproxy/handler.go index a2d76527..2b2b3707 100644 --- a/act/cacheproxy/handler.go +++ b/act/cacheproxy/handler.go @@ -55,7 +55,7 @@ type RunData struct { } func (h *Handler) CreateRunData(fullName, runNumber, timestamp, writeIsolationKey string) RunData { - mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp) + mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp, writeIsolationKey) return RunData{ RepositoryFullName: fullName, RunNumber: runNumber, @@ -212,12 +212,14 @@ func (h *Handler) Close() error { return retErr } -func computeMac(secret, repo, run, ts string) string { +func computeMac(secret, repo, run, ts, writeIsolationKey string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(repo)) mac.Write([]byte(">")) mac.Write([]byte(run)) mac.Write([]byte(">")) mac.Write([]byte(ts)) + mac.Write([]byte(">")) + mac.Write([]byte(writeIsolationKey)) return hex.EncodeToString(mac.Sum(nil)) } From dded18c94d6e2732c53abf2ca57540284bb9a161 Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Sun, 17 Aug 2025 16:36:14 -0600 Subject: [PATCH 4/5] add an integration test for PR cache pollution --- internal/app/run/runner_test.go | 127 +++++++++++++++++++++++++++++--- 1 file changed, 115 insertions(+), 12 deletions(-) diff --git a/internal/app/run/runner_test.go b/internal/app/run/runner_test.go index 94e476f4..9db43b4e 100644 --- a/internal/app/run/runner_test.go +++ b/internal/app/run/runner_test.go @@ -171,11 +171,7 @@ func TestRunnerCacheConfiguration(t *testing.T) { // Must set up cache for our test require.NotNil(t, runner.cacheProxy) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - // Run a given workflow w/ event... - runWorkflow := func(yamlContent, eventName, ref, description string) { + runWorkflow := func(ctx context.Context, cancel context.CancelFunc, yamlContent, eventName, ref, description string) { task := &runnerv1.Task{ WorkflowPayload: []byte(yamlContent), Context: &structpb.Struct{ @@ -195,8 +191,12 @@ func TestRunnerCacheConfiguration(t *testing.T) { require.NoError(t, err, description) } - // Step 1: Populate shared cache with push workflow - populateYaml := ` + t.Run("Cache accessible", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + // Step 1: Populate shared cache with push workflow + populateYaml := ` name: Cache Testing Action on: push: @@ -212,11 +212,11 @@ jobs: mkdir -p cache_path_1 echo "Hello from push workflow!" > cache_path_1/cache_content_1 ` - runWorkflow(populateYaml, "push", "refs/heads/main", "step 1: push cache populate expected to succeed") + runWorkflow(ctx, cancel, populateYaml, "push", "refs/heads/main", "step 1: push cache populate expected to succeed") - // Step 2: Validate that cache is accessible; mostly a sanity check that the test environment and mock context - // provides everything needed for the cache setup. - checkYaml := ` + // Step 2: Validate that cache is accessible; mostly a sanity check that the test environment and mock context + // provides everything needed for the cache setup. + checkYaml := ` name: Cache Testing Action on: push: @@ -232,5 +232,108 @@ jobs: [[ -f cache_path_1/cache_content_1 ]] && echo "Step 2: cache file found." || echo "Step 2: cache file missing!" [[ -f cache_path_1/cache_content_1 ]] || exit 1 ` - runWorkflow(checkYaml, "push", "refs/heads/main", "step 2: push cache check expected to succeed") + runWorkflow(ctx, cancel, checkYaml, "push", "refs/heads/main", "step 2: push cache check expected to succeed") + }) + + t.Run("PR cache pollution prevented", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + // Step 1: Populate shared cache with push workflow + populateYaml := ` +name: Cache Testing Action +on: + push: +jobs: + job-cache-check-1: + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + with: + path: cache_path_1 + key: cache-key-1 + - run: | + mkdir -p cache_path_1 + echo "Hello from push workflow!" > cache_path_1/cache_content_1 +` + runWorkflow(ctx, cancel, populateYaml, "push", "refs/heads/main", "step 1: push cache populate expected to succeed") + + // Step 2: Check if pull_request can read push cache, should be available as it's a trusted cache. + checkPRYaml := ` +name: Cache Testing Action +on: + pull_request: +jobs: + job-cache-check-2: + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + with: + path: cache_path_1 + key: cache-key-1 + - run: | + [[ -f cache_path_1/cache_content_1 ]] && echo "Step 2: cache file found." || echo "Step 2: cache file missing!" + [[ -f cache_path_1/cache_content_1 ]] || exit 1 +` + runWorkflow(ctx, cancel, checkPRYaml, "pull_request", "refs/pull/1234/head", "step 2: PR should read push cache") + + // Step 3: Pull request writes to cache; here we need to use a new cache key because we'll get a cache-hit like we + // did in step #2 if we keep the same key, and then the cache contents won't be updated. + populatePRYaml := ` +name: Cache Testing Action +on: + pull_request: +jobs: + job-cache-check-3: + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + with: + path: cache_path_1 + key: cache-key-2 + - run: | + mkdir -p cache_path_1 + echo "Hello from PR workflow!" > cache_path_1/cache_content_2 +` + runWorkflow(ctx, cancel, populatePRYaml, "pull_request", "refs/pull/1234/head", "step 3: PR cache populate expected to succeed") + + // Step 4: Check if pull_request can read its own cache written by step #3. + checkPRKey2Yaml := ` +name: Cache Testing Action +on: + pull_request: +jobs: + job-cache-check-4: + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + with: + path: cache_path_1 + key: cache-key-2 + - run: | + [[ -f cache_path_1/cache_content_2 ]] && echo "Step 4 cache file found." || echo "Step 4 cache file missing!" + [[ -f cache_path_1/cache_content_2 ]] || exit 1 +` + runWorkflow(ctx, cancel, checkPRKey2Yaml, "pull_request", "refs/pull/1234/head", "step 4: PR should read its own cache") + + // Step 5: Check that the push workflow cannot access the isolated cache that was written by the pull_request in + // step #3, ensuring that it's not possible to pollute the cache by predicting cache keys. + checkKey2Yaml := ` +name: Cache Testing Action +on: + push: +jobs: + job-cache-check-6: + runs-on: ubuntu-latest + steps: + - uses: actions/cache@v4 + with: + path: cache_path_1 + key: cache-key-2 + - run: | + [[ -f cache_path_1/cache_content_2 ]] && echo "Step 5 cache file found, oh no!" || echo "Step 5: cache file missing as expected." + [[ -f cache_path_1/cache_content_2 ]] && exit 1 || exit 0 +` + runWorkflow(ctx, cancel, checkKey2Yaml, "push", "refs/heads/main", "step 5: push cache should not be polluted by PR") + }) } From 5a569d4ed1cecd6cee41fce9c954c0ab79486fce Mon Sep 17 00:00:00 2001 From: Mathieu Fenniak Date: Thu, 21 Aug 2025 09:21:20 -0600 Subject: [PATCH 5/5] adopt t.Context() now that we're on go1.24; remove per-test explicit timeout --- internal/app/run/runner_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/app/run/runner_test.go b/internal/app/run/runner_test.go index 9db43b4e..92e96452 100644 --- a/internal/app/run/runner_test.go +++ b/internal/app/run/runner_test.go @@ -192,7 +192,7 @@ func TestRunnerCacheConfiguration(t *testing.T) { } t.Run("Cache accessible", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Step 1: Populate shared cache with push workflow @@ -236,7 +236,7 @@ jobs: }) t.Run("PR cache pollution prevented", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() // Step 1: Populate shared cache with push workflow