diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 852082add9..25340ec208 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -166,6 +166,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect github.com/sendgrid/rest v2.6.9+incompatible // indirect + github.com/shamaton/msgpack/v2 v2.2.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/cast v1.4.1 // indirect diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index e62eda562d..5a1c5a7c25 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -23,6 +23,7 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.16.0 github.com/ray-project/kuberay/ray-operator v1.1.0-rc.1 + github.com/shamaton/msgpack/v2 v2.2.2 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index c8aa6c1254..ac78be7844 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -342,6 +342,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shamaton/msgpack/v2 v2.2.2 h1:GOIg0c9LV04VwzOOqZSrmsv/JzjNOOMxnS/HvOHGdgs= +github.com/shamaton/msgpack/v2 v2.2.2/go.mod h1:6khjYnkx73f7VQU7wjcFS9DFjs+59naVWJv1TB7qdOI= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= diff --git a/flyteplugins/go/tasks/pluginmachinery/core/template/template.go b/flyteplugins/go/tasks/pluginmachinery/core/template/template.go index 7a787c5590..5aea60c4b9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/template/template.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/template/template.go @@ -27,7 +27,9 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/pkg/errors" + "github.com/shamaton/msgpack/v2" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" idlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" @@ -199,6 +201,19 @@ func serializeLiteralScalar(l *idlCore.Scalar) (string, error) { return o.Blob.Uri, nil case *idlCore.Scalar_Schema: return o.Schema.Uri, nil + case *idlCore.Scalar_Binary: + binaryBytes := o.Binary.Value + var currVal any + if o.Binary.Tag == coreutils.MESSAGEPACK { + err := msgpack.Unmarshal(binaryBytes, &currVal) + if err != nil { + return "", fmt.Errorf("failed to unmarshal messagepack bytes with literal:[%v], err:[%v]", l, err) + } + // TODO: Try to support Primitive_Datetime, Primitive_Duration, Flyte File, and Flyte Directory. + return fmt.Sprintf("%v", currVal), nil + } + return "", fmt.Errorf("unsupported binary tag [%v]", o.Binary.Tag) + default: return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/template/template_test.go b/flyteplugins/go/tasks/pluginmachinery/core/template/template_test.go index 956ec33cfd..0fa96a1a05 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/template/template_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/template/template_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" @@ -747,3 +748,55 @@ func TestSerializeLiteral(t *testing.T) { assert.Equal(t, "s3://some-bucket/fdsa/x.parquet", interpolated) }) } + +func TestSerializeLiteralScalar_BinaryMessagePack(t *testing.T) { + // Create a simple map to be serialized into MessagePack format + testMap := map[string]interface{}{ + "a": 1, + "b": true, + "c": 1.1, + "d": "string", + } + + // Serialize the map using MessagePack + encodedData, err := msgpack.Marshal(testMap) + assert.NoError(t, err) + + // Create the core.Scalar_Binary with the encoded MessagePack data and MESSAGEPACK tag + binaryScalar := &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: encodedData, + Tag: coreutils.MESSAGEPACK, + }, + }, + } + + // Call the function we want to test + result, err := serializeLiteralScalar(binaryScalar) + assert.NoError(t, err) + + // Since the map should be decoded back, we expect a simple string representation of the map + expectedResult := "map[a:1 b:true c:1.1 d:string]" + assert.Equal(t, expectedResult, result) +} + +func TestSerializeLiteralScalar_BinaryUnsupportedTag(t *testing.T) { + // Create some binary data for testing + binaryData := []byte{0x01, 0x02, 0x03} + + // Create a core.Scalar_Binary with an unsupported tag + binaryScalar := &core.Scalar{ + Value: &core.Scalar_Binary{ + Binary: &core.Binary{ + Value: binaryData, + Tag: "unsupported-tag", + }, + }, + } + + // Call the function and expect an error because the tag is unsupported + _, err := serializeLiteralScalar(binaryScalar) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported binary tag") +}