/*
public class TreeNode {
int val = 0;
TreeNode left = null;
TreeNode right = null;
public TreeNode(int val) {
this.val = val;
}
}
*/
import java.util.*;
public class Solution {
TreeNode KthNode(TreeNode pRoot, int k){
if(pRoot == null || k == 0){
return null;
}
Stack<TreeNode> s = new Stack<TreeNode>();
TreeNode node = pRoot;
s.add(node);
while(node.left != null){
s.add(node.left);
node = node.left;
}
int num = 0;
while(!s.empty()){
TreeNode temp = s.pop();
num ++;
if(num==k){
return temp;
}
if(temp.right != null){
temp = temp.right;
s.add(temp);
while(temp.left != null){
s.add(temp.left);
temp = temp.left;
}
}
}
return null;
}
}