十二道MR习题 - 4 - TopN问题

题目:

有一个很大的文件,这文件中的内容全部都是数字,要求尝试从这个文件中找出最大的10个数字。

分析:

看起来像是一个比较简单的问题。不用大数据框架的话,也能比较轻易的实现:就是逐个读取文件中的每个数字,放到一个大顶堆结构中;将大顶堆放满以后,每读取一个数字就将之和大顶堆中的最小值进行比较,如果其大于这个最小值的话,就将其放入堆中,并将堆中的最小值删除;这样读取到最后,堆中剩下来的内容就是top 10了。

用MapReduce实现的话也说不上困难:我们只使用Map任务读取文件,而reduce中输出的内容就是一个有序的结果集,那么后十位自然就是Top10了。这方案虽说可行,但绝说不上是好的方案。

换个思路:map任务中先完成一轮过滤(没必要多添一重Combiner),先取出每个Map中的top10来,而后在reduce中再进行一轮筛选,从所有map的top10中再选出个top10来。这样处理效率应该会高一些。

看看实现过程:

package com.zhyea.dev;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.TreeSet;

public class TopN {

    private static final Logger logger = LoggerFactory.getLogger(TopN.class);


    public static class SplitterMapper extends Mapper<Object, Text, IntWritable, NullWritable> {

        private static final IntWritable intWritable = new IntWritable();

        private static final TreeSet<Integer> set = new TreeSet<>();

        @Override
        public void map(Object key, Text value, Context context) {
            int num = Integer.valueOf(value.toString());

            if (set.size() < 10) {
                set.add(num);
                return;
            }

            if (num > set.first()) {
                set.add(num);
                set.pollFirst();
            }
        }

        @Override
        public void cleanup(Context context) {
            for (Integer i : set) {
                intWritable.set(i);
                try {
                    context.write(intWritable, NullWritable.get());
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }


    public static class IntegrateReducer extends Reducer<IntWritable, NullWritable, IntWritable, NullWritable> {


        private static final IntWritable intWritable = new IntWritable();
        private static final TreeSet<Integer> set = new TreeSet<>();

        @Override
        public void reduce(IntWritable key, Iterable<NullWritable> values, Context context) {
            try {
                int num = key.get();
                if (set.size() < 10) {
                    set.add(num);
                    return;
                }

                if (num > set.first()) {
                    set.add(num);
                    set.pollFirst();
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        @Override
        public void cleanup(Context context) {
            for (Integer i : set) {
                intWritable.set(i);
                try {
                    context.write(intWritable, NullWritable.get());
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

    }


    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {

        Configuration conf = new Configuration();

        Job job = Job.getInstance(conf, "top-n");
        job.setJarByClass(TopN.class);

        job.setMapperClass(SplitterMapper.class);
        job.setReducerClass(IntegrateReducer.class);

        job.setOutputKeyClass(IntWritable.class);
        job.setOutputValueClass(NullWritable.class);

        FileInputFormat.addInputPath(job, new Path(args[0]));
        FileOutputFormat.setOutputPath(job, new Path(args[1]));

        System.exit(job.waitForCompletion(true) ? 0 : 1);
    }


}

程序里在map或reduce方法中没有做任何输出,只是实现了比较逻辑,真正的输出是在cleanup方法中完成的。

用spark实现的话可以先做全排序,然后排重,take前N个记录就可以了。当然也可以按照上面的思路来做实现,下面的代码就是按照我们前面的思路来做的实现:

package com.zhyea.dev

import java.util

import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.spark.{SparkConf, SparkContext}

import collection.JavaConversions.asScalaIterator

object TopTen {

  def main(args: Array[String]): Unit = {
    val inputPath = args(0)
    val outputPath = args(1)
    val conf = new SparkConf().setAppName("Top Ten")
    val sc = new SparkContext(conf)
    val data = sc.hadoopFile[LongWritable, Text, TextInputFormat](inputPath)
    data.mapPartitions[Long](findTopTen)
      .repartition(1)
      .distinct()
      .sortBy(_.toLong, false)
      .mapPartitions(itr => itr.slice(0, 10))
      .saveAsTextFile(outputPath)


    def findTopTen(itr: Iterator[(LongWritable, Text)]) = {
      val set = new util.TreeSet[Long]()
      itr.foreach(p => {
        val v = p._2.toString.toLong
        if (set.size <= 10) {
          set.add(v)
        } else if (v > set.first) {
          set.pollFirst()
          set.add(v)
        }
      })
      set.iterator
    }

  }

}

############################

posted @ 2017-09-27 22:02  robin·张  阅读(585)  评论(0编辑  收藏  举报