From d0b9bbe138cfe6b2e4a860a761091ee050637701 Mon Sep 17 00:00:00 2001 From: Min <76146890+MinH-09@users.noreply.github.com> Date: Tue, 11 Jul 2023 16:14:04 +0800 Subject: [PATCH] test: add unit test for model and digest (#2538) Signed-off-by: MinH-09 <2107139596@qq.com> --- manager/types/model.go | 4 +- manager/types/model_test.go | 199 ++++++++++++++++++++++++++++++++++++ pkg/digest/digest_test.go | 4 + 3 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 manager/types/model_test.go diff --git a/manager/types/model.go b/manager/types/model.go index 7842e6d9544..386738c641d 100644 --- a/manager/types/model.go +++ b/manager/types/model.go @@ -63,12 +63,12 @@ type ModelEvaluation struct { MAE float64 `json:"mae" binding:"omitempty,gte=0"` } -// MakeModelName returns model name of GNN. +// MakeGNNModelName returns model name of GNN. func MakeGNNModelName(hostname, ip string, clusterID uint64) string { return fmt.Sprintf("%s_%s_%s_%s", ip, hostname, fmt.Sprint(clusterID), GNNModelNameSuffix) } -// MakeModelName returns model name of MLP. +// MakeMLPModelName returns model name of MLP. func MakeMLPModelName(hostname, ip string, clusterID uint64) string { return fmt.Sprintf("%s_%s_%s_%s", ip, hostname, fmt.Sprint(clusterID), MLPModelNameSuffix) } diff --git a/manager/types/model_test.go b/manager/types/model_test.go new file mode 100644 index 00000000000..614f2a993f6 --- /dev/null +++ b/manager/types/model_test.go @@ -0,0 +1,199 @@ +/* + * Copyright 2023 The Dragonfly Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_MakeGNNModelName(t *testing.T) { + tests := []struct { + name string + hostname string + ip string + clusterID uint64 + expect func(t *testing.T, s string) + }{ + { + name: "make gnn model name", + hostname: "foo", + ip: "127.0.0.1", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "127.0.0.1_foo_1_gnn") + }, + }, + { + name: "hostname is empty", + hostname: "", + ip: "127.0.0.1", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "127.0.0.1__1_gnn") + }, + }, + { + name: "ip is empty", + hostname: "foo", + ip: "", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "_foo_1_gnn") + }, + }, + { + name: "hostname and ip are empty", + hostname: "", + ip: "", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "__1_gnn") + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.expect(t, MakeGNNModelName(tc.hostname, tc.ip, tc.clusterID)) + }) + } +} + +func Test_MakeMLPModelName(t *testing.T) { + tests := []struct { + name string + hostname string + ip string + clusterID uint64 + expect func(t *testing.T, s string) + }{ + { + name: "make mlp model name", + hostname: "foo", + ip: "127.0.0.1", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "127.0.0.1_foo_1_mlp") + }, + }, + { + name: "hostname is empty", + hostname: "", + ip: "127.0.0.1", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "127.0.0.1__1_mlp") + }, + }, + { + name: "ip is empty", + hostname: "foo", + ip: "", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "_foo_1_mlp") + }, + }, + { + name: "hostname and ip are empty", + hostname: "", + ip: "", + clusterID: uint64(1), + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "__1_mlp") + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.expect(t, MakeMLPModelName(tc.hostname, tc.ip, tc.clusterID)) + }) + } +} + +func Test_MakeObjectKeyOfModelFile(t *testing.T) { + tests := []struct { + name string + modelName string + version int + expect func(t *testing.T, s string) + }{ + { + name: "make objectKey of model file", + modelName: "foo", + version: 1, + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "foo/1/model.graphdef") + }, + }, + { + name: "modelName is empty", + modelName: "", + version: 1, + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "/1/model.graphdef") + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.expect(t, MakeObjectKeyOfModelFile(tc.modelName, tc.version)) + }) + } +} + +func Test_MakeObjectKeyOfModelConfigFile(t *testing.T) { + tests := []struct { + name string + modelName string + version int + expect func(t *testing.T, s string) + }{ + { + name: "make objectKey of model file", + modelName: "foo", + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "foo/config.pbtxt") + }, + }, + { + name: "modelName is empty", + modelName: "", + expect: func(t *testing.T, s string) { + assert := assert.New(t) + assert.Equal(s, "/config.pbtxt") + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.expect(t, MakeObjectKeyOfModelConfigFile(tc.modelName)) + }) + } +} diff --git a/pkg/digest/digest_test.go b/pkg/digest/digest_test.go index b5afdeca632..685a47ac5d5 100644 --- a/pkg/digest/digest_test.go +++ b/pkg/digest/digest_test.go @@ -171,3 +171,7 @@ func TestDigest_MD5FromBytes(t *testing.T) { func TestDigest_SHA256FromStrings(t *testing.T) { assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", SHA256FromStrings("hello")) } + +func TestDigest_SHA256FromBytes(t *testing.T) { + assert.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", SHA256FromBytes([]byte("hello"))) +}