《航天宏图杯》比赛中基于最小生成树进行影像分割的算法详解
下面介绍利用最小生成树进行影像分割的算法。
Kruskal算法
此算法可以称为“加边法”,初始最小生成树边数为0,每迭代一次就选择一条满足条件的最小代价边,加入到最小生成树的边集合里。
- 把图中的所有边按代价从小到大排序;
- 把图中的n个顶点看成独立的n棵树组成的森林;
- 按权值从小到大选择边,所选的边连接的两个顶点ui,viui,vi,应属于两颗不同的树,则成为最小生成树的一条边,并将这两颗树合并作为一颗树。
- 重复(3),直到所有顶点都在一颗树内或者有n-1条边为止。

首先,我们对影像进行参数检查以及初始化创建掩膜
bool MSTSegmenter::_CheckAndInit()
{
//1.Check the parameters
if (_threshold<=0)
{
cout<<"Segmentation parameter \"threshold\" must greater than 0";
return false;
}
if (_minObjectSize <= 0)
{
cout<<"Segmentation parameter \"minObjectSize\" must greater than 0";
return false;
}
//2.Check Input Image
GDALDataset* poSrcDS = ( GDALDataset*)GDALOpen(_inputImagePath, GA_ReadOnly);
GDALDataset* poSrcDS2 = (GDALDataset*)GDALOpen(_inputImagePath2, GA_ReadOnly);
if (poSrcDS == NULL)
{
cout<<"Open Image Failed!";
return false;
}
srcDS = poSrcDS;
srcDS2 = poSrcDS2;
if (_layerWeights.size()==0)
{
_layerWeights.resize(poSrcDS->GetRasterCount());
std::for_each(_layerWeights.begin(),_layerWeights.end(),[&](double &data)
{
data = 1;// . / poSrcDS->GetRasterCount();
});
}
//3.Init Output Image
GDALDriver* poDriver = (GDALDriver*)GDALGetDriverByName("GTiff");
GDALDataset* poDstDS = poDriver->Create(_outputImagePath,poSrcDS->GetRasterXSize(),poSrcDS->GetRasterYSize(),1,GDT_Int32,NULL);
GDALDataset* poDstDS2 = poDriver->Create(_outputImagePath2, poSrcDS->GetRasterXSize(), poSrcDS->GetRasterYSize(), 1, GDT_Int32, NULL);
if ((poDstDS == NULL)|| (poDstDS2 == NULL))
{
cout<<("Create Image ")+_outputImagePath+(" Failed!");
return false;
}
dstDS = poDstDS;
dstDS2 = poDstDS2;
poDstDS->CreateMaskBand(GMF_PER_DATASET);//创建掩膜层
GDALRasterBand* poBand = poDstDS->GetRasterBand(1)->GetMaskBand();
poDstDS2->CreateMaskBand(GMF_PER_DATASET);//创建掩膜层
GDALRasterBand* poBand2 = poDstDS2->GetRasterBand(1)->GetMaskBand();
for (int i=0;i<poDstDS->GetRasterYSize();++i)
{
std::vector<unsigned char> mask(poDstDS->GetRasterXSize(),1);
if (poSrcDS->GetRasterBand(1)->GetRasterDataType() == GDT_Byte && shield_0_255)
{
unsigned char* image = new unsigned char[mask.size()];
for (int k=0;k<poSrcDS->GetRasterCount();++k)
{
poSrcDS->GetRasterBand(k+1)->RasterIO(GF_Read,0,i,mask.size(),1,image,mask.size(),1,GDT_Byte,0,0);
for (int j=0;j<mask.size();++j)
{
if (image[j]==0 || image[j]==255)
mask[j]=0;
}
}
delete []image;
}
poBand->RasterIO(GF_Write, 0,i,mask.size(),1,&mask[0],mask.size(),1,GDT_Byte,0,0);
poBand2->RasterIO(GF_Write, 0, i, mask.size(), 1, &mask[0], mask.size(), 1, GDT_Byte, 0, 0);
}
double adfGeoTransform[6];
poDstDS->SetProjection(poSrcDS->GetProjectionRef());
poDstDS2->SetProjection(poSrcDS->GetProjectionRef());
poSrcDS->GetGeoTransform(adfGeoTransform);
poDstDS->SetGeoTransform(adfGeoTransform);
poDstDS2->SetGeoTransform(adfGeoTransform);
return true;
}
接下来算最小生成树每一条边的权值
bool MSTSegmenter::_CreateEdgeWeights(void *p)
{
EdgeVector* vecEdge = (EdgeVector*)p;
vecEdge->clear();
GDALDataset* poSrcDS = (GDALDataset*)srcDS;
GDALDataset* poDstDS = (GDALDataset*)dstDS;
GDALDataset* poSrcDS2 = (GDALDataset*)srcDS2;
GDALDataset* poDstDS2 = (GDALDataset*)dstDS2;
unsigned nBandCount = poSrcDS->GetRasterCount();
unsigned width = poSrcDS->GetRasterXSize();
unsigned height= poSrcDS->GetRasterYSize();
GDALDataType gdalDataType = poSrcDS->GetRasterBand(1)->GetRasterDataType();
unsigned pixelSize = GDALGetDataTypeSize(gdalDataType)/8;
std::vector<double> buffer1(nBandCount*2);
std::vector<double> buffer2(nBandCount*2);
std::vector<unsigned char*> lineBufferUp(nBandCount);
std::vector<unsigned char*> lineBufferDown(nBandCount);
std::vector<unsigned char*> lineBufferUp2(nBandCount);
std::vector<unsigned char*> lineBufferDown2(nBandCount);
GDALRasterBand* poMask = poDstDS->GetRasterBand(1)->GetMaskBand();
unsigned char* maskUp = new unsigned char[width];
unsigned char* maskDown = new unsigned char[width];
for (unsigned k=0;k<nBandCount;++k)
{
lineBufferUp[k] = new unsigned char[width*pixelSize];
lineBufferDown[k] = new unsigned char[width*pixelSize];
lineBufferUp2[k] = new unsigned char[width*pixelSize];
lineBufferDown2[k] = new unsigned char[width*pixelSize];
}
for (unsigned k=0;k<nBandCount;++k)
{
poSrcDS->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, 0, width, 1, lineBufferUp[k], width, 1, gdalDataType, 0, 0);
poSrcDS2->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, 0, width, 1, lineBufferUp2[k], width, 1, gdalDataType, 0, 0);
}
poMask->RasterIO(GF_Read, 0,0,width,1,maskUp,width,1,GDT_Byte,0,0);
for (unsigned y = 0; y < height - 1; ++y)
{
for (unsigned k = 0; k < nBandCount; ++k){
poSrcDS->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, y + 1, width, 1, lineBufferDown[k], width, 1, gdalDataType, 0, 0);
poSrcDS2->GetRasterBand(k + 1)->RasterIO(GF_Read, 0, y + 1, width, 1, lineBufferDown2[k], width, 1, gdalDataType, 0, 0);
}
poMask->RasterIO(GF_Read, 0,y+1,width,1,maskDown,width,1,GDT_Byte,0,0);
for(unsigned x=0;x<width;++x)
{
if (maskUp[x]==0) continue;
unsigned nodeID1 = y*width+x;
unsigned nodeIDNextLIne = nodeID1+width;
if (x < width-1 )
{
if (maskUp[x+1]!=0)
{
for(unsigned k=0;k<nBandCount;++k)
{
buffer1[k] = SRCVAL(lineBufferUp[k],gdalDataType,x);
buffer2[k] = SRCVAL(lineBufferUp[k],gdalDataType,x+1);
}
for (unsigned k = 0; k<nBandCount; ++k)
{
buffer1[k + nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x);
buffer2[k + nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x + 1);
}
onComputeEdgeWeight(nodeID1,nodeID1+1,buffer1,buffer2,_layerWeights,vecEdge);//左右
}
}
if(y < height-1)
{
if (maskDown[x]!=0)
{
for(unsigned k=0;k<nBandCount;++k)
{
buffer1[k] = SRCVAL(lineBufferUp[k],gdalDataType,x);
buffer2[k] = SRCVAL(lineBufferDown[k],gdalDataType,x);
}
for (unsigned k = 0; k<nBandCount; ++k)
{
buffer1[k+nBandCount] = SRCVAL(lineBufferUp2[k], gdalDataType, x);
buffer2[k+nBandCount] = SRCVAL(lineBufferDown2[k], gdalDataType, x);
}
onComputeEdgeWeight(nodeID1,nodeIDNextLIne,buffer1,buffer2,_layerWeights,vecEdge);//上下
}
}
}
std::vector<unsigned char*> tempBuffer = lineBufferDown;
lineBufferDown = lineBufferUp;
lineBufferUp = tempBuffer;
std::vector<unsigned char*> tempBuffer2 = lineBufferDown2;
lineBufferDown2 = lineBufferUp2;
lineBufferUp2 = tempBuffer2;
unsigned char* p = maskUp;
maskUp = maskDown;
maskDown = p;
}
for (unsigned k=0;k<nBandCount;++k)
{
delete [](lineBufferUp[k]);
delete [](lineBufferDown[k]);
delete[](lineBufferUp2[k]);
delete[](lineBufferDown2[k]);
}
delete []maskUp;
delete []maskDown;
return true;
}
接下来进行边的合并
bool MSTSegmenter::_ObjectMerge(GraphKruskal *&graph,
void *p,
unsigned num_vertices,
double threshold)
{
EdgeVector* vecEdge = (EdgeVector*)p;
graph = new GraphKruskal(num_vertices);
while(!vecEdge->empty())
{
edge edge_temp = *(*vecEdge);
unsigned a = graph->find(edge_temp.GetNode1());
unsigned b = graph->find(edge_temp.GetNode2());
int nPredict = 0;
if ((a != b) && (graph->joinPredicate_sw(a,b,(float)threshold,edge_temp.GetWeight(),nPredict)==true))
{
graph->join_band_sw(a,b,edge_temp.GetWeight());
graph->find(a);
}
++(*vecEdge);
}
return true;
}
接下来消除小区域,减少小图斑的影响
bool MSTSegmenter::_EliminateSmallArea(GraphKruskal * &graph,
void *p,
double _minObjectSize)
{
EdgeVector* vecEdge = (EdgeVector*)p;
vecEdge->rewind();
while(!vecEdge->empty())
{
edge edge_temp = *(*vecEdge);
unsigned a = graph->find(edge_temp.GetNode1());
unsigned b = graph->find(edge_temp.GetNode2());
if ((a != b) && ((graph->size(a) <_minObjectSize) || (graph->size(b) < _minObjectSize)) )
{
graph->join_band_sw(a,b,edge_temp.GetWeight());
graph->find(a);
}
++(*vecEdge);
}
return true;
}
最后将结果写入flagimage影像中
bool MSTSegmenter::_GenerateFlagImage(GraphKruskal *&graph,const std::map<unsigned, unsigned> &mapRootidObjectid)
{
GDALDataset* poSrcDS = (GDALDataset*)srcDS;
GDALDataset* poDstDS = (GDALDataset*)dstDS;
GDALDataset* poDstDS2 = (GDALDataset*)dstDS2;
unsigned width = poSrcDS->GetRasterXSize();
unsigned height= poSrcDS->GetRasterYSize();
GDALRasterBand* poFlagBand = poDstDS->GetRasterBand(1);
GDALRasterBand* poFlagBand2 = poDstDS2->GetRasterBand(1);
GDALRasterBand* poMaskBand = poFlagBand->GetMaskBand();
stxxl::vector<unsigned char> mask(width);
for(unsigned i=0,index =0;i<height;++i)
{
poMaskBand->RasterIO(GF_Read,0,i,width,1,&mask[0],width,1,GDT_Byte,0,0);
for(unsigned j=0;j<width;++j,++index)
{
int objectID = 0;
if (mask[j]!=0)
{
int root = graph->find(index);
objectID = mapRootidObjectid.at(root);//std
//objectID = mapRootidObjectid[root];//std, stxxl all ok
}
poFlagBand->RasterIO(GF_Write,j,i,1,1,(int *)&objectID,1,1,GDT_Int32,0,0);
poFlagBand2->RasterIO(GF_Write, j, i, 1, 1, (int *)&objectID, 1, 1, GDT_Int32, 0, 0);
}
}
return true;
}
附:最小生成树算法:
GraphKruskal::GraphKruskal(unsigned elements)
:elementCount(elements),num(elements),elts(NULL)
{
//m_ElementsFileName=m_ElementsFileName+"_elements.dat";
//CString TempPathName=m_ElementsFileName;
//HANDLE hFile;
//DWORD dwFileSize;
//hFile = CreateFile((LPCTSTR)TempPathName, // create edgefile.dat
// GENERIC_WRITE|GENERIC_READ, // open for writing
// 0, // do not share
// NULL, // no security
// CREATE_ALWAYS, // overwrite existing CREATE_ALWAYS
// FILE_ATTRIBUTE_NORMAL, // normal file //FILE_ATTRIBUTE_TEMPORARY
// NULL); // no attr.template
//
//if (hFile == INVALID_HANDLE_VALUE)
//{
// //CString msg=TempPathName+"文件打开失败!";
// AfxMessageBox("文件打开失败");
// //return(FALSE);
//}
////获取文件大小
//dwFileSize=GetFileSize(hFile,NULL);//0
//
////创建映射文件
////HANDLE hMapFile;
////
////
////hMapFile = CreateFileMapping(hFile, // Current file handle.
//// NULL, // Default security.
//// PAGE_READWRITE, // Read/write permission.
//// 0, // Max. object size.
//// dwFileSize+elements*sizeof(GraphElement), // Size of hFile. elements*sizeof(uni_elt)
//// NULL); // Name of mapping object. 这个名字用于其它进程调用,不需要其它进程调用赋为NULL
//
//if (hMapFile. == NULL)
//{
// AfxMessageBox("不能创建映射文件对象.");
// CloseHandle(hFile);
// //return(FALSE);
//}
////PVOID pvFile;
////For files that are larger than the address space, you can only map a small portion of the file data at one time. When the first view is complete, you can unmap it and map a new view.
//pvFile=MapViewOfFile(hMapFile.~vector,FILE_MAP_WRITE,0,0,0);
//if (pvFile.~vector == NULL)
//{
// AfxMessageBox("不能映射文件到地址空间 1.");
// CloseHandle(hMapFile.~vector);
// CloseHandle(hFile);
// //return(FALSE);
//}
//elts = (GraphElement*)pvFile;
elts = new GraphElement[elements];
if (elts == NULL)
{
cout<<"Map Failed!";
//return;
}
GraphElement* p = elts;
for (unsigned i = 0; i < elements; ++i,++p)
{
p->rank = 0;
p->p = i;//每个区域(集合)的初始根节点是它本身
p->sw = 0;
p->size = 1;
}
}
GraphKruskal::~GraphKruskal()
{
/*if (pvFile!=NULL)
{
UnmapViewOfFile(pvFile.~vector);
}
if (hMapFile.~vector!=NULL)
{
CloseHandle(hMapFile.~vector);
}
if (hFile.~vector!=NULL)
{
CloseHandle(hFile.~vector);
DeleteFile(m_ElementsFileName);
}*/
delete[] elts;
}
unsigned GraphKruskal::find( unsigned x )
{
int y = x;
while (y != elts[y].p)//p为x的父节点,不相等,说明x不是根节点
y = elts[y].p;//找x的父节点的父节点,直到相等,说明找到了根节点
elts[x].p = y;//将x的父节点设为找到的根节点,优化,提高查找速度
return y;
}
unsigned GraphKruskal::join_band_sw( unsigned x,unsigned y,float edgeWeight )
{
if (elts[x].rank > elts[y].rank)
{
elts[x].size += elts[y].size; //合并所得区域的大小
elts[x].sw += elts[y].sw + edgeWeight;//合并所得区域的边权和
elts[y].p = x;
num--;//合并后区域数减一
return x;
}
else
{
elts[y].size += elts[x].size;//区域大小
elts[y].sw += elts[x].sw + edgeWeight;//组成该区域的边权和
elts[x].p = y;
if (elts[x].rank == elts[y].rank)
elts[y].rank++;//同时将集合y的元素数加1
num--;
return y;
}
}
//预测两个区域是否合并,控制边权和的大小
const float LOG20MULTI2 = 2*log(2/0.1f);
bool GraphKruskal::joinPredicate_sw(unsigned reg1, unsigned reg2, float th, float edgeWeight, int nPredict )
{
GraphElement elt1 = elts[reg1];
GraphElement elt2 = elts[reg2];
float swreg1 = elt1.sw;
unsigned size1=elt1.size;//区域1的像素数
float swreg2 = elt2.sw;
unsigned size2=elt2.size;//区域2的像素数
float nedge1 = (float)size1-1+0.000001f;//区域1的边数
float nedge2 = (float)size2-1+0.000001f;//区域2的边数
double g=th;//255*sqrt(m_Bands);//
float if1=(swreg1+edgeWeight)/(nedge1+2);//把当前边权加入区域1的边权后的区域边权均值
float if2=(swreg2+edgeWeight)/(nedge2+2);//把当前边权加入区域2的边权后的区域边权均值
float sn1 = (float)(g*sqrt(LOG20MULTI2/size1));
float sn2 = (float)(g*sqrt(LOG20MULTI2/size2));
bool bMerge=false;
if (nPredict==0)//准则1
{
if(if1<=sn1 || if2<=sn2)//ok
bMerge=true;
else
bMerge=false;
}
if (nPredict==1)//准则2
{
if ((edgeWeight<=sn1)||(edgeWeight<=sn2))//ok
bMerge=true;
else
bMerge=false;
}
return bMerge;
}
unsigned GraphKruskal::GetMapNodeidObjectid(GDALRasterBand *&poMaskBand, map<unsigned, unsigned> &mapRootidObjectid)
{
mapRootidObjectid.clear();
int width = poMaskBand->GetXSize();
int height = poMaskBand->GetYSize();
std::vector<unsigned char> mask(width);
unsigned objectID = 0;
unsigned index = 0;
for(int i=0;i<height;++i)
{
poMaskBand->RasterIO(GF_Read,0,i,width,1,&mask[0],width,1,GDT_Byte,0,0);
for (int j=0;j<width;++j,++index)
{
if (mask[j]==0)
continue;
unsigned RootNode=find(index);
if (mapRootidObjectid.find(RootNode)==mapRootidObjectid.end())
{
mapRootidObjectid[RootNode] = objectID;
objectID++;
}
}
}
return objectID;
}
``````````

浙公网安备 33010602011771号