diff --git a/internal/app/cmd/cmd.go b/internal/app/cmd/cmd.go index e8db01de..aada3bbe 100644 --- a/internal/app/cmd/cmd.go +++ b/internal/app/cmd/cmd.go @@ -10,7 +10,6 @@ import ( "github.com/spf13/cobra" - "code.forgejo.org/forgejo/runner/v11/internal/pkg/common" "code.forgejo.org/forgejo/runner/v11/internal/pkg/config" "code.forgejo.org/forgejo/runner/v11/internal/pkg/ver" ) @@ -46,7 +45,7 @@ func Execute(ctx context.Context) { Use: "daemon", Short: "Run as a runner daemon", Args: cobra.MaximumNArgs(1), - RunE: runDaemon(common.WithDaemonContext(ctx, ctx), &configFile), + RunE: getRunDaemonCommandProcessor(ctx, &configFile), } rootCmd.AddCommand(daemonCmd) diff --git a/internal/app/cmd/daemon.go b/internal/app/cmd/daemon.go index cb3d278a..d864af82 100644 --- a/internal/app/cmd/daemon.go +++ b/internal/app/cmd/daemon.go @@ -21,119 +21,82 @@ import ( "code.forgejo.org/forgejo/runner/v11/internal/app/poll" "code.forgejo.org/forgejo/runner/v11/internal/app/run" "code.forgejo.org/forgejo/runner/v11/internal/pkg/client" + "code.forgejo.org/forgejo/runner/v11/internal/pkg/common" "code.forgejo.org/forgejo/runner/v11/internal/pkg/config" "code.forgejo.org/forgejo/runner/v11/internal/pkg/envcheck" "code.forgejo.org/forgejo/runner/v11/internal/pkg/labels" "code.forgejo.org/forgejo/runner/v11/internal/pkg/ver" ) -func runDaemon(ctx context.Context, configFile *string) func(cmd *cobra.Command, args []string) error { +func getRunDaemonCommandProcessor(signalContext context.Context, configFile *string) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { - cfg, err := config.LoadDefault(*configFile) - if err != nil { - return fmt.Errorf("invalid configuration: %w", err) - } - - initLogging(cfg) - log.Infoln("Starting runner daemon") - - reg, err := config.LoadRegistration(cfg.Runner.File) - if os.IsNotExist(err) { - log.Error("registration file not found, please register the runner first") - return err - } else if err != nil { - return fmt.Errorf("failed to load registration file: %w", err) - } - - cfg.Tune(reg.Address) - - lbls := reg.Labels - if len(cfg.Runner.Labels) > 0 { - lbls = cfg.Runner.Labels - } - - ls := labels.Labels{} - for _, l := range lbls { - label, err := labels.Parse(l) - if err != nil { - log.WithError(err).Warnf("ignored invalid label %q", l) - continue - } - ls = append(ls, label) - } - if len(ls) == 0 { - log.Warn("no labels configured, runner may not be able to pick up jobs") - } - - if ls.RequireDocker() { - dockerSocketPath, err := getDockerSocketPath(cfg.Container.DockerHost) - if err != nil { - return err - } - if err := envcheck.CheckIfDockerRunning(ctx, dockerSocketPath); err != nil { - return err - } - os.Setenv("DOCKER_HOST", dockerSocketPath) - if cfg.Container.DockerHost == "automount" { - cfg.Container.DockerHost = dockerSocketPath - } - // check the scheme, if the scheme is not npipe or unix - // set cfg.Container.DockerHost to "-" because it can't be mounted to the job container - if protoIndex := strings.Index(cfg.Container.DockerHost, "://"); protoIndex != -1 { - scheme := cfg.Container.DockerHost[:protoIndex] - if !strings.EqualFold(scheme, "npipe") && !strings.EqualFold(scheme, "unix") { - cfg.Container.DockerHost = "-" - } - } - } - - cli := client.New( - reg.Address, - cfg.Runner.Insecure, - reg.UUID, - reg.Token, - ver.Version(), - ) - - runner := run.NewRunner(cfg, reg, cli) - // declare the labels of the runner before fetching tasks - resp, err := runner.Declare(ctx, ls.Names()) - if err != nil && connect.CodeOf(err) == connect.CodeUnimplemented { - log.Warn("Because the Forgejo instance is an old version, skipping declaring the labels and version.") - } else if err != nil { - log.WithError(err).Error("fail to invoke Declare") - return err - } else { - log.Infof("runner: %s, with version: %s, with labels: %v, declared successfully", - resp.Msg.GetRunner().GetName(), resp.Msg.GetRunner().GetVersion(), resp.Msg.GetRunner().GetLabels()) - // if declared successfully, override the labels in the.runner file with valid labels in the config file (if specified) - runner.Update(ctx, ls) - reg.Labels = ls.ToStrings() - if err := config.SaveRegistration(cfg.Runner.File, reg); err != nil { - return fmt.Errorf("failed to save runner config: %w", err) - } - } - - poller := poll.New(ctx, cfg, cli, runner) - - go poller.Poll() - - <-ctx.Done() - log.Infof("runner: %s shutdown initiated, waiting [runner].shutdown_timeout=%s for running jobs to complete before shutting down", resp.Msg.GetRunner().GetName(), cfg.Runner.ShutdownTimeout) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.Runner.ShutdownTimeout) - defer cancel() - - err = poller.Shutdown(ctx) - if err != nil { - log.Warnf("runner: %s cancelled in progress jobs during shutdown", resp.Msg.GetRunner().GetName()) - } - return nil + return runDaemon(signalContext, configFile) } } +func runDaemon(signalContext context.Context, configFile *string) error { + // signalContext will be 'done' when we receive a graceful shutdown signal; daemonContext is not a derived context + // because we want it to 'outlive' the signalContext in order to perform graceful cleanup. + daemonContext, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctx := common.WithDaemonContext(daemonContext, daemonContext) + + cfg, err := initializeConfig(configFile) + if err != nil { + return err + } + + initLogging(cfg) + log.Infoln("Starting runner daemon") + + reg, err := loadRegistration(cfg) + if err != nil { + return err + } + + cfg.Tune(reg.Address) + ls := extractLabels(cfg, reg) + + err = configCheck(ctx, cfg, ls) + if err != nil { + return err + } + + cli := createClient(cfg, reg) + + runner, runnerName, err := createRunner(ctx, cfg, reg, cli, ls) + if err != nil { + return err + } + + poller := createPoller(ctx, cfg, cli, runner) + + go poller.Poll() + + <-signalContext.Done() + log.Infof("runner: %s shutdown initiated, waiting [runner].shutdown_timeout=%s for running jobs to complete before shutting down", runnerName, cfg.Runner.ShutdownTimeout) + + shutdownCtx, cancel := context.WithTimeout(daemonContext, cfg.Runner.ShutdownTimeout) + defer cancel() + + err = poller.Shutdown(shutdownCtx) + if err != nil { + log.Warnf("runner: %s cancelled in progress jobs during shutdown", runnerName) + } + return nil +} + +var initializeConfig = func(configFile *string) (*config.Config, error) { + cfg, err := config.LoadDefault(*configFile) + if err != nil { + return nil, fmt.Errorf("invalid configuration: %w", err) + } + return cfg, nil +} + // initLogging setup the global logrus logger. -func initLogging(cfg *config.Config) { +var initLogging = func(cfg *config.Config) { isTerm := isatty.IsTerminal(os.Stdout.Fd()) format := &log.TextFormatter{ DisableColors: !isTerm, @@ -170,6 +133,98 @@ func initLogging(cfg *config.Config) { } } +var loadRegistration = func(cfg *config.Config) (*config.Registration, error) { + reg, err := config.LoadRegistration(cfg.Runner.File) + if os.IsNotExist(err) { + log.Error("registration file not found, please register the runner first") + return nil, err + } else if err != nil { + return nil, fmt.Errorf("failed to load registration file: %w", err) + } + return reg, nil +} + +var extractLabels = func(cfg *config.Config, reg *config.Registration) labels.Labels { + lbls := reg.Labels + if len(cfg.Runner.Labels) > 0 { + lbls = cfg.Runner.Labels + } + + ls := labels.Labels{} + for _, l := range lbls { + label, err := labels.Parse(l) + if err != nil { + log.WithError(err).Warnf("ignored invalid label %q", l) + continue + } + ls = append(ls, label) + } + if len(ls) == 0 { + log.Warn("no labels configured, runner may not be able to pick up jobs") + } + return ls +} + +var configCheck = func(ctx context.Context, cfg *config.Config, ls labels.Labels) error { + if ls.RequireDocker() { + dockerSocketPath, err := getDockerSocketPath(cfg.Container.DockerHost) + if err != nil { + return err + } + if err := envcheck.CheckIfDockerRunning(ctx, dockerSocketPath); err != nil { + return err + } + os.Setenv("DOCKER_HOST", dockerSocketPath) + if cfg.Container.DockerHost == "automount" { + cfg.Container.DockerHost = dockerSocketPath + } + // check the scheme, if the scheme is not npipe or unix + // set cfg.Container.DockerHost to "-" because it can't be mounted to the job container + if protoIndex := strings.Index(cfg.Container.DockerHost, "://"); protoIndex != -1 { + scheme := cfg.Container.DockerHost[:protoIndex] + if !strings.EqualFold(scheme, "npipe") && !strings.EqualFold(scheme, "unix") { + cfg.Container.DockerHost = "-" + } + } + } + return nil +} + +var createClient = func(cfg *config.Config, reg *config.Registration) client.Client { + return client.New( + reg.Address, + cfg.Runner.Insecure, + reg.UUID, + reg.Token, + ver.Version(), + ) +} + +var createRunner = func(ctx context.Context, cfg *config.Config, reg *config.Registration, cli client.Client, ls labels.Labels) (run.RunnerInterface, string, error) { + runner := run.NewRunner(cfg, reg, cli) + // declare the labels of the runner before fetching tasks + resp, err := runner.Declare(ctx, ls.Names()) + if err != nil && connect.CodeOf(err) == connect.CodeUnimplemented { + log.Warn("Because the Forgejo instance is an old version, skipping declaring the labels and version.") + } else if err != nil { + log.WithError(err).Error("fail to invoke Declare") + return nil, "", err + } else { + log.Infof("runner: %s, with version: %s, with labels: %v, declared successfully", + resp.Msg.GetRunner().GetName(), resp.Msg.GetRunner().GetVersion(), resp.Msg.GetRunner().GetLabels()) + // if declared successfully, override the labels in the.runner file with valid labels in the config file (if specified) + runner.Update(ctx, ls) + reg.Labels = ls.ToStrings() + if err := config.SaveRegistration(cfg.Runner.File, reg); err != nil { + return nil, "", fmt.Errorf("failed to save runner config: %w", err) + } + } + return runner, resp.Msg.GetRunner().GetName(), nil +} + +// func(ctx context.Context, cfg *config.Config, cli client.Client, runner run.RunnerInterface) poll.Poller +var createPoller = poll.New + var commonSocketPaths = []string{ "/var/run/docker.sock", "/run/podman/podman.sock", diff --git a/internal/app/cmd/daemon_test.go b/internal/app/cmd/daemon_test.go new file mode 100644 index 00000000..16ce8e50 --- /dev/null +++ b/internal/app/cmd/daemon_test.go @@ -0,0 +1,117 @@ +// Copyright 2025 The Forgejo Authors +// SPDX-License-Identifier: MIT + +package cmd + +import ( + "context" + "testing" + "time" + + "code.forgejo.org/forgejo/runner/v11/internal/app/poll" + mock_poller "code.forgejo.org/forgejo/runner/v11/internal/app/poll/mocks" + "code.forgejo.org/forgejo/runner/v11/internal/app/run" + mock_runner "code.forgejo.org/forgejo/runner/v11/internal/app/run/mocks" + "code.forgejo.org/forgejo/runner/v11/internal/pkg/client" + mock_client "code.forgejo.org/forgejo/runner/v11/internal/pkg/client/mocks" + "code.forgejo.org/forgejo/runner/v11/internal/pkg/config" + "code.forgejo.org/forgejo/runner/v11/internal/pkg/labels" + "code.forgejo.org/forgejo/runner/v11/testutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRunDaemonGracefulShutdown(t *testing.T) { + // Key assertions for graceful shutdown test: + // + // - ctx passed to createRunner, createPoller, and Shutdown must outlive signalContext passed to runDaemon, allowing + // the poller to operate without errors after termination signal is received: #1 + // + // - When shutting down, the order of operations should be: close signalContext, which causes Shutdown mock to be + // invoked, and Shutdown mock causes the Poll method to be stopped: #2 + + mockClient := mock_client.NewClient(t) + mockRunner := mock_runner.NewRunnerInterface(t) + mockPoller := mock_poller.NewPoller(t) + + defer testutils.MockVariable(&initializeConfig, func(configFile *string) (*config.Config, error) { + return &config.Config{ + Runner: config.Runner{ + // Default ShutdownTimeout of 0s won't work for the graceful shutdown test. + ShutdownTimeout: 30 * time.Second, + }, + }, nil + })() + defer testutils.MockVariable(&initLogging, func(cfg *config.Config) {})() + defer testutils.MockVariable(&loadRegistration, func(cfg *config.Config) (*config.Registration, error) { + return &config.Registration{}, nil + })() + defer testutils.MockVariable(&extractLabels, func(cfg *config.Config, reg *config.Registration) labels.Labels { + return labels.Labels{} + })() + defer testutils.MockVariable(&configCheck, func(ctx context.Context, cfg *config.Config, ls labels.Labels) error { + return nil + })() + defer testutils.MockVariable(&createClient, func(cfg *config.Config, reg *config.Registration) client.Client { + return mockClient + })() + var runnerContext context.Context + defer testutils.MockVariable(&createRunner, func(ctx context.Context, cfg *config.Config, reg *config.Registration, cli client.Client, ls labels.Labels) (run.RunnerInterface, string, error) { + runnerContext = ctx + return mockRunner, "runner", nil + })() + var pollerContext context.Context + defer testutils.MockVariable(&createPoller, func(ctx context.Context, cfg *config.Config, cli client.Client, runner run.RunnerInterface) poll.Poller { + pollerContext = ctx + return mockPoller + })() + + pollBegunChannel := make(chan interface{}) + shutdownChannel := make(chan interface{}) + mockPoller.On("Poll").Run(func(args mock.Arguments) { + close(pollBegunChannel) + // Simulate running the poll by waiting and doing nothing until shutdownChannel says Shutdown was invoked + require.NotNil(t, pollerContext) + select { + case <-pollerContext.Done(): + assert.Fail(t, "pollerContext was closed before shutdownChannel") // #1 + return + case <-shutdownChannel: + return + } + }) + mockPoller.On("Shutdown", mock.Anything).Run(func(args mock.Arguments) { + shutdownContext := args.Get(0).(context.Context) + select { + case <-shutdownContext.Done(): + assert.Fail(t, "shutdownContext was closed, but was expected to be open") // #1 + return + case <-runnerContext.Done(): + assert.Fail(t, "runnerContext was closed, but was expected to be open") // #1 + return + case <-time.After(time.Microsecond): + close(shutdownChannel) + return + } + }).Return(nil) + + // When runDaemon is begun, it will run "forever" until the passed-in context is done. So, let's start that in a goroutine... + mockSignalContext, cancelSignal := context.WithCancel(t.Context()) + runDaemonComplete := make(chan interface{}) + go func() { + configFile := "config.yaml" + err := runDaemon(mockSignalContext, &configFile) + close(runDaemonComplete) + require.NoError(t, err) + }() + + // Wait until runDaemon reaches poller.Poll(), where we expect graceful shutdown to trigger + <-pollBegunChannel + + // Now we'll signal to the daemon to begin graceful shutdown; this begins the events described in #2 + cancelSignal() + + // Wait for the daemon goroutine to stop + <-runDaemonComplete +} diff --git a/internal/app/poll/mocks/Poller.go b/internal/app/poll/mocks/Poller.go new file mode 100644 index 00000000..9d2b1e4d --- /dev/null +++ b/internal/app/poll/mocks/Poller.go @@ -0,0 +1,52 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Poller is an autogenerated mock type for the Poller type +type Poller struct { + mock.Mock +} + +// Poll provides a mock function with no fields +func (_m *Poller) Poll() { + _m.Called() +} + +// Shutdown provides a mock function with given fields: ctx +func (_m *Poller) Shutdown(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Shutdown") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewPoller creates a new instance of Poller. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPoller(t interface { + mock.TestingT + Cleanup(func()) +}, +) *Poller { + mock := &Poller{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/app/poll/poller.go b/internal/app/poll/poller.go index 452f356d..7b954cb7 100644 --- a/internal/app/poll/poller.go +++ b/internal/app/poll/poller.go @@ -22,6 +22,7 @@ import ( const PollerID = "PollerID" +//go:generate mockery --name Poller type Poller interface { Poll() Shutdown(ctx context.Context) error diff --git a/internal/app/run/mocks/RunnerInterface.go b/internal/app/run/mocks/RunnerInterface.go new file mode 100644 index 00000000..16dbfecb --- /dev/null +++ b/internal/app/run/mocks/RunnerInterface.go @@ -0,0 +1,49 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + runnerv1 "code.forgejo.org/forgejo/actions-proto/runner/v1" +) + +// RunnerInterface is an autogenerated mock type for the RunnerInterface type +type RunnerInterface struct { + mock.Mock +} + +// Run provides a mock function with given fields: ctx, task +func (_m *RunnerInterface) Run(ctx context.Context, task *runnerv1.Task) error { + ret := _m.Called(ctx, task) + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *runnerv1.Task) error); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewRunnerInterface creates a new instance of RunnerInterface. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRunnerInterface(t interface { + mock.TestingT + Cleanup(func()) +}, +) *RunnerInterface { + mock := &RunnerInterface{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/app/run/runner.go b/internal/app/run/runner.go index 98e1c524..3f2b0127 100644 --- a/internal/app/run/runner.go +++ b/internal/app/run/runner.go @@ -48,6 +48,7 @@ type Runner struct { runningTasks sync.Map } +//go:generate mockery --name RunnerInterface type RunnerInterface interface { Run(ctx context.Context, task *runnerv1.Task) error }