一种基于zookeeper的分布式队列的设计与实现

package com.ysl.zkclient.queue;

import com.ysl.zkclient.ZKClient;
import com.ysl.zkclient.exception.ZKNoNodeException;
import com.ysl.zkclient.utils.ExceptionUtil;
import org.apache.zookeeper.CreateMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.List;

/**
 * 一种分布式队列的实现
 * @param <T>
 */
public class ZKDistributedQueue<T extends Serializable> {

    private static final Logger LOG = LoggerFactory.getLogger(ZKDistributedQueue.class);

    private static final String ELEMENT_NAME = "node";

    private ZKClient client;
    private String rootPath;

    /**
     * 创建分布式队列
     * @param client zk客户端
     * @param rootPath 队列的跟路径
     */
    public ZKDistributedQueue(ZKClient client, String rootPath) {
        this.client = client;
        this.rootPath = rootPath;
        if(!client.exists(rootPath)){
            throw new ZKNoNodeException("the root path is not exists, please create path first ["+rootPath+"]");
        }
    }

    /**
     * 添加一个元素
     * @param node
     * @return
     */
    public boolean offer(T node){
        try{
            client.create(rootPath+"/"+ELEMENT_NAME + "-",node, CreateMode.PERSISTENT_SEQUENTIAL);
        }catch (Exception e){
            throw ExceptionUtil.convertToRuntimeException(e);
        }
        return true;
    }

    /**
     * 删除并返回顶部元素
     * @return
     */
    public T pool(){
        while(true){
            Node node = getFirstNode();
            if(node == null){
                return null;
            }

            try{
                boolean flag = client.delete(node.getName());
                if(flag){
                    return (T)node.getData();
                }else{
                    //删除失败,说明数据已经被其他的线程获取,重新获取底部元素
                }
            }catch (Exception e){
                throw ExceptionUtil.convertToRuntimeException(e);
            }
        }
    }

    /**
     * 获取队列顶部元素
     * @return
     */
    private Node<T> getFirstNode() {
        try{
            while(true){
                List<String> children = client.getChild(rootPath,true);
                if(children == null || children.isEmpty()){
                    return null;
                }

                String nodeName = getNodeName(children);
                try{
                    return new Node<T>(rootPath+"/"+nodeName,(T)client.getData(rootPath+"/"+nodeName));
                }catch (ZKNoNodeException e){
                    //如果抛出此异常,证明该节点已被其他线程获取
                }
            }
        }catch (Exception e){
            throw ExceptionUtil.convertToRuntimeException(e);
        }
    }

    /**
     * 获取编号最小的节点
     * @param children
     * @return
     */
    private String getNodeName(List<String> children) {
        String child= children.get(0);
        for(String path : children){
            if(path.compareTo(child) < 0){
                child = path;
            }
        }
        return child;
    }

    public boolean isEmpty(){
        return client.getChild(rootPath,true).size() == 0;
    }

    public T peek(){
        Node<T> node = getFirstNode();
        if(node == null){
            return null;
        }
        return node.getData();
    }

    private class Node<T>{

        private String name;
        private T data;

        public Node(String name, T data) {
            this.name = name;
            this.data = data;
        }

        public String getName() {
            return name;
        }

        public T getData() {
            return data;
        }
    }
}

测试代码

/**
     * 测试分布式队列
     * @throws Exception 
     * @return void
     */
    @Test
    public void testDistributedQueue() throws Exception{
        final String rootPath = "/zk/queue";
        //创建rootPath
        zkClient.createRecursive(rootPath, null, CreateMode.PERSISTENT);
        
        final List<String> list1 = new ArrayList<String>();
        final List<String> list2 = new ArrayList<String>();
        for(int i=0;i<21;i++){
            Thread thread1 = new Thread(new Runnable() {
                public void run() {
                    ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
                    queue.offer(Thread.currentThread().getName());
                    list1.add(Thread.currentThread().getName());
                }
            });
            thread1.start();
        }
        
        //等待事件到达
        int size1 = TestUtil.waitUntil(21, new Callable<Integer>() {
            @Override
            public Integer call() throws Exception {
                return list1.size();
            }
            
        }, TimeUnit.SECONDS, 100);
        System.out.println(zkClient.getChildren(rootPath));

        for(int i=0;i<20;i++){
            Thread thread = new Thread(new Runnable() {
                public void run() {
                    ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
                    list2.add(queue.poll());
                }
            });
            thread.start();
        }
        //等待事件到达
        int size2 = TestUtil.waitUntil(20, new Callable<Integer>() {
            @Override
            public Integer call() throws Exception {
                return list2.size();
            }
            
        }, TimeUnit.SECONDS, 100);
        assertThat(size2).isEqualTo(20);
        boolean flag = true;
        for(int i =0;i<20;i++){
           if(!list1.get(i).equals(list2.get(i))){
               flag = false;
               break;
           }
        }
        assertThat(flag).isTrue();
        
        ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
        assertThat(queue.peek()).isEqualTo(queue.poll());
    }

 

posted @ 2017-11-16 16:55  木易森林  阅读(441)  评论(0编辑  收藏  举报