diff --git a/rbtree.go b/rbtree.go index 96f062e..e66b21f 100644 --- a/rbtree.go +++ b/rbtree.go @@ -30,22 +30,28 @@ func NewTree[K Ordered, V any]() *Tree[K, V] { // Find finds the node and return its value. func (t *Tree[K, V]) Find(key K) V { + var result V + if t == nil { + return result + } n := t.findnode(key) if n != nil { return n.Value } - var result V return result } // FindIt finds the node and return it as an iterator. func (t *Tree[K, V]) FindIt(key K) *node[K, V] { + if t == nil || t.root == nil { + return nil + } return t.findnode(key) } // Empty checks whether the rbtree is empty. func (t *Tree[K, V]) Empty() bool { - if t.root == nil { + if t == nil || t.root == nil { return true } return false @@ -61,13 +67,18 @@ func (t *Tree[K, V]) Iterator() *node[K, V] { // Size returns the size of the rbtree. func (t *Tree[K, V]) Size() int { + if t == nil { + return 0 + } return t.size } // Clear destroys the rbtree. func (t *Tree[K, V]) Clear() { - t.root = nil - t.size = 0 + if t != nil { + t.root = nil + t.size = 0 + } } // Insert inserts the key-value pair into the rbtree. @@ -102,6 +113,10 @@ func (t *Tree[K, V]) Insert(key K, value V) { // Delete deletes the node by key func (t *Tree[K, V]) Delete(key K) { + if t == nil || t.root == nil { + return + } + z := t.findnode(key) if z == nil { return diff --git a/rbtree_test.go b/rbtree_test.go index 67e8cbb..c31aad6 100644 --- a/rbtree_test.go +++ b/rbtree_test.go @@ -2,6 +2,7 @@ package rbtree import ( "fmt" + "strings" "testing" ) @@ -57,6 +58,56 @@ func TestPreorder(t *testing.T) { tree.Preorder() } +func TestTree(t *testing.T) { + tree := NewTree[int, string]() + + t.Run("empty", func(t *testing.T) { + tree.Clear() + tree.Delete(1) + + if !tree.Empty() { + t.Fatal("tree isn't empty") + } + + size := tree.Size() + if size != 0 { + t.Errorf("unexpected tree size of %d", size) + } + }) + + t.Run("nil", func(t *testing.T) { + tree = nil + + tree.Clear() + tree.Delete(1) + + if !tree.Empty() { + t.Fatal("tree isn't empty") + } + + size := tree.Size() + if size != 0 { + t.Errorf("unexpected tree size of %d", size) + } + }) + + var caught interface{} + t.Run("catch insert panic", func(t *testing.T) { + defer func() { + if err := recover(); err != nil { + caught = err + } + }() + // panics + tree.Insert(1, "abc") + }) + + error := fmt.Sprintf("%v", caught) + if !strings.Contains(error, "nil pointer dereference") { + t.Fatalf("unexpected error: %#v", caught) + } +} + func TestFind(t *testing.T) { tree := NewTree[int, string]() @@ -93,6 +144,20 @@ func TestFind(t *testing.T) { t.Fatalf("got %q", value) } }) + + t.Run("nil", func(t *testing.T) { + tree = nil + + n := tree.FindIt(4) + if n != nil { + t.Fatalf("got %#v", n) + } + + value := tree.Find(5) + if value != "" { + t.Fatalf("got %q", value) + } + }) } func TestIterator(t *testing.T) { @@ -115,7 +180,20 @@ func TestIterator(t *testing.T) { tree = NewTree[int, string]() next := tree.Iterator() - t.Logf("tree.Iterator()=%#v", next) + if next != nil { + t.Fatalf(".Iterator() returned %#v", next) + } + + size := tree.Size() + if size != 0 { + t.Fatalf("got size %d", size) + } + }) + + t.Run("nil", func(t *testing.T) { + tree = nil + + next := tree.Iterator() if next != nil { t.Fatalf(".Iterator() returned %#v", next) } @@ -158,6 +236,16 @@ func TestDelete(t *testing.T) { t.Fatalf("after size is %d", size) } }) + + t.Run("nil", func(t *testing.T) { + tree = nil + tree.Delete(1) + + size := tree.Size() + if size != 0 { + t.Fatalf("after size is %d", size) + } + }) } func TestDelete2(t *testing.T) {