diff --git a/.golangci.yml b/.golangci.yml index e6c4cf24b..4e07ba3cc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,7 +11,7 @@ linters-settings: - github.com/CheckmarxDev/containers-resolver/pkg/containerResolver - github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation - github.com/spf13/viper - - github.com/checkmarxDev/gpt-wrapper + - github.com/Checkmarx/gen-ai-wrapper - github.com/spf13/cobra - github.com/pkg/errors - github.com/google diff --git a/go.mod b/go.mod index 75629f156..34400c7d7 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.22.7 require ( github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63 + github.com/Checkmarx/gen-ai-wrapper v1.0.2 github.com/CheckmarxDev/containers-resolver v1.0.14 github.com/MakeNowJust/heredoc v1.0.0 github.com/bouk/monkey v1.0.0 - github.com/checkmarxDev/gpt-wrapper v0.0.0-20230721160222-85da2fd1cc4c github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gomarkdown/markdown v0.0.0-20241102151059-6bc1ffdc6e8c github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 @@ -17,7 +17,7 @@ require ( github.com/mssola/user_agent v0.6.0 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.8.1 - github.com/spf13/viper v1.18.2 + github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/xeipuuv/gojsonschema v1.2.0 @@ -195,7 +195,7 @@ require ( github.com/opencontainers/selinux v1.11.0 // indirect github.com/pborman/indent v1.2.1 // indirect github.com/pelletier/go-toml v1.9.5 // indirect - github.com/pelletier/go-toml/v2 v2.2.1 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect diff --git a/go.sum b/go.sum index 767da068c..b22124381 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63 h1:SCuTcE+CFvgjbIxUNL8rsdB2sAhfuNx85HvxImKta3g= github.com/Checkmarx/gen-ai-prompts v0.0.0-20240807143411-708ceec12b63/go.mod h1:MI6lfLerXU+5eTV/EPTDavgnV3owz3GPT4g/msZBWPo= +github.com/Checkmarx/gen-ai-wrapper v1.0.2 h1:T6X40+4hYnwfDsvkjWs9VIcE6s1O+8DUu0+sDdCY3GI= +github.com/Checkmarx/gen-ai-wrapper v1.0.2/go.mod h1:xwRLefezwNNnRGu1EjGS6wNiR9FVV/eP9D+oXwLViVM= github.com/CheckmarxDev/containers-resolver v1.0.14 h1:jOc1POn6XVETZtXMyer6Rp6K/MLfkuGCKgMk6ZLEf/w= github.com/CheckmarxDev/containers-resolver v1.0.14/go.mod h1:ne5YunT/hKQ7fnZejFVGodlfOUReNE7hZW2KLbpQi48= github.com/CycloneDX/cyclonedx-go v0.9.0 h1:inaif7qD8bivyxp7XLgxUYtOXWtDez7+j72qKTMQTb8= @@ -195,8 +197,6 @@ github.com/charmbracelet/x/term v0.1.1 h1:3cosVAiPOig+EV4X9U+3LDgtwwAoEzJjNdwbXD github.com/charmbracelet/x/term v0.1.1/go.mod h1:wB1fHt5ECsu3mXYusyzcngVWWlu1KKUmmLhfgr/Flxw= github.com/charmbracelet/x/windows v0.1.0 h1:gTaxdvzDM5oMa/I2ZNF7wN78X/atWemG9Wph7Ika2k4= github.com/charmbracelet/x/windows v0.1.0/go.mod h1:GLEO/l+lizvFDBPLIOk+49gdX49L9YWMB5t+DZd0jkQ= -github.com/checkmarxDev/gpt-wrapper v0.0.0-20230721160222-85da2fd1cc4c h1:oKI4C1dXYpi0B8pltDDzp1ZRiyeILv5enbp9h4ASQ3s= -github.com/checkmarxDev/gpt-wrapper v0.0.0-20230721160222-85da2fd1cc4c/go.mod h1:l+0rISRGaps2HWkpvKbYPE1nsNx28vBj6bKorEm1M5o= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -763,8 +763,8 @@ github.com/pborman/indent v1.2.1/go.mod h1:FitS+t35kIYtB5xWTZAPhnmrxcciEEOdbyrrp github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= -github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= -github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= @@ -885,8 +885,8 @@ github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.10.0/go.mod h1:SoyBPwAtKDzypXNDFKN5kzH7ppppbGZtls1UpIy5AsM= -github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= -github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= +github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= diff --git a/internal/commands/chat-kics.go b/internal/commands/chat-kics.go index 25e2461af..d2453a468 100644 --- a/internal/commands/chat-kics.go +++ b/internal/commands/chat-kics.go @@ -3,15 +3,15 @@ package commands import ( "fmt" "os" + "strings" + "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + "github.com/Checkmarx/gen-ai-wrapper/pkg/role" + gptWrapper "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/checkmarx/ast-cli/internal/commands/util/printer" "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" - "github.com/checkmarxDev/gpt-wrapper/pkg/connector" - "github.com/checkmarxDev/gpt-wrapper/pkg/message" - "github.com/checkmarxDev/gpt-wrapper/pkg/role" - "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" "github.com/google/uuid" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -41,18 +41,24 @@ const dropLen = 4 const FileErrorFormat = "It seems that %s is not available for AI Guided Remediation. Please ensure that you have opened the correct workspace or the relevant file." +// chatModel model to use when calling the CheckmarxAI +const checkmarxAiChatModel = "gpt-4" +const tenantIDClaimKey = "tenant_id" +const guidedRemediationFeatureNameKics = "cli-guided-remediation-kics" +const guidedRemediationFeatureNameSast = "cli-guided-remediation-sast" + type OutputModel struct { ConversationID string `json:"conversationId"` Response []string `json:"response"` } -func ChatKicsSubCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command { +func ChatKicsSubCommand(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.TenantConfigurationWrapper) *cobra.Command { chatKicsCmd := &cobra.Command{ Use: "kics", Short: "Chat about KICS result with OpenAI models", Long: "Chat about KICS result with OpenAI models", Hidden: true, - RunE: runChatKics(chatWrapper), + RunE: runChatKics(chatWrapper, tenantWrapper), } chatKicsCmd.Flags().String(params.ChatAPIKey, "", "OpenAI API key") @@ -65,7 +71,6 @@ func ChatKicsSubCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command { chatKicsCmd.Flags().String(params.ChatKicsResultVulnerability, "", "IaC result vulnerability name") _ = chatKicsCmd.MarkFlagRequired(params.ChatUserInput) - _ = chatKicsCmd.MarkFlagRequired(params.ChatAPIKey) _ = chatKicsCmd.MarkFlagRequired(params.ChatKicsResultFile) _ = chatKicsCmd.MarkFlagRequired(params.ChatKicsResultLine) _ = chatKicsCmd.MarkFlagRequired(params.ChatKicsResultSeverity) @@ -74,27 +79,29 @@ func ChatKicsSubCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command { return chatKicsCmd } -func runChatKics(chatKicsWrapper wrappers.ChatWrapper) func(cmd *cobra.Command, args []string) error { +func runChatKics( + chatKicsWrapper wrappers.ChatWrapper, tenantWrapper wrappers.TenantConfigurationWrapper, +) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { - chatAPIKey, _ := cmd.Flags().GetString(params.ChatAPIKey) chatConversationID, _ := cmd.Flags().GetString(params.ChatConversationID) - chatModel, _ := cmd.Flags().GetString(params.ChatModel) chatResultFile, _ := cmd.Flags().GetString(params.ChatKicsResultFile) chatResultLine, _ := cmd.Flags().GetString(params.ChatKicsResultLine) chatResultSeverity, _ := cmd.Flags().GetString(params.ChatKicsResultSeverity) chatResultVulnerability, _ := cmd.Flags().GetString(params.ChatKicsResultVulnerability) userInput, _ := cmd.Flags().GetString(params.ChatUserInput) - statefulWrapper := wrapper.NewStatefulWrapper(connector.NewFileSystemConnector(""), chatAPIKey, chatModel, dropLen, 0) - - if chatConversationID == "" { - chatConversationID = statefulWrapper.GenerateId().String() + azureAiEnabled, checkmarxAiEnabled, tenantConfigurationResponses, err := getEngineSelection(cmd, tenantWrapper) + if err != nil { + return err } - id, err := uuid.Parse(chatConversationID) + statefulWrapper, customerToken := CreateStatefulWrapper(cmd, azureAiEnabled, checkmarxAiEnabled, tenantConfigurationResponses) + + tenantID := getTenantID(customerToken) + + id, err := getKicsConversationID(cmd, chatConversationID, statefulWrapper) if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID)) + return err } chatResultCode, err := os.ReadFile(chatResultFile) @@ -103,14 +110,13 @@ func runChatKics(chatKicsWrapper wrappers.ChatWrapper) func(cmd *cobra.Command, return outputError(cmd, id, errors.Errorf(FileErrorFormat, chatResultFile)) } - newMessages := buildMessages(chatResultCode, chatResultVulnerability, chatResultLine, chatResultSeverity, userInput) - response, err := chatKicsWrapper.Call(statefulWrapper, id, newMessages) + newMessages := buildKicsMessages(chatResultCode, chatResultVulnerability, chatResultLine, chatResultSeverity, userInput) + + responseContent, err := sendRequest(statefulWrapper, azureAiEnabled, checkmarxAiEnabled, tenantID, chatKicsWrapper, id, newMessages, customerToken, guidedRemediationFeatureNameKics) if err != nil { return outputError(cmd, id, err) } - responseContent := getMessageContents(response) - return printer.Print(cmd.OutOrStdout(), &OutputModel{ ConversationID: id.String(), Response: responseContent, @@ -118,7 +124,82 @@ func runChatKics(chatKicsWrapper wrappers.ChatWrapper) func(cmd *cobra.Command, } } -func buildMessages(chatResultCode []byte, +func getKicsConversationID(cmd *cobra.Command, chatConversationID string, statefulWrapper gptWrapper.StatefulWrapper) (uuid.UUID, error) { + if chatConversationID == "" { + chatConversationID = statefulWrapper.GenerateId().String() + } + + id, err := uuid.Parse(chatConversationID) + if err != nil { + logger.PrintIfVerbose(err.Error()) + return uuid.UUID{}, outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID)) + } + return id, err +} + +func getTenantID(customerToken string) string { + tenantID, _ := wrappers.ExtractFromTokenClaims(customerToken, tenantIDClaimKey) + // remove from tenant id all the string before :: + if strings.Contains(tenantID, "::") { + tenantID = tenantID[strings.LastIndex(tenantID, "::")+2:] + } + return tenantID +} + +func sendRequest(statefulWrapper gptWrapper.StatefulWrapper, azureAiEnabled bool, checkmarxAiEnabled bool, tenantID string, chatKicsWrapper wrappers.ChatWrapper, + id uuid.UUID, newMessages []message.Message, customerToken string, featureName string) (responseContent []string, err error) { + requestID := statefulWrapper.GenerateId().String() + + var response []message.Message + + if azureAiEnabled || checkmarxAiEnabled { + metadata := message.MetaData{ + TenantID: tenantID, + RequestID: requestID, + UserAgent: params.DefaultAgent, + Feature: featureName, + } + if azureAiEnabled { + logger.Printf("Sending message to Azure AI model for " + featureName + " guided remediation. RequestID: " + requestID) + } else { + logger.Printf("Sending message to Checkmarx AI model for " + featureName + " guided remediation. RequestID: " + requestID) + } + response, err = chatKicsWrapper.SecureCall(statefulWrapper, id, newMessages, &metadata, customerToken) + if err != nil { + return nil, err + } + } else { // if chatgpt is enabled or no engine is enabled + logger.Printf("Sending message to ChatGPT model for " + featureName + " guided remediation. RequestID: " + requestID) + response, err = chatKicsWrapper.Call(statefulWrapper, id, newMessages) + if err != nil { + return nil, err + } + } + + responseContent = getMessageContents(response) + return responseContent, nil +} + +func getEngineSelection(cmd *cobra.Command, tenantWrapper wrappers.TenantConfigurationWrapper) (azureAiEnabled, checkmarxAiEnabled bool, + tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse, err error) { + if !isCxOneAPIKeyAvailable() { + azureAiEnabled = false + checkmarxAiEnabled = false + logger.Printf("CxOne API key is not available, ChatGPT model will be used for guided remediation.") + } else { + var err error + tenantConfigurationResponses, err = GetTenantConfigurationResponses(tenantWrapper) + if err != nil { + return false, false, nil, outputError(cmd, uuid.Nil, err) + } + + azureAiEnabled = isAzureAiGuidedRemediationEnabled(tenantConfigurationResponses) + checkmarxAiEnabled = isCheckmarxAiGuidedRemediationEnabled(tenantConfigurationResponses) + } + return azureAiEnabled, checkmarxAiEnabled, tenantConfigurationResponses, nil +} + +func buildKicsMessages(chatResultCode []byte, chatResultVulnerability, chatResultLine, chatResultSeverity, userInput string) []message.Message { var newMessages []message.Message newMessages = append(newMessages, message.Message{ diff --git a/internal/commands/chat-kics_test.go b/internal/commands/chat-kics_test.go index e743a4f72..c110cee33 100644 --- a/internal/commands/chat-kics_test.go +++ b/internal/commands/chat-kics_test.go @@ -6,7 +6,11 @@ import ( "strings" "testing" + "github.com/checkmarx/ast-cli/internal/params" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/mock" "github.com/google/uuid" + "github.com/spf13/viper" "gotest.tools/assert" ) @@ -49,7 +53,6 @@ func TestChatKicsInvalidFile(t *testing.T) { func TestChatKicsCorrectResponse(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "kics", "--conversation-id", uuid.New().String(), - "--chat-apikey", "apiKey", "--user-input", "userInput", "--result-file", "./data/Dockerfile", "--result-line", "0", @@ -61,3 +64,76 @@ func TestChatKicsCorrectResponse(t *testing.T) { s := strings.ToLower(string(output)) assert.Assert(t, strings.Contains(s, "mock"), s) } + +func TestChatKicsAzureAICorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.azureAiGuidedRemediation", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "azureai", + }, + } + origAPIKey := viper.GetString(params.AstAPIKey) + viper.Set(params.AstAPIKey, "SomeKey") + + buffer, err := executeRedirectedTestCommand("chat", "kics", + "--conversation-id", uuid.New().String(), + "--user-input", "userInput", + "--result-file", "./data/Dockerfile", + "--result-line", "0", + "--result-severity", "LOW", + "--result-vulnerability", "Vulnerability") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + viper.Set(params.AstAPIKey, origAPIKey) + + assert.Assert(t, strings.Contains(s, "mock message from securecall"), s) +} + +func TestChatKicsCheckmarxAICorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.checkmarxAiGuidedRemediation", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "checkmarxai", + }, + } + origAPIKey := viper.GetString(params.AstAPIKey) + viper.Set(params.AstAPIKey, "SomeKey") + + buffer, err := executeRedirectedTestCommand("chat", "kics", + "--conversation-id", uuid.New().String(), + "--chat-apikey", "apiKey", + "--user-input", "userInput", + "--result-file", "./data/Dockerfile", + "--result-line", "0", + "--result-severity", "LOW", + "--result-vulnerability", "Vulnerability") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + viper.Set(params.AstAPIKey, origAPIKey) + + assert.Assert(t, strings.Contains(s, "mock message from securecall"), s) +} diff --git a/internal/commands/chat-sast.go b/internal/commands/chat-sast.go index d39a81d7c..7e9345672 100644 --- a/internal/commands/chat-sast.go +++ b/internal/commands/chat-sast.go @@ -3,19 +3,21 @@ package commands import ( "fmt" "strconv" + "strings" sastchat "github.com/Checkmarx/gen-ai-prompts/prompts/sast_result_remediation" + "github.com/Checkmarx/gen-ai-wrapper/pkg/connector" + "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + "github.com/Checkmarx/gen-ai-wrapper/pkg/role" + "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/checkmarx/ast-cli/internal/commands/util/printer" "github.com/checkmarx/ast-cli/internal/logger" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" - "github.com/checkmarxDev/gpt-wrapper/pkg/connector" - "github.com/checkmarxDev/gpt-wrapper/pkg/message" - "github.com/checkmarxDev/gpt-wrapper/pkg/role" - "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" "github.com/google/uuid" "github.com/pkg/errors" "github.com/spf13/cobra" + "github.com/spf13/viper" ) const UserInputRequiredErrorFormat = "%s is required when %s is provided" @@ -38,7 +40,6 @@ func ChatSastSubCommand(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers chatSastCmd.Flags().String(params.ChatSastSourceDir, "", "Source code root directory relevant for the results file") chatSastCmd.Flags().String(params.ChatSastResultID, "", "ID of the result to remediate") - _ = chatSastCmd.MarkFlagRequired(params.ChatAPIKey) _ = chatSastCmd.MarkFlagRequired(params.ChatSastScanResultsFile) _ = chatSastCmd.MarkFlagRequired(params.ChatSastSourceDir) _ = chatSastCmd.MarkFlagRequired(params.ChatSastResultID) @@ -46,68 +47,43 @@ func ChatSastSubCommand(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers return chatSastCmd } -func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.TenantConfigurationWrapper) func(cmd *cobra.Command, args []string) error { +func runChatSast( + chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.TenantConfigurationWrapper, +) func(cmd *cobra.Command, args []string) error { return func(cmd *cobra.Command, args []string) error { - if !isAiGuidedRemediationEnabled(tenantWrapper) { + tenantConfigurationResponses, err := GetTenantConfigurationResponses(tenantWrapper) + if err != nil { + return outputError(cmd, uuid.Nil, err) + } + if !isAiGuidedRemediationEnabled(tenantConfigurationResponses) { return outputError(cmd, uuid.Nil, errors.Errorf(AiGuidedRemediationDisabledError)) } - chatAPIKey, _ := cmd.Flags().GetString(params.ChatAPIKey) chatConversationID, _ := cmd.Flags().GetString(params.ChatConversationID) - chatModel, _ := cmd.Flags().GetString(params.ChatModel) scanResultsFile, _ := cmd.Flags().GetString(params.ChatSastScanResultsFile) sourceDir, _ := cmd.Flags().GetString(params.ChatSastSourceDir) sastResultID, _ := cmd.Flags().GetString(params.ChatSastResultID) + azureAiEnabled := isAzureAiGuidedRemediationEnabled(tenantConfigurationResponses) + checkmarxAiEnabled := isCheckmarxAiGuidedRemediationEnabled(tenantConfigurationResponses) - statefulWrapper := wrapper.NewStatefulWrapper(connector.NewFileSystemConnector(""), chatAPIKey, chatModel, dropLen, 0) - - newConversation := false - var userInput string - if chatConversationID == "" { - newConversation = true - chatConversationID = statefulWrapper.GenerateId().String() - } else { - userInput, _ = cmd.Flags().GetString(params.ChatUserInput) - if userInput == "" { - msg := fmt.Sprintf(UserInputRequiredErrorFormat, params.ChatUserInput, params.ChatConversationID) - logger.PrintIfVerbose(msg) - return outputError(cmd, uuid.Nil, errors.Errorf(msg)) - } - } + statefulWrapper, customerToken := CreateStatefulWrapper(cmd, azureAiEnabled, checkmarxAiEnabled, tenantConfigurationResponses) + + tenantID := getTenantID(customerToken) - id, err := uuid.Parse(chatConversationID) + newConversation, userInput, id, err := getSastConversationDetails(cmd, chatConversationID, statefulWrapper) if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID)) + return err } - var newMessages []message.Message - if newConversation { - systemPrompt, userPrompt, e := sastchat.BuildPrompt(scanResultsFile, sastResultID, sourceDir) - if e != nil { - logger.PrintIfVerbose(e.Error()) - return outputError(cmd, id, e) - } - newMessages = append(newMessages, message.Message{ - Role: role.System, - Content: systemPrompt, - }, message.Message{ - Role: role.User, - Content: userPrompt, - }) - } else { - newMessages = append(newMessages, message.Message{ - Role: role.User, - Content: userInput, - }) + newMessages, err := buildSastMessages(cmd, newConversation, scanResultsFile, sastResultID, sourceDir, id, userInput) + if err != nil { + return err } - response, err := chatWrapper.Call(statefulWrapper, id, newMessages) + responseContent, err := sendRequest(statefulWrapper, azureAiEnabled, checkmarxAiEnabled, tenantID, chatWrapper, id, newMessages, customerToken, guidedRemediationFeatureNameSast) if err != nil { return outputError(cmd, id, err) } - responseContent := getMessageContents(response) - responseContent = sastchat.AddDescriptionForIdentifier(responseContent) return printer.Print(cmd.OutOrStdout(), &OutputModel{ @@ -117,23 +93,135 @@ func runChatSast(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Tenant } } -func isAiGuidedRemediationEnabled(tenantWrapper wrappers.TenantConfigurationWrapper) bool { +func getSastConversationDetails(cmd *cobra.Command, chatConversationID string, statefulWrapper wrapper.StatefulWrapper) ( + isNewConversation bool, userInput string, conversationID uuid.UUID, err error) { + newConversation := false + if chatConversationID == "" { + newConversation = true + chatConversationID = statefulWrapper.GenerateId().String() + } else { + userInput, _ = cmd.Flags().GetString(params.ChatUserInput) + if userInput == "" { + msg := fmt.Sprintf(UserInputRequiredErrorFormat, params.ChatUserInput, params.ChatConversationID) + logger.PrintIfVerbose(msg) + return false, "", uuid.UUID{}, outputError(cmd, uuid.Nil, errors.Errorf(msg)) + } + } + + id, err := uuid.Parse(chatConversationID) + if err != nil { + logger.PrintIfVerbose(err.Error()) + return false, "", uuid.UUID{}, outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID)) + } + return newConversation, userInput, id, nil +} + +func buildSastMessages(cmd *cobra.Command, newConversation bool, scanResultsFile, sastResultID, sourceDir string, id uuid.UUID, userInput string) ([]message.Message, error) { + var newMessages []message.Message + if newConversation { + systemPrompt, userPrompt, e := sastchat.BuildPrompt(scanResultsFile, sastResultID, sourceDir) + if e != nil { + logger.PrintIfVerbose(e.Error()) + return nil, outputError(cmd, id, e) + } + newMessages = append(newMessages, message.Message{ + Role: role.System, + Content: systemPrompt, + }, message.Message{ + Role: role.User, + Content: userPrompt, + }) + } else { + newMessages = append(newMessages, message.Message{ + Role: role.User, + Content: userInput, + }) + } + return newMessages, nil +} + +func CreateStatefulWrapper(cmd *cobra.Command, azureAiEnabled, checkmarxAiEnabled bool, tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse) ( + statefulWrapper wrapper.StatefulWrapper, customerToken string) { + conn := connector.NewFileSystemConnector("") + + customerToken, _ = wrappers.GetAccessToken() + + if azureAiEnabled { + aiProxyAzureAIRoute := viper.GetString(params.AiProxyAzureAiRouteKey) + aiProxyEndPoint, _ := wrappers.GetURL(aiProxyAzureAIRoute, customerToken) + statefulWrapper, _ = wrapper.NewStatefulWrapperNew(conn, aiProxyEndPoint, customerToken, "", dropLen, 0) + } else if checkmarxAiEnabled { + aiProxyCheckmarxAIRoute := viper.GetString(params.AiProxyCheckmarxAiRouteKey) + aiProxyEndPoint, _ := wrappers.GetURL(aiProxyCheckmarxAIRoute, customerToken) + model := checkmarxAiChatModel + statefulWrapper, _ = wrapper.NewStatefulWrapperNew(conn, aiProxyEndPoint, customerToken, model, dropLen, 0) + } else { + chatModel, _ := cmd.Flags().GetString(params.ChatModel) + chatAPIKey, _ := cmd.Flags().GetString(params.ChatAPIKey) + statefulWrapper = wrapper.NewStatefulWrapper(conn, chatAPIKey, chatModel, dropLen, 0) + } + return statefulWrapper, customerToken +} + +func GetTenantConfigurationResponses(tenantWrapper wrappers.TenantConfigurationWrapper) (*[]*wrappers.TenantConfigurationResponse, error) { tenantConfigurationResponse, errorModel, err := tenantWrapper.GetTenantConfiguration() if err != nil { - return false + return nil, err } if errorModel != nil { - return false + return nil, errors.New(errorModel.Message) } - if tenantConfigurationResponse != nil { - for _, resp := range *tenantConfigurationResponse { - if resp.Key == AiGuidedRemediationEnabled { - isEnabled, _ := strconv.ParseBool(resp.Value) - return isEnabled + return tenantConfigurationResponse, nil +} + +func GetTenantConfiguration(tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse, configKey string) (string, error) { + if tenantConfigurationResponses != nil { + for _, resp := range *tenantConfigurationResponses { + if resp.Key == configKey { + return resp.Value, nil } } } - return false + return "", errors.New(configKey + " not found") +} + +func GetTenantConfigurationBool(tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse, configKey string) (bool, error) { + value, err := GetTenantConfiguration(tenantConfigurationResponses, configKey) + if err != nil { + return false, err + } + return strconv.ParseBool(value) +} + +func isAiGuidedRemediationEnabled(tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse) bool { + isEnabled, err := GetTenantConfigurationBool(tenantConfigurationResponses, AiGuidedRemediationEnabled) + if err != nil { + return false + } + return isEnabled +} + +func isCxOneAPIKeyAvailable() bool { + apiKey := viper.GetString(params.AstAPIKey) + return apiKey != "" +} + +func isAzureAiGuidedRemediationEnabled(tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse) bool { + engine, err := GetTenantConfiguration(tenantConfigurationResponses, AiGuidedRemediationEngine) + if err != nil { + return false + } + isEnabled := strings.EqualFold(engine, AiGuidedRemediationAzureAiValue) + return isEnabled +} + +func isCheckmarxAiGuidedRemediationEnabled(tenantConfigurationResponses *[]*wrappers.TenantConfigurationResponse) bool { + engine, err := GetTenantConfiguration(tenantConfigurationResponses, AiGuidedRemediationEngine) + if err != nil { + return false + } + isEnabled := strings.EqualFold(engine, AiGuidedRemediationCheckmarxAiValue) + return isEnabled } func getMessageContents(response []message.Message) []string { diff --git a/internal/commands/chat-sast_test.go b/internal/commands/chat-sast_test.go index 6161d0114..447eca86e 100644 --- a/internal/commands/chat-sast_test.go +++ b/internal/commands/chat-sast_test.go @@ -112,6 +112,20 @@ func TestChatSastInvalidSourceDir(t *testing.T) { } func TestChatSastFirstMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "openai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } buffer, err := executeRedirectedTestCommand("chat", "sast", "--chat-apikey", "apiKey", "--scan-results-file", "./data/cx_result.json", @@ -124,7 +138,75 @@ func TestChatSastFirstMessageCorrectResponse(t *testing.T) { assert.Assert(t, strings.Contains(s, "mock"), s) } +func TestChatSastAzureAIFirstMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "azureai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--scan-results-file", "./data/cx_result.json", + "--source-dir", "./data", + "--sast-result-id", "13588362") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + assert.Assert(t, strings.Contains(s, "mock"), s) +} + +func TestChatSastCheckmarxAIFirstMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "checkmarxai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--scan-results-file", "./data/cx_result.json", + "--source-dir", "./data", + "--sast-result-id", "13588362") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + assert.Assert(t, strings.Contains(s, "mock"), s) +} + func TestChatSastSecondMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "openai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } buffer, err := executeRedirectedTestCommand("chat", "sast", "--chat-apikey", "apiKey", "--scan-results-file", "./data/cx_result.json", @@ -138,3 +220,61 @@ func TestChatSastSecondMessageCorrectResponse(t *testing.T) { s := strings.ToLower(string(output)) assert.Assert(t, strings.Contains(s, "mock"), s) } + +func TestChatSastAzureAISecondMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "azureai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--scan-results-file", "./data/cx_result.json", + "--source-dir", "./data", + "--sast-result-id", "13588362", + "--conversation-id", uuid.New().String(), + "--user-input", "userInput") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + assert.Assert(t, strings.Contains(s, "mock message from securecall"), s) +} + +func TestChatSastCheckmarxAISecondMessageCorrectResponse(t *testing.T) { + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "checkmarxai", + }, + { + Key: "scan.config.plugins.aiGuidedRemediation", + Value: "true", + }, + } + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--scan-results-file", "./data/cx_result.json", + "--source-dir", "./data", + "--sast-result-id", "13588362", + "--conversation-id", uuid.New().String(), + "--user-input", "userInput") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + mock.TenantConfiguration = []*wrappers.TenantConfigurationResponse{} + assert.Assert(t, strings.Contains(s, "mock message from securecall"), s) +} diff --git a/internal/commands/chat.go b/internal/commands/chat.go index 28c205462..795394c2f 100644 --- a/internal/commands/chat.go +++ b/internal/commands/chat.go @@ -6,8 +6,11 @@ import ( ) const ( - ConversationIDErrorFormat = "Invalid conversation ID %s" - AiGuidedRemediationEnabled = "scan.config.plugins.aiGuidedRemediation" + ConversationIDErrorFormat = "Invalid conversation ID %s" + AiGuidedRemediationEnabled = "scan.config.plugins.aiGuidedRemediation" + AiGuidedRemediationEngine = "scan.config.plugins.aiGuidedRemediationAiEngine" + AiGuidedRemediationAzureAiValue = "azureai" + AiGuidedRemediationCheckmarxAiValue = "checkmarxai" ) func NewChatCommand(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.TenantConfigurationWrapper) *cobra.Command { @@ -17,7 +20,7 @@ func NewChatCommand(chatWrapper wrappers.ChatWrapper, tenantWrapper wrappers.Ten Long: "Chat with OpenAI models regarding KICS or SAST results", Hidden: true, } - chatKicsCmd := ChatKicsSubCommand(chatWrapper) + chatKicsCmd := ChatKicsSubCommand(chatWrapper, tenantWrapper) chatSastCmd := ChatSastSubCommand(chatWrapper, tenantWrapper) chatCmd.AddCommand(chatKicsCmd, chatSastCmd) diff --git a/internal/commands/util/maskSecret.go b/internal/commands/util/maskSecret.go index ef894bf09..99ad8db06 100644 --- a/internal/commands/util/maskSecret.go +++ b/internal/commands/util/maskSecret.go @@ -3,12 +3,12 @@ package util import ( "os" + "github.com/Checkmarx/gen-ai-wrapper/pkg/connector" + "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/MakeNowJust/heredoc" "github.com/checkmarx/ast-cli/internal/commands/util/printer" "github.com/checkmarx/ast-cli/internal/params" "github.com/checkmarx/ast-cli/internal/wrappers" - "github.com/checkmarxDev/gpt-wrapper/pkg/connector" - "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" "github.com/pkg/errors" "github.com/spf13/cobra" ) diff --git a/internal/params/binds.go b/internal/params/binds.go index 71b359d1d..065187913 100644 --- a/internal/params/binds.go +++ b/internal/params/binds.go @@ -63,5 +63,7 @@ var EnvVarsBinds = []struct { {PolicyEvaluationPathKey, PolicyEvaluationPathEnv, "api/policy_management_service_uri/evaluation"}, {AccessManagementPathKey, AccessManagementPathEnv, "api/access-management"}, {ByorPathKey, ByorPathEnv, "api/byor"}, + {AiProxyAzureAiRouteKey, AiProxyAzureAiRouteEnv, "api/ai-proxy/redirect/externalAzure"}, + {AiProxyCheckmarxAiRouteKey, AiProxyCheckmarxAiRouteEnv, "api/ai-proxy/redirect/azure"}, {ASCAPortKey, ASCAPortEnv, ""}, } diff --git a/internal/params/envs.go b/internal/params/envs.go index cd29f1081..500f032af 100644 --- a/internal/params/envs.go +++ b/internal/params/envs.go @@ -62,5 +62,7 @@ const ( AccessManagementPathEnv = "CX_ACCESS_MANAGEMENT_PATH" ByorPathEnv = "CX_BYOR_PATH" IgnoreProxyEnv = "CX_IGNORE_PROXY" + AiProxyAzureAiRouteEnv = "CX_AIPROXY_AZUREAI_ROUTE" + AiProxyCheckmarxAiRouteEnv = "CX_AIPROXY_CHECKMARXAI_ROUTE" ASCAPortEnv = "CX_ASCA_PORT" ) diff --git a/internal/params/keys.go b/internal/params/keys.go index ed23c6b9a..8c944f761 100644 --- a/internal/params/keys.go +++ b/internal/params/keys.go @@ -62,5 +62,7 @@ var ( PolicyEvaluationPathKey = strings.ToLower(PolicyEvaluationPathEnv) AccessManagementPathKey = strings.ToLower(AccessManagementPathEnv) ByorPathKey = strings.ToLower(ByorPathEnv) + AiProxyAzureAiRouteKey = strings.ToLower(AiProxyAzureAiRouteEnv) + AiProxyCheckmarxAiRouteKey = strings.ToLower(AiProxyCheckmarxAiRouteEnv) ASCAPortKey = strings.ToLower(ASCAPortEnv) ) diff --git a/internal/wrappers/chat-http.go b/internal/wrappers/chat-http.go index 35c6f70cb..ad9118de7 100644 --- a/internal/wrappers/chat-http.go +++ b/internal/wrappers/chat-http.go @@ -1,9 +1,9 @@ package wrappers import ( - gptWrapperMaskedSecret "github.com/checkmarxDev/gpt-wrapper/pkg/maskedSecret" - gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message" - gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" + gptWrapperMaskedSecret "github.com/Checkmarx/gen-ai-wrapper/pkg/maskedSecret" + gptWrapperMessage "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + gptWrapper "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/google/uuid" ) @@ -18,6 +18,13 @@ func (c ChatHTTPWrapper) Call(w gptWrapper.StatefulWrapper, id uuid.UUID, messag return w.Call(id, messages) } +func (c ChatHTTPWrapper) SecureCall(w gptWrapper.StatefulWrapper, historyID uuid.UUID, messages []gptWrapperMessage.Message, metaData *gptWrapperMessage.MetaData, cxAuth string) ( + []gptWrapperMessage.Message, + error, +) { + return w.SecureCall(cxAuth, metaData, historyID, messages) +} + func NewChatWrapper() ChatWrapper { return ChatHTTPWrapper{} } diff --git a/internal/wrappers/chat.go b/internal/wrappers/chat.go index 3cdb62f25..f9428587f 100644 --- a/internal/wrappers/chat.go +++ b/internal/wrappers/chat.go @@ -1,13 +1,14 @@ package wrappers import ( - gptWrapperMaskedSecret "github.com/checkmarxDev/gpt-wrapper/pkg/maskedSecret" - gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message" - gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" + gptWrapperMaskedSecret "github.com/Checkmarx/gen-ai-wrapper/pkg/maskedSecret" + gptWrapperMessage "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + gptWrapper "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/google/uuid" ) type ChatWrapper interface { Call(gptWrapper.StatefulWrapper, uuid.UUID, []gptWrapperMessage.Message) ([]gptWrapperMessage.Message, error) + SecureCall(gptWrapper.StatefulWrapper, uuid.UUID, []gptWrapperMessage.Message, *gptWrapperMessage.MetaData, string) ([]gptWrapperMessage.Message, error) MaskSecrets(gptWrapper.StatefulWrapper, string) (*gptWrapperMaskedSecret.MaskedEntry, error) } diff --git a/internal/wrappers/client.go b/internal/wrappers/client.go index 9b132dc04..d8fb131d0 100644 --- a/internal/wrappers/client.go +++ b/internal/wrappers/client.go @@ -662,7 +662,7 @@ func GetAuthURI() (string, error) { apiKey := viper.GetString(commonParams.AstAPIKey) if len(apiKey) > 0 { logger.PrintIfVerbose("Base Auth URI - Extract from API KEY") - authURI, err = extractFromTokenClaims(apiKey, audienceClaimKey) + authURI, err = ExtractFromTokenClaims(apiKey, audienceClaimKey) if err != nil { return "", err } @@ -710,7 +710,7 @@ func GetURL(path, accessToken string) (string, error) { override := viper.GetBool(commonParams.ApikeyOverrideFlag) if accessToken != "" { logger.PrintIfVerbose("Base URI - Extract from JWT token") - cleanURL, err = extractFromTokenClaims(accessToken, baseURLKey) + cleanURL, err = ExtractFromTokenClaims(accessToken, baseURLKey) if err != nil { return "", err } @@ -731,7 +731,7 @@ func GetURL(path, accessToken string) (string, error) { return fmt.Sprintf("%s/%s", cleanURL, path), nil } -func extractFromTokenClaims(accessToken, claim string) (string, error) { +func ExtractFromTokenClaims(accessToken, claim string) (string, error) { var value string token, _, err := new(jwt.Parser).ParseUnverified(accessToken, jwt.MapClaims{}) if err != nil { diff --git a/internal/wrappers/feature-flags-http.go b/internal/wrappers/feature-flags-http.go index 40908f783..b1aa13b74 100644 --- a/internal/wrappers/feature-flags-http.go +++ b/internal/wrappers/feature-flags-http.go @@ -98,7 +98,7 @@ func (f FeatureFlagsHTTPWrapper) getTenantDetails() (params map[string]string, c return nil, 0, tokenError } - tenantIDFromClaims, extractError := extractFromTokenClaims(accessToken, tenantIDClaimKey) + tenantIDFromClaims, extractError := ExtractFromTokenClaims(accessToken, tenantIDClaimKey) if extractError != nil { return nil, 0, extractError } diff --git a/internal/wrappers/mock/chat-mock.go b/internal/wrappers/mock/chat-mock.go index ad3540828..ee97da9db 100644 --- a/internal/wrappers/mock/chat-mock.go +++ b/internal/wrappers/mock/chat-mock.go @@ -1,10 +1,10 @@ package mock import ( - gptWrapperMaskedSecret "github.com/checkmarxDev/gpt-wrapper/pkg/maskedSecret" - gptWrapperMessage "github.com/checkmarxDev/gpt-wrapper/pkg/message" - gptWrapperRole "github.com/checkmarxDev/gpt-wrapper/pkg/role" - gptWrapper "github.com/checkmarxDev/gpt-wrapper/pkg/wrapper" + gptWrapperMaskedSecret "github.com/Checkmarx/gen-ai-wrapper/pkg/maskedSecret" + gptWrapperMessage "github.com/Checkmarx/gen-ai-wrapper/pkg/message" + gptWrapperRole "github.com/Checkmarx/gen-ai-wrapper/pkg/role" + gptWrapper "github.com/Checkmarx/gen-ai-wrapper/pkg/wrapper" "github.com/google/uuid" ) @@ -28,3 +28,13 @@ func (c ChatMockWrapper) Call(_ gptWrapper.StatefulWrapper, _ uuid.UUID, _ []gpt Content: "Mock message", }}, nil } + +func (c ChatMockWrapper) SecureCall(w gptWrapper.StatefulWrapper, id uuid.UUID, messages []gptWrapperMessage.Message, metaData *gptWrapperMessage.MetaData, cxAuth string) ( + []gptWrapperMessage.Message, + error, +) { + return []gptWrapperMessage.Message{{ + Role: gptWrapperRole.Assistant, + Content: "Mock message from SecureCall", + }}, nil +} diff --git a/internal/wrappers/mock/tenant-mock.go b/internal/wrappers/mock/tenant-mock.go index 3ef26e66a..4cacbd990 100644 --- a/internal/wrappers/mock/tenant-mock.go +++ b/internal/wrappers/mock/tenant-mock.go @@ -22,7 +22,15 @@ func (t TenantConfigurationMockWrapper) GetTenantConfiguration() ( Key: "scan.config.plugins.aiGuidedRemediation", Value: "true", }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "azureai", + }, } } return &TenantConfiguration, nil, nil } + +func (t TenantConfigurationMockWrapper) SetTenantConfiguration(response []*wrappers.TenantConfigurationResponse) { + TenantConfiguration = response +} diff --git a/test/integration/chat_test.go b/test/integration/chat_test.go index a9b1bfc2e..eb0cfd7f6 100644 --- a/test/integration/chat_test.go +++ b/test/integration/chat_test.go @@ -3,14 +3,21 @@ package integration import ( + "bytes" "strings" "testing" "github.com/checkmarx/ast-cli/internal/commands" + "github.com/checkmarx/ast-cli/internal/wrappers" + "github.com/checkmarx/ast-cli/internal/wrappers/mock" "github.com/google/uuid" "gotest.tools/assert" ) +const ( + INCORRECT_API_ERROR = "Error Code: 401, Access denied due to invalid subscription key or wrong API endpoint" +) + func TestChatKicsInvalidAPIKey(t *testing.T) { args := []string{ "chat", "kics", @@ -26,7 +33,7 @@ func TestChatKicsInvalidAPIKey(t *testing.T) { assert.NilError(t, err) outputModel := commands.OutputModel{} unmarshall(t, respBuffer, &outputModel, "Reading results should pass") - assert.Assert(t, strings.Contains(outputModel.Response[0], "Incorrect API key provided"), "Expecting incorrect api key error") + assert.Assert(t, strings.Contains(outputModel.Response[0], "Incorrect API key provided"), "Expecting incorrect api key error. Got: "+outputModel.Response[0]) } func TestChatSastInvalidAPIKey(t *testing.T) { @@ -41,5 +48,55 @@ func TestChatSastInvalidAPIKey(t *testing.T) { assert.NilError(t, err) outputModel := commands.OutputModel{} unmarshall(t, respBuffer, &outputModel, "Reading results should pass") - assert.Assert(t, strings.Contains(outputModel.Response[0], "Incorrect API key provided"), "Expecting incorrect api key error") + assert.Assert(t, strings.Contains(outputModel.Response[0], "Incorrect API key provided"), "Expecting incorrect api key error. Got: "+outputModel.Response[0]) +} + +func TestChatKicsAzureAIInvalidAPIKey(t *testing.T) { + t.Skip("Skipping this test since not all services are deployed to production yet") + createASTIntegrationTestCommand(t) + mockConfig := []*wrappers.TenantConfigurationResponse{ + { + Key: "scan.config.plugins.ideScans", + Value: "true", + }, + { + Key: "scan.config.plugins.azureAiGuidedRemediation", + Value: "true", + }, + { + Key: "scan.config.plugins.aiGuidedRemediationAiEngine", + Value: "azureai", + }, + } + + mockTenant := mock.TenantConfigurationMockWrapper{} + mockTenant.SetTenantConfiguration(mockConfig) + + args := []string{ + "chat", "kics", + "--conversation-id", uuid.New().String(), + "--chat-apikey", "invalidApiKey", + "--user-input", "userInput", + "--result-file", "./data/Dockerfile", + "--result-line", "0", + "--result-severity", "LOW", + "--result-vulnerability", "Vulnerability", + } + + response, responseString := RunKicsChatForTest(t, mockTenant, args...) + assert.Assert(t, strings.Contains(response.Response[0], INCORRECT_API_ERROR), "Expecting incorrect api key error. Got: "+responseString) + +} + +func RunKicsChatForTest(t *testing.T, tenantWrapper mock.TenantConfigurationMockWrapper, args ...string) (commands.OutputModel, string) { + outputBuffer := bytes.NewBufferString("") + cmd := commands.ChatKicsSubCommand(wrappers.NewChatWrapper(), tenantWrapper) + cmd.SetArgs(args) + cmd.SetOut(outputBuffer) + err := cmd.Execute() + assert.NilError(t, err) + outputModel := commands.OutputModel{} + unmarshall(t, outputBuffer, &outputModel, "Reading results should pass") + outputString := outputBuffer.String() + return outputModel, outputString }