<code> Linear classification

  1 import java.util.ArrayList;
  2 import java.util.Random;
  3 
  4 
  5 public class LinearClassifier {
  6 
  7     private Vector V;
  8     private double vn;
  9     private double eta;//learning rate, the 3st input parameter
 10     
 11     int interval=10;//the 4st input parameter args[3], how often to check to stop
 12     double test_percent;
 13     int []test_pids;
 14     Vector[] testpoints;
 15     static int most=10000;
 16     
 17     double train_percent;
 18     int []train_pids;
 19     Vector[] trainpoints;
 20     double percent=0;//the percentage of all the data
 21     private Vector median;
 22     
 23     public LinearClassifier(int N) {
 24         V=new Vector(N);
 25         median=new Vector(N);
 26         eta=0.4;
 27         vn=0;
 28     }
 29     /**
 30      *  set V and vn, such as they represent an hyperplane of origin C and normal N.
 31      */
 32      private void set_weights(final Vector C, final Vector N){
 33          V=new Vector(N);
 34          vn=-V.dot(C);
 35      }
 36     /**
 37      * returns the signed distance of point X to the hyperplane represented by (V,vn).
 38      */
 39     private  double signed_dist(final Vector X){
 40         return vn+X.dot(V);
 41     }
 42     /**
 43      *  returns true if X is on the positive side of the hyperplane, 
 44      *  false if X is on the negative side of the hyperplane.
 45      */
 46     public boolean classify(final Vector X){
 47         if(signed_dist(X)>0)return true;
 48         else return false;    
 49     }
 50     /**
 51      * updates (V,vn)
 52      * 
 53      * inSide is true if X, a point from a dataset, should be on the positive side of the hyperplane.
 54      * inSide is false if X, a point from a dataset, should be on the negative side of the hyperplane.
 55      * update weights implements one iteration of the stochastic gradient descent. The learning rate is eta.
 56      * 
 57      */
 58     void update_weights(final Vector X, boolean inSide){
 59         double delt_v=0;
 60         double Fx=Math.tanh(signed_dist(X));
 61          Fx=1-Fx*Fx;
 62     //     System.out.print("***"+Fx+"    "+signed_dist(X)+"  -------  "+inSide+" ");    
 63         
 64          double tempvn=vn;
 65          Vector tempv=new Vector(V);
 66          
 67         double z=0,t=0;
 68         if(inSide)
 69             t=1;
 70         else t=-1;
 71         z=Math.tanh(signed_dist(X));
 72         
 73         double error=0.5*(t-z)*(t-z);    
 74         
 75         for(int i=0;i<V.get_length();i++){
 76             delt_v=eta*(t-z)*X.get(i)*Fx;
 77             V.set(i, V.get(i)+delt_v);
 78         }        
 79         vn+=eta*(t-z)*Fx;
 80         
 81         z=Math.tanh(signed_dist(X));
 82         double errornew=0.5*(t-z)*(t-z);
 83         if(error<errornew)
 84         {System.out.println("!!!!!!"+signed_dist(X));
 85             V=tempv;
 86             vn=tempvn;
 87         }
 88         
 89         
 90         
 91     }
 92     
 93     public void reset(Random rd){
 94         Vector N=new Vector(V.get_length());
 95         N.fill(0);    
 96         for(int i=0;i<N.get_length();i++)
 97             N.set(i,rd.nextGaussian());    
 98         //normalize the vector N
 99         N.mul(1/N.norm());    //N.printvec();
100         set_weights(new Vector(V.get_length()),N);
101         //set_weights(median, N);
102     }
103     /**
104      * to test the 1st and 2st dateset, each of them only have 4 Vector
105      */
106     void test1(Random rd,boolean[] inSide,Vector[] test_point){        
107         reset(rd);
108         int i=0;
109         //check the symbol are all the same or all the different
110         //while(!(check_equil(inSide, test_point)||check_equiloppose(inSide, test_point))){
111             while(!check_equil(inSide, test_point)||i>10000){
112             update_weights(test_point[i%4], inSide[i%4]);        
113             i++;
114         }
115         System.out.println("eta= "+eta+" iteration="+i);
116     }
117     /**
118      * check if the symbol are all the same
119      */
120     boolean check_equil(boolean[] inSide,Vector[] point){
121             for(int i=0;i<point.length;i++)
122             if(inSide[i]!=classify(point[i]))
123                 return false;            
124         return true;
125     }
126     /**
127      * check if the symbol are all opposite/on the contrary
128      */
129     boolean check_equiloppose(boolean[] inSide,Vector[] point){
130         for(int i=0;i<point.length;i++)
131         if(inSide[i]==classify(point[i]))
132             return false;            
133     return true;
134 }
135     /**
136      * Training the dataset by the train points,
137      * using test points to determine when to stop learning
138      */
139     void train(){
140         Random rd =new Random();
141         reset(rd);    
142         int i=0;
143         setpercent();
144         
145         double oldtest=test_percent;
146         double oldtrain=train_percent;
147         while(!stop_learning(oldtest,oldtrain)){
148         //while(true){
149             int a=rd.nextInt(trainpoints.length);
150             if(train_pids[a]>0)
151                 update_weights(trainpoints[a], true);
152             else 
153                 update_weights(trainpoints[a], false);
154             i++;
155             //if(i % interval==0)
156                 setpercent();
157             if(i%interval==interval/2){//update old value in different time
158                 oldtest=test_percent;
159                 oldtrain=train_percent;
160             }
161             
162         //    System.out.println(oldtest+"---"+oldtrain);    
163             if(i>most)break;
164         }
165         
166         System.out.println("iteration= "+i);
167         
168     }
169     /**
170      * set the value from 0-size-2 to be the "Vector", size-1 be the "side"
171      * Separate the dataset into two parts: test points and train points
172      */
173     void get_vector(ArrayList<Vector> p,Random rd){
174         int size=p.size();
175         int dim=p.get(0).get_length();    
176         
177         int test_size=(int) (size*0.3);
178         int train_size=size-test_size;
179         
180         train_pids=new int[train_size];
181         trainpoints=new Vector[train_size];
182         
183         test_pids=new int[test_size];
184         testpoints=new Vector[test_size];
185         
186         for(int i=0;i<test_size;i++){
187             int j=rd.nextInt(size);
188             testpoints[i]=Vector.get_sub_vector(p.get(j), 0, dim-2);//0~size-2
189             test_pids[i]=(int) p.get(j).get(dim-1);
190             p.remove(j);
191             size--;
192             testpoints[i].sub(median);
193         }
194         
195         for(int i=0;i<train_size;i++){
196             trainpoints[i]=Vector.get_sub_vector(p.get(i), 0, dim-2);//0~size-2
197             train_pids[i]=(int) p.get(i).get(dim-1);
198             trainpoints[i].sub(median);
199         }
200         
201         
202     }    
203     /**
204      * stop learning, according to the percentage of accuracy of testing and training points
205      * this function got executed every 10, 100, or more times after doing update
206      */
207     boolean stop_learning(double oldTest,double oldTrain){
208         
209         double d1=Math.abs(oldTest-oldTrain);//old delta value
210         double d2=Math.abs(train_percent-test_percent);//new delta value
211         double d3=train_percent+test_percent;
212         //Guarantee least correct, some parameters that I guess, cann't fit to any dataset
213         if(((d2 >2*d1)||(d2 <0.00001)||train_percent>0.85) && train_percent >0.75 &&test_percent>0.72 &&d3>1.6)
214             return  true;
215         return false;
216     }
217     
218     void setpercent(){
219         test_percent=0;
220         train_percent=0;
221         for(int i=0;i<testpoints.length;i++){
222             if(classify(testpoints[i])==true&&test_pids[i]==1)
223                 test_percent++;
224             if(classify(testpoints[i])==false&&test_pids[i]==0)
225                 test_percent++;
226         }
227         
228 
229         for(int i=0;i<trainpoints.length;i++){
230             if(classify(trainpoints[i])==true&&train_pids[i]==1)
231                 train_percent++;
232             if(classify(trainpoints[i])==false&&train_pids[i]==0)
233                 train_percent++;
234         }
235         
236         percent=test_percent+train_percent;
237         percent/=testpoints.length+trainpoints.length;
238         
239         test_percent/=testpoints.length;
240         train_percent/=trainpoints.length;
241         
242         //System.out.println("testPercent: "+test_percent+"  trainPercent: "+train_percent);    
243     }
244     
245     public static void main(String[] args) {
246 //        System.out.println("test 1: --------------------------");
247 //------ test one --------------------------------------------------------------/
248         Vector test_point[]=new Vector[4];
249         test_point[0]=new Vector(2);test_point[1]=new Vector(2);
250         test_point[2]=new Vector(2);test_point[3]=new Vector(2);
251 
252         test_point[0].set(0, -1);test_point[0].set(1, 1);
253         test_point[1].set(0, 1);test_point[1].set(1, 1);
254         test_point[2].set(0, -1);test_point[2].set(1, -1);
255         test_point[3].set(0, 1);test_point[3].set(1, -1);
256         
257         boolean [] inSide=new boolean[4];
258         inSide[0]=true;
259         inSide[1]=true;
260         inSide[2]=false;
261         inSide[3]=false;
262         
263         LinearClassifier test1=new LinearClassifier(2);
264         
265         test1.eta=0.5;
266         test1.test1(new Random(),inSide,test_point);test1.V.printvec();
267         test1.eta=0.01;
268         test1.test1(new Random(),inSide,test_point);test1.V.printvec();
269         test1.eta=0.001;
270         test1.test1(new Random(),inSide,test_point);test1.V.printvec();
271 //------ test two --------------------------------------------------------------/
272         System.out.println("test 2: ==========================");    
273         test_point[0].set(0, 100);test_point[0].set(1, 101);
274         test_point[1].set(0, 101);test_point[1].set(1, 101);
275         test_point[2].set(0, 100);test_point[2].set(1, 100);
276         test_point[3].set(0, 101);test_point[3].set(1, 100);
277         System.out.println("no Optimization:"+"\n too much time 。。。"+"\nuse median to do the optimization");
278 //        test1.eta=2.5;test1.test1(new Random(),inSide,test_point);
279         ArrayList<Vector> points=new ArrayList<Vector>();
280         points.add(test_point[0]);points.add(test_point[1]);
281         points.add(test_point[2]);points.add(test_point[3]);
282         
283         test1.median=new Vector(Vector.vector_median(points));
284         test1.median.printvec();
285         for(int i=0;i<4;i++)
286             test_point[i].sub(test1.median);
287         
288         test1.eta=0.5;
289         test1.test1(new Random(),inSide,test_point);
290         test1.eta=0.01;
291         test1.test1(new Random(),inSide,test_point);
292         test1.V.add(test1.median);
293         test1.V.printvec();
294         
295         for(int i=0;i<4;i++)
296             test_point[i].add(test1.median);
297 //------ test three --------------------------------------------------------------/
298         System.out.println("test 3: =========================="+
299                 "\nseperate the dataset into 30% testing part and 70% training part, " +
300                 "\nusing the percentage to determine when to stop the learning,the max iteration is "+most);        
301         points.clear();
302         if(args.length==0)
303             points=Vector.read_data("dataset-2");//the dataset
304         else
305             points=Vector.read_data(args[0]);
306         
307         int size=points.get(0).get_length();
308 
309         LinearClassifier lc=new LinearClassifier(size-1);
310         lc.median=Vector.get_sub_vector(Vector.vector_median(points), 0, size-2);;
311         lc.eta=0.4;
312         
313         if(args.length>3)
314             lc.interval=new Integer(args[3]);
315         if(args.length>=3)
316             lc.eta=new Double(args[2]);
317         
318         lc.get_vector(points,new Random());
319         lc.train();
320         lc.setpercent();
321         System.out.println(" percentage: "+lc.percent+" test: "+lc.test_percent+"  train: "+lc.train_percent);    
322         
323         int idtest[]=new int[lc.testpoints.length];
324         for(int i=0;i<lc.testpoints.length;i++){
325             if(lc.classify(lc.testpoints[i]))
326             idtest[i]=0;
327             else idtest[i]=1;
328             lc.testpoints[i].add(lc.median);
329         }
330         
331         int idtrain[] =new int[lc.trainpoints.length];
332         for(int i=0;i<lc.trainpoints.length;i++){
333             if(lc.classify(lc.trainpoints[i]))
334             idtrain[i]=0;
335             else idtrain[i]=1;
336             lc.trainpoints[i].add(lc.median);
337         }
338         
339         lc.V.add(lc.median);
340         lc.V.printvec();
341         if(args.length<2)
342         {
343             Vector.write_data_withID("out-dataset", lc.testpoints, idtest);
344             Vector.write_data_withID("out-dataset", lc.trainpoints, idtrain,true);
345         }
346         else {
347             Vector.write_data_withID(args[1], lc.testpoints, idtest);
348             Vector.write_data_withID(args[1], lc.trainpoints, idtrain,true);
349         }
350         
351     }
352 
353 }

 

posted @ 2013-04-11 10:15  SONGHY  阅读(269)  评论(0编辑  收藏  举报