Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions coderd/aitasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,27 @@ func (api *API) authAndDoWithTaskAppClient(
ctx := r.Context()

if task.Status != database.TaskStatusActive {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task status must be active.",
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
})
// Return 409 Conflict for valid requests blocked by current state
// (pending/initializing are transitional, paused requires resume).
// Return 400 Bad Request for error/unknown states.
switch task.Status {
case database.TaskStatusPending, database.TaskStatusInitializing:
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: fmt.Sprintf("Task is %s.", task.Status),
Detail: "The task is resuming. Wait for the task to become active before sending messages.",
})
case database.TaskStatusPaused:
return httperror.NewResponseError(http.StatusConflict, codersdk.Response{
Message: "Task is paused.",
Detail: "Resume the task to send messages.",
})
default:
// Default handler for error and unknown status.
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Message: "Task must be active.",
Detail: fmt.Sprintf("Task status is %q, it must be %q to interact with the task.", task.Status, codersdk.TaskStatusActive),
})
}
}
if !task.WorkspaceID.Valid {
return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{
Expand Down
219 changes: 157 additions & 62 deletions coderd/aitasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
Expand All @@ -39,6 +40,66 @@ import (
"github.com/coder/quartz"
)

// createTaskInState is a helper to create a task in the desired state.
// It returns a function that takes context, test, and status, and returns the task ID.
// The caller is responsible for setting up the database, owner, and user.
func createTaskInState(db database.Store, ownerSubject rbac.Subject, ownerOrgID, userID uuid.UUID) func(context.Context, *testing.T, database.TaskStatus) uuid.UUID {
return func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)

builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: ownerOrgID,
OwnerID: userID,
}).
WithTask(database.TaskTable{
OrganizationID: ownerOrgID,
OwnerID: userID,
}, nil)

switch status {
case database.TaskStatusPending:
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
})
case database.TaskStatusError:
// For error state, create a completed build then manipulate app health.
default:
require.Fail(t, "unsupported task status in test helper", "status: %s", status)
}

resp := builder.Do()
taskID := resp.Task.ID

// Post-process by manipulating agent and app state.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)

// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")

err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}

return taskID
}
}

func TestTasks(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -591,6 +652,94 @@ func TestTasks(t *testing.T) {
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})

t.Run("SendToNonActiveStates", func(t *testing.T) {
t.Parallel()

client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitMedium)

ownerUser, err := client.User(ctx, owner.UserID.String())
require.NoError(t, err)
ownerSubject := coderdtest.AuthzUserSubject(ownerUser)

// Create a regular user for task ownership.
_, user := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)

createTask := createTaskInState(db, ownerSubject, owner.OrganizationID, user.ID)

t.Run("Paused", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPaused)

err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})

var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "paused")
require.Contains(t, sdkErr.Detail, "Resume")
})

t.Run("Initializing", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusInitializing)

err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})

var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "initializing")
require.Contains(t, sdkErr.Detail, "resuming")
})

t.Run("Pending", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusPending)

err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})

var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusConflict, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "pending")
require.Contains(t, sdkErr.Detail, "resuming")
})

t.Run("Error", func(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTask(ctx, t, database.TaskStatusError)

err := client.TaskSend(ctx, "me", taskID, codersdk.TaskSendRequest{
Input: "Hello",
})

var sdkErr *codersdk.Error
require.Error(t, err)
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
require.Contains(t, sdkErr.Message, "must be active")
})
})
})

t.Run("Logs", func(t *testing.T) {
Expand Down Expand Up @@ -737,61 +886,7 @@ func TestTasks(t *testing.T) {
// Create a regular user to test snapshot access.
client, user := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID)

// Helper to create a task in the desired state.
createTaskInState := func(ctx context.Context, t *testing.T, status database.TaskStatus) uuid.UUID {
ctx = dbauthz.As(ctx, ownerSubject)

builder := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: user.ID,
}).
WithTask(database.TaskTable{
OrganizationID: owner.OrganizationID,
OwnerID: user.ID,
}, nil)

switch status {
case database.TaskStatusPending:
builder = builder.Pending()
case database.TaskStatusInitializing:
builder = builder.Starting()
case database.TaskStatusPaused:
builder = builder.Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
})
case database.TaskStatusError:
// For error state, create a completed build then manipulate app health.
default:
require.Fail(t, "unsupported task status in test helper", "status: %s", status)
}

resp := builder.Do()
taskID := resp.Task.ID

// Post-process by manipulating agent and app state.
if status == database.TaskStatusError {
// First, set agent to ready state so agent_status returns 'active'.
// This ensures the cascade reaches app_status.
err := db.UpdateWorkspaceAgentLifecycleStateByID(ctx, database.UpdateWorkspaceAgentLifecycleStateByIDParams{
ID: resp.Agents[0].ID,
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
})
require.NoError(t, err)

// Then set workspace app health to unhealthy to trigger error state.
apps, err := db.GetWorkspaceAppsByAgentID(ctx, resp.Agents[0].ID)
require.NoError(t, err)
require.Len(t, apps, 1, "expected exactly one app for task")

err = db.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{
ID: apps[0].ID,
Health: database.WorkspaceAppHealthUnhealthy,
})
require.NoError(t, err)
}

return taskID
}
createTask := createTaskInState(db, ownerSubject, owner.OrganizationID, user.ID)

// Prepare snapshot data used across tests.
snapshotMessages := []agentapisdk.Message{
Expand Down Expand Up @@ -853,7 +948,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)

err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
Expand All @@ -871,7 +966,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusInitializing)
taskID := createTask(ctx, t, database.TaskStatusInitializing)

err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
Expand All @@ -889,7 +984,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusPaused)
taskID := createTask(ctx, t, database.TaskStatusPaused)

err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
Expand All @@ -907,7 +1002,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)

logsResp, err := client.TaskLogs(ctx, "me", taskID)
require.NoError(t, err)
Expand All @@ -921,7 +1016,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)

invalidEnvelope := coderd.TaskLogSnapshotEnvelope{
Format: "unknown-format",
Expand Down Expand Up @@ -950,7 +1045,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusPending)
taskID := createTask(ctx, t, database.TaskStatusPending)

err := db.UpsertTaskSnapshot(dbauthz.As(ctx, ownerSubject), database.UpsertTaskSnapshotParams{
TaskID: taskID,
Expand All @@ -971,7 +1066,7 @@ func TestTasks(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, testutil.WaitMedium)
taskID := createTaskInState(ctx, t, database.TaskStatusError)
taskID := createTask(ctx, t, database.TaskStatusError)

_, err := client.TaskLogs(ctx, "me", taskID)
require.Error(t, err)
Expand Down
Loading