Skip to content

Commit

Permalink
Prevent extra JSON from being stored in actions. (#1880)
Browse files Browse the repository at this point in the history
* Prevent extra JSON from being stored in actions.

* Make mapstructure safe in all uses.

* Simplify test
  • Loading branch information
crspeller authored Nov 2, 2023
1 parent f480a02 commit c4b026c
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 103 deletions.
106 changes: 79 additions & 27 deletions server/api/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
"net/url"

"github.com/mattermost/mattermost-plugin-playbooks/server/app"
"github.com/mitchellh/mapstructure"
"github.com/mattermost/mattermost-plugin-playbooks/server/safemapstructure"
"github.com/mattermost/mattermost-server/v6/model"
"github.com/pkg/errors"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -66,20 +67,8 @@ func (a *ActionsHandler) createChannelAction(c *Context, w http.ResponseWriter,
return
}

if channelAction.ActionType == app.ActionTypePromptRunPlaybook {
var payload app.PromptRunPlaybookFromKeywordsPayload
if err := mapstructure.Decode(channelAction.Payload, &payload); err != nil {
a.HandleErrorWithCode(w, c.logger, http.StatusBadRequest, "couldn't verify permissions for run playbook action", err)
return
}

if !a.PermissionsCheck(w, c.logger, a.permissions.PlaybookView(userID, payload.PlaybookID)) {
return
}
}

// Validate the action type and payload
if err := a.channelActionsService.Validate(channelAction); err != nil {
if err := a.ValidateChannelAction(c, w, &channelAction, userID); err != nil {
a.HandleErrorWithCode(w, c.logger, http.StatusBadRequest, "invalid action", err)
return
}
Expand All @@ -100,6 +89,81 @@ func (a *ActionsHandler) createChannelAction(c *Context, w http.ResponseWriter,
ReturnJSON(w, &result, http.StatusCreated)
}

func (a *ActionsHandler) ValidateChannelAction(c *Context, w http.ResponseWriter, action *app.GenericChannelAction, userID string) error {
// Validate the trigger type and action types
switch action.TriggerType {
case app.TriggerTypeNewMemberJoins:
switch action.ActionType {
case app.ActionTypeWelcomeMessage:
break
case app.ActionTypeCategorizeChannel:
break
default:
return fmt.Errorf("action type %q is not valid for trigger type %q", action.ActionType, action.TriggerType)
}
case app.TriggerTypeKeywordsPosted:
if action.ActionType != app.ActionTypePromptRunPlaybook {
return fmt.Errorf("action type %q is not valid for trigger type %q", action.ActionType, action.TriggerType)
}
default:
return fmt.Errorf("trigger type %q not recognized", action.TriggerType)
}

// Validate the payload depending on the action type
switch action.ActionType {
case app.ActionTypeWelcomeMessage:
var payload app.WelcomeMessagePayload
if err := safemapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}

// Force the payload to only include the recognized decoded fields.
action.Payload = payload
case app.ActionTypePromptRunPlaybook:
var payload app.PromptRunPlaybookFromKeywordsPayload
if err := safemapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}
if err := checkValidPromptRunPlaybookFromKeywordsPayload(payload); err != nil {
return err
}

if !a.PermissionsCheck(w, c.logger, a.permissions.PlaybookView(userID, payload.PlaybookID)) {
return fmt.Errorf("user does not have permissions to view playbook %s", payload.PlaybookID)
}

// Force the payload to only include the recognized decoded fields.
action.Payload = payload
case app.ActionTypeCategorizeChannel:
var payload app.CategorizeChannelPayload
if err := safemapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}

// Force the payload to only include the recognized decoded fields.
action.Payload = payload

default:
return fmt.Errorf("action type %q not recognized", action.ActionType)
}

return nil
}

func checkValidPromptRunPlaybookFromKeywordsPayload(payload app.PromptRunPlaybookFromKeywordsPayload) error {
for _, keyword := range payload.Keywords {
if keyword == "" {
return fmt.Errorf("payload field 'keywords' must contain only non-empty keywords")
}
}

if payload.PlaybookID != "" && !model.IsValidId(payload.PlaybookID) {
return fmt.Errorf("payload field 'playbook_id' must be a valid ID")
}

return nil
}

func isValidTrigger(trigger string) bool {
if trigger == "" {
return true
Expand Down Expand Up @@ -208,20 +272,8 @@ func (a *ActionsHandler) updateChannelAction(c *Context, w http.ResponseWriter,
return
}

if newChannelAction.ActionType == app.ActionTypePromptRunPlaybook {
var payload app.PromptRunPlaybookFromKeywordsPayload
if err := mapstructure.Decode(newChannelAction.Payload, &payload); err != nil {
a.HandleErrorWithCode(w, c.logger, http.StatusBadRequest, "couldn't verify permissions for run playbook action", err)
return
}

if !a.PermissionsCheck(w, c.logger, a.permissions.PlaybookView(userID, payload.PlaybookID)) {
return
}
}

// Validate the new action type and payload
if err := a.channelActionsService.Validate(newChannelAction); err != nil {
if err := a.ValidateChannelAction(c, w, &newChannelAction, userID); err != nil {
a.HandleErrorWithCode(w, c.logger, http.StatusBadRequest, "invalid action", err)
return
}
Expand Down
12 changes: 6 additions & 6 deletions server/api_actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"testing"

"github.com/mattermost/mattermost-plugin-playbooks/client"
"github.com/mattermost/mattermost-plugin-playbooks/server/safemapstructure"
"github.com/mattermost/mattermost-server/v6/model"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -269,19 +269,19 @@ func TestActionList(t *testing.T) {
switch action.ID {
case welcomeActionID:
var payload client.WelcomeMessagePayload
err = mapstructure.Decode(action.Payload, &payload)
err = safemapstructure.Decode(action.Payload, &payload)
assert.NoError(t, err)
assert.Equal(t, "msg", payload.Message)

case categorizeActionID:
var payload client.CategorizeChannelPayload
err = mapstructure.Decode(action.Payload, &payload)
err = safemapstructure.Decode(action.Payload, &payload)
assert.NoError(t, err)
assert.Equal(t, "category", payload.CategoryName)

case promptActionID:
var payload client.PromptRunPlaybookFromKeywordsPayload
err = mapstructure.Decode(action.Payload, &payload)
err = safemapstructure.Decode(action.Payload, &payload)
assert.NoError(t, err)
assert.EqualValues(t, []string{"one", "two"}, payload.Keywords)
assert.Equal(t, playbookID, payload.PlaybookID)
Expand Down Expand Up @@ -386,7 +386,7 @@ func TestActionUpdate(t *testing.T) {
assert.Len(t, actions, 1)
fetchedAction := actions[0]
var fetchedPayload client.PromptRunPlaybookFromKeywordsPayload
err = mapstructure.Decode(fetchedAction.Payload, &fetchedPayload)
err = safemapstructure.Decode(fetchedAction.Payload, &fetchedPayload)
assert.NoError(t, err)

// Verify that the payload of the created action has one keyword
Expand All @@ -411,7 +411,7 @@ func TestActionUpdate(t *testing.T) {
assert.Len(t, updatedActions, 1)
updatedAction := updatedActions[0]
var updatedPayload client.PromptRunPlaybookFromKeywordsPayload
err = mapstructure.Decode(updatedAction.Payload, &updatedPayload)
err = safemapstructure.Decode(updatedAction.Payload, &updatedPayload)
assert.NoError(t, err)

// Verify that the payload of the updated action has no keywords
Expand Down
4 changes: 0 additions & 4 deletions server/app/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ type ChannelActionService interface {
// filtered with the options if different from its zero value
GetChannelActions(channelID string, options GetChannelActionOptions) ([]GenericChannelAction, error)

// Validate checks that the action type, trigger type and
// payload are all valid and consistent with each other
Validate(action GenericChannelAction) error

// Update updates an existing action identified by action.ID
Update(action GenericChannelAction, userID string) error

Expand Down
70 changes: 4 additions & 66 deletions server/app/actions_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
pluginapi "github.com/mattermost/mattermost-plugin-api"
"github.com/mattermost/mattermost-plugin-playbooks/server/bot"
"github.com/mattermost/mattermost-plugin-playbooks/server/config"
"github.com/mattermost/mattermost-plugin-playbooks/server/safemapstructure"
"github.com/mattermost/mattermost-server/v6/model"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -115,68 +115,6 @@ func (a *channelActionServiceImpl) GetChannelActions(channelID string, options G
return a.store.GetChannelActions(channelID, options)
}

func (a *channelActionServiceImpl) Validate(action GenericChannelAction) error {
// Validate the trigger type and action types
switch action.TriggerType {
case TriggerTypeNewMemberJoins:
switch action.ActionType {
case ActionTypeWelcomeMessage:
break
case ActionTypeCategorizeChannel:
break
default:
return fmt.Errorf("action type %q is not valid for trigger type %q", action.ActionType, action.TriggerType)
}
case TriggerTypeKeywordsPosted:
if action.ActionType != ActionTypePromptRunPlaybook {
return fmt.Errorf("action type %q is not valid for trigger type %q", action.ActionType, action.TriggerType)
}
default:
return fmt.Errorf("trigger type %q not recognized", action.TriggerType)
}

// Validate the payload depending on the action type
switch action.ActionType {
case ActionTypeWelcomeMessage:
var payload WelcomeMessagePayload
if err := mapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}
case ActionTypePromptRunPlaybook:
var payload PromptRunPlaybookFromKeywordsPayload
if err := mapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}
if err := checkValidPromptRunPlaybookFromKeywordsPayload(payload); err != nil {
return err
}
case ActionTypeCategorizeChannel:
var payload CategorizeChannelPayload
if err := mapstructure.Decode(action.Payload, &payload); err != nil {
return fmt.Errorf("unable to decode payload from action")
}

default:
return fmt.Errorf("action type %q not recognized", action.ActionType)
}

return nil
}

func checkValidPromptRunPlaybookFromKeywordsPayload(payload PromptRunPlaybookFromKeywordsPayload) error {
for _, keyword := range payload.Keywords {
if keyword == "" {
return fmt.Errorf("payload field 'keywords' must contain only non-empty keywords")
}
}

if payload.PlaybookID != "" && !model.IsValidId(payload.PlaybookID) {
return fmt.Errorf("payload field 'playbook_id' must be a valid ID")
}

return nil
}

func (a *channelActionServiceImpl) Update(action GenericChannelAction, userID string) error {
oldAction, err := a.Get(action.ID)
if err != nil {
Expand Down Expand Up @@ -244,7 +182,7 @@ func (a *channelActionServiceImpl) UserHasJoinedChannel(userID, channelID, actor
}

var payload CategorizeChannelPayload
if err = mapstructure.Decode(action.Payload, &payload); err != nil {
if err = safemapstructure.Decode(action.Payload, &payload); err != nil {
logrus.WithError(err).Error("unable to decode payload of CategorizeChannelPayload")
return
}
Expand Down Expand Up @@ -366,7 +304,7 @@ func (a *channelActionServiceImpl) CheckAndSendMessageOnJoin(userID, channelID s
for _, action := range actions {
if action.ActionType == ActionTypeWelcomeMessage {
var payload WelcomeMessagePayload
if err := mapstructure.Decode(action.Payload, &payload); err != nil {
if err := safemapstructure.Decode(action.Payload, &payload); err != nil {
logrus.WithError(err).WithField("action_type", action.ActionType).Error("payload of action is not valid")
}

Expand Down Expand Up @@ -412,7 +350,7 @@ func (a *channelActionServiceImpl) MessageHasBeenPosted(post *model.Post) {
}

var payload PromptRunPlaybookFromKeywordsPayload
if err := mapstructure.Decode(action.Payload, &payload); err != nil {
if err := safemapstructure.Decode(action.Payload, &payload); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"payload": payload,
"actionType": action.ActionType,
Expand Down
20 changes: 20 additions & 0 deletions server/safemapstructure/safemapstructure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package safemapstructure

import (
"github.com/mitchellh/mapstructure"
)

func Decode(input interface{}, output interface{}) error {
config := &mapstructure.DecoderConfig{
Metadata: nil,
Result: output,
MatchName: func(a string, b string) bool { return a == b }, // Only match exactly
}

decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return err
}

return decoder.Decode(input)
}
52 changes: 52 additions & 0 deletions server/safemapstructure/safemapstructure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package safemapstructure

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDecodeNoMatchesCase(t *testing.T) {
type Test struct {
Test string `mapstructure:"test"`
OtherTest string `mapstructure:"other_test"`
}

input := map[string]interface{}{
"tEst": "incorrect",
"Test": "incorrect",
"Other_test": "incorrect",
"other_tEst": "incorrect",
}

var output Test
err := Decode(input, &output)
require.Nil(t, err)

require.Equal(t, "", output.Test)
require.Equal(t, "", output.OtherTest)
}

func TestDecodeHasMatch(t *testing.T) {
type Test struct {
Test string `mapstructure:"test"`
OtherTest string `mapstructure:"other_test"`
}

input := map[string]interface{}{
"tEst": "incorrect",
"test": "correct1",
"other_test": "correct2",
"other_tEst": "incorrect",
}

// Do it a bunch of times since map order is randomized
for i := 0; i < 100; i++ {
var output Test
err := Decode(input, &output)
require.Nil(t, err)

require.Equal(t, "correct1", output.Test)
require.Equal(t, "correct2", output.OtherTest)
}
}

0 comments on commit c4b026c

Please sign in to comment.