package cn.aust.zyw.demo;
public class BST<Key extends Comparable<Key>,Value> {
public static void main(String args[]){
BST<Integer,String> bst=new BST<>();
bst.put(3,"three");
bst.put(5,"five");
bst.put(1,"one");
bst.put(6,"six");
bst.put(4,"four");
System.out.println(bst.celling(2));
}
private Node root;
private class Node{
private Key key;
private Value val;
private Node left,right;
private int N;//以该结点为根的子树的结点总数
public Node(Key key,Value val,int N){
this.key=key;this.val=val;this.N=N;
}
}
public int size(){return size(root);}
private int size(Node x){
if(x==null) return 0;
return x.N;
}
public Value get(Key key){return get(root,key);}
//按索引值查找
private Value get(Node x,Key key){
if(x==null) return null;
int cmp=key.compareTo(x.key);
if(cmp<0) return get(x.left,key);
if(cmp>0) return get(x.right,key);
else return x.val;
}
public void put(Key key,Value val){
root=put(root,key,val);
}
//若存在直接改变val,不存在new node,则相应子节点数目+1
private Node put(Node x,Key key,Value val){
if(x==null) return new Node(key,val,1);
int cmp=key.compareTo(x.key);
if(cmp<0) x.left=put(x.left,key,val);
else if(cmp>0) x.right=put(x.right,key,val);
else x.val=val;
x.N=size(x.left)+size(x.right)+1;
return x;
}
//return 最小子节点
public Key min(){return min(root).key;}
private Node min(Node x){
if(x.left==null) return x;
return min(x.left);
}
//return 最大子节点
public Key max(){return max(root).key;}
public Node max(Node x){
if(x.right==null) return x;
return max(x.right);
}
//[key]向下取整
public Key floor(Key key){
Node x=floor(root,key);
if(x==null) return null;
return x.key;
}
private Node floor(Node x,Key key){
if(x==null) return null;
int cmp=key.compareTo(x.key);
if(cmp==0) return x;
if(cmp<0) return floor(x.left,key);
Node t=floor(x.right,key);
if(t!=null) return t;
else return x;
}
public Key celling(Key key){
Node x=celling(root,key);
if(x==null) return null;
return x.key;
}
private Node celling(Node x,Key key){
if(x==null) return null;
int cmp=key.compareTo(x.key);
if(cmp==0) return x;
if(cmp>0) return celling(x.right,key);
Node t=celling(x.left,key);
if(t!=null) return t;
else return x;
}
}