1def is_balanced(root):
2 def max_depth(node):
3 if not node:
4 return 0
5 left = max_depth(node.left)
6 if left == -1:
7 return -1
8 right = max_depth(node.right)
9 if right == -1:
10 return -1
11 if abs(left - right) > 1:
12 return -1
13 return 1 + max(left, right)
14
15 return max_depth(root) != -1