diff --git a/assert_test.go b/assert_test.go index 185e2547e..d25217bd7 100644 --- a/assert_test.go +++ b/assert_test.go @@ -10,6 +10,10 @@ import ( "testing" ) +func assertNilE(t *testing.T, actual any, descriptions ...string) { + errorOnNonEmpty(t, validateNil(actual, descriptions...)) +} + func assertNilF(t *testing.T, actual any, descriptions ...string) { fatalOnNonEmpty(t, validateNil(actual, descriptions...)) } diff --git a/auth_test.go b/auth_test.go index c3a39a707..4a6fd0e9f 100644 --- a/auth_test.go +++ b/auth_test.go @@ -656,6 +656,36 @@ func TestUnitAuthenticateWithConfigMFA(t *testing.T) { } } +func TestUnitAuthenticateWithConfigOkta(t *testing.T) { + var err error + sr := &snowflakeRestful{ + Protocol: "https", + Host: "abc.com", + Port: 443, + FuncPostAuthSAML: postAuthSAMLAuthSuccess, + FuncPostAuthOKTA: postAuthOKTASuccess, + FuncGetSSO: getSSOSuccess, + FuncPostAuth: postAuthSuccess, + TokenAccessor: getSimpleTokenAccessor(), + } + sc := getDefaultSnowflakeConn() + sc.cfg.Authenticator = AuthTypeOkta + sc.cfg.OktaURL = &url.URL{ + Scheme: "https", + Host: "abc.com", + } + sc.rest = sr + sc.ctx = context.Background() + + err = authenticateWithConfig(sc) + assertNilE(t, err, "expected to have no error.") + + sr.FuncPostAuthSAML = postAuthSAMLError + err = authenticateWithConfig(sc) + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to get SAML response") +} + func TestUnitAuthenticateExternalBrowser(t *testing.T) { var err error sr := &snowflakeRestful{ diff --git a/authokta.go b/authokta.go index d1a08a47e..818753af8 100644 --- a/authokta.go +++ b/authokta.go @@ -111,8 +111,8 @@ func authenticateBySAML( if tokenURL, err = url.Parse(respd.Data.TokenURL); err != nil { return nil, fmt.Errorf("failed to parse token URL. %v", respd.Data.TokenURL) } - if ssoURL, err = url.Parse(respd.Data.TokenURL); err != nil { - return nil, fmt.Errorf("failed to parse ssoURL URL. %v", respd.Data.SSOURL) + if ssoURL, err = url.Parse(respd.Data.SSOURL); err != nil { + return nil, fmt.Errorf("failed to parse SSO URL. %v", respd.Data.SSOURL) } if !isPrefixEqual(oktaURL, ssoURL) || !isPrefixEqual(oktaURL, tokenURL) { return nil, &SnowflakeError{ diff --git a/authokta_test.go b/authokta_test.go index a9a4b2772..56e151215 100644 --- a/authokta_test.go +++ b/authokta_test.go @@ -7,6 +7,7 @@ import ( "errors" "net/http" "net/url" + "strconv" "testing" "time" ) @@ -122,6 +123,10 @@ func TestUnitGetSSO(t *testing.T) { if err != nil { t.Fatalf("failed to get HTML content. err: %v", err) } + _, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0) + if err == nil { + t.Fatal("should have failed to parse URL.") + } } func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { @@ -135,6 +140,14 @@ func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]s }, nil } +func postAuthSAMLAuthFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: false, + Code: strconv.Itoa(ErrCodeIdpConnectionError), + Message: "SAML auth failed", + }, nil +} + func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, @@ -146,6 +159,28 @@ func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful }, nil } +func postAuthSAMLAuthSuccessButInvalidTokenURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: true, + Message: "", + Data: authResponseMain{ + TokenURL: "invalid!@url$%^", + SSOURL: "https://abc.com/sso", + }, + }, nil +} + +func postAuthSAMLAuthSuccessButInvalidSSOURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { + return &authResponse{ + Success: true, + Message: "", + Data: authResponseMain{ + TokenURL: "https://abc.com/token", + SSOURL: "invalid!@url$%^", + }, + }, nil +} + func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) { return &authResponse{ Success: true, @@ -177,6 +212,10 @@ func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[ return []byte(`
`), nil } +func getSSOSuccessButWrongPrefixURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) { + return []byte(``), nil +} + func TestUnitAuthenticateBySAML(t *testing.T) { authenticator := &url.URL{ Scheme: "https", @@ -195,46 +234,63 @@ func TestUnitAuthenticateBySAML(t *testing.T) { } var err error _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to get SAML response") + sr.FuncPostAuthSAML = postAuthSAMLAuthFail _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } - sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "strconv.Atoi: parsing \"\": invalid syntax") + + sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrCodeIdpConnectionError { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrCodeIdpConnectionError, driverErr.Number) - } + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) + + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + driverErr, ok = err.(*SnowflakeError) + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError) + + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to parse token URL. invalid!@url$%^") + + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to parse SSO URL. invalid!@url$%^") + sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess sr.FuncPostAuthOKTA = postAuthOKTAError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncPostAuthOKTA.") + assertEqualE(t, err.Error(), "failed to get SAML response") + sr.FuncPostAuthOKTA = postAuthOKTASuccess sr.FuncGetSSO = getSSOError _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncGetSSO.") + assertEqualE(t, err.Error(), "failed to get SSO html") + sr.FuncGetSSO = getSSOSuccessButInvalidURL _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err == nil { - t.Fatal("should have failed.") - } + assertNotNilF(t, err, "should have failed at FuncGetSSO.") + assertHasPrefixE(t, err.Error(), "failed to find action field in HTML response") + sr.FuncGetSSO = getSSOSuccess _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) - if err != nil { - t.Fatalf("failed. err: %v", err) - } + assertNilF(t, err, "should have succeeded at FuncGetSSO.") + + sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL + _, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password) + assertNotNilF(t, err, "should have failed at FuncGetSSO.") + driverErr, ok = err.(*SnowflakeError) + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch) } diff --git a/connector_test.go b/connector_test.go index 6c6b3caf4..76886199e 100644 --- a/connector_test.go +++ b/connector_test.go @@ -47,3 +47,23 @@ func TestConnector(t *testing.T) { t.Fatalf("Missing driver") } } + +func TestConnectorWithMissingConfig(t *testing.T) { + conn := snowflakeConn{} + mock := noopTestDriver{conn: &conn} + config := Config{ + User: "u", + Password: "p", + Account: "", + } + expectedErr := errEmptyAccount() + + connector := NewConnector(&mock, config) + _, err := connector.Connect(context.Background()) + assertNotNilF(t, err, "the connection should have failed due to empty account.") + + driverErr, ok := err.(*SnowflakeError) + assertTrueF(t, ok, "should be a SnowflakeError") + assertEqualE(t, driverErr.Number, expectedErr.Number) + assertEqualE(t, driverErr.Message, expectedErr.Message) +} diff --git a/converter_test.go b/converter_test.go index 98dda7aa2..a50115e9d 100644 --- a/converter_test.go +++ b/converter_test.go @@ -538,6 +538,70 @@ func TestArrowToValue(t *testing.T) { }, higherPrecision: true, }, + { + logical: "fixed", + physical: "int16", + values: []string{"1.2345", "2.3456"}, + rowType: execResponseRowType{Scale: 4}, + builder: array.NewInt16Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 4) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int16Builder).Append(int16(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 4) + if !ok { + return i + } + srcDec := intToBigFloat(num, 4) + dstDec := dst[i].(*big.Float) + if srcDec.Cmp(dstDec) != 0 { + return i + } + } + return -1 + }, + higherPrecision: true, + }, + { + logical: "fixed", + physical: "int16", + values: []string{"1.2345", "2.3456"}, + rowType: execResponseRowType{Scale: 4}, + builder: array.NewInt16Builder(pool), + append: func(b array.Builder, vs interface{}) { + for _, s := range vs.([]string) { + num, ok := stringFloatToInt(s, 4) + if !ok { + t.Fatalf("failed to convert to int") + } + b.(*array.Int16Builder).Append(int16(num)) + } + }, + compare: func(src interface{}, dst []snowflakeValue) int { + srcvs := src.([]string) + for i := range srcvs { + num, ok := stringFloatToInt(srcvs[i], 4) + if !ok { + return i + } + srcDec := fmt.Sprintf("%.*f", 4, float64(num)/math.Pow10(int(4))) + dstDec := dst[i] + if srcDec != dstDec { + return i + } + } + return -1 + }, + higherPrecision: false, + }, { logical: "fixed", physical: "int32", diff --git a/dsn_test.go b/dsn_test.go index f939f866c..9af66c90e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -706,6 +706,10 @@ func TestParseDSN(t *testing.T) { ocspMode: ocspModeFailOpen, err: nil, }, + { + dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", + err: errFailedToParseAuthenticator(), + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { diff --git a/errors.go b/errors.go index 1f1ad17a8..b41ad2723 100644 --- a/errors.go +++ b/errors.go @@ -327,6 +327,14 @@ func errInvalidRegion() *SnowflakeError { } } +// Returned if a DSN includes an invalid authenticator. +func errFailedToParseAuthenticator() *SnowflakeError { + return &SnowflakeError{ + Number: ErrCodeFailedToParseAuthenticator, + Message: "failed to parse an authenticator", + } +} + // Returned if the server side returns an error without meaningful message. func errUnknownError() *SnowflakeError { return &SnowflakeError{ diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index f4ec224c6..b94261ce3 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -658,3 +658,96 @@ func TestReadonlyTmpDirPathShouldFail(t *testing.T) { t.Fatalf("should not upload file as temporary directory is not readable") } } + +func TestUploadDownloadOneFileRequireCompress(t *testing.T) { + testUploadDownloadOneFile(t, false) +} + +func TestUploadDownloadOneFileRequireCompressStream(t *testing.T) { + testUploadDownloadOneFile(t, true) +} + +func testUploadDownloadOneFile(t *testing.T, isStream bool) { + tmpDir, err := os.MkdirTemp("", "data") + if err != nil { + t.Fatalf("cannot create temp directory: %v", err) + } + defer os.RemoveAll(tmpDir) + uploadFile := filepath.Join(tmpDir, "data.txt") + f, err := os.Create(uploadFile) + if err != nil { + t.Error(err) + } + f.WriteString("test1,test2\ntest3,test4\n") + f.Close() + + uploadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + dstFileName: "data.txt.gz", + srcFileName: uploadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + requireCompress: true, + } + + downloadFile := filepath.Join(tmpDir, "download.txt") + downloadMeta := &fileMetadata{ + name: "data.txt.gz", + stageLocationType: "local", + noSleepingTime: true, + client: local, + sha256Digest: "123456789abcdef", + stageInfo: &execResponseStageInfo{ + Location: tmpDir, + LocationType: "local", + }, + srcFileName: "data.txt.gz", + dstFileName: downloadFile, + overwrite: true, + options: &SnowflakeFileTransferOptions{ + MultiPartThreshold: dataSizeThreshold, + }, + } + + sfa := snowflakeFileTransferAgent{ + sc: &snowflakeConn{ + cfg: &Config{ + TmpDirPath: tmpDir, + }, + }, + stageLocationType: local, + } + + if isStream { + fileStream, _ := os.Open(uploadFile) + ctx := WithFileStream(context.Background(), fileStream) + uploadMeta.srcStream = getFileStream(ctx) + } + + _, err = sfa.uploadOneFile(uploadMeta) + if err != nil { + t.Fatal(err) + } + if uploadMeta.resStatus != uploaded { + t.Fatalf("failed to upload file") + } + + _, err = sfa.downloadOneFile(downloadMeta) + if err != nil { + t.Fatal(err) + } + defer os.Remove("download.txt") + if downloadMeta.resStatus != downloaded { + t.Fatalf("failed to download file") + } +} diff --git a/heartbeat_test.go b/heartbeat_test.go index 17235f57b..291b8f847 100644 --- a/heartbeat_test.go +++ b/heartbeat_test.go @@ -3,6 +3,7 @@ package gosnowflake import ( + "context" "testing" ) @@ -19,27 +20,45 @@ func TestUnitPostHeartbeat(t *testing.T) { restful: sr, } err := heartbeat.heartbeatMain() - if err != nil { - t.Fatalf("failed to heartbeat and renew session. err: %v", err) - } + assertNilF(t, err, "failed to heartbeat and renew session") + + heartbeat.restful.FuncPost = postTestError + err = heartbeat.heartbeatMain() + assertNotNilF(t, err, "should have failed to start heartbeat") + assertEqualE(t, err.Error(), "failed to run post method") heartbeat.restful.FuncPost = postTestSuccessButInvalidJSON err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed to start heartbeat") + assertHasPrefixE(t, err.Error(), "invalid character") heartbeat.restful.FuncPost = postTestAppForbiddenError err = heartbeat.heartbeatMain() - if err == nil { - t.Fatal("should have failed") - } + assertNotNilF(t, err, "should have failed to start heartbeat") driverErr, ok := err.(*SnowflakeError) - if !ok { - t.Fatalf("should be snowflake error. err: %v", err) - } - if driverErr.Number != ErrFailedToHeartbeat { - t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToHeartbeat, driverErr.Number) - } + assertTrueF(t, ok, "connection should be snowflakeConn") + assertEqualE(t, driverErr.Number, ErrFailedToHeartbeat) }) } + +func TestHeartbeatStartAndStop(t *testing.T) { + createDSNWithClientSessionKeepAlive() + config, err := ParseDSN(dsn) + if err != nil { + t.Fatalf("failed to parse dsn. err: %v", err) + } + driver := SnowflakeDriver{} + db, err := driver.OpenWithConfig(context.Background(), *config) + if err != nil { + t.Fatalf("failed to open with config. config: %v, err: %v", config, err) + } + + conn, ok := db.(*snowflakeConn) + assertTrueF(t, ok, "connection should be snowflakeConn") + assertNotNilF(t, conn.rest, "heartbeat should not be nil") + assertNotNilF(t, conn.rest.HeartBeat, "heartbeat should not be nil") + + err = db.Close() + assertNilF(t, err, "should not cause error in Close") + assertNilF(t, conn.rest, "heartbeat should be nil") +} diff --git a/transaction_test.go b/transaction_test.go index 5027e47f0..ccbf30fb3 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -95,3 +95,33 @@ func withRetry(fn func(context.Context, *sql.Conn) error, numAttempts int, timeo return fmt.Errorf("context deadline exceeded, failed after [%d] attempts", numAttempts) } } + +func TestTransactionError(t *testing.T) { + sr := &snowflakeRestful{ + FuncPostQuery: postQueryFail, + } + + tx := snowflakeTx{ + sc: &snowflakeConn{ + cfg: &Config{Params: map[string]*string{}}, + rest: sr, + }, + ctx: context.Background(), + } + + // test for post query error when executing the txCommand + err := tx.execTxCommand(rollback) + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "failed to get query response") + + // test for invalid txCommand + err = tx.execTxCommand(2) + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "unsupported transaction command") + + // test for bad connection error when snowflakeConn is nil + tx.sc = nil + err = tx.execTxCommand(rollback) + assertNotNilF(t, err, "") + assertEqualE(t, err.Error(), "driver: bad connection") +}