diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a98055f..7303146 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -45,6 +45,10 @@ jobs: check-latest: true - run: go version - run: go mod download + - name: disable ipv6 + run: | + sudo sysctl -w net.ipv6.conf.all.disable_ipv6=1 + sudo sysctl -w net.ipv6.conf.default.disable_ipv6=1 - name: Run tests run: | FORCE_RUN_INTEGRATION_TESTS=true make test exclude="${{ matrix.exclude }}" package=${{ matrix.package }} diff --git a/go.mod b/go.mod index 652e3f5..9a6b9ee 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/databricks/databricks-sql-go v1.5.5 github.com/dlclark/regexp2 v1.11.0 github.com/gliderlabs/ssh v0.3.7 - github.com/go-sql-driver/mysql v1.7.1 + github.com/go-sql-driver/mysql v1.8.1 github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 github.com/ory/dockertest/v3 v3.10.0 @@ -40,6 +40,7 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect cloud.google.com/go/iam v1.1.7 // indirect dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 // indirect diff --git a/go.sum b/go.sum index e1a61a6..2aae60b 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ cloud.google.com/go/storage v1.40.0 h1:VEpDQV5CJxFmJ6ueWNsKxcr1QAYOXEgxDa+sBbJah cloud.google.com/go/storage v1.40.0/go.mod h1:Rrj7/hKlG87BLqDJYtwR0fbPld8uJPbQ2ucUMY7Ir0g= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMbk2FiG/kXiLl8BRyzTWDw7gX/Hz7Dd5eDMs= github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= @@ -161,8 +163,8 @@ github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= -github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= diff --git a/sqlconnect/internal/mysql/legacy_mappings.go b/sqlconnect/internal/mysql/legacy_mappings.go index be4cc02..5bce242 100644 --- a/sqlconnect/internal/mysql/legacy_mappings.go +++ b/sqlconnect/internal/mysql/legacy_mappings.go @@ -2,6 +2,7 @@ package mysql import ( "encoding/json" + "fmt" "strconv" "strings" "time" @@ -13,8 +14,9 @@ func legacyJsonRowMapper(databaseTypeName string, value any) any { return nil } databaseTypeName = strings.Replace(databaseTypeName, "UNSIGNED ", "", 1) + switch databaseTypeName { - case "CHAR", "VARCHAR", "BLOB", "TEXT", "TINYBLOB", "TINYTEXT", "MEDIUMBLOB", "MEDIUMTEXT", "LONGBLOB", "LONGTEXT", "ENUM": + case "CHAR", "VARCHAR", "BLOB", "TEXT", "TINYBLOB", "TINYTEXT", "MEDIUMBLOB", "MEDIUMTEXT", "LONGBLOB", "LONGTEXT", "ENUM", "SET": switch v := value.(type) { case []uint8: return string(v) @@ -23,6 +25,8 @@ func legacyJsonRowMapper(databaseTypeName string, value any) any { } case "DATE", "DATETIME", "TIMESTAMP", "TIME", "YEAR": switch v := value.(type) { + case int, int32, int64, uint32, uint64: + return fmt.Sprintf("%d", v) case []uint8: return string(v) default: @@ -31,6 +35,18 @@ func legacyJsonRowMapper(databaseTypeName string, value any) any { case "FLOAT", "DOUBLE", "DECIMAL": switch v := value.(type) { + case int, int32, int64, uint32, uint64: + n, err := strconv.ParseFloat(fmt.Sprintf("%d", v), 64) + if err != nil { + panic(err) + } + return n + case float32, float64: + n, err := strconv.ParseFloat(fmt.Sprintf("%f", v), 64) + if err != nil { + panic(err) + } + return n case []uint8: n, err := strconv.ParseFloat(string(v), 64) if err != nil { @@ -46,6 +62,12 @@ func legacyJsonRowMapper(databaseTypeName string, value any) any { } case "INT", "TINYINT", "SMALLINt", "MEDIUMINT", "BIGINT", "UNSIGNED BIGINT": switch v := value.(type) { + case int, int32, int64, uint32, uint64: + n, err := strconv.ParseInt(fmt.Sprintf("%d", v), 10, 64) + if err != nil { + panic(err) + } + return n case []uint8: n, err := strconv.ParseInt(string(v), 10, 64) if err != nil { @@ -59,6 +81,16 @@ func legacyJsonRowMapper(databaseTypeName string, value any) any { } return n } + case "SMALLINT": + switch v := value.(type) { + case int, int32, int64, uint32, uint64: + // converting to []byte to be backwards compatible + return []byte(fmt.Sprintf("%d", v)) + case float32, float64: + // converting to []byte to be backwards compatible + return []byte(fmt.Sprintf("%f", v)) + } + } return value diff --git a/sqlconnect/internal/mysql/mappings.go b/sqlconnect/internal/mysql/mappings.go index de5c5a9..ef805d3 100644 --- a/sqlconnect/internal/mysql/mappings.go +++ b/sqlconnect/internal/mysql/mappings.go @@ -3,6 +3,7 @@ package mysql import ( "encoding/binary" "encoding/json" + "fmt" "strconv" "strings" "time" @@ -70,6 +71,12 @@ func jsonRowMapper(databaseTypeName string, value interface{}) interface{} { stringValue = v.String() case string: stringValue = v + case int, int32, int64, uint32, uint64: + stringValue = fmt.Sprintf("%d", v) + case float32, float64: + stringValue = fmt.Sprintf("%f", v) + default: + return value } switch databaseTypeName { @@ -100,12 +107,18 @@ func jsonRowMapper(databaseTypeName string, value interface{}) interface{} { case "JSON": return json.RawMessage(stringValue) case "FLOAT", "DOUBLE", "DECIMAL": + if stringValue == "" { + return nil + } n, err := strconv.ParseFloat(stringValue, 64) if err != nil { panic(err) } return n case "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT": + if stringValue == "" { + return nil + } n, err := strconv.ParseInt(stringValue, 10, 64) if err != nil { panic(err) diff --git a/sqlconnect/internal/mysql/testdata/legacy-column-mapping-test-columns-sql.json b/sqlconnect/internal/mysql/testdata/legacy-column-mapping-test-columns-sql.json index a3312f3..e116c2a 100644 --- a/sqlconnect/internal/mysql/testdata/legacy-column-mapping-test-columns-sql.json +++ b/sqlconnect/internal/mysql/testdata/legacy-column-mapping-test-columns-sql.json @@ -23,8 +23,8 @@ "_tinytext": "TEXT", "_mediumtext": "TEXT", "_longtext": "TEXT", - "_enum": "CHAR", - "_set": "CHAR", + "_enum": "ENUM", + "_set": "SET", "_date": "DATE", "_datetime": "DATETIME", "_timestamp": "TIMESTAMP", diff --git a/sqlconnect/internal/sshtunnel/tcp_tunnel.go b/sqlconnect/internal/sshtunnel/tcp_tunnel.go index 3d60d21..6b7403b 100644 --- a/sqlconnect/internal/sshtunnel/tcp_tunnel.go +++ b/sqlconnect/internal/sshtunnel/tcp_tunnel.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "strconv" + "time" "github.com/rudderlabs/sql-tunnels/tunnel" ) @@ -24,10 +25,30 @@ func NewTcpTunnel(c Config, remoteHost string, remotePort int) (Tunnel, error) { RemoteHost: remoteHost, RemotePort: remotePort, } + t, err := tunnel.ListenAndForward(&tunnelConfig) if err != nil { return nil, fmt.Errorf("creating ssh tunnel: %w", err) } + + // Wait for the tunnel to be ready (go routine) + var ( + established bool + retries int + ) + for !established && retries < 10 { + con, err := net.Dial("tcp", t.Addr()) + if con != nil { + _ = con.Close() + } + if err != nil { + retries++ + time.Sleep(10 * time.Millisecond) + continue + } + established = true + } + return &tcpTunnel{t}, nil }