Sparse vector multiplication

// sparse vector mul 

// 1. the the vector is sparse , how to save time 
// only save the non zero element
// what data strucure do you want to use to keep the non zero element 

// use list of list 
// the inner list has two elements , the first one is the index of the 
// non zero element, and the second element is the non zero value

// for example, {2, 4} means the non zero element is at index 2, and 
// its value is 4 

// so we only need to multiply the numbers with the same index 
// say we have input like like 
// A : {{2, 4}, {0, 10}, {3, 15}}
// B : {{1, 3}, {3, 5}, {2, 6}}


// when the input is not sorted
// if we use brute force, if the size of A is m, and the size of B is n
// then the time is m * n , we need to check every pair

// code 
public int SparseVectorMultiplication(List<List<Integer>> a, List<List<Integer>> b){
  // sanity check
  if(a == null || b == null || a.size() == 0  || b.size() == 0){
    return 0;
  }
  
  int m = a.size();
  int n = b.size();
  
  int res = 0;
  for(int i = 0; i < a.size(); i++){
    ArrayList<Integer> pairA = a.get(i);
    int indexA = pairA.get(0);
    
    for(int j = 0; j < b.size(); j++){
      ArrayList<Integer> pairB = b.get(j);
      int indexB = pairB.get(0);
      
      if(indexA == indexB){
        res += pairA.get(1) * pairB.get(1);
        break; // no need to traverse the b for this index A 
      }
    }
  }
  return res;  
}


// input is sorted by index
// time is m + n 
// need to override your own comparator , compare by index 
public class Solution{
  private Comparator<ArrayList<Integer>> sparseVectorComparator = new Comparator<ArrayList<Integer>>(){
    public int compare (ArrayList<Integer> a1, ArrayList<Integer> a2 ){
      if(a1.get(0) < a2.get(0)){
        return -1;
      }else if(a1.get(0) > a2.get(0)){
        return 1;
      }else{
        return 0;
      }
    }
  }
  
  Collections.sort(a, sparseVectorComparator);
  Collections.sort(b, sparseVectorComparator);
  
  int m = a.size();
  int n = b.size();
  
  int i = 0; 
  int j = 0;
  
  while(i < m && j < n){
    ArrayList<Integer> pairA = a.get(i);
    ArrayList<Integer> pairB = b.get(j);
    
    int indexA = pairA.get(0);
    int indexB = pairB.get(0);
    
    if(indexA < indexB){
      i++;
    }else if(indexA > indexB){
      j++;
    }else{
      // i = j 
      res += pairA.get(1) * pairB.get(1);
      i++;
      j++;
    }
  }
  return res;
}



// two input are sorted by index, but one input size is much larger
// than the other input size 

// how to solve it more efficiently 
// iterative thru the shorter one and do bianry search on b, 
// time is m* log n

public int sparseVectorMultiplication(List<List<Integer>> a, List<List<Integer>> b){
  int i = 0;
  int j = 0;
  int m = a.size();
  int n = b.size();
  
  while(i < m){
    ArrayList<Integer> pairA = a.get(i);
    int indexA = pairA.get(0);
    
    j = search(b, j, n, indexA);
    
    ArrayList<Integer> pairB = b.get(j);
    int indexB = pairB.get(0);
    j++; // the next time, j's start point is j + 1
    
    if(indexA == indexB){
      res += pairA.get(1) * pairB.get(1);
    }
  }
  return res;
}


// a : {{4, 3}}
// b : {{},{}, {}, {4, 6}}

private int search (List<List<Integer>> array, int start, int end, int target){
  while(start + 1 < end){
    int mid = start + (end - start) / 2;
    List<Integer> pair = array.get(mid);
    if(pair.get(0) == target){
      return mid;
    }else if(pair.get(0) < target){
      start = mid;
    }else{
      end = mid;
    }
  }
  if(array.get(end).get(0) == target){
    return end;
  }
  return start;
}

 

posted on 2018-08-09 17:26  猪猪&#128055;  阅读(193)  评论(0)    收藏  举报

导航