diff --git a/act/artifactcache/handler_test.go b/act/artifactcache/handler_test.go index d078b90e..1d813d56 100644 --- a/act/artifactcache/handler_test.go +++ b/act/artifactcache/handler_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "code.forgejo.org/forgejo/runner/v9/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/timshannon/bolthold" @@ -23,21 +22,26 @@ const ( cacheRepo = "testuser/repo" cacheRunnum = "1" cacheTimestamp = "0" - cacheMac = "c13854dd1ac599d1d61680cd93c26b77ba0ee10f374a3408bcaea82f38ca1865" + cacheMac = "bc2e9167f9e310baebcead390937264e4c0b21d2fdd49f5b9470d54406099360" ) var handlerExternalURL string type AuthHeaderTransport struct { - T http.RoundTripper - WriteIsolationKey string + T http.RoundTripper + WriteIsolationKey string + OverrideDefaultMac string } func (t *AuthHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Set("Forgejo-Cache-Repo", cacheRepo) req.Header.Set("Forgejo-Cache-RunNumber", cacheRunnum) req.Header.Set("Forgejo-Cache-Timestamp", cacheTimestamp) - req.Header.Set("Forgejo-Cache-MAC", cacheMac) + if t.OverrideDefaultMac != "" { + req.Header.Set("Forgejo-Cache-MAC", t.OverrideDefaultMac) + } else { + req.Header.Set("Forgejo-Cache-MAC", cacheMac) + } req.Header.Set("Forgejo-Cache-Host", handlerExternalURL) if t.WriteIsolationKey != "" { req.Header.Set("Forgejo-Cache-WriteIsolationKey", t.WriteIsolationKey) @@ -467,25 +471,28 @@ func TestHandler(t *testing.T) { uploadCacheNormally(t, base, key, version, "TestWriteKey", make([]byte, 64)) - httpClientTransport.WriteIsolationKey = "AnotherTestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 204, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("AnotherTestWriteKey")() + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + }() - httpClientTransport.WriteIsolationKey = "" - resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 204, resp.StatusCode) + { + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + } - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err = httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + }() }) t.Run("find prefers WriteIsolationKey match", func(t *testing.T) { - defer func() { httpClientTransport.WriteIsolationKey = "" }() - version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" key := strings.ToLower(t.Name()) @@ -494,47 +501,51 @@ func TestHandler(t *testing.T) { uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) // We should read the value with the matching WriteIsolationKey from the cache... - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() - got := struct { - ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 200, contentResp.StatusCode) - content, err := io.ReadAll(contentResp.Body) - require.NoError(t, err) - // Which we finally check matches the correct WriteIsolationKey's content here. - assert.Equal(t, make([]byte, 64), content) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + // Which we finally check matches the correct WriteIsolationKey's content here. + assert.Equal(t, make([]byte, 64), content) + }() }) t.Run("find falls back if matching WriteIsolationKey not available", func(t *testing.T) { - defer func() { httpClientTransport.WriteIsolationKey = "" }() - version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d21" key := strings.ToLower(t.Name()) uploadCacheNormally(t, base, key, version, "", make([]byte, 128)) - httpClientTransport.WriteIsolationKey = "TestWriteKey" - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) + func() { + defer overrideWriteIsolationKey("TestWriteKey")() - got := struct { - ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 200, contentResp.StatusCode) - content, err := io.ReadAll(contentResp.Body) - require.NoError(t, err) - assert.Equal(t, make([]byte, 128), content) + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 200, contentResp.StatusCode) + content, err := io.ReadAll(contentResp.Body) + require.NoError(t, err) + assert.Equal(t, make([]byte, 128), content) + }() }) t.Run("case insensitive", func(t *testing.T) { @@ -666,7 +677,7 @@ func TestHandler(t *testing.T) { }) t.Run("upload across WriteIsolationKey", func(t *testing.T) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + defer overrideWriteIsolationKey("CorrectKey")() key := strings.ToLower(t.Name()) version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" @@ -692,8 +703,8 @@ func TestHandler(t *testing.T) { id = got.CacheID } // upload, but with the incorrect write isolation key relative to the cache obj created - { - httpClientTransport.WriteIsolationKey = "WrongKey" + func() { + defer overrideWriteIsolationKey("WrongKey")() req, err := http.NewRequest(http.MethodPatch, fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) require.NoError(t, err) @@ -702,11 +713,11 @@ func TestHandler(t *testing.T) { resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, 403, resp.StatusCode) - } + }() }) t.Run("commit across WriteIsolationKey", func(t *testing.T) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, "CorrectKey")() + defer overrideWriteIsolationKey("CorrectKey")() key := strings.ToLower(t.Name()) version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" @@ -743,12 +754,12 @@ func TestHandler(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) } // commit, but with the incorrect write isolation key relative to the cache obj created - { - httpClientTransport.WriteIsolationKey = "WrongKey" + func() { + defer overrideWriteIsolationKey("WrongKey")() resp, err := httpClient.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) require.NoError(t, err) assert.Equal(t, 403, resp.StatusCode) - } + }() }) t.Run("get across WriteIsolationKey", func(t *testing.T) { @@ -762,8 +773,8 @@ func TestHandler(t *testing.T) { // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be OK for `key` since it // was written with WriteIsolationKey "" meaning it is available for non-isolated access - { - httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" + func() { + defer overrideWriteIsolationKey("WhoopsWrongKey")() resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) require.NoError(t, err) @@ -777,30 +788,52 @@ func TestHandler(t *testing.T) { require.NoError(t, err) require.Equal(t, 200, contentResp.StatusCode) httpClientTransport.WriteIsolationKey = "CorrectKey" // reset for next find - } + }() // Perform the 'get' without the right WriteIsolationKey for the cache entry... should be 403 for `keyIsolated` // because it was written with a different WriteIsolationKey. { - httpClientTransport.WriteIsolationKey = "CorrectKey" // for test purposes make the `find` successful... - resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, keyIsolated, version)) - require.NoError(t, err) - require.Equal(t, 200, resp.StatusCode) - got := struct { + got := func() struct { ArchiveLocation string `json:"archiveLocation"` - }{} - require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + } { + defer overrideWriteIsolationKey("CorrectKey")() // for test purposes make the `find` successful... + resp, err := httpClient.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, keyIsolated, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + ArchiveLocation string `json:"archiveLocation"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + return got + }() - httpClientTransport.WriteIsolationKey = "WhoopsWrongKey" // but then access w/ the wrong key for `get` - contentResp, err := httpClient.Get(got.ArchiveLocation) - require.NoError(t, err) - require.Equal(t, 403, contentResp.StatusCode) + func() { + defer overrideWriteIsolationKey("WhoopsWrongKey")() // but then access w/ the wrong key for `get` + contentResp, err := httpClient.Get(got.ArchiveLocation) + require.NoError(t, err) + require.Equal(t, 403, contentResp.StatusCode) + }() } }) } +func overrideWriteIsolationKey(writeIsolationKey string) func() { + originalWriteIsolationKey := httpClientTransport.WriteIsolationKey + originalMac := httpClientTransport.OverrideDefaultMac + + httpClientTransport.WriteIsolationKey = writeIsolationKey + httpClientTransport.OverrideDefaultMac = computeMac("secret", cacheRepo, cacheRunnum, cacheTimestamp, httpClientTransport.WriteIsolationKey) + + return func() { + httpClientTransport.WriteIsolationKey = originalWriteIsolationKey + httpClientTransport.OverrideDefaultMac = originalMac + } +} + func uploadCacheNormally(t *testing.T, base, key, version, writeIsolationKey string, content []byte) { - defer testutils.MockVariable(&httpClientTransport.WriteIsolationKey, writeIsolationKey)() + if writeIsolationKey != "" { + defer overrideWriteIsolationKey(writeIsolationKey)() + } var id uint64 { diff --git a/act/artifactcache/mac.go b/act/artifactcache/mac.go index 6ed20b01..72e5fa2e 100644 --- a/act/artifactcache/mac.go +++ b/act/artifactcache/mac.go @@ -22,7 +22,7 @@ func (h *Handler) validateMac(rundata cacheproxy.RunData) (string, error) { return "", ErrValidation } - expectedMAC := computeMac(h.secret, rundata.RepositoryFullName, rundata.RunNumber, rundata.Timestamp) + expectedMAC := computeMac(h.secret, rundata.RepositoryFullName, rundata.RunNumber, rundata.Timestamp, rundata.WriteIsolationKey) if hmac.Equal([]byte(expectedMAC), []byte(rundata.RepositoryMAC)) { return rundata.RepositoryFullName, nil } @@ -40,12 +40,14 @@ func validateAge(ts string) bool { return true } -func computeMac(secret, repo, run, ts string) string { +func computeMac(secret, repo, run, ts, writeIsolationKey string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(repo)) mac.Write([]byte(">")) mac.Write([]byte(run)) mac.Write([]byte(">")) mac.Write([]byte(ts)) + mac.Write([]byte(">")) + mac.Write([]byte(writeIsolationKey)) return hex.EncodeToString(mac.Sum(nil)) } diff --git a/act/artifactcache/mac_test.go b/act/artifactcache/mac_test.go index a4b801a9..0e1f3be6 100644 --- a/act/artifactcache/mac_test.go +++ b/act/artifactcache/mac_test.go @@ -19,7 +19,7 @@ func TestMac(t *testing.T) { run := "1" ts := strconv.FormatInt(time.Now().Unix(), 10) - mac := computeMac(handler.secret, name, run, ts) + mac := computeMac(handler.secret, name, run, ts, "") rundata := cacheproxy.RunData{ RepositoryFullName: name, RunNumber: run, @@ -37,7 +37,7 @@ func TestMac(t *testing.T) { run := "1" ts := "9223372036854775807" // This should last us for a while... - mac := computeMac(handler.secret, name, run, ts) + mac := computeMac(handler.secret, name, run, ts, "") rundata := cacheproxy.RunData{ RepositoryFullName: name, RunNumber: run, @@ -72,9 +72,12 @@ func TestMac(t *testing.T) { run := "42" ts := "1337" - mac := computeMac(secret, name, run, ts) - expectedMac := "f666f06f917acb7186e152195b2a8c8d36d068ce683454be0878806e08e04f2b" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + mac := computeMac(secret, name, run, ts, "") + expectedMac := "4754474b21329e8beadd2b4054aa4be803965d66e710fa1fee091334ed804f29" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + require.Equal(t, expectedMac, mac) - require.Equal(t, mac, expectedMac) + mac = computeMac(secret, name, run, ts, "refs/pull/12/head") + expectedMac = "9ca8f4cb5e1b083ee8cd215215bc00f379b28511d3ef7930bf054767de34766d" // * Precomputed, anytime the computeMac function changes this needs to be recalculated + require.Equal(t, expectedMac, mac) }) } diff --git a/act/cacheproxy/handler.go b/act/cacheproxy/handler.go index a2d76527..2b2b3707 100644 --- a/act/cacheproxy/handler.go +++ b/act/cacheproxy/handler.go @@ -55,7 +55,7 @@ type RunData struct { } func (h *Handler) CreateRunData(fullName, runNumber, timestamp, writeIsolationKey string) RunData { - mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp) + mac := computeMac(h.cacheSecret, fullName, runNumber, timestamp, writeIsolationKey) return RunData{ RepositoryFullName: fullName, RunNumber: runNumber, @@ -212,12 +212,14 @@ func (h *Handler) Close() error { return retErr } -func computeMac(secret, repo, run, ts string) string { +func computeMac(secret, repo, run, ts, writeIsolationKey string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(repo)) mac.Write([]byte(">")) mac.Write([]byte(run)) mac.Write([]byte(">")) mac.Write([]byte(ts)) + mac.Write([]byte(">")) + mac.Write([]byte(writeIsolationKey)) return hex.EncodeToString(mac.Sum(nil)) }