java 实现DBScan聚类算法

  最近有一个需求,在地图上,将客户按照距离进行聚合。比如,a客户到b客户5km,b客户到c客户5km,那么abc就可以聚合成一个集合。首先想到的就是找一个根据坐标来聚合的算法,这里找了一些后来选择了较为简单也符合要求的DBScan聚类算法。

  它是一种基于密度的聚类算法,简单来说就是根据样本的紧密程度和数量将其分成多个集合。这个样本一般来说是一堆坐标点。参数可以为欧式距离邻域密度阈值(就是每次寻找相邻的点的最低数量)。最终返回多个样本集合。

 

2.java实现

  坐标点:这个类如果测试的话,只用到里面的point坐标点这个属性

import java.util.Collection;
import org.apache.commons.math.stat.clustering.Clusterable;
import org.apache.commons.math.util.MathUtils;

import bsh.This;

/**
 * @author xjx
 *
 */
public class CustomerPoint implements Clusterable<CustomerPoint>{

    
    private String sender;
    private String sender_addr;
    private int value;
    private final double[] point;

    
    public int getValue() {
        return value;
    }
    public void setValue(int value) {
        this.value = value;
    }
    public String getSender() {
        return sender;
    }
    public void setSender(String sender) {
        this.sender = sender;
    }
    public String getSender_addr() {
        return sender_addr;
    }
    public void setSender_addr(String sender_addr) {
        this.sender_addr = sender_addr;
    }

    public CustomerPoint(final double[] point) {
        this.point = point;
    }

    public double[] getPoint() {
        return point;
    }

    public double distanceFrom(final CustomerPoint p) {
        return MathUtils.distance(point, p.getPoint());
    }

    public CustomerPoint centroidOf(final Collection<CustomerPoint> points) {
        double[] centroid = new double[getPoint().length];
        for (CustomerPoint p : points) {
            for (int i = 0; i < centroid.length; i++) {
                centroid[i] += p.getPoint()[i];
            }
        }
        for (int i = 0; i < centroid.length; i++) {
            centroid[i] /= points.size();
        }
        return new CustomerPoint(centroid);
    }

    @Override
    public boolean equals(final Object other) {
        if (!(other instanceof CustomerPoint)) {
            return false;
        }
        final double[] otherPoint = ((CustomerPoint) other).getPoint();
        if (point.length != otherPoint.length) {
            return false;
        }
        for (int i = 0; i < point.length; i++) {
            if (point[i] != otherPoint[i]) {
                return false;
            }
        }
        return true;
    }
    @Override
    public String toString() {
        final StringBuffer buff = new StringBuffer("{");
        final double[] coordinates = getPoint();
        buff.append("lat:"+coordinates[0]+",");
        buff.append("lng:"+coordinates[1]+",");
        buff.append("value:"+this.getValue());
        buff.append("}");
        return buff.toString();
    }
}

2.算法实现和测试:

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.math3.util.MathUtils;
import ...CustomerPoint;
/**
 * 
 * @author xjx
 *
 */
public class DBScanTest3{
    //欧式距离
    private final double distance;
    //最低要求的寻找邻居数量
    private final int minPoints;
    
    private final Map<CustomerPoint, PointStatus> visited = new HashMap<CustomerPoint, PointStatus>();
    //点的标记,point:聚合内的点,noise:噪音点
    private enum PointStatus {
        NOISE,POINT
    }


    public DBScanTest3(final double distance, final int minPoints)
        throws Exception {
        if (distance < 0.0d) {
            throw new Exception("距离小于0");
        }
        if (minPoints < 0) {
            throw new Exception("点数小于0");
        }
        this.distance = distance;
        this.minPoints = minPoints;
    }

    public double getDistance() {
        return distance;
    }

    public int getMinPoints() {
        return minPoints;
    }
    
    public Map<CustomerPoint, PointStatus> getVisited() {
        return visited;
    }
    /**
     * 返回customerPoint的多个聚合
     * @param points
     * @return
     */
    public List<List<CustomerPoint>> cluster(List<CustomerPoint> points){

        final List<List<CustomerPoint>> clusters = new ArrayList<List<CustomerPoint>>();
                
        for (CustomerPoint point : points) {
        //如果已经被标记
if (visited.get(point) != null) { continue; } List<CustomerPoint> neighbors = getNeighbors(point, points); if (neighbors.size() >= minPoints) { visited.put(point, PointStatus.POINT); List<CustomerPoint> cluster = new ArrayList<CustomerPoint>();           //遍历所有邻居继续拓展找点 clusters.add(expandCluster(cluster, point, neighbors, points, visited)); } else { visited.put(point, PointStatus.NOISE); } } return clusters; } private List<CustomerPoint> expandCluster( List<CustomerPoint> cluster, CustomerPoint point, List<CustomerPoint> neighbors, List<CustomerPoint> points, Map<CustomerPoint, PointStatus> visited) { cluster.add(point); visited.put(point, PointStatus.POINT); int index = 0; //遍历 所有的邻居 while (index < neighbors.size()) { //移动当前的点 CustomerPoint current = neighbors.get(index); PointStatus pStatus = visited.get(current); if (pStatus == null) { List<CustomerPoint> currentNeighbors = getNeighbors(current, points); neighbors.addAll(currentNeighbors); }
          //如果该点未被标记,将点进行标记并加入到集合中
if (pStatus != PointStatus.POINT) { visited.put(current, PointStatus.POINT); cluster.add(current); } index++; } return cluster; } //找到所有的邻居 private List<CustomerPoint> getNeighbors(CustomerPoint point,List<CustomerPoint> points) { List<CustomerPoint> neighbors = new ArrayList<CustomerPoint>(); for (CustomerPoint neighbor : points) { if (visited.get(neighbor) != null) { continue; } if (point != neighbor && neighbor.distanceFrom(point) <= distance) { neighbors.add(neighbor); } } return neighbors; }
  //做数据进行测试
public static void main(String[] args) throws Exception { CustomerPoint customerPoint = new CustomerPoint(new double[] {3,8}); CustomerPoint customerPoint1 = new CustomerPoint(new double[] {4,7}); CustomerPoint customerPoint2 = new CustomerPoint(new double[] {4,8}); CustomerPoint customerPoint3 = new CustomerPoint(new double[] {5,6}); CustomerPoint customerPoint4 = new CustomerPoint(new double[] {3,9}); CustomerPoint customerPoint5 = new CustomerPoint(new double[] {5,1}); CustomerPoint customerPoint6 = new CustomerPoint(new double[] {5,2}); CustomerPoint customerPoint7 = new CustomerPoint(new double[] {6,3}); CustomerPoint customerPoint8 = new CustomerPoint(new double[] {7,3}); CustomerPoint customerPoint9 = new CustomerPoint(new double[] {7,4}); CustomerPoint customerPoint10 = new CustomerPoint(new double[] {0,2}); CustomerPoint customerPoint11 = new CustomerPoint(new double[] {8,16}); CustomerPoint customerPoint12 = new CustomerPoint(new double[] {1,1}); CustomerPoint customerPoint13 = new CustomerPoint(new double[] {1,3}); List<CustomerPoint> cs = new ArrayList<>(); cs.add(customerPoint13); cs.add(customerPoint12); cs.add(customerPoint11); cs.add(customerPoint10); cs.add(customerPoint9); cs.add(customerPoint8); cs.add(customerPoint7); cs.add(customerPoint6); cs.add(customerPoint5); cs.add(customerPoint4); cs.add(customerPoint3); cs.add(customerPoint2); cs.add(customerPoint1); cs.add(customerPoint);
    //这里第一个参数为距离,第二个参数为最小邻居数量 DBScanTest3 db
= new DBScanTest3(1.5, 1);
    //返回结果并打印 List
<List<CustomerPoint>> aa =db.cluster(cs); for(int i =0;i<aa.size();i++) { for(int j=0;j<aa.get(i).size();j++) { System.out.print(aa.get(i).get(j).toString()); } System.out.println(); } } }

结果打印:

{lat:1.0,lng:3.0,value:0}{lat:0.0,lng:2.0,value:0}{lat:1.0,lng:1.0,value:0}
{lat:7.0,lng:4.0,value:0}{lat:7.0,lng:3.0,value:0}{lat:6.0,lng:3.0,value:0}{lat:5.0,lng:2.0,value:0}{lat:5.0,lng:1.0,value:0}
{lat:3.0,lng:9.0,value:0}{lat:4.0,lng:8.0,value:0}{lat:3.0,lng:8.0,value:0}{lat:4.0,lng:7.0,value:0}{lat:5.0,lng:6.0,value:0}

这里返回3个集合,其余的为噪音点,读者可以将这些坐标点画在网格图上,可以看到它们分为3部分,每一部分的点距离都小于1.5。

 

posted @ 2019-06-18 16:44  戏言xjx  阅读(4348)  评论(0编辑  收藏  举报