《航天宏图杯》比赛中基于最小生成树进行影像分割的算法详解

下面介绍利用最小生成树进行影像分割的算法。
Kruskal算法
此算法可以称为“加边法”,初始最小生成树边数为0,每迭代一次就选择一条满足条件的最小代价边,加入到最小生成树的边集合里。

  1. 把图中的所有边按代价从小到大排序;
  2. 把图中的n个顶点看成独立的n棵树组成的森林;
  3. 按权值从小到大选择边,所选的边连接的两个顶点ui,viui,vi,应属于两颗不同的树,则成为最小生成树的一条边,并将这两颗树合并作为一颗树。
  4. 重复(3),直到所有顶点都在一颗树内或者有n-1条边为止。

首先,我们对影像进行参数检查以及初始化创建掩膜

bool MSTSegmenter::_CheckAndInit()
{
    //1.Check the parameters
    if (_threshold<=0)
    {
        cout<<"Segmentation parameter \"threshold\" must greater than 0";
        return false;
    }

    if (_minObjectSize <= 0)
    {
        cout<<"Segmentation parameter \"minObjectSize\" must greater than 0";
        return false;
    }
    
    //2.Check Input Image
    GDALDataset* poSrcDS = ( GDALDataset*)GDALOpen(_inputImagePath, GA_ReadOnly);
	GDALDataset* poSrcDS2 = (GDALDataset*)GDALOpen(_inputImagePath2, GA_ReadOnly);
    if (poSrcDS == NULL)
    {
        cout<<"Open Image Failed!";
        return false;
    }
    srcDS = poSrcDS;
	srcDS2 = poSrcDS2;
    if (_layerWeights.size()==0)
    {
        _layerWeights.resize(poSrcDS->GetRasterCount());
        std::for_each(_layerWeights.begin(),_layerWeights.end(),[&](double &data)
        {
			data = 1;// . / poSrcDS->GetRasterCount();
        });
    }

    //3.Init Output Image
    GDALDriver* poDriver = (GDALDriver*)GDALGetDriverByName("GTiff");
    GDALDataset* poDstDS = poDriver->Create(_outputImagePath,poSrcDS->GetRasterXSize(),poSrcDS->GetRasterYSize(),1,GDT_Int32,NULL);
	GDALDataset* poDstDS2 = poDriver->Create(_outputImagePath2, poSrcDS->GetRasterXSize(), poSrcDS->GetRasterYSize(), 1, GDT_Int32, NULL);

    if ((poDstDS == NULL)|| (poDstDS2 == NULL))
    {
       cout<<("Create Image ")+_outputImagePath+(" Failed!");
        return false;
    }
    dstDS = poDstDS;
	dstDS2 = poDstDS2;
    poDstDS->CreateMaskBand(GMF_PER_DATASET);//创建掩膜层
    GDALRasterBand* poBand = poDstDS->GetRasterBand(1)->GetMaskBand();    
	poDstDS2->CreateMaskBand(GMF_PER_DATASET);//创建掩膜层
	GDALRasterBand* poBand2 = poDstDS2->GetRasterBand(1)->GetMaskBand();

    for (int i=0;i<poDstDS->GetRasterYSize();++i)
    {
        std::vector<unsigned char> mask(poDstDS->GetRasterXSize(),1);
        if (poSrcDS->GetRasterBand(1)->GetRasterDataType() == GDT_Byte && shield_0_255)
        {
            unsigned char* image = new unsigned char[mask.size()];
            for (int k=0;k<poSrcDS->GetRasterCount();++k)
            {
                poSrcDS->GetRasterBand(k+1)->RasterIO(GF_Read,0,i,mask.size(),1,image,mask.size(),1,GDT_Byte,0,0);
                for (int j=0;j<mask.size();++j)
                {
                    if (image[j]==0 || image[j]==255)
                        mask[j]=0;
                }
            }
            delete []image;
        }

        poBand->RasterIO(GF_Write, 0,i,mask.size(),1,&mask[0],mask.size(),1,GDT_Byte,0,0);
		poBand2->RasterIO(GF_Write, 0, i, mask.size(), 1, &mask[0], mask.size(), 1, GDT_Byte, 0, 0);
    }

    double adfGeoTransform[6];
    poDstDS->SetProjection(poSrcDS->GetProjectionRef());
	poDstDS2->SetProjection(poSrcDS->GetProjectionRef());
    poSrcDS->GetGeoTransform(adfGeoTransform);
    poDstDS->SetGeoTransform(adfGeoTransform);
	poDstDS2->SetGeoTransform(adfGeoTransform);
    return true;
}

接下来算最小生成树每一条边的权值

bool MSTSegmenter::_CreateEdgeWeights(void *p)
{
    EdgeVector* vecEdge = (EdgeVector*)p;
    vecEdge->clear();

    GDALDataset* poSrcDS = (GDALDataset*)srcDS;
    GDALDataset* poDstDS = (GDALDataset*)dstDS;
	GDALDataset* poSrcDS2 = (GDALDataset*)srcDS2;
	GDALDataset* poDstDS2 = (GDALDataset*)dstDS2;

    unsigned nBandCount = poSrcDS->GetRasterCount();
    unsigned width = poSrcDS->GetRasterXSize();
    unsigned height= poSrcDS->GetRasterYSize();
    GDALDataType gdalDataType = poSrcDS->GetRasterBand(1)->GetRasterDataType();
    unsigned pixelSize = GDALGetDataTypeSize(gdalDataType)/8;

    std::vector<double> buffer1(nBandCount*2);
    std::vector<double> buffer2(nBandCount*2);

    std::vector<unsigned char*> lineBufferUp(nBandCount);
    std::vector<unsigned char*> lineBufferDown(nBandCount);
	std::vector<unsigned char*> lineBufferUp2(nBandCount);
	std::vector<unsigned char*> lineBufferDown2(nBandCount);

    GDALRasterBand* poMask = poDstDS->GetRasterBand(1)->GetMaskBand();
    unsigned char* maskUp   = new unsigned char[width];
    unsigned char* maskDown = new unsigned char[width];
    for (unsigned k=0;k<nBandCount;++k)
    {
        lineBufferUp[k] = new unsigned char[width*pixelSize];
        lineBufferDown[k] = new unsigned char[width*pixelSize];
		lineBufferUp2[k] = new unsigned char[width*pixelSize];
		lineBufferDown2[k] = new unsigned char[width*pixelSize];
    }

    for (unsigned k=0;k<nBandCount;++k)
    {
		poSrcDS->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, 0, width, 1, lineBufferUp[k], width, 1, gdalDataType, 0, 0);
		poSrcDS2->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, 0, width, 1, lineBufferUp2[k], width, 1, gdalDataType, 0, 0);
    }
    poMask->RasterIO(GF_Read, 0,0,width,1,maskUp,width,1,GDT_Byte,0,0);

	for (unsigned y = 0; y < height - 1; ++y)
	{
		for (unsigned k = 0; k < nBandCount; ++k){
			poSrcDS->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, y + 1, width, 1, lineBufferDown[k], width, 1, gdalDataType, 0, 0);
		poSrcDS2->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, y + 1, width, 1, lineBufferDown2[k], width, 1, gdalDataType, 0, 0);
	}

        poMask->RasterIO(GF_Read, 0,y+1,width,1,maskDown,width,1,GDT_Byte,0,0);

        for(unsigned x=0;x<width;++x)
        {
            if (maskUp[x]==0) continue;

            unsigned nodeID1 = y*width+x;
            unsigned nodeIDNextLIne = nodeID1+width;
            if (x < width-1 )
            {
                if (maskUp[x+1]!=0)
                {
                    for(unsigned k=0;k<nBandCount;++k)
                    {
                        buffer1[k] = SRCVAL(lineBufferUp[k],gdalDataType,x);
                        buffer2[k] = SRCVAL(lineBufferUp[k],gdalDataType,x+1);

                    }
					for (unsigned k = 0; k<nBandCount; ++k)
					{
						buffer1[k + nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x);
						buffer2[k + nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x + 1);

					}
                    onComputeEdgeWeight(nodeID1,nodeID1+1,buffer1,buffer2,_layerWeights,vecEdge);//左右
                }
            }

            if(y < height-1)
            {
                if (maskDown[x]!=0)
                {
                    for(unsigned k=0;k<nBandCount;++k)
                    {
                        buffer1[k] = SRCVAL(lineBufferUp[k],gdalDataType,x);
                        buffer2[k] = SRCVAL(lineBufferDown[k],gdalDataType,x);
                    }
					for (unsigned k = 0; k<nBandCount; ++k)
					{
						buffer1[k+nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x);
						buffer2[k+nBandCount] = SRCVAL(lineBufferDown2[k], gdalDataType, x);
					}
                    onComputeEdgeWeight(nodeID1,nodeIDNextLIne,buffer1,buffer2,_layerWeights,vecEdge);//上下
                }
            }
        }

        std::vector<unsigned char*> tempBuffer = lineBufferDown;
        lineBufferDown = lineBufferUp;
        lineBufferUp = tempBuffer;

		std::vector<unsigned char*> tempBuffer2 = lineBufferDown2;
		lineBufferDown2 = lineBufferUp2;
		lineBufferUp2 = tempBuffer2;

        unsigned char* p    = maskUp;
        maskUp      = maskDown;
        maskDown    = p;

    }

    for (unsigned k=0;k<nBandCount;++k)
    {
        delete [](lineBufferUp[k]);
        delete [](lineBufferDown[k]);
		delete[](lineBufferUp2[k]);
		delete[](lineBufferDown2[k]);
    }
    delete []maskUp;
    delete []maskDown;
	
    return true;
}

接下来进行边的合并

bool MSTSegmenter::_ObjectMerge(GraphKruskal *&graph,
        void *p,
        unsigned num_vertices,
        double threshold)
{
    EdgeVector* vecEdge = (EdgeVector*)p;
    graph = new GraphKruskal(num_vertices);


    while(!vecEdge->empty())
    {
        edge edge_temp = *(*vecEdge);
        unsigned a = graph->find(edge_temp.GetNode1());
        unsigned b = graph->find(edge_temp.GetNode2());

        int nPredict = 0;
        if ((a != b) && (graph->joinPredicate_sw(a,b,(float)threshold,edge_temp.GetWeight(),nPredict)==true))
        {
            graph->join_band_sw(a,b,edge_temp.GetWeight());
            graph->find(a);
        }
        ++(*vecEdge);
    }
  
    return true;
}

接下来消除小区域,减少小图斑的影响

bool MSTSegmenter::_EliminateSmallArea(GraphKruskal * &graph,
                                       void *p,
                                       double _minObjectSize)
{
    EdgeVector* vecEdge = (EdgeVector*)p;
    vecEdge->rewind();
    while(!vecEdge->empty())
    {
        edge edge_temp = *(*vecEdge);
        unsigned a = graph->find(edge_temp.GetNode1());
        unsigned b = graph->find(edge_temp.GetNode2());

        if ((a != b) && ((graph->size(a) <_minObjectSize) || (graph->size(b) < _minObjectSize)) )
        {
            graph->join_band_sw(a,b,edge_temp.GetWeight());
            graph->find(a);
        }
        ++(*vecEdge);
    }

    return true;
}

最后将结果写入flagimage影像中

bool MSTSegmenter::_GenerateFlagImage(GraphKruskal *&graph,const std::map<unsigned, unsigned> &mapRootidObjectid)
{    
    GDALDataset* poSrcDS = (GDALDataset*)srcDS;
    GDALDataset* poDstDS = (GDALDataset*)dstDS;
	GDALDataset* poDstDS2 = (GDALDataset*)dstDS2;

    unsigned width = poSrcDS->GetRasterXSize();
    unsigned height= poSrcDS->GetRasterYSize();

    GDALRasterBand* poFlagBand = poDstDS->GetRasterBand(1);
	GDALRasterBand* poFlagBand2 = poDstDS2->GetRasterBand(1);
    GDALRasterBand* poMaskBand = poFlagBand->GetMaskBand();
    stxxl::vector<unsigned char> mask(width);

    for(unsigned i=0,index =0;i<height;++i)
    {
        poMaskBand->RasterIO(GF_Read,0,i,width,1,&mask[0],width,1,GDT_Byte,0,0);
        for(unsigned j=0;j<width;++j,++index)
        {
            int objectID = 0;
            if (mask[j]!=0)
            {
                int root = graph->find(index);
                objectID = mapRootidObjectid.at(root);//std
				//objectID = mapRootidObjectid[root];//std, stxxl all ok
            }
            poFlagBand->RasterIO(GF_Write,j,i,1,1,(int *)&objectID,1,1,GDT_Int32,0,0);
			poFlagBand2->RasterIO(GF_Write, j, i, 1, 1, (int *)&objectID, 1, 1, GDT_Int32, 0, 0);
        }
    }
    return true;
}

附:最小生成树算法:

GraphKruskal::GraphKruskal(unsigned elements)
    :elementCount(elements),num(elements),elts(NULL)
{
		//m_ElementsFileName=m_ElementsFileName+"_elements.dat";
		//CString TempPathName=m_ElementsFileName;
		//HANDLE hFile; 
		//DWORD dwFileSize;
		//hFile = CreateFile((LPCTSTR)TempPathName,		// create edgefile.dat 
		//	GENERIC_WRITE|GENERIC_READ,					// open for writing 
		//	0,											// do not share 
		//	NULL,										// no security 
		//	CREATE_ALWAYS,								// overwrite existing CREATE_ALWAYS
		//	FILE_ATTRIBUTE_NORMAL,					// normal file //FILE_ATTRIBUTE_TEMPORARY
		//	NULL);										// no attr.template 
		//
		//if (hFile == INVALID_HANDLE_VALUE) 
		//{ 
		//	//CString msg=TempPathName+"文件打开失败!";
		//	AfxMessageBox("文件打开失败");
		//	//return(FALSE);
		//}


		////获取文件大小
		//dwFileSize=GetFileSize(hFile,NULL);//0
		//
		////创建映射文件
		////HANDLE hMapFile;
		////
		////
		////hMapFile = CreateFileMapping(hFile,    // Current file handle. 
		////	NULL,                              // Default security. 
		////	PAGE_READWRITE,                    // Read/write permission. 
		////	0,									// Max. object size. 
		////	dwFileSize+elements*sizeof(GraphElement),	// Size of hFile. elements*sizeof(uni_elt)
		////	NULL);              // Name of mapping object. 这个名字用于其它进程调用,不需要其它进程调用赋为NULL
		//

		//if (hMapFile. == NULL) 
		//{ 
		//	AfxMessageBox("不能创建映射文件对象."); 
		//	CloseHandle(hFile);
		//	//return(FALSE);
		//} 
		////PVOID pvFile;
		////For files that are larger than the address space, you can only map a small portion of the file data at one time. When the first view is complete, you can unmap it and map a new view.
		//pvFile=MapViewOfFile(hMapFile.~vector,FILE_MAP_WRITE,0,0,0);
		//if (pvFile.~vector == NULL) 
		//{ 
		//	AfxMessageBox("不能映射文件到地址空间  1."); 
		//	CloseHandle(hMapFile.~vector);
		//	CloseHandle(hFile);
		//	//return(FALSE);
		//} 
		//elts = (GraphElement*)pvFile;
		elts = new GraphElement[elements];

    if (elts == NULL)
    {
        cout<<"Map Failed!";
        //return;
    }
    GraphElement* p = elts;
    for (unsigned i = 0; i < elements; ++i,++p)
    {
        p->rank = 0;
        p->p = i;//每个区域(集合)的初始根节点是它本身
        p->sw = 0;
        p->size = 1;
    }
}

GraphKruskal::~GraphKruskal()
{

	/*if (pvFile!=NULL)
	{
		UnmapViewOfFile(pvFile.~vector);
	}
	if (hMapFile.~vector!=NULL)
	{
		CloseHandle(hMapFile.~vector);
	}
	if (hFile.~vector!=NULL)
	{
		CloseHandle(hFile.~vector);
		DeleteFile(m_ElementsFileName);
	}*/
	delete[] elts;
}

unsigned GraphKruskal::find( unsigned x )
{
    int y = x;

    while (y != elts[y].p)//p为x的父节点,不相等,说明x不是根节点
        y = elts[y].p;//找x的父节点的父节点,直到相等,说明找到了根节点

    elts[x].p = y;//将x的父节点设为找到的根节点,优化,提高查找速度
    return y;
}

unsigned GraphKruskal::join_band_sw( unsigned x,unsigned y,float edgeWeight )
{
    if (elts[x].rank > elts[y].rank)
        {
            elts[x].size += elts[y].size; //合并所得区域的大小
            elts[x].sw += elts[y].sw + edgeWeight;//合并所得区域的边权和
            elts[y].p = x;

            num--;//合并后区域数减一
            return x;
        }

        else
        {
            elts[y].size += elts[x].size;//区域大小
            elts[y].sw += elts[x].sw + edgeWeight;//组成该区域的边权和

            elts[x].p = y;

            if (elts[x].rank == elts[y].rank)
                elts[y].rank++;//同时将集合y的元素数加1
            num--;
            return y;
        }
}


//预测两个区域是否合并,控制边权和的大小

const float LOG20MULTI2 = 2*log(2/0.1f);
bool GraphKruskal::joinPredicate_sw(unsigned reg1, unsigned reg2, float th, float edgeWeight, int nPredict )
{
    GraphElement elt1 = elts[reg1];
    GraphElement elt2 = elts[reg2];

    float swreg1 = elt1.sw;
    unsigned size1=elt1.size;//区域1的像素数

    float swreg2 = elt2.sw;
    unsigned size2=elt2.size;//区域2的像素数

    float nedge1 = (float)size1-1+0.000001f;//区域1的边数
    float nedge2 = (float)size2-1+0.000001f;//区域2的边数

    double g=th;//255*sqrt(m_Bands);//

    float if1=(swreg1+edgeWeight)/(nedge1+2);//把当前边权加入区域1的边权后的区域边权均值
    float if2=(swreg2+edgeWeight)/(nedge2+2);//把当前边权加入区域2的边权后的区域边权均值


    float sn1 = (float)(g*sqrt(LOG20MULTI2/size1));
    float sn2 = (float)(g*sqrt(LOG20MULTI2/size2));


    bool bMerge=false;
    if (nPredict==0)//准则1
    {
        if(if1<=sn1 || if2<=sn2)//ok
            bMerge=true;
        else
            bMerge=false;
    }
    if (nPredict==1)//准则2
    {
        if ((edgeWeight<=sn1)||(edgeWeight<=sn2))//ok
            bMerge=true;
        else
            bMerge=false;
    }

    return bMerge;
}

unsigned GraphKruskal::GetMapNodeidObjectid(GDALRasterBand *&poMaskBand, map<unsigned, unsigned> &mapRootidObjectid)
{
    mapRootidObjectid.clear();
    int width = poMaskBand->GetXSize();
    int height = poMaskBand->GetYSize();
    std::vector<unsigned char> mask(width);

    unsigned objectID = 0;
    unsigned index = 0;
    for(int i=0;i<height;++i)
    {
        poMaskBand->RasterIO(GF_Read,0,i,width,1,&mask[0],width,1,GDT_Byte,0,0);
        for (int j=0;j<width;++j,++index)
        {
            if (mask[j]==0)
                continue;
            unsigned RootNode=find(index);
            if (mapRootidObjectid.find(RootNode)==mapRootidObjectid.end())
            {
                mapRootidObjectid[RootNode] = objectID;
                objectID++;
            }
        }
    }
	return objectID;
}
``````````
posted @ 2020-11-09 00:07  4AM_Ruiii  阅读(227)  评论(0)    收藏  举报