diff --git a/graphs_trees/bst/bst.py b/graphs_trees/bst/bst.py index 39a9836..e4a1d2b 100644 --- a/graphs_trees/bst/bst.py +++ b/graphs_trees/bst/bst.py @@ -19,11 +19,13 @@ def insert(root, data): if root.left is None: root.left = insert(root.left, data) root.left.parent = root + return root.left else: - insert(root.left, data) + return insert(root.left, data) else: if root.right is None: root.right = insert(root.right, data) root.right.parent = root + return root.right else: - insert(root.right, data) \ No newline at end of file + return insert(root.right, data) \ No newline at end of file diff --git a/graphs_trees/bst/bst_challenge.ipynb b/graphs_trees/bst/bst_challenge.ipynb index bcf6340..b3a968e 100644 --- a/graphs_trees/bst/bst_challenge.ipynb +++ b/graphs_trees/bst/bst_challenge.ipynb @@ -158,18 +158,18 @@ "\n", " def test_tree(self):\n", " node = Node(5)\n", - " insert(node, 2)\n", - " insert(node, 8)\n", - " insert(node, 1)\n", - " insert(node, 3)\n", + " assert_equal(insert(node, 2).data, 2)\n", + " assert_equal(insert(node, 8).data, 8)\n", + " assert_equal(insert(node, 1).data, 1)\n", + " assert_equal(insert(node, 3).data, 3)\n", " in_order_traversal(node, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 5, 8]')\n", " self.results.clear_results()\n", "\n", " node = insert(None, 1)\n", - " insert(node, 2)\n", - " insert(node, 3)\n", - " insert(node, 4)\n", + " assert_equal(insert(node, 2).data, 2)\n", + " assert_equal(insert(node, 3).data, 3)\n", + " assert_equal(insert(node, 4).data, 4)\n", " insert(node, 5)\n", " in_order_traversal(node, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 4, 5]')\n", diff --git a/graphs_trees/bst/bst_solution.ipynb b/graphs_trees/bst/bst_solution.ipynb index 3bab5e0..8a092e4 100644 --- a/graphs_trees/bst/bst_solution.ipynb +++ b/graphs_trees/bst/bst_solution.ipynb @@ -136,14 +136,16 @@ " if root.left is None:\n", " root.left = insert(root.left, data)\n", " root.left.parent = root\n", + " return root.left\n", " else:\n", - " insert(root.left, data)\n", + " return insert(root.left, data)\n", " else:\n", " if root.right is None:\n", " root.right = insert(root.right, data)\n", " root.right.parent = root\n", + " return root.right\n", " else:\n", - " insert(root.right, data)" + " return insert(root.right, data)" ] }, { @@ -213,18 +215,18 @@ "\n", " def test_tree(self):\n", " node = Node(5)\n", - " insert(node, 2)\n", - " insert(node, 8)\n", - " insert(node, 1)\n", - " insert(node, 3)\n", + " assert_equal(insert(node, 2).data, 2)\n", + " assert_equal(insert(node, 8).data, 8)\n", + " assert_equal(insert(node, 1).data, 1)\n", + " assert_equal(insert(node, 3).data, 3)\n", " in_order_traversal(node, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 5, 8]')\n", " self.results.clear_results()\n", "\n", " node = insert(None, 1)\n", - " insert(node, 2)\n", - " insert(node, 3)\n", - " insert(node, 4)\n", + " assert_equal(insert(node, 2).data, 2)\n", + " assert_equal(insert(node, 3).data, 3)\n", + " assert_equal(insert(node, 4).data, 4)\n", " insert(node, 5)\n", " in_order_traversal(node, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 4, 5]')\n", diff --git a/graphs_trees/bst/test_bst.py b/graphs_trees/bst/test_bst.py index 7f4f895..a900ea4 100644 --- a/graphs_trees/bst/test_bst.py +++ b/graphs_trees/bst/test_bst.py @@ -8,18 +8,18 @@ class TestTree(object): def test_tree(self): node = Node(5) - insert(node, 2) - insert(node, 8) - insert(node, 1) - insert(node, 3) + assert_equal(insert(node, 2).data, 2) + assert_equal(insert(node, 8).data, 8) + assert_equal(insert(node, 1).data, 1) + assert_equal(insert(node, 3).data, 3) in_order_traversal(node, self.results.add_result) assert_equal(str(self.results), '[1, 2, 3, 5, 8]') self.results.clear_results() node = insert(None, 1) - insert(node, 2) - insert(node, 3) - insert(node, 4) + assert_equal(insert(node, 2).data, 2) + assert_equal(insert(node, 3).data, 3) + assert_equal(insert(node, 4).data, 4) insert(node, 5) in_order_traversal(node, self.results.add_result) assert_equal(str(self.results), '[1, 2, 3, 4, 5]')