Skip to content

Commit 6299d12

Browse files
committed
cleanup
1 parent 95fcea8 commit 6299d12

9 files changed

Lines changed: 57 additions & 101 deletions

File tree

internal/agent/handler.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package agent
44
import (
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"net/http"
910
"strings"
@@ -341,7 +342,7 @@ func (h *Handler) GetConversation(w http.ResponseWriter, r *http.Request) {
341342

342343
conv, err := GetConversation(ctx, userID, conversationID)
343344
if err != nil {
344-
if strings.Contains(err.Error(), "not found") {
345+
if errors.Is(err, ErrConversationNotFound) {
345346
writeJSONError(w, http.StatusNotFound, "conversation not found")
346347
return
347348
}
@@ -378,7 +379,7 @@ func (h *Handler) DeleteConversation(w http.ResponseWriter, r *http.Request) {
378379
}
379380

380381
if err := DeleteConversation(ctx, userID, conversationID); err != nil {
381-
if strings.Contains(err.Error(), "not found") {
382+
if errors.Is(err, ErrConversationNotFound) {
382383
writeJSONError(w, http.StatusNotFound, "conversation not found")
383384
return
384385
}

internal/agent/orchestrator.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ const (
3131
OrchestratorEventThinking OrchestratorEventType = "thinking"
3232
OrchestratorEventChart OrchestratorEventType = "chart"
3333
OrchestratorEventError OrchestratorEventType = "error"
34-
OrchestratorEventUsage OrchestratorEventType = "usage"
3534
OrchestratorEventDone OrchestratorEventType = "done"
3635
)
3736

@@ -70,6 +69,9 @@ type ContentBlock struct {
7069
ToolCallID string `json:"tool_call_id,omitempty"`
7170
// ToolName is the name of the tool (for "tool_use" type blocks).
7271
ToolName string `json:"tool_name,omitempty"`
72+
// ToolArguments are the arguments passed to the tool (for "tool_use" type blocks).
73+
// Stored so the full tool call can be reconstructed when replaying history to the LLM.
74+
ToolArguments map[string]any `json:"tool_arguments,omitempty"`
7375
// ToolSuccess indicates whether the tool execution succeeded (for "tool_use" type blocks).
7476
ToolSuccess bool `json:"tool_success,omitempty"`
7577
// ToolResult is the result returned by the tool (for "tool_use" type blocks).
@@ -265,11 +267,12 @@ func (cl *conversationLoop) executeToolCall(toolCall chat.ToolCall) toolExecutio
265267
// tool result message for the next LLM iteration.
266268
func (cl *conversationLoop) recordToolExecution(result toolExecutionResult) {
267269
cl.blocks = append(cl.blocks, ContentBlock{
268-
Type: ContentBlockTypeToolUse,
269-
ToolCallID: result.ToolCall.ID,
270-
ToolName: result.ToolCall.Name,
271-
ToolSuccess: result.Success,
272-
ToolResult: result.Result,
270+
Type: ContentBlockTypeToolUse,
271+
ToolCallID: result.ToolCall.ID,
272+
ToolName: result.ToolCall.Name,
273+
ToolArguments: result.ToolCall.Arguments,
274+
ToolSuccess: result.Success,
275+
ToolResult: result.Result,
273276
})
274277

275278
if result.ChartData != nil {

internal/agent/queries.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package agent
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"strings"
89

@@ -13,6 +14,9 @@ import (
1314
"github.com/nais/api/internal/database"
1415
)
1516

17+
// ErrConversationNotFound is returned when a conversation does not exist or does not belong to the user.
18+
var ErrConversationNotFound = errors.New("conversation not found")
19+
1620
// GetOrCreateConversation retrieves an existing conversation or creates a new one.
1721
func GetOrCreateConversation(ctx context.Context, userID uuid.UUID, conversationID string, firstMessage string) (uuid.UUID, error) {
1822
if conversationID != "" {
@@ -29,7 +33,7 @@ func GetOrCreateConversation(ctx context.Context, userID uuid.UUID, conversation
2933
return uuid.Nil, fmt.Errorf("failed to check conversation: %w", err)
3034
}
3135
if !exists {
32-
return uuid.Nil, fmt.Errorf("conversation not found")
36+
return uuid.Nil, ErrConversationNotFound
3337
}
3438

3539
if err := db(ctx).TouchConversation(ctx, id); err != nil {
@@ -133,7 +137,7 @@ func GetConversation(ctx context.Context, userID uuid.UUID, conversationID uuid.
133137
})
134138
if err != nil {
135139
if err == pgx.ErrNoRows {
136-
return nil, fmt.Errorf("conversation not found")
140+
return nil, ErrConversationNotFound
137141
}
138142
return nil, fmt.Errorf("failed to get conversation: %w", err)
139143
}
@@ -195,7 +199,7 @@ func DeleteConversation(ctx context.Context, userID uuid.UUID, conversationID uu
195199
}
196200

197201
if rowsAffected == 0 {
198-
return fmt.Errorf("conversation not found")
202+
return ErrConversationNotFound
199203
}
200204

201205
return nil
@@ -291,8 +295,9 @@ func extractToolCallsFromBlocks(blocks []ContentBlock) []chat.ToolCall {
291295
for _, block := range blocks {
292296
if block.Type == ContentBlockTypeToolUse {
293297
toolCalls = append(toolCalls, chat.ToolCall{
294-
ID: block.ToolCallID,
295-
Name: block.ToolName,
298+
ID: block.ToolCallID,
299+
Name: block.ToolName,
300+
Arguments: block.ToolArguments,
296301
})
297302
}
298303
}

internal/agent/tools/chart.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func NewChartTools() *ChartTools {
1616

1717
// RenderChart validates chart parameters and prepares the chart data.
1818
// Note: The actual rendering happens on the client side using the returned ChartData.
19-
func (t *ChartTools) RenderChart(ctx context.Context, input RenderChartInput) (*ChartData, error) {
19+
func (t *ChartTools) RenderChart(ctx context.Context, input ChartData) (*ChartData, error) {
2020
// Validate required fields
2121
if input.ChartType == "" {
2222
return nil, fmt.Errorf("chart_type is required")
@@ -52,13 +52,5 @@ func (t *ChartTools) RenderChart(ctx context.Context, input RenderChartInput) (*
5252
}
5353
}
5454

55-
return &ChartData{
56-
ChartType: input.ChartType,
57-
Title: input.Title,
58-
Environment: input.Environment,
59-
Query: input.Query,
60-
Interval: input.Interval,
61-
YFormat: input.YFormat,
62-
LabelTemplate: input.LabelTemplate,
63-
}, nil
55+
return &input, nil
6456
}

internal/agent/tools/chart_test.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ func TestRenderChart(t *testing.T) {
1111

1212
tests := []struct {
1313
name string
14-
input RenderChartInput
14+
input ChartData
1515
wantErr bool
1616
errContains string
1717
validate func(*testing.T, *ChartData)
1818
}{
1919
{
2020
name: "valid chart with all fields",
21-
input: RenderChartInput{
21+
input: ChartData{
2222
ChartType: "line",
2323
Title: "CPU Usage",
2424
Environment: "dev",
@@ -54,7 +54,7 @@ func TestRenderChart(t *testing.T) {
5454
},
5555
{
5656
name: "valid chart with required fields only",
57-
input: RenderChartInput{
57+
input: ChartData{
5858
ChartType: "line",
5959
Title: "Memory Usage",
6060
Environment: "prod",
@@ -78,49 +78,49 @@ func TestRenderChart(t *testing.T) {
7878
},
7979
{
8080
name: "missing chart_type",
81-
input: RenderChartInput{Title: "CPU Usage", Environment: "dev", Query: "some_query"},
81+
input: ChartData{Title: "CPU Usage", Environment: "dev", Query: "some_query"},
8282
wantErr: true,
8383
errContains: "chart_type is required",
8484
},
8585
{
8686
name: "empty chart_type",
87-
input: RenderChartInput{ChartType: "", Title: "CPU Usage", Environment: "dev", Query: "some_query"},
87+
input: ChartData{ChartType: "", Title: "CPU Usage", Environment: "dev", Query: "some_query"},
8888
wantErr: true,
8989
errContains: "chart_type is required",
9090
},
9191
{
9292
name: "unsupported chart_type",
93-
input: RenderChartInput{ChartType: "bar", Title: "CPU Usage", Environment: "dev", Query: "some_query"},
93+
input: ChartData{ChartType: "bar", Title: "CPU Usage", Environment: "dev", Query: "some_query"},
9494
wantErr: true,
9595
errContains: "unsupported chart_type",
9696
},
9797
{
9898
name: "missing title",
99-
input: RenderChartInput{ChartType: "line", Environment: "dev", Query: "some_query"},
99+
input: ChartData{ChartType: "line", Environment: "dev", Query: "some_query"},
100100
wantErr: true,
101101
errContains: "title is required",
102102
},
103103
{
104104
name: "missing environment",
105-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Query: "some_query"},
105+
input: ChartData{ChartType: "line", Title: "CPU Usage", Query: "some_query"},
106106
wantErr: true,
107107
errContains: "environment is required",
108108
},
109109
{
110110
name: "missing query",
111-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev"},
111+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev"},
112112
wantErr: true,
113113
errContains: "query is required",
114114
},
115115
{
116116
name: "invalid interval",
117-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "2h"},
117+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "2h"},
118118
wantErr: true,
119119
errContains: "invalid interval",
120120
},
121121
{
122122
name: "valid interval 6h",
123-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "6h"},
123+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "6h"},
124124
validate: func(t *testing.T, c *ChartData) {
125125
if c.Interval != "6h" {
126126
t.Errorf("expected interval '6h', got %q", c.Interval)
@@ -129,7 +129,7 @@ func TestRenderChart(t *testing.T) {
129129
},
130130
{
131131
name: "valid interval 1d",
132-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "1d"},
132+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "1d"},
133133
validate: func(t *testing.T, c *ChartData) {
134134
if c.Interval != "1d" {
135135
t.Errorf("expected interval '1d', got %q", c.Interval)
@@ -138,7 +138,7 @@ func TestRenderChart(t *testing.T) {
138138
},
139139
{
140140
name: "valid interval 7d",
141-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "7d"},
141+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "7d"},
142142
validate: func(t *testing.T, c *ChartData) {
143143
if c.Interval != "7d" {
144144
t.Errorf("expected interval '7d', got %q", c.Interval)
@@ -147,7 +147,7 @@ func TestRenderChart(t *testing.T) {
147147
},
148148
{
149149
name: "valid interval 30d",
150-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "30d"},
150+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", Interval: "30d"},
151151
validate: func(t *testing.T, c *ChartData) {
152152
if c.Interval != "30d" {
153153
t.Errorf("expected interval '30d', got %q", c.Interval)
@@ -156,13 +156,13 @@ func TestRenderChart(t *testing.T) {
156156
},
157157
{
158158
name: "invalid y_format",
159-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "invalid"},
159+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "invalid"},
160160
wantErr: true,
161161
errContains: "invalid y_format",
162162
},
163163
{
164164
name: "valid y_format number",
165-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "number"},
165+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "number"},
166166
validate: func(t *testing.T, c *ChartData) {
167167
if c.YFormat != "number" {
168168
t.Errorf("expected y_format 'number', got %q", c.YFormat)
@@ -171,7 +171,7 @@ func TestRenderChart(t *testing.T) {
171171
},
172172
{
173173
name: "valid y_format percentage",
174-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "percentage"},
174+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "percentage"},
175175
validate: func(t *testing.T, c *ChartData) {
176176
if c.YFormat != "percentage" {
177177
t.Errorf("expected y_format 'percentage', got %q", c.YFormat)
@@ -180,7 +180,7 @@ func TestRenderChart(t *testing.T) {
180180
},
181181
{
182182
name: "valid y_format bytes",
183-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "bytes"},
183+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "bytes"},
184184
validate: func(t *testing.T, c *ChartData) {
185185
if c.YFormat != "bytes" {
186186
t.Errorf("expected y_format 'bytes', got %q", c.YFormat)
@@ -189,7 +189,7 @@ func TestRenderChart(t *testing.T) {
189189
},
190190
{
191191
name: "valid y_format cpu_cores",
192-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "cpu_cores"},
192+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "cpu_cores"},
193193
validate: func(t *testing.T, c *ChartData) {
194194
if c.YFormat != "cpu_cores" {
195195
t.Errorf("expected y_format 'cpu_cores', got %q", c.YFormat)
@@ -198,7 +198,7 @@ func TestRenderChart(t *testing.T) {
198198
},
199199
{
200200
name: "valid y_format duration",
201-
input: RenderChartInput{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "duration"},
201+
input: ChartData{ChartType: "line", Title: "CPU Usage", Environment: "dev", Query: "q", YFormat: "duration"},
202202
validate: func(t *testing.T, c *ChartData) {
203203
if c.YFormat != "duration" {
204204
t.Errorf("expected y_format 'duration', got %q", c.YFormat)

internal/agent/tools/graphql.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ type GraphQLClient interface {
3030

3131
// UserInfo contains information about an authenticated user.
3232
type UserInfo struct {
33-
Name string
34-
IsAdmin bool
33+
Name string `json:"name"`
34+
IsAdmin bool `json:"is_admin,omitempty"`
3535
}
3636

3737
// TeamInfo contains information about a team the user belongs to.
3838
type TeamInfo struct {
39-
Slug string
40-
Purpose string
41-
Role string
39+
Slug string `json:"slug"`
40+
Purpose string `json:"purpose,omitempty"`
41+
Role string `json:"role"`
4242
}
4343

4444
// GraphQLTools provides GraphQL execution functionality.
@@ -73,21 +73,9 @@ func (g *GraphQLTools) GetNaisContext(ctx context.Context) (GetNaisContextOutput
7373
return GetNaisContextOutput{}, fmt.Errorf("failed to get user teams: %w", err)
7474
}
7575

76-
// Build teams list
77-
teamsList := make([]NaisTeamInfo, 0, len(teams))
78-
for _, team := range teams {
79-
teamsList = append(teamsList, NaisTeamInfo{
80-
Slug: team.Slug,
81-
Purpose: team.Purpose,
82-
Role: team.Role,
83-
})
84-
}
85-
8676
return GetNaisContextOutput{
87-
User: NaisUserInfo{
88-
Name: user.Name,
89-
},
90-
Teams: teamsList,
77+
User: *user,
78+
Teams: teams,
9179
ConsoleBaseURL: g.consoleBaseURL,
9280
ConsoleURLPatterns: g.urlPatterns,
9381
}, nil

0 commit comments

Comments
 (0)