Skip to content

Commit

Permalink
Merge branch 'master' into improved_scheduling_report
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesMurkin committed Jun 5, 2024
2 parents a585f8e + ee1ee05 commit a6eded9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion internal/executor/node/node_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
util2 "github.com/armadaproject/armada/internal/executor/util"
)

const defaultNodeType = "none"

type NodeInfoService interface {
IsAvailableProcessingNode(*v1.Node) bool
GetAllAvailableProcessingNodes() ([]*v1.Node, error)
Expand Down Expand Up @@ -61,7 +63,7 @@ func (kubernetesNodeInfoService *KubernetesNodeInfoService) GroupNodesByType(nod
}

func (kubernetesNodeInfoService *KubernetesNodeInfoService) GetType(node *v1.Node) string {
nodeType := kubernetesNodeInfoService.clusterContext.GetClusterPool()
nodeType := defaultNodeType

if labelValue, ok := node.Labels[kubernetesNodeInfoService.nodeTypeLabel]; ok {
nodeType = labelValue
Expand Down
10 changes: 5 additions & 5 deletions internal/executor/node/node_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestGetType_WhenNodeHasNoTaint(t *testing.T) {
node := createNodeWithTaints("node1")

result := nodeInfoService.GetType(node)
assert.Equal(t, result, context.GetClusterPool())
assert.Equal(t, result, defaultNodeType)
}

func TestGetType_WhenNodeHasNodeTypeLabel(t *testing.T) {
Expand All @@ -42,7 +42,7 @@ func TestGetType_WhenNodeHasUntoleratedTaint(t *testing.T) {
node := createNodeWithTaints("node1", "untolerated")

result := nodeInfoService.GetType(node)
assert.Equal(t, result, context.GetClusterPool())
assert.Equal(t, result, defaultNodeType)
}

func TestGetType_WhenNodeHasToleratedTaint(t *testing.T) {
Expand Down Expand Up @@ -81,9 +81,9 @@ func TestGroupNodesByType(t *testing.T) {
assert.Equal(t, len(groupedNodes), 3)

expected := map[string][]*v1.Node{
context.GetClusterPool(): {node1, node2},
"tolerated1": {node3, node4},
"tolerated1,tolerated2": {node5},
defaultNodeType: {node1, node2},
"tolerated1": {node3, node4},
"tolerated1,tolerated2": {node5},
}

for _, nodeGroup := range groupedNodes {
Expand Down

0 comments on commit a6eded9

Please sign in to comment.