Skip to content
Prev Previous commit
Next Next commit
refactor: address feedback on optionalParamOK helper
  • Loading branch information
monotykamary committed Apr 5, 2025
commit d206a7c2e81979a565c2b975566324d67f203bf2
112 changes: 112 additions & 0 deletions pkg/github/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,115 @@ func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent {
assert.Equal(t, "text", textContent.Type)
return textContent
}

func TestOptionalParamOK(t *testing.T) {
tests := []struct {
name string
args map[string]interface{}
paramName string
expectedVal interface{}
expectedOk bool
expectError bool
errorMsg string
}{
{
name: "present and correct type (string)",
args: map[string]interface{}{"myParam": "hello"},
paramName: "myParam",
expectedVal: "hello",
expectedOk: true,
expectError: false,
},
{
name: "present and correct type (bool)",
args: map[string]interface{}{"myParam": true},
paramName: "myParam",
expectedVal: true,
expectedOk: true,
expectError: false,
},
{
name: "present and correct type (number)",
args: map[string]interface{}{"myParam": float64(123)},
paramName: "myParam",
expectedVal: float64(123),
expectedOk: true,
expectError: false,
},
{
name: "present but wrong type (string expected, got bool)",
args: map[string]interface{}{"myParam": true},
paramName: "myParam",
expectedVal: "", // Zero value for string
expectedOk: true, // ok is true because param exists
expectError: true,
errorMsg: "parameter myParam is not of type string, is bool",
},
{
name: "present but wrong type (bool expected, got string)",
args: map[string]interface{}{"myParam": "true"},
paramName: "myParam",
expectedVal: false, // Zero value for bool
expectedOk: true, // ok is true because param exists
expectError: true,
errorMsg: "parameter myParam is not of type bool, is string",
},
{
name: "parameter not present",
args: map[string]interface{}{"anotherParam": "value"},
paramName: "myParam",
expectedVal: "", // Zero value for string
expectedOk: false,
expectError: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := createMCPRequest(tc.args)

// Test with string type assertion
if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" {
val, ok, err := optionalParamOK[string](request, tc.paramName)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}

// Test with bool type assertion
if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" {
val, ok, err := optionalParamOK[bool](request, tc.paramName)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
assert.Equal(t, tc.expectedOk, ok) // Check ok even on error
assert.Equal(t, tc.expectedVal, val) // Check zero value on error
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}

// Test with float64 type assertion (for number case)
if _, isFloat := tc.expectedVal.(float64); isFloat {
val, ok, err := optionalParamOK[float64](request, tc.paramName)
if tc.expectError {
// This case shouldn't happen for float64 in the defined tests
require.Fail(t, "Unexpected error case for float64")
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedOk, ok)
assert.Equal(t, tc.expectedVal, val)
}
}
})
}
}
10 changes: 5 additions & 5 deletions pkg/github/pullrequests.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,35 +118,35 @@ func updatePullRequest(client *github.Client, t translations.TranslationHelperFu
update := &github.PullRequest{}
updateNeeded := false

if title, ok, err := optionalParamOk[string](request, "title"); err != nil {
if title, ok, err := optionalParamOK[string](request, "title"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Title = github.Ptr(title)
updateNeeded = true
}

if body, ok, err := optionalParamOk[string](request, "body"); err != nil {
if body, ok, err := optionalParamOK[string](request, "body"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Body = github.Ptr(body)
updateNeeded = true
}

if state, ok, err := optionalParamOk[string](request, "state"); err != nil {
if state, ok, err := optionalParamOK[string](request, "state"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.State = github.Ptr(state)
updateNeeded = true
}

if base, ok, err := optionalParamOk[string](request, "base"); err != nil {
if base, ok, err := optionalParamOK[string](request, "base"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.Base = &github.PullRequestBranch{Ref: github.Ptr(base)}
updateNeeded = true
}

if maintainerCanModify, ok, err := optionalParamOk[bool](request, "maintainer_can_modify"); err != nil {
if maintainerCanModify, ok, err := optionalParamOK[bool](request, "maintainer_can_modify"); err != nil {
return mcp.NewToolResultError(err.Error()), nil
} else if ok {
update.MaintainerCanModify = github.Ptr(maintainerCanModify)
Expand Down
24 changes: 14 additions & 10 deletions pkg/github/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,28 @@ func getMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc
}
}

// optionalParamOk is a helper function that can be used to fetch a requested parameter from the request.
// optionalParamOK is a helper function that can be used to fetch a requested parameter from the request.
// It returns the value, a boolean indicating if the parameter was present, and an error if the type is wrong.
func optionalParamOk[T any](r mcp.CallToolRequest, p string) (T, bool, error) {
var zero T

func optionalParamOK[T any](r mcp.CallToolRequest, p string) (value T, ok bool, err error) {
// Check if the parameter is present in the request
val, ok := r.Params.Arguments[p]
if !ok {
return zero, false, nil // Not present, return zero value, false, no error
val, exists := r.Params.Arguments[p]
if !exists {
// Not present, return zero value, false, no error
return
}

// Check if the parameter is of the expected type
typedVal, ok := val.(T)
value, ok = val.(T)
if !ok {
return zero, true, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, val) // Present but wrong type
// Present but wrong type
err = fmt.Errorf("parameter %s is not of type %T, is %T", p, value, val)
ok = true // Set ok to true because the parameter *was* present, even if wrong type
return
}

return typedVal, true, nil // Present and correct type
// Present and correct type
ok = true
return
}

// isAcceptedError checks if the error is an accepted error.
Expand Down