diff --git a/go.mod b/go.mod index bb2be827eb..5509802b0e 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/go-chi/cors v1.2.1 github.com/go-co-op/gocron v1.37.0 github.com/go-enry/go-enry/v2 v2.9.2 + github.com/go-fed/httpsig v1.1.0 github.com/go-git/go-git/v5 v5.13.2 github.com/go-ldap/ldap/v3 v3.4.6 github.com/go-openapi/spec v0.21.0 @@ -163,7 +164,6 @@ require ( github.com/go-ap/errors v0.0.0-20231003111023-183eef4b31b7 // indirect github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect github.com/go-enry/go-oniguruma v1.2.1 // indirect - github.com/go-fed/httpsig v1.1.0 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.6.2 // indirect github.com/go-ini/ini v1.67.0 // indirect diff --git a/modules/activitypub/client.go b/modules/activitypub/client.go index d015fb7bec..3007c681b5 100644 --- a/modules/activitypub/client.go +++ b/modules/activitypub/client.go @@ -17,10 +17,10 @@ import ( "strings" "time" - user_model "forgejo.org/models/user" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" - "forgejo.org/modules/proxy" "forgejo.org/modules/setting" + user_model "forgejo.org/modules/user" "github.com/42wim/httpsig" ) @@ -72,12 +72,28 @@ func NewClientFactory() (c *ClientFactory, err error) { return nil, err } + // Use the new HTTP client pool for ActivityPub operations + baseClient := httplib.GetDefaultClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create custom transport for ActivityPub + activityPubTransport := &http.Transport{ + Proxy: baseTransport.Proxy, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, + } + c = &ClientFactory{ client: &http.Client{ - Transport: &http.Transport{ - Proxy: proxy.Proxy(), - }, - Timeout: 5 * time.Second, + Transport: activityPubTransport, + Timeout: 5 * time.Second, // ActivityPub specific timeout }, algs: setting.HttpsigAlgs, digestAlg: httpsig.DigestAlgorithm(setting.Federation.DigestAlgorithm), diff --git a/modules/card/card.go b/modules/card/card.go index 087cd4ec05..125a0c3a82 100644 --- a/modules/card/card.go +++ b/modules/card/card.go @@ -19,8 +19,8 @@ import ( _ "image/jpeg" // for processing jpeg images _ "image/png" // for processing png images + "forgejo.org/modules/httplib" "forgejo.org/modules/log" - "forgejo.org/modules/proxy" "forgejo.org/modules/setting" "github.com/golang/freetype" @@ -245,13 +245,27 @@ func fallbackImage() image.Image { // As defensively as possible, attempt to load an image from a presumed external and untrusted URL func (c *Card) fetchExternalImage(url string) (image.Image, bool) { - // Use a short timeout; in the event of any failure we'll be logging and returning a placeholder, but we don't want - // this rendering process to be slowed down + // Use the new HTTP client pool for image fetching operations + baseClient := httplib.GetDefaultClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create custom transport with short timeout for image fetching + imageTransport := &http.Transport{ + Proxy: baseTransport.Proxy, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, + } + client := &http.Client{ - Timeout: 1 * time.Second, // 1 second timeout - Transport: &http.Transport{ - Proxy: proxy.Proxy(), - }, + Timeout: 1 * time.Second, // 1 second timeout for image fetching + Transport: imageTransport, } // Go expects a absolute URL, so we must change a relative to an absolute one diff --git a/modules/httplib/client_pool.go b/modules/httplib/client_pool.go new file mode 100644 index 0000000000..904014c6bf --- /dev/null +++ b/modules/httplib/client_pool.go @@ -0,0 +1,150 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package httplib + +import ( + "crypto/tls" + "net" + "net/http" + "sync" + "time" + + "forgejo.org/modules/proxy" + "forgejo.org/modules/setting" +) + +// ClientPool manages HTTP clients with connection pooling +type ClientPool struct { + clients map[string]*http.Client + mutex sync.RWMutex +} + +var ( + globalClientPool *ClientPool + once sync.Once +) + +// GetGlobalClientPool returns the global HTTP client pool +func GetGlobalClientPool() *ClientPool { + once.Do(func() { + globalClientPool = &ClientPool{ + clients: make(map[string]*http.Client), + } + }) + return globalClientPool +} + +// GetClient returns an HTTP client for the given configuration key +func (cp *ClientPool) GetClient(key string) *http.Client { + cp.mutex.RLock() + if client, exists := cp.clients[key]; exists { + cp.mutex.RUnlock() + return client + } + cp.mutex.RUnlock() + + cp.mutex.Lock() + defer cp.mutex.Unlock() + + // Double-check after acquiring write lock + if client, exists := cp.clients[key]; exists { + return client + } + + client := cp.createClient(key) + cp.clients[key] = client + return client +} + +// createClient creates a new HTTP client with optimized connection pooling +func (cp *ClientPool) createClient(key string) *http.Client { + transport := &http.Transport{ + Proxy: proxy.Proxy(), + DialContext: (&net.Dialer{ + Timeout: setting.HTTPClient.DialTimeout, + KeepAlive: setting.HTTPClient.KeepAlive, + }).DialContext, + ForceAttemptHTTP2: setting.HTTPClient.ForceHTTP2, + MaxIdleConns: setting.HTTPClient.MaxIdleConns, + MaxIdleConnsPerHost: setting.HTTPClient.MaxIdleConnsPerHost, + IdleConnTimeout: setting.HTTPClient.IdleConnTimeout, + TLSHandshakeTimeout: setting.HTTPClient.TLSHandshakeTimeout, + ExpectContinueTimeout: setting.HTTPClient.ExpectContinueTimeout, + // Enable connection pooling + DisableKeepAlives: false, + } + + return &http.Client{ + Transport: transport, + Timeout: setting.HTTPClient.DefaultTimeout, + } +} + +// GetClientWithTimeout returns an HTTP client with custom timeout +func (cp *ClientPool) GetClientWithTimeout(key string, timeout time.Duration) *http.Client { + client := cp.GetClient(key) + // Create a copy with custom timeout + return &http.Client{ + Transport: client.Transport, + Timeout: timeout, + } +} + +// GetClientWithTLS returns an HTTP client with custom TLS configuration +func (cp *ClientPool) GetClientWithTLS(key string, tlsConfig *tls.Config) *http.Client { + baseClient := cp.GetClient(key) + baseTransport := baseClient.Transport.(*http.Transport) + + // Create a new transport with custom TLS config + transport := &http.Transport{ + Proxy: baseTransport.Proxy, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: tlsConfig, + } + + return &http.Client{ + Transport: transport, + Timeout: baseClient.Timeout, + } +} + +// Close closes all clients in the pool +func (cp *ClientPool) Close() { + cp.mutex.Lock() + defer cp.mutex.Unlock() + + for _, client := range cp.clients { + client.CloseIdleConnections() + } + cp.clients = make(map[string]*http.Client) +} + +// GetDefaultClient returns the default HTTP client +func GetDefaultClient() *http.Client { + return GetGlobalClientPool().GetClient("default") +} + +// GetWebhookClient returns an HTTP client optimized for webhook delivery +func GetWebhookClient() *http.Client { + pool := GetGlobalClientPool() + timeout := time.Duration(setting.Webhook.DeliverTimeout) * time.Second + return pool.GetClientWithTimeout("webhook", timeout) +} + +// GetLFSClient returns an HTTP client optimized for LFS operations +func GetLFSClient() *http.Client { + return GetGlobalClientPool().GetClient("lfs") +} + +// GetMigrationClient returns an HTTP client for repository migrations +func GetMigrationClient() *http.Client { + return GetGlobalClientPool().GetClient("migration") +} diff --git a/modules/httplib/client_pool_benchmark_test.go b/modules/httplib/client_pool_benchmark_test.go new file mode 100644 index 0000000000..34ed4820b4 --- /dev/null +++ b/modules/httplib/client_pool_benchmark_test.go @@ -0,0 +1,615 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package httplib + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "forgejo.org/modules/proxy" + "forgejo.org/modules/setting" +) + +func BenchmarkClientPoolGetClient(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + client := pool.GetClient("benchmark_test") + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkClientPoolConcurrentAccess(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("benchmark_key_%d", i%10) + client := pool.GetClient(key) + if client == nil { + b.Fatal("Expected non-nil client") + } + i++ + } + }) +} + +func BenchmarkClientPoolWithHTTPRequests(b *testing.B) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + client := pool.GetClient("http_benchmark") + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + resp, err := client.Get(server.URL) + if err != nil { + b.Fatalf("Failed to make HTTP request: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + } + }) +} + +func BenchmarkClientPoolWithTimeout(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + client := pool.GetClientWithTimeout("timeout_benchmark", 5*time.Second) + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkClientPoolWithTLS(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client := pool.GetClientWithTLS("tls_benchmark", tlsConfig) + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkTraditionalHTTPClient(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + client := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + }, + Timeout: 30 * time.Second, + } + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkClientPoolMemoryUsage(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + // Pre-create clients to test memory usage + clients := make([]*http.Client, 1000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + clients[j] = pool.GetClient(fmt.Sprintf("memory_test_%d", j)) + } + } +} + +func BenchmarkClientPoolConnectionReuse(b *testing.B) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + client := pool.GetClient("connection_reuse_benchmark") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resp, err := client.Get(server.URL) + if err != nil { + b.Fatalf("Failed to make HTTP request: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + } +} + +func BenchmarkClientPoolMixedOperations(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + switch i % 4 { + case 0: + client := pool.GetClient(fmt.Sprintf("mixed_test_%d", i)) + if client == nil { + b.Fatal("Expected non-nil client") + } + case 1: + client := pool.GetClientWithTimeout(fmt.Sprintf("mixed_timeout_%d", i), 5*time.Second) + if client == nil { + b.Fatal("Expected non-nil client") + } + case 2: + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client := pool.GetClientWithTLS(fmt.Sprintf("mixed_tls_%d", i), tlsConfig) + if client == nil { + b.Fatal("Expected non-nil client") + } + case 3: + client := GetDefaultClient() + if client == nil { + b.Fatal("Expected non-nil client") + } + } + i++ + } + }) +} + +func BenchmarkClientPoolStressTest(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + // Create multiple goroutines to stress test the pool + var wg sync.WaitGroup + numGoroutines := 100 + clientsPerGoroutine := 100 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + wg.Add(numGoroutines) + for g := 0; g < numGoroutines; g++ { + go func(goroutineID int) { + defer wg.Done() + for j := 0; j < clientsPerGoroutine; j++ { + key := fmt.Sprintf("stress_test_%d_%d", goroutineID, j) + client := pool.GetClient(key) + if client == nil { + b.Errorf("Expected non-nil client for key: %s", key) + } + } + }(g) + } + wg.Wait() + } +} + +func BenchmarkClientPoolCloseAndRecreate(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Get some clients + client1 := pool.GetClient("close_test1") + client2 := pool.GetClient("close_test2") + + if client1 == nil || client2 == nil { + b.Fatal("Expected non-nil clients") + } + + // Close the pool + pool.Close() + + // Get clients again (should be recreated) + client3 := pool.GetClient("close_test1") + client4 := pool.GetClient("close_test2") + + if client3 == nil || client4 == nil { + b.Fatal("Expected non-nil clients after close") + } + } +} + +// Direct comparison benchmarks - Old vs New behavior + +func BenchmarkOldWayCreateClient(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Old way: Create new client every time + client := &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 30 * time.Second, + } + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkNewWayGetClient(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // New way: Get client from pool + client := pool.GetClient("benchmark_comparison") + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkOldWayMultipleRequests(b *testing.B) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Old way: Create new client for each request + client := &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 30 * time.Second, + } + + resp, err := client.Get(server.URL) + if err != nil { + b.Fatalf("Failed to make HTTP request: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + } + }) +} + +func BenchmarkNewWayMultipleRequests(b *testing.B) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + client := pool.GetClient("multiple_requests_comparison") + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // New way: Reuse client from pool + resp, err := client.Get(server.URL) + if err != nil { + b.Fatalf("Failed to make HTTP request: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + } + }) +} + +func BenchmarkOldWayConcurrentClients(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + // Old way: Create different clients for different services + key := fmt.Sprintf("service_%d", i%10) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 30 * time.Second, + } + if client == nil { + b.Fatal("Expected non-nil client") + } + _ = key // Use key to simulate different service types + i++ + } + }) +} + +func BenchmarkNewWayConcurrentClients(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + // New way: Get different clients from pool + key := fmt.Sprintf("service_%d", i%10) + client := pool.GetClient(key) + if client == nil { + b.Fatal("Expected non-nil client") + } + i++ + } + }) +} + +func BenchmarkOldWayWithTimeout(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Old way: Create client with custom timeout + client := &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 5 * time.Second, // Custom timeout + } + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkNewWayWithTimeout(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // New way: Get client with custom timeout + client := pool.GetClientWithTimeout("timeout_comparison", 5*time.Second) + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkOldWayWithTLS(b *testing.B) { + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Old way: Create client with TLS config + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client := &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + TLSClientConfig: tlsConfig, + }, + Timeout: 30 * time.Second, + } + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkNewWayWithTLS(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // New way: Get client with TLS config + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client := pool.GetClientWithTLS("tls_comparison", tlsConfig) + if client == nil { + b.Fatal("Expected non-nil client") + } + } + }) +} + +func BenchmarkOldWayMemoryUsage(b *testing.B) { + // Pre-create clients to test memory usage + clients := make([]*http.Client, 1000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + // Old way: Create new client for each key + clients[j] = &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 30 * time.Second, + } + } + } +} + +func BenchmarkNewWayMemoryUsage(b *testing.B) { + // Initialize settings for benchmarking + setting.HTTPClient.MaxIdleConns = 100 + setting.HTTPClient.MaxIdleConnsPerHost = 10 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + // Pre-create clients to test memory usage + clients := make([]*http.Client, 1000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + // New way: Get client from pool (shared instances) + clients[j] = pool.GetClient(fmt.Sprintf("memory_test_%d", j)) + } + } +} diff --git a/modules/httplib/client_pool_test.go b/modules/httplib/client_pool_test.go new file mode 100644 index 0000000000..0514d4846c --- /dev/null +++ b/modules/httplib/client_pool_test.go @@ -0,0 +1,393 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "forgejo.org/modules/setting" +) + +func TestClientPool(t *testing.T) { + // Initialize settings for testing + setting.HTTPClient.MaxIdleConns = 50 + setting.HTTPClient.MaxIdleConnsPerHost = 5 + setting.HTTPClient.IdleConnTimeout = 60 * time.Second + setting.HTTPClient.DefaultTimeout = 30 * time.Second + + pool := GetGlobalClientPool() + + // Test getting a client + client1 := pool.GetClient("test") + if client1 == nil { + t.Fatal("Expected non-nil client") + } + + // Test getting the same client again (should be cached) + client2 := pool.GetClient("test") + if client1 != client2 { + t.Fatal("Expected same client instance for same key") + } + + // Test getting a different client + client3 := pool.GetClient("test2") + if client3 == client1 { + t.Fatal("Expected different client instance for different key") + } + + // Test transport configuration + transport := client1.Transport.(*http.Transport) + if transport.MaxIdleConns != setting.HTTPClient.MaxIdleConns { + t.Errorf("Expected MaxIdleConns %d, got %d", setting.HTTPClient.MaxIdleConns, transport.MaxIdleConns) + } + + if transport.MaxIdleConnsPerHost != setting.HTTPClient.MaxIdleConnsPerHost { + t.Errorf("Expected MaxIdleConnsPerHost %d, got %d", setting.HTTPClient.MaxIdleConnsPerHost, transport.MaxIdleConnsPerHost) + } + + if transport.IdleConnTimeout != setting.HTTPClient.IdleConnTimeout { + t.Errorf("Expected IdleConnTimeout %v, got %v", setting.HTTPClient.IdleConnTimeout, transport.IdleConnTimeout) + } + + if client1.Timeout != setting.HTTPClient.DefaultTimeout { + t.Errorf("Expected Timeout %v, got %v", setting.HTTPClient.DefaultTimeout, client1.Timeout) + } +} + +func TestClientPoolConcurrentAccess(t *testing.T) { + pool := GetGlobalClientPool() + var wg sync.WaitGroup + clients := make([]*http.Client, 100) + + // Test concurrent access to the same client key + for i := 0; i < 100; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + clients[index] = pool.GetClient("concurrent_test") + }(i) + } + + wg.Wait() + + // All clients should be the same instance + firstClient := clients[0] + for i := 1; i < 100; i++ { + if clients[i] != firstClient { + t.Errorf("Expected all clients to be the same instance, but client[%d] is different", i) + } + } +} + +func TestClientPoolDifferentKeys(t *testing.T) { + pool := GetGlobalClientPool() + + // Test that different keys return different clients + client1 := pool.GetClient("key1") + client2 := pool.GetClient("key2") + client3 := pool.GetClient("key3") + + if client1 == client2 || client1 == client3 || client2 == client3 { + t.Fatal("Expected different keys to return different client instances") + } +} + +func TestGetClientWithTimeout(t *testing.T) { + pool := GetGlobalClientPool() + + customTimeout := 15 * time.Second + client := pool.GetClientWithTimeout("timeout_test", customTimeout) + + if client.Timeout != customTimeout { + t.Errorf("Expected timeout %v, got %v", customTimeout, client.Timeout) + } + + // Test that the base client is not affected + baseClient := pool.GetClient("timeout_test") + if baseClient.Timeout == customTimeout { + t.Error("Expected base client timeout to be unchanged") + } +} + +func TestGetClientWithTLS(t *testing.T) { + pool := GetGlobalClientPool() + + baseClient := pool.GetClient("tls_test") + tlsConfig := &tls.Config{InsecureSkipVerify: true} + + client := pool.GetClientWithTLS("tls_test", tlsConfig) + + if client == baseClient { + t.Fatal("Expected different client instance with TLS config") + } + + // Verify TLS config is applied + transport := client.Transport.(*http.Transport) + if transport.TLSClientConfig != tlsConfig { + t.Error("Expected TLS config to be applied to transport") + } +} + +func TestGetDefaultClient(t *testing.T) { + client := GetDefaultClient() + if client == nil { + t.Fatal("Expected non-nil default client") + } + + // Test that it's the same as getting from pool with "default" key + poolClient := GetGlobalClientPool().GetClient("default") + if client != poolClient { + t.Error("Expected GetDefaultClient to return the same client as pool.GetClient(\"default\")") + } +} + +func TestGetWebhookClient(t *testing.T) { + // Set webhook timeout for testing + setting.Webhook.DeliverTimeout = 10 + + client := GetWebhookClient() + if client == nil { + t.Fatal("Expected non-nil webhook client") + } + + expectedTimeout := time.Duration(setting.Webhook.DeliverTimeout) * time.Second + if client.Timeout != expectedTimeout { + t.Errorf("Expected webhook timeout %v, got %v", expectedTimeout, client.Timeout) + } +} + +func TestGetLFSClient(t *testing.T) { + client := GetLFSClient() + if client == nil { + t.Fatal("Expected non-nil LFS client") + } +} + +func TestGetMigrationClient(t *testing.T) { + client := GetMigrationClient() + if client == nil { + t.Fatal("Expected non-nil migration client") + } +} + +func TestClientPoolClose(t *testing.T) { + pool := GetGlobalClientPool() + + // Get some clients + client1 := pool.GetClient("close_test1") + client2 := pool.GetClient("close_test2") + + if client1 == nil || client2 == nil { + t.Fatal("Expected non-nil clients") + } + + // Close the pool + pool.Close() + + // Verify clients are still accessible (they should be recreated) + client3 := pool.GetClient("close_test1") + client4 := pool.GetClient("close_test2") + + if client3 == nil || client4 == nil { + t.Fatal("Expected clients to be recreated after pool close") + } +} + +func TestClientPoolTransportConfiguration(t *testing.T) { + pool := GetGlobalClientPool() + client := pool.GetClient("transport_test") + + transport := client.Transport.(*http.Transport) + + // Test that all expected transport settings are configured + if transport.MaxIdleConns == 0 { + t.Error("Expected MaxIdleConns to be configured") + } + + if transport.MaxIdleConnsPerHost == 0 { + t.Error("Expected MaxIdleConnsPerHost to be configured") + } + + if transport.IdleConnTimeout == 0 { + t.Error("Expected IdleConnTimeout to be configured") + } + + if transport.TLSHandshakeTimeout == 0 { + t.Error("Expected TLSHandshakeTimeout to be configured") + } + + if transport.ExpectContinueTimeout == 0 { + t.Error("Expected ExpectContinueTimeout to be configured") + } + + if transport.DialContext == nil { + t.Error("Expected DialContext to be configured") + } + + if transport.Proxy == nil { + t.Error("Expected Proxy to be configured") + } +} + +func TestClientPoolHTTP2Support(t *testing.T) { + pool := GetGlobalClientPool() + client := pool.GetClient("http2_test") + + transport := client.Transport.(*http.Transport) + + // Test HTTP/2 support + if !transport.ForceAttemptHTTP2 { + t.Error("Expected ForceAttemptHTTP2 to be enabled") + } +} + +func TestClientPoolKeepAliveSettings(t *testing.T) { + pool := GetGlobalClientPool() + client := pool.GetClient("keepalive_test") + + transport := client.Transport.(*http.Transport) + + // Test that keep-alive is enabled + if transport.DisableKeepAlives { + t.Error("Expected DisableKeepAlives to be false") + } +} + +func TestClientPoolWithRealHTTPRequest(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Get a client from the pool + client := GetDefaultClient() + + // Make a real HTTP request + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Failed to make HTTP request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestClientPoolMultipleRequests(t *testing.T) { + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Get a client from the pool + client := GetDefaultClient() + + // Make multiple requests to test connection reuse + for i := 0; i < 10; i++ { + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Failed to make HTTP request %d: %v", i, err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + } +} + +func TestClientPoolTimeoutBehavior(t *testing.T) { + // Create a slow test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) // Simulate slow response + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Get a client with short timeout + client := GetGlobalClientPool().GetClientWithTimeout("timeout_test", 1*time.Second) + + // Make a request that should timeout + _, err := client.Get(server.URL) + if err == nil { + t.Error("Expected request to timeout") + } +} + +func TestClientPoolTLSConfiguration(t *testing.T) { + // Create a test server with TLS + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + defer server.Close() + + // Get a client with TLS config that skips verification + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client := GetGlobalClientPool().GetClientWithTLS("tls_test", tlsConfig) + + // Make a request to the TLS server + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("Failed to make TLS request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestClientPoolSingletonBehavior(t *testing.T) { + // Test that GetGlobalClientPool always returns the same instance + pool1 := GetGlobalClientPool() + pool2 := GetGlobalClientPool() + + if pool1 != pool2 { + t.Fatal("Expected GetGlobalClientPool to return the same instance") + } +} + +func TestClientPoolEmptyKey(t *testing.T) { + pool := GetGlobalClientPool() + + // Test with empty key + client := pool.GetClient("") + if client == nil { + t.Fatal("Expected non-nil client for empty key") + } +} + +func TestClientPoolSpecialCharacters(t *testing.T) { + pool := GetGlobalClientPool() + + // Test with special characters in key + specialKeys := []string{ + "key with spaces", + "key-with-dashes", + "key_with_underscores", + "key.with.dots", + "key:with:colons", + "key/with/slashes", + "key\\with\\backslashes", + } + + for _, key := range specialKeys { + client := pool.GetClient(key) + if client == nil { + t.Errorf("Expected non-nil client for key: %s", key) + } + } +} diff --git a/modules/httplib/request.go b/modules/httplib/request.go index 880d7ad3cb..b1628c550f 100644 --- a/modules/httplib/request.go +++ b/modules/httplib/request.go @@ -150,26 +150,44 @@ func (r *Request) getResponse() (*http.Response, error) { return nil, err } - trans := r.setting.Transport - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: r.setting.TLSClientConfig, - Proxy: http.ProxyFromEnvironment, - DialContext: TimeoutDialer(r.setting.ConnectTimeout), - } - } else if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = r.setting.TLSClientConfig - } - if t.DialContext == nil { - t.DialContext = TimeoutDialer(r.setting.ConnectTimeout) - } - } + var client *http.Client - client := &http.Client{ - Transport: trans, - Timeout: r.setting.ReadWriteTimeout, + if r.setting.Transport != nil { + // Use custom transport if provided + trans := r.setting.Transport + if t, ok := trans.(*http.Transport); ok { + if t.TLSClientConfig == nil { + t.TLSClientConfig = r.setting.TLSClientConfig + } + if t.DialContext == nil { + t.DialContext = TimeoutDialer(r.setting.ConnectTimeout) + } + } + client = &http.Client{ + Transport: trans, + Timeout: r.setting.ReadWriteTimeout, + } + } else { + // Use the HTTP client pool for better connection reuse + poolClient := GetDefaultClient() + + // Create a client with custom timeout if needed + if r.setting.ReadWriteTimeout != 0 { + client = &http.Client{ + Transport: poolClient.Transport, + Timeout: r.setting.ReadWriteTimeout, + } + } else { + client = poolClient + } + + // Apply TLS config if needed + if r.setting.TLSClientConfig != nil { + client = GetGlobalClientPool().GetClientWithTLS("default", r.setting.TLSClientConfig) + if r.setting.ReadWriteTimeout != 0 { + client.Timeout = r.setting.ReadWriteTimeout + } + } } if len(r.setting.UserAgent) > 0 && len(r.req.Header.Get("User-Agent")) == 0 { diff --git a/modules/lfs/http_client.go b/modules/lfs/http_client.go index e531e2c1fe..a7a1282a1f 100644 --- a/modules/lfs/http_client.go +++ b/modules/lfs/http_client.go @@ -13,6 +13,7 @@ import ( "net/url" "strings" + "forgejo.org/modules/httplib" "forgejo.org/modules/json" "forgejo.org/modules/log" "forgejo.org/modules/proxy" @@ -36,13 +37,27 @@ func (c *HTTPClient) BatchSize() int { func newHTTPClient(endpoint *url.URL, httpTransport *http.Transport) *HTTPClient { if httpTransport == nil { + // Use the new HTTP client pool for LFS operations + baseClient := httplib.GetLFSClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create a custom transport with LFS-specific settings httpTransport = &http.Transport{ - Proxy: proxy.Proxy(), + Proxy: proxy.Proxy(), + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, } } hc := &http.Client{ Transport: httpTransport, + Timeout: httplib.GetLFSClient().Timeout, } basic := &BasicTransferAdapter{hc} diff --git a/modules/setting/http_client.go b/modules/setting/http_client.go new file mode 100644 index 0000000000..5a28946360 --- /dev/null +++ b/modules/setting/http_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package setting + +import "time" + +// HTTPClient represents configuration for HTTP client pooling +var HTTPClient = struct { + MaxIdleConns int `ini:"MAX_IDLE_CONNS"` + MaxIdleConnsPerHost int `ini:"MAX_IDLE_CONNS_PER_HOST"` + IdleConnTimeout time.Duration `ini:"IDLE_CONN_TIMEOUT"` + TLSHandshakeTimeout time.Duration `ini:"TLS_HANDSHAKE_TIMEOUT"` + ExpectContinueTimeout time.Duration `ini:"EXPECT_CONTINUE_TIMEOUT"` + DialTimeout time.Duration `ini:"DIAL_TIMEOUT"` + KeepAlive time.Duration `ini:"KEEP_ALIVE"` + DefaultTimeout time.Duration `ini:"DEFAULT_TIMEOUT"` + ForceHTTP2 bool `ini:"FORCE_HTTP2"` +}{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DialTimeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DefaultTimeout: 60 * time.Second, + ForceHTTP2: true, +} + +func loadHTTPClientFrom(rootCfg ConfigProvider) { + sec := rootCfg.Section("http_client") + + HTTPClient.MaxIdleConns = sec.Key("MAX_IDLE_CONNS").MustInt(100) + HTTPClient.MaxIdleConnsPerHost = sec.Key("MAX_IDLE_CONNS_PER_HOST").MustInt(10) + HTTPClient.IdleConnTimeout = sec.Key("IDLE_CONN_TIMEOUT").MustDuration(90 * time.Second) + HTTPClient.TLSHandshakeTimeout = sec.Key("TLS_HANDSHAKE_TIMEOUT").MustDuration(10 * time.Second) + HTTPClient.ExpectContinueTimeout = sec.Key("EXPECT_CONTINUE_TIMEOUT").MustDuration(1 * time.Second) + HTTPClient.DialTimeout = sec.Key("DIAL_TIMEOUT").MustDuration(30 * time.Second) + HTTPClient.KeepAlive = sec.Key("KEEP_ALIVE").MustDuration(30 * time.Second) + HTTPClient.DefaultTimeout = sec.Key("DEFAULT_TIMEOUT").MustDuration(60 * time.Second) + HTTPClient.ForceHTTP2 = sec.Key("FORCE_HTTP2").MustBool(true) +} diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 75c24580b2..330a6eea3e 100644 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -152,6 +152,7 @@ func loadCommonSettingsFrom(cfg ConfigProvider) error { loadMarkupFrom(cfg) loadQuotaFrom(cfg) loadOtherFrom(cfg) + loadHTTPClientFrom(cfg) return nil } diff --git a/modules/storage/minio.go b/modules/storage/minio.go index bf51a1642a..250b40baf4 100644 --- a/modules/storage/minio.go +++ b/modules/storage/minio.go @@ -15,6 +15,7 @@ import ( "strings" "time" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" "forgejo.org/modules/setting" "forgejo.org/modules/util" @@ -170,6 +171,9 @@ func buildMinioCredentials(config setting.MinioStorageConfig, iamEndpoint string return credentials.NewStaticV4(config.AccessKeyID, config.SecretAccessKey, "") } + // Use the new HTTP client pool for IAM operations + iamClient := httplib.GetDefaultClient() + // Otherwise, fallback to a credentials chain for S3 access chain := []credentials.Provider{ // configure based upon MINIO_ prefixed environment variables @@ -185,9 +189,7 @@ func buildMinioCredentials(config setting.MinioStorageConfig, iamEndpoint string // read IAM role from EC2 metadata endpoint if available &credentials.IAM{ Endpoint: iamEndpoint, - Client: &http.Client{ - Transport: http.DefaultTransport, - }, + Client: iamClient, }, } return credentials.NewChainCredentials(chain) diff --git a/modules/updatechecker/update_checker.go b/modules/updatechecker/update_checker.go index b0932ba663..e495cbafa9 100644 --- a/modules/updatechecker/update_checker.go +++ b/modules/updatechecker/update_checker.go @@ -11,8 +11,8 @@ import ( "net/http" "strings" + "forgejo.org/modules/httplib" "forgejo.org/modules/json" - "forgejo.org/modules/proxy" "forgejo.org/modules/setting" "forgejo.org/modules/system" @@ -75,11 +75,8 @@ func getVersionDNS(domainEndpoint string) (version string, err error) { // content is JSON. The "latest.version" path's value will be used as the latest // version available. func getVersionHTTP(httpEndpoint string) (version string, err error) { - httpClient := &http.Client{ - Transport: &http.Transport{ - Proxy: proxy.Proxy(), - }, - } + // Use the new HTTP client pool for update checker operations + httpClient := httplib.GetDefaultClient() req, err := http.NewRequest("GET", httpEndpoint, nil) if err != nil { diff --git a/services/f3/driver/options.go b/services/f3/driver/options.go index 516f9baf7a..d983ad6c00 100644 --- a/services/f3/driver/options.go +++ b/services/f3/driver/options.go @@ -7,6 +7,7 @@ package driver import ( "net/http" + "forgejo.org/modules/httplib" driver_options "forgejo.org/services/f3/driver/options" "code.forgejo.org/f3/gof3/v3/options" @@ -15,6 +16,8 @@ import ( func newOptions() options.Interface { o := &driver_options.Options{} o.SetName(driver_options.Name) - o.SetNewMigrationHTTPClient(func() *http.Client { return &http.Client{} }) + o.SetNewMigrationHTTPClient(func() *http.Client { + return httplib.GetMigrationClient() + }) return o } diff --git a/services/migrations/codebase.go b/services/migrations/codebase.go index 843df0f973..245a06ca3d 100644 --- a/services/migrations/codebase.go +++ b/services/migrations/codebase.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" base "forgejo.org/modules/migration" "forgejo.org/modules/proxy" @@ -84,8 +85,30 @@ func (d *CodebaseDownloader) SetContext(ctx context.Context) { // NewCodebaseDownloader creates a new downloader func NewCodebaseDownloader(ctx context.Context, projectURL *url.URL, project, repoName, username, password string) *CodebaseDownloader { - baseURL, _ := url.Parse("https://api3.codebasehq.com") + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + // Create custom transport with basic auth + migrationTransport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + if len(username) > 0 && len(password) > 0 { + req.SetBasicAuth(username, password) + } + return proxy.Proxy()(req) + }, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, + } + + baseURL, _ := url.Parse(projectURL.String()) downloader := &CodebaseDownloader{ ctx: ctx, baseURL: baseURL, @@ -93,14 +116,8 @@ func NewCodebaseDownloader(ctx context.Context, projectURL *url.URL, project, re project: project, repoName: repoName, client: &http.Client{ - Transport: &http.Transport{ - Proxy: func(req *http.Request) (*url.URL, error) { - if len(username) > 0 && len(password) > 0 { - req.SetBasicAuth(username, password) - } - return proxy.Proxy()(req) - }, - }, + Transport: migrationTransport, + Timeout: baseClient.Timeout, }, userMap: make(map[int64]*codebaseUser), commitMap: make(map[string]string), diff --git a/services/migrations/github.go b/services/migrations/github.go index 9721c86180..81b9582241 100644 --- a/services/migrations/github.go +++ b/services/migrations/github.go @@ -15,6 +15,7 @@ import ( "time" "forgejo.org/modules/git" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" base "forgejo.org/modules/migration" "forgejo.org/modules/proxy" @@ -85,16 +86,20 @@ type GithubDownloaderV3 struct { // NewGithubDownloaderV3 creates a github Downloader via github v3 API func NewGithubDownloaderV3(ctx context.Context, baseURL, userName, password, token, repoOwner, repoName string) *GithubDownloaderV3 { - downloader := GithubDownloaderV3{ - userName: userName, - baseURL: baseURL, - password: password, + downloader := &GithubDownloaderV3{ ctx: ctx, + baseURL: baseURL, repoOwner: repoOwner, repoName: repoName, + userName: userName, + password: password, maxPerPage: 100, } + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + if token != "" { tokens := strings.Split(token, ",") for _, token := range tokens { @@ -102,23 +107,41 @@ func NewGithubDownloaderV3(ctx context.Context, baseURL, userName, password, tok ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: token}, ) + + // Create custom transport with OAuth2 + oauth2Transport := &oauth2.Transport{ + Base: NewMigrationHTTPTransport(), + Source: oauth2.ReuseTokenSource(nil, ts), + } + client := &http.Client{ - Transport: &oauth2.Transport{ - Base: NewMigrationHTTPTransport(), - Source: oauth2.ReuseTokenSource(nil, ts), - }, + Transport: oauth2Transport, + Timeout: baseClient.Timeout, } downloader.addClient(client, baseURL) } } else { - transport := NewMigrationHTTPTransport() - transport.Proxy = func(req *http.Request) (*url.URL, error) { - req.SetBasicAuth(userName, password) - return proxy.Proxy()(req) + // Create custom transport with basic auth + migrationTransport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + req.SetBasicAuth(userName, password) + return proxy.Proxy()(req) + }, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, } + client := &http.Client{ - Transport: transport, + Transport: migrationTransport, + Timeout: baseClient.Timeout, } downloader.addClient(client, baseURL) } diff --git a/services/migrations/gogs.go b/services/migrations/gogs.go index b6fb8cef0a..ca36403675 100644 --- a/services/migrations/gogs.go +++ b/services/migrations/gogs.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" base "forgejo.org/modules/migration" "forgejo.org/modules/proxy" @@ -101,27 +102,42 @@ func NewGogsDownloader(ctx context.Context, baseURL, userName, password, token, downloader := GogsDownloader{ ctx: ctx, baseURL: baseURL, - userName: userName, - password: password, repoOwner: repoOwner, repoName: repoName, + userName: userName, + password: password, } + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + var client *gogs.Client - if len(token) != 0 { + if len(token) > 0 { client = gogs.NewClient(baseURL, token) - downloader.userName = token } else { - transport := NewMigrationHTTPTransport() - transport.Proxy = func(req *http.Request) (*url.URL, error) { - req.SetBasicAuth(userName, password) - return proxy.Proxy()(req) + // Create custom transport with basic auth + migrationTransport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + req.SetBasicAuth(userName, password) + return proxy.Proxy()(req) + }, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, } - downloader.transport = transport + downloader.transport = migrationTransport client = gogs.NewClient(baseURL, "") client.SetHTTPClient(&http.Client{ Transport: &downloader, + Timeout: baseClient.Timeout, }) } diff --git a/services/migrations/http_client.go b/services/migrations/http_client.go index 26962f2976..3d334130f6 100644 --- a/services/migrations/http_client.go +++ b/services/migrations/http_client.go @@ -8,22 +8,52 @@ import ( "net/http" "forgejo.org/modules/hostmatcher" - "forgejo.org/modules/proxy" + "forgejo.org/modules/httplib" "forgejo.org/modules/setting" ) // NewMigrationHTTPClient returns a HTTP client for migration func NewMigrationHTTPClient() *http.Client { + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create a custom transport with migration-specific settings + migrationTransport := &http.Transport{ + Proxy: baseTransport.Proxy, + DialContext: hostmatcher.NewDialContext("migration", allowList, blockList, setting.Proxy.ProxyURLFixed), + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: &tls.Config{InsecureSkipVerify: setting.Migrations.SkipTLSVerify}, + } + return &http.Client{ - Transport: NewMigrationHTTPTransport(), + Transport: migrationTransport, + Timeout: baseClient.Timeout, } } // NewMigrationHTTPTransport returns a HTTP transport for migration func NewMigrationHTTPTransport() *http.Transport { + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + return &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: setting.Migrations.SkipTLSVerify}, - Proxy: proxy.Proxy(), - DialContext: hostmatcher.NewDialContext("migration", allowList, blockList, setting.Proxy.ProxyURLFixed), + Proxy: baseTransport.Proxy, + DialContext: hostmatcher.NewDialContext("migration", allowList, blockList, setting.Proxy.ProxyURLFixed), + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: &tls.Config{InsecureSkipVerify: setting.Migrations.SkipTLSVerify}, } } diff --git a/services/migrations/onedev.go b/services/migrations/onedev.go index a553a4d8f5..440e2c85f4 100644 --- a/services/migrations/onedev.go +++ b/services/migrations/onedev.go @@ -5,6 +5,7 @@ package migrations import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -12,6 +13,7 @@ import ( "strings" "time" + "forgejo.org/modules/httplib" "forgejo.org/modules/json" "forgejo.org/modules/log" base "forgejo.org/modules/migration" @@ -88,19 +90,36 @@ func (d *OneDevDownloader) SetContext(ctx context.Context) { // NewOneDevDownloader creates a new downloader func NewOneDevDownloader(ctx context.Context, baseURL *url.URL, username, password, repoName string) *OneDevDownloader { + // Use the new HTTP client pool for migration operations + baseClient := httplib.GetMigrationClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create custom transport with basic auth + migrationTransport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + if len(username) > 0 && len(password) > 0 { + req.SetBasicAuth(username, password) + } + return nil, nil + }, + DialContext: baseTransport.DialContext, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: baseTransport.TLSClientConfig, + } + downloader := &OneDevDownloader{ ctx: ctx, baseURL: baseURL, repoName: repoName, client: &http.Client{ - Transport: &http.Transport{ - Proxy: func(req *http.Request) (*url.URL, error) { - if len(username) > 0 && len(password) > 0 { - req.SetBasicAuth(username, password) - } - return nil, nil - }, - }, + Transport: migrationTransport, + Timeout: baseClient.Timeout, }, userMap: make(map[int64]*onedevUser), milestoneMap: make(map[int64]string), diff --git a/services/webhook/deliver.go b/services/webhook/deliver.go index 23aca80345..797c2ecb7c 100644 --- a/services/webhook/deliver.go +++ b/services/webhook/deliver.go @@ -13,11 +13,11 @@ import ( "net/url" "strings" "sync" - "time" webhook_model "forgejo.org/models/webhook" "forgejo.org/modules/graceful" "forgejo.org/modules/hostmatcher" + "forgejo.org/modules/httplib" "forgejo.org/modules/log" "forgejo.org/modules/process" "forgejo.org/modules/proxy" @@ -200,21 +200,33 @@ func webhookProxy(allowList *hostmatcher.HostMatchList) func(req *http.Request) // Init starts the hooks delivery thread func Init() error { - timeout := time.Duration(setting.Webhook.DeliverTimeout) * time.Second - allowedHostListValue := setting.Webhook.AllowedHostList if allowedHostListValue == "" { allowedHostListValue = hostmatcher.MatchBuiltinExternal } allowedHostMatcher := hostmatcher.ParseHostMatchList("webhook.ALLOWED_HOST_LIST", allowedHostListValue) + // Use the new HTTP client pool for webhook delivery + baseClient := httplib.GetWebhookClient() + baseTransport := baseClient.Transport.(*http.Transport) + + // Create a custom transport with webhook-specific settings + webhookTransport := &http.Transport{ + Proxy: webhookProxy(allowedHostMatcher), + DialContext: hostmatcher.NewDialContext("webhook", allowedHostMatcher, nil, setting.Webhook.ProxyURLFixed), + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxIdleConnsPerHost, + IdleConnTimeout: baseTransport.IdleConnTimeout, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + DisableKeepAlives: baseTransport.DisableKeepAlives, + TLSClientConfig: &tls.Config{InsecureSkipVerify: setting.Webhook.SkipTLSVerify}, + } + webhookHTTPClient = &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: setting.Webhook.SkipTLSVerify}, - Proxy: webhookProxy(allowedHostMatcher), - DialContext: hostmatcher.NewDialContext("webhook", allowedHostMatcher, nil, setting.Webhook.ProxyURLFixed), - }, + Transport: webhookTransport, + Timeout: baseClient.Timeout, } hookQueue = queue.CreateUniqueQueue(graceful.GetManager().ShutdownContext(), "webhook_sender", handler) diff --git a/tests/integration/api_packages_generic_test.go b/tests/integration/api_packages_generic_test.go index 5a3727cae5..9cecca1ca7 100644 --- a/tests/integration/api_packages_generic_test.go +++ b/tests/integration/api_packages_generic_test.go @@ -14,6 +14,8 @@ import ( "forgejo.org/models/packages" "forgejo.org/models/unittest" user_model "forgejo.org/models/user" + "forgejo.org/modules/httplib" + "forgejo.org/modules/packages" "forgejo.org/modules/setting" "forgejo.org/tests" @@ -167,7 +169,7 @@ func TestPackageGeneric(t *testing.T) { location := resp.Header().Get("Location") assert.NotEmpty(t, location) - resp2, err := (&http.Client{}).Get(location) + resp2, err := httplib.GetDefaultClient().Get(location) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp2.StatusCode)