稀疏矩阵的加法,乘法,转置

这里用行号和列号构成的ArrayList作为hashmap的键,只要ArrayList里面的元素相同且按顺序,那么equals就是相同的,因为我去看了它的源码,equals就是比较每个位置的是不是相等。

所以这里这样用很好:

可以保证只要行号和列号相同且按顺序就能保证是唯一的,乘法麻烦一点

public class SMatrix {
    public HashMap<ArrayList<Integer>, Integer> Triples;//矩阵的三元组表示,前面的list是行数和列数,后面的是长度
    public int rowNum;//矩阵行数  
    public int colNum;//矩阵列数 
    public SMatrix(HashMap<ArrayList<Integer>, Integer> triples, int rowNum, int colNum) {           
        Triples = triples;  
        this.rowNum = rowNum;  
        this.colNum = colNum;  
    }
    public SMatrix(){} 
    /* 
     * 稀疏矩阵相加函数 
     */  
    public static SMatrix Add(SMatrix M,SMatrix N){  
        if(M.colNum != N.colNum || M.rowNum != N.rowNum){  
            System.out.println("矩阵相加不满足条件");  
            return null;  
        }  
        SMatrix s = new SMatrix();  
        HashMap<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();//结果 
        Iterator<HashMap.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator();  
        Iterator<HashMap.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator();          
        while(it1.hasNext()){  
            Entry<ArrayList<Integer>, Integer> entry = it1.next();  
            ArrayList<Integer> position = entry.getKey();//行号和列号  
            int value = entry.getValue();//大小  
            if(triples.containsKey(position)){  
                triples.put(position, triples.get(position) + value);  
            }else{  
                triples.put(position, value);  
            }                            
        }           
        while(it2.hasNext()){  
            Entry<ArrayList<Integer>,Integer> entry = it2.next();  
            ArrayList<Integer> position = entry.getKey();  
            int value = entry.getValue();  
            if(triples.containsKey(position)){//这里可以是同一个对象? 
                triples.put(position, triples.get(position) + value);  
            }else{  
                triples.put(position, value);  
            }  
              
        }  
        return s;  
    } 
    /* 
     * 稀疏矩阵求转置矩阵函数 
     */  
    public SMatrix Transposition(){            
        HashMap<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();  
        Iterator<HashMap.Entry<ArrayList<Integer>, Integer>> it = this.Triples.entrySet().iterator();  
        while(it.hasNext()){  
            Entry<ArrayList<Integer>, Integer> entry = it.next();  
            ArrayList<Integer> position = entry.getKey();  
            int value = entry.getValue();  
            ArrayList<Integer> transP = new ArrayList<Integer>();  
            transP.add(position.get(1));  
            transP.add(position.get(0));               
            triples.put(transP, value);  
              
        }  
        SMatrix s = new SMatrix(triples,this.colNum,this.rowNum);  
        return s;  
    }  
    
    /* 
     * 稀疏矩阵相乘函数 
     */  
    public SMatrix Multiply(SMatrix M,SMatrix N){  
        if(M.colNum != N.rowNum){  
            System.out.println("矩阵相乘不满足条件");  
            return null;  
        }                   
        HashMap<ArrayList<Integer>,Integer> triples = new HashMap<ArrayList<Integer>,Integer>();  
        Iterator<HashMap.Entry<ArrayList<Integer>, Integer>> it1 = M.Triples.entrySet().iterator();                      
        int iter = 0;  
        while(it1.hasNext()){  
            iter++;  
            Entry<ArrayList<Integer>, Integer> entry = it1.next();  
            ArrayList<Integer> position = entry.getKey();  
            int value = entry.getValue();     
            int flag = 0;  
            Iterator<HashMap.Entry<ArrayList<Integer>, Integer>> it2 = N.Triples.entrySet().iterator();  
            while(it2.hasNext()){  
                Entry<ArrayList<Integer>,Integer> entry2 = it2.next();  
                ArrayList<Integer> position2 = entry2.getKey();  
                int value2 = entry2.getValue();                                     
                if(position.get(1) == position2.get(0)){  //前面的列等于后面的行的话
                    flag = 1;  
                    ArrayList<Integer> temp = new ArrayList<Integer>();  
                    temp.add(position.get(0));  
                    temp.add(position2.get(1));  
                    int v = value * value2;  
                    if(triples.containsKey(temp)){  
                        triples.put(temp, triples.get(temp) + v);                                                    
                    }  
                    else{  
                        triples.put(temp, v);  
                    }                 
                }  
                  
            }     
        }     
        SMatrix s = new SMatrix(triples,M.rowNum,N.colNum);  
        return s;  
    } 
}

 

posted @ 2019-08-13 16:03  LeeJuly  阅读(535)  评论(0)    收藏  举报