1def max_path_sum(root):
2 max_sum = -float('inf')
3
4 def gain(node):
5 nonlocal max_sum
6 if not node:
7 return 0
8 left_gain = max(gain(node.left), 0)
9 right_gain = max(gain(node.right), 0)
10 max_sum = max(max_sum, node.val + left_gain + right_gain)
11 return node.val + max(left_gain, right_gain)
12
13 gain(root)
14 return max_sum