用MapReduce实现大矩阵相乘

方法一

第一阶段的Map,对每一数据项 $(i,j,v)$,若来自矩阵 A ,则输出 $(j,(A,i,v))$,若来自矩阵 B ,则输出$(i,(B,j,v))$,这 样 矩 阵 A 的 第j 列 和 矩 阵 B 的 第 i 行 会被同一个 reduce 节点处理;在 Reduce 端,将来自 A 和 B 的数据分别存储在数组 listA 和 listB ,对来自 A 中的数据 $(j,(A,i,v))$ ,令 listA[i] = v ,对来自 B 的数 据$(i,(B,j,v))$,令 listB[ j] = v 。将 listA 中的每个项 乘以 listB 中的每个项然后输出,对于 listA[i] 和 listB[ j] ,输出 $((i,j),listA[i]*listB[j])$。在第二阶段只 需要将第一阶段输出中有相同 key 的数据求和即可。 

MatrixMultiplication1.java

  1 package com.lagou.mining.hdfs;
  2 
  3 import java.io.IOException;
  4 import java.util.HashMap;
  5 import java.util.Iterator;
  6 import java.util.Map;
  7 import java.util.Map.Entry;
  8 
  9 import org.apache.hadoop.conf.Configuration;
 10 import org.apache.hadoop.fs.FileSystem;
 11 import org.apache.hadoop.fs.Path;
 12 import org.apache.hadoop.io.DoubleWritable;
 13 import org.apache.hadoop.io.IntWritable;
 14 import org.apache.hadoop.io.LongWritable;
 15 import org.apache.hadoop.io.Text;
 16 import org.apache.hadoop.mapreduce.Job;
 17 import org.apache.hadoop.mapreduce.Mapper;
 18 import org.apache.hadoop.mapreduce.Reducer;
 19 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 20 import org.apache.hadoop.mapreduce.lib.input.FileSplit;
 21 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 22 
 23 
 24 /**
 25  * 矩阵相乘。
 26  * 
 27  * @Author:orisun
 28  * @Since:2015-6-26
 29  * @Version:1.0
 30  */
 31 public class MatrixMultiplication1 {
 32 
 33     private static final String MATRIXFILE1 = "A";
 34     private static final String MATRIXFILE2 = "B";
 35 
 36     public static class EleEmitMapper extends
 37             Mapper<IntWritable, Text, IntWritable, Text> {
 38 
 39         private String matrixFile1 = null;
 40         private String matrixFile2 = null;
 41 
 42         @Override
 43         protected void setup(Context context) {
 44             matrixFile1 = context.getConfiguration().get(MATRIXFILE1);
 45             matrixFile2 = context.getConfiguration().get(MATRIXFILE2);
 46         }
 47 
 48         @Override
 49         protected void map(IntWritable key, Text value, Context context)
 50                 throws IOException, InterruptedException {
 51             IntWritable i = key;
 52             String filename = ((FileSplit) context.getInputSplit()).getPath()
 53                     .getName();
 54             String[] arr = value.toString().split("\\s+");
 55             // 发出去N个互不相同的key
 56             if (filename.equals(matrixFile1)) {
 57                 for (int j = 0; j < arr.length; j++) {
 58                     double v = Double.parseDouble(arr[j]);
 59                     //元素为0时不发出,对于稀疏矩阵这样效率会比较高
 60                     if (v != 0) {
 61                         context.write(new IntWritable(j), new Text(MATRIXFILE1
 62                                 + "\t" + i.toString() + "\t" + v));
 63                     }
 64                 }
 65             } else if (filename.equals(matrixFile2)) {
 66                 for (int j = 0; j < arr.length; j++) {
 67                     double v = Double.parseDouble(arr[j]);
 68                     //元素为0时不发出,对于稀疏矩阵这样效率会比较高
 69                     if (v != 0) {
 70                         context.write(i, new Text(MATRIXFILE2 + "\t" + j + "\t"
 71                                 + v));
 72                     }
 73                 }
 74             }
 75         }
 76 
 77         @Override
 78         protected void cleanup(Context context) {
 79         }
 80     }
 81 
 82     public static class MultiplicationReducer extends
 83             Reducer<IntWritable, Text, Text, DoubleWritable> {
 84 
 85         @Override
 86         protected void reduce(IntWritable key, Iterable<Text> value,
 87                 Context context) throws IOException, InterruptedException {
 88             Map<Integer, Double> listA = new HashMap<Integer, Double>();
 89             Map<Integer, Double> listB = new HashMap<Integer, Double>();
 90             Iterator<Text> itr = value.iterator();
 91             while (itr.hasNext()) {
 92                 String[] arr = itr.next().toString().split("\\s+");
 93                 String matrixTag = arr[0];
 94                 int pos = Integer.parseInt(arr[1]);
 95                 double v = Double.parseDouble(arr[2]);
 96                 if (MATRIXFILE1.equals(matrixTag)) {
 97                     listA.put(pos, v);
 98                 } else if (MATRIXFILE2.equals(matrixTag)) {
 99                     listB.put(pos, v);
100                 }
101             }
102             // 在此需要进行N*N次的乘法
103             for (Entry<Integer, Double> entryA : listA.entrySet()) {
104                 int posA = entryA.getKey();
105                 double valA = entryA.getValue();
106                 for (Entry<Integer, Double> entryB : listB.entrySet()) {
107                     int posB = entryB.getKey();
108                     double valB = entryB.getValue();
109                     double production = valA * valB;
110                     context.write(new Text(posA + "\t" + posB),
111                             new DoubleWritable(production));
112                 }
113             }
114         }
115     }
116 
117     public static class SumMapper extends
118             Mapper<LongWritable, Text, Text, DoubleWritable> {
119 
120         @Override
121         protected void map(LongWritable key, Text value, Context context)
122                 throws IOException, InterruptedException {
123             String[] arr = value.toString().split("\\s+");
124             if (arr.length == 3) {
125                 context.write(new Text(arr[0] + "\t" + arr[1]),
126                         new DoubleWritable(Double.parseDouble(arr[2])));
127             }
128         }
129     }
130 
131     public static class SumCombiner extends
132             Reducer<Text, DoubleWritable, Text, DoubleWritable> {
133 
134         @Override
135         protected void reduce(Text key, Iterable<DoubleWritable> value,
136                 Context context) throws IOException, InterruptedException {
137             double sum = 0;
138             Iterator<DoubleWritable> itr = value.iterator();
139             while (itr.hasNext()) {
140                 sum += itr.next().get();
141             }
142             context.write(key, new DoubleWritable(sum));
143         }
144     }
145 
146     public static class SumReducer extends
147             Reducer<Text, DoubleWritable, Text, DoubleWritable> {
148 
149         @Override
150         protected void reduce(Text key, Iterable<DoubleWritable> value,
151                 Context context) throws IOException, InterruptedException {
152             double sum = 0;
153             Iterator<DoubleWritable> itr = value.iterator();
154             while (itr.hasNext()) {
155                 sum += itr.next().get();
156             }
157             context.write(key, new DoubleWritable(sum));
158         }
159     }
160 
161     /**
162      * matrix1 * matrix2 = product<br>
163      * matrixFile1:输入文件,m行q列。<br>
164      * matrixFile2:输入文件,q行n列。<br>
165      * productFile:输出文件,m行n列。<br>
166      * 各列用空白符分隔。
167      */
168     public static void main(String[] args) throws IOException,
169             ClassNotFoundException, InterruptedException {
170         if (args.length < 3) {
171             System.err
172                     .println("please input 3 cmd args: matrixFile1 matrixFile2 productFile");
173             System.exit(1);
174         }
175         String matrixFile1 = args[0];
176         String matrixFile2 = args[1];
177         String productFile = args[2];
178 
179         Configuration conf = new Configuration();
180         FileSystem fs = FileSystem.get(conf);
181         Path inFile1 = new Path(matrixFile1);
182         Path inFile2 = new Path(matrixFile2);
183         conf.set(MATRIXFILE1, inFile1.getName());
184         conf.set(MATRIXFILE2, inFile2.getName());
185         Path midFile = new Path(inFile1.getParent().toUri().getPath()
186                 + "/product_tmp");
187         Path outFile = new Path(productFile);
188         if (!fs.exists(inFile2) || !fs.exists(inFile1)) {
189             System.err.println("input matrix file does not exists");
190             System.exit(1);
191         }
192         if (fs.exists(midFile)) {
193             fs.delete(midFile, true);
194         }
195         if (fs.exists(outFile)) {
196             fs.delete(outFile, true);
197         }
198 
199         {
200             Job productionJob1 = Job.getInstance(conf);
201             productionJob1.setJobName("MatrixMultiplication1_step1");
202             productionJob1.setJarByClass(MatrixMultiplication1.class);
203 
204             FileInputFormat.addInputPath(productionJob1, inFile1);
205             FileInputFormat.addInputPath(productionJob1, inFile2);
206             productionJob1.setInputFormatClass(MatrixInputFormat.class);
207             productionJob1.setMapperClass(EleEmitMapper.class);
208             productionJob1.setMapOutputKeyClass(IntWritable.class);
209             productionJob1.setMapOutputValueClass(Text.class);
210 
211             FileOutputFormat.setOutputPath(productionJob1, midFile);
212             productionJob1.setReducerClass(MultiplicationReducer.class);
213             productionJob1.setNumReduceTasks(12);
214             productionJob1.setOutputKeyClass(Text.class);
215             productionJob1.setOutputValueClass(DoubleWritable.class);
216 
217             productionJob1.waitForCompletion(true);
218         }
219 
220         {
221             Job productionJob2 = Job.getInstance(conf);
222             productionJob2.setJobName("MatrixMultiplication1_step2");
223             productionJob2.setJarByClass(MatrixMultiplication1.class);
224 
225             FileInputFormat.setInputPaths(productionJob2, midFile);
226             productionJob2.setMapperClass(SumMapper.class);
227             productionJob2.setMapOutputKeyClass(Text.class);
228             productionJob2.setMapOutputValueClass(DoubleWritable.class);
229 
230             FileOutputFormat.setOutputPath(productionJob2, outFile);
231             productionJob2.setCombinerClass(SumCombiner.class);
232             productionJob2.setReducerClass(SumReducer.class);
233             productionJob2.setNumReduceTasks(1);
234             productionJob2.setOutputKeyClass(Text.class);
235             productionJob2.setOutputValueClass(DoubleWritable.class);
236 
237             productionJob2.waitForCompletion(true);
238             
239         }
240         fs.delete(midFile, true);
241         System.exit(0);
242     }
243 }
View Code

MatrixInputFormat.java

 1 import java.io.IOException;
 2 
 3 import org.apache.hadoop.fs.Path;
 4 import org.apache.hadoop.io.IntWritable;
 5 import org.apache.hadoop.io.Text;
 6 import org.apache.hadoop.mapreduce.InputSplit;
 7 import org.apache.hadoop.mapreduce.JobContext;
 8 import org.apache.hadoop.mapreduce.RecordReader;
 9 import org.apache.hadoop.mapreduce.TaskAttemptContext;
10 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
11 
12 public class MatrixInputFormat extends FileInputFormat<IntWritable,Text>{
13 
14     @Override
15     public RecordReader<IntWritable, Text> createRecordReader(
16             InputSplit split, TaskAttemptContext context)
17             throws IOException, InterruptedException {
18         return new MatrixLineRecordReader();
19     }
20     
21     /*因为读入时要记录行号,所以要保证中有一个mapper,这样行号才是一致的*/
22     @Override
23     protected boolean isSplitable(JobContext context, Path filename){
24         return false;
25     }
26 
27 }
View Code

MatrixLineRecordReader.java

 1 import java.io.IOException;
 2 
 3 import org.apache.hadoop.conf.Configuration;
 4 import org.apache.hadoop.fs.FSDataInputStream;
 5 import org.apache.hadoop.fs.FileSystem;
 6 import org.apache.hadoop.fs.Path;
 7 import org.apache.hadoop.io.IntWritable;
 8 import org.apache.hadoop.io.Text;
 9 import org.apache.hadoop.mapreduce.InputSplit;
10 import org.apache.hadoop.mapreduce.RecordReader;
11 import org.apache.hadoop.mapreduce.TaskAttemptContext;
12 import org.apache.hadoop.mapreduce.lib.input.FileSplit;
13 import org.apache.hadoop.util.LineReader;
14 
15 public class MatrixLineRecordReader extends RecordReader<IntWritable, Text> {
16     private LineReader in;
17     private int lineno = 0; // 行编号从0开始
18     private boolean more = true;
19     private IntWritable key = null;
20     private Text value = null;
21 
22     @Override
23     public void initialize(InputSplit inputSplit, TaskAttemptContext context)
24             throws IOException, InterruptedException {
25         FileSplit split = (FileSplit) inputSplit;
26         Configuration conf = context.getConfiguration();
27         Path file = split.getPath();
28         FileSystem fs = file.getFileSystem(conf);
29         FSDataInputStream fileIn = fs.open(file);
30         in = new LineReader(fileIn, conf);
31     }
32 
33     @Override
34     public boolean nextKeyValue() throws IOException, InterruptedException {
35         if (key == null)
36             key = new IntWritable();
37         if (value == null)
38             value = new Text();
39         int readsize = in.readLine(value);
40         if (readsize == 0) {
41             more = false;
42             return false;
43         }
44         key.set(lineno);
45         lineno++;
46         return true;
47     }
48 
49     @Override
50     public IntWritable getCurrentKey() throws IOException, InterruptedException {
51         return key;
52     }
53 
54     @Override
55     public Text getCurrentValue() throws IOException, InterruptedException {
56         return value;
57     }
58 
59     @Override
60     public float getProgress() throws IOException, InterruptedException {
61         if (more)
62             return 0.0f;
63         else
64             return 100f;
65     }
66 
67     @Override
68     public void close() throws IOException {
69         in.close();
70     }
71 
72 }
View Code

方法一的优点是:再大的矩阵也可以处理。缺点是:网络IO太大,速度慢。如果两个矩阵中有一个不是太大,可以采用方法二。

方法二

对于$A\times{B}$,如果B不是很大,可以把B放到分布式缓存上,把A按行切分发送给多个Mapper Task,各个Mapper Task把B完全放入内存中。

MatrixMultiplication2.java

  1 import java.io.BufferedReader;
  2 import java.io.FileReader;
  3 import java.io.IOException;
  4 import java.net.URISyntaxException;
  5 
  6 import org.apache.hadoop.conf.Configuration;
  7 import org.apache.hadoop.fs.FileSystem;
  8 import org.apache.hadoop.fs.Path;
  9 import org.apache.hadoop.io.DoubleWritable;
 10 import org.apache.hadoop.io.IntWritable;
 11 import org.apache.hadoop.io.Text;
 12 import org.apache.hadoop.mapreduce.Job;
 13 import org.apache.hadoop.mapreduce.Mapper;
 14 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 15 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 16 
 17 
 18 /**
 19  * 计算C=A*B。A的规模是M*R,B的规模是R*N,则C的规模是M*N。<br>
 20  * A矩阵按行切分,发送给不同的Mapper,B作为分布式缓存,全量加载到各个Mapper的内存当中。
 21  * 
 22  * @Author:orisun
 23  * @Since:2015-7-14
 24  * @Version:1.0
 25  */
 26 public class MatrixMultiplication2 {
 27 
 28 //    private static final Log LOG = LogFactory.getLog(FileInputFormat.class);
 29 
 30     private static final String ROW1 = "M";
 31     private static final String COL1 = "R";
 32     private static final String COL2 = "N";
 33 
 34     public static class MultipleMapper extends
 35             Mapper<IntWritable, Text, Text, DoubleWritable> {
 36 
 37         private double[][] B;
 38         private int M;
 39         private int R;
 40         private int N;
 41 
 42         @Override
 43         protected void setup(Context context) throws IOException {
 44             M = Integer.parseInt(context.getConfiguration().get(ROW1));
 45             R = Integer.parseInt(context.getConfiguration().get(COL1));
 46             N = Integer.parseInt(context.getConfiguration().get(COL2));
 47             B = new double[R][];
 48             // URI[] localCacheFiles =
 49             // context.getCacheFiles();//这个方法是新版推荐的,但是不work
 50             @SuppressWarnings("deprecation")
 51             Path[] localFiles = context.getLocalCacheFiles();// 这个方法虽然是过期的,但是work
 52             BufferedReader br = new BufferedReader(new FileReader(
 53                     localFiles[0].toString()));
 54             for (int i = 0; i < R; i++) {
 55                 String line = br.readLine();
 56                 assert line != null;
 57                 B[i] = new double[N];
 58                 String[] arr = line.split("\\s+");
 59                 assert arr.length == N;
 60                 for (int j = 0; j < N; j++) {
 61                     B[i][j] = Double.parseDouble(arr[j]);
 62                 }
 63             }
 64             br.close();
 65 //            LOG.warn("B scale:" + B.length + "," + B[B.length - 1].length);
 66             System.out.print("B scale:" + B.length + "," + B[B.length - 1].length);
 67         }
 68 
 69         @Override
 70         protected void map(IntWritable key, Text value, Context context)
 71                 throws IOException, InterruptedException {
 72             int lineno = key.get();// 第一个矩阵的行号
 73             assert lineno < M;
 74             String[] arr = value.toString().split("\\s+");
 75             assert arr.length == R;
 76             for (int j = 0; j < N; j++) {
 77                 double sum = 0.0;
 78                 for (int i = 0; i < R; i++) {
 79                     sum += B[i][j] * Double.parseDouble(arr[i]);
 80                 }
 81                 context.write(new Text(lineno + "," + j), new DoubleWritable(
 82                         sum));
 83             }
 84         }
 85     }
 86 
 87     public static void main(String[] args) throws IOException,
 88             URISyntaxException, ClassNotFoundException, InterruptedException {
 89         if (args.length < 6) {
 90             System.err
 91                     .println("please input 6 cmd args: matrixFile1 matrixFile2 row_of_m1 col_of_m1 col_of_m2 productFile");
 92             System.exit(1);
 93         }
 94         String matrixFile1 = args[0];
 95         String matrixFile2 = args[1];
 96         String productFile = args[5];
 97 
 98         Configuration conf = new Configuration();
 99         FileSystem fs = FileSystem.get(conf);
100         Path inFile1 = new Path(matrixFile1);
101         Path inFile2 = new Path(matrixFile2);
102         conf.set(ROW1, args[2]);
103         conf.set(COL1, args[3]);
104         conf.set(COL2, args[4]);
105 
106         Path outFile = new Path(productFile);
107         if (!fs.exists(inFile1) || !fs.exists(inFile2)) {
108             System.err.println("input matrix file does not exists");
109             System.exit(1);
110         }
111         if (fs.exists(outFile)) {
112             fs.delete(outFile, true);
113         }
114 
115         Job job = Job.getInstance(conf);
116         job.setJobName("MatrixMultiplication2");
117         job.setJarByClass(MatrixMultiplication2.class);
118         job.addCacheFile(inFile2.toUri());// 分发缓存文件
119 
120         FileInputFormat.setInputPaths(job, inFile1);
121         FileOutputFormat.setOutputPath(job, outFile);
122         job.setInputFormatClass(LineInputFormat.class);
123         job.setMapperClass(MultipleMapper.class);
124         job.setMapOutputKeyClass(Text.class);
125         job.setMapOutputValueClass(DoubleWritable.class);
126 
127         job.setNumReduceTasks(0);
128         int exitcode = job.waitForCompletion(true) ? 0 : 1;
129         System.exit(exitcode);
130     }
131 }
View Code

LineInputFormat.java

 1 import java.io.IOException;
 2 import java.util.ArrayList;
 3 import java.util.List;
 4 
 5 import org.apache.commons.logging.Log;
 6 import org.apache.commons.logging.LogFactory;
 7 import org.apache.hadoop.fs.BlockLocation;
 8 import org.apache.hadoop.fs.FileStatus;
 9 import org.apache.hadoop.fs.FileSystem;
10 import org.apache.hadoop.fs.Path;
11 import org.apache.hadoop.io.IntWritable;
12 import org.apache.hadoop.io.Text;
13 import org.apache.hadoop.mapreduce.InputSplit;
14 import org.apache.hadoop.mapreduce.JobContext;
15 import org.apache.hadoop.mapreduce.RecordReader;
16 import org.apache.hadoop.mapreduce.TaskAttemptContext;
17 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
18 import org.apache.hadoop.mapreduce.lib.input.FileSplit;
19 
20 public class LineInputFormat extends FileInputFormat<IntWritable, Text> {
21     public long mySplitSize = 5 * 1024 * 1024; // 一个split的大小(单位B)。注意:对于每个inputsplit都会产生一个mapper
22     // task来处理它,所以要控制inputsplit的数量
23     private static final Log LOG = LogFactory.getLog(FileInputFormat.class);
24     private static final double SPLIT_SLOP = 1.1; // 10% slop
25 
26     @Override
27     public RecordReader<IntWritable, Text> createRecordReader(InputSplit split,
28             TaskAttemptContext context) {
29         return new MatrixLineRecordReader();
30     }
31 
32     @Override
33     protected boolean isSplitable(JobContext context, Path file) {
34         return true;// 要求分片
35     }
36 
37     @Override
38     public List<InputSplit> getSplits(JobContext job) throws IOException {
39         // generate splits
40         List<InputSplit> splits = new ArrayList<InputSplit>(); // 用以存放生成的split的
41         for (FileStatus file : listStatus(job)) {// filestatues是文件对应的信息,具体看对应的类
42             Path path = file.getPath();
43             FileSystem fs = path.getFileSystem(job.getConfiguration());
44             long length = file.getLen(); // 得到文本的长度
45             BlockLocation[] blkLocations = fs.getFileBlockLocations(file, 0,
46                     length); // 取得文件所在块的位置
47             if ((length != 0) && isSplitable(job, path)) { // 如果文件不为空,并且可以分片的话就进行下列操作,
48                 // 可以通过调整splitSize的大小来控制对应的文件块的大小,比如设置splitSize=100,那么就可以控制成每个split的大小
49                 // 但是问题是,我是要求按行进行处理的,虽然这样应该也可以按行进行切分了,不过却不能保证每个split对应的行数都是相等的
50                 // 一般情况是如果文件大于64M(32M)就会使用块大小来作为split
51                 long splitSize = mySplitSize;
52                 long bytesRemaining = length; // 文本的长度
53                 while (((double) bytesRemaining) / splitSize > SPLIT_SLOP) {// 剩下的文本长度大于split大小的SPLIT_SLOP倍数
54                     int blkIndex = getBlockIndex(blkLocations, length
55                             - bytesRemaining);// 找到对应block块中对应的第0个字符开始,
56                     splits.add(new FileSplit(path, length - bytesRemaining,
57                             splitSize, blkLocations[blkIndex].getHosts()));
58                     // 这个是形成split的代码FileSplit(文件路径,0,split大小,host)
59                     // 原始函数为 FileSplit(Path file, long start, long length,
60                     // String[] hosts) {
61                     // 但是应该可以通过重写FileSplit来实现对应的要求
62                     bytesRemaining -= splitSize;
63                 }
64                 if (bytesRemaining != 0) {
65                     splits.add(new FileSplit(path, length - bytesRemaining,
66                             bytesRemaining,
67                             blkLocations[blkLocations.length - 1].getHosts()));
68                 }
69             } else if (length != 0) {
70                 splits.add(new FileSplit(path, 0, length, blkLocations[0]
71                         .getHosts()));
72             } else {
73                 splits.add(new FileSplit(path, 0, length, new String[0]));
74             }
75         }
76         LOG.info("NUMBER OF SPLITS: " + splits.size());
77         return splits;
78     }
79 }
View Code

 

posted @ 2015-07-14 18:45  张朝阳  阅读(2991)  评论(0编辑  收藏  举报