diff --git a/internal/pkg/report/reporter.go b/internal/pkg/report/reporter.go index 68fa3db4..3233cb2a 100644 --- a/internal/pkg/report/reporter.go +++ b/internal/pkg/report/reporter.go @@ -209,6 +209,17 @@ func (r *Reporter) logf(format string, a ...any) { } } +func (r *Reporter) GetOutputs() map[string]string { + outputs := make(map[string]string) + r.outputs.Range(func(k, v any) bool { + if val, ok := v.(string); ok { + outputs[k.(string)] = val + } + return true + }) + return outputs +} + func (r *Reporter) SetOutputs(outputs map[string]string) error { r.stateMu.Lock() defer r.stateMu.Unlock() @@ -360,13 +371,7 @@ func (r *Reporter) ReportState() error { state := proto.Clone(r.state).(*runnerv1.TaskState) r.stateMu.RUnlock() - outputs := make(map[string]string) - r.outputs.Range(func(k, v any) bool { - if val, ok := v.(string); ok { - outputs[k.(string)] = val - } - return true - }) + outputs := r.GetOutputs() resp, err := r.client.UpdateTask(r.ctx, connect.NewRequest(&runnerv1.UpdateTaskRequest{ State: state, @@ -380,7 +385,8 @@ func (r *Reporter) ReportState() error { r.outputs.Store(k, struct{}{}) } - if resp.Msg.GetState().GetResult() == runnerv1.Result_RESULT_CANCELLED { + switch resp.Msg.GetState().GetResult() { + case runnerv1.Result_RESULT_CANCELLED, runnerv1.Result_RESULT_FAILURE: r.cancel() } diff --git a/internal/pkg/report/reporter_test.go b/internal/pkg/report/reporter_test.go index bbb084b1..65fcbfbe 100644 --- a/internal/pkg/report/reporter_test.go +++ b/internal/pkg/report/reporter_test.go @@ -286,6 +286,112 @@ func TestReporter_Fire(t *testing.T) { }) } +func TestReporterReportState(t *testing.T) { + for _, testCase := range []struct { + name string + fixture func(t *testing.T, reporter *Reporter, client *mocks.Client) + assert func(t *testing.T, reporter *Reporter, ctx context.Context, err error) + }{ + { + name: "PartialOutputs", + fixture: func(t *testing.T, reporter *Reporter, client *mocks.Client) { + t.Helper() + outputKey1 := "KEY1" + outputValue1 := "VALUE1" + outputKey2 := "KEY2" + outputValue2 := "VALUE2" + reporter.SetOutputs(map[string]string{ + outputKey1: outputValue1, + outputKey2: outputValue2, + }) + + client.On("UpdateTask", mock.Anything, mock.Anything).Return(func(_ context.Context, req *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) { + t.Logf("Received UpdateTask: %s", req.Msg.String()) + return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{ + SentOutputs: []string{outputKey1}, + }), nil + }) + }, + assert: func(t *testing.T, reporter *Reporter, ctx context.Context, err error) { + t.Helper() + require.ErrorContains(t, err, "not all logs are submitted 1 remain") + outputs := reporter.GetOutputs() + assert.Equal(t, map[string]string{ + "KEY2": "VALUE2", + }, outputs) + assert.NoError(t, ctx.Err()) + }, + }, + { + name: "AllDone", + fixture: func(t *testing.T, reporter *Reporter, client *mocks.Client) { + t.Helper() + client.On("UpdateTask", mock.Anything, mock.Anything).Return(func(_ context.Context, req *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) { + t.Logf("Received UpdateTask: %s", req.Msg.String()) + return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{}), nil + }) + }, + assert: func(t *testing.T, reporter *Reporter, ctx context.Context, err error) { + t.Helper() + require.NoError(t, err) + assert.NoError(t, ctx.Err()) + }, + }, + { + name: "Canceled", + fixture: func(t *testing.T, reporter *Reporter, client *mocks.Client) { + t.Helper() + client.On("UpdateTask", mock.Anything, mock.Anything).Return(func(_ context.Context, req *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) { + t.Logf("Received UpdateTask: %s", req.Msg.String()) + return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{ + State: &runnerv1.TaskState{ + Result: runnerv1.Result_RESULT_CANCELLED, + }, + }), nil + }) + }, + assert: func(t *testing.T, reporter *Reporter, ctx context.Context, err error) { + t.Helper() + require.NoError(t, err) + assert.ErrorIs(t, ctx.Err(), context.Canceled) + }, + }, + { + name: "Failed", + fixture: func(t *testing.T, reporter *Reporter, client *mocks.Client) { + t.Helper() + client.On("UpdateTask", mock.Anything, mock.Anything).Return(func(_ context.Context, req *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) { + t.Logf("Received UpdateTask: %s", req.Msg.String()) + return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{ + State: &runnerv1.TaskState{ + Result: runnerv1.Result_RESULT_FAILURE, + }, + }), nil + }) + }, + assert: func(t *testing.T, reporter *Reporter, ctx context.Context, err error) { + t.Helper() + require.NoError(t, err) + assert.ErrorIs(t, ctx.Err(), context.Canceled) + }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + client := mocks.NewClient(t) + ctx, cancel := context.WithCancel(context.Background()) + taskCtx, err := structpb.NewStruct(map[string]any{}) + require.NoError(t, err) + reporter := NewReporter(common.WithDaemonContext(ctx, t.Context()), cancel, client, &runnerv1.Task{ + Context: taskCtx, + }, time.Second) + + testCase.fixture(t, reporter, client) + err = reporter.ReportState() + testCase.assert(t, reporter, ctx, err) + }) + } +} + func TestReporterReportLogLost(t *testing.T) { reporter, client, _ := mockReporter(t) reporter.logRows = stringToRows("A")