Image2Graph with pixel's 4 or 8 neighbors
import argparse
import logging
import time
from random import random
from PIL import Image, ImageFilter
from skimage import io
import numpy as np
def diff(img, x1, y1, x2, y2):# edge weights
_out = np.sum((img[x1, y1] - img[x2, y2]) ** 2)
return np.sqrt(_out)
def create_edge(img, width, x, y, x1, y1, diff):
vertex_id = lambda x, y: y * width + x
w = diff(img, x, y, x1, y1)
return (vertex_id(x, y), vertex_id(x1, y1), w)
def build_graph(img, width, height, diff, neighborhood_8=False):
graph_edges = []
for y in range(height):
for x in range(width):
if x > 0:
graph_edges.append(create_edge(img, width, x, y, x-1, y, diff))
if y > 0:
graph_edges.append(create_edge(img, width, x, y, x, y-1, diff))
if neighborhood_8:
if x > 0 and y > 0:
graph_edges.append(create_edge(img, width, x, y, x-1, y-1, diff))
if x > 0 and y < height-1:
graph_edges.append(create_edge(img, width, x, y, x-1, y+1, diff))
return graph_edges
def GetGaussianBlurImage2Graph(sigma, neighbor, input_file):
if neighbor != 4 and neighbor!= 8:
logger.warn('Invalid neighborhood choosed. The acceptable values are 4 or 8.')
start_time = time.time()
image_file = Image.open(input_file)
size = image_file.size # (width, height) in Pillow/PIL
logger.info('Image info: {} | {} | {}'.format(image_file.format, size, image_file.mode))
# Gaussian Filter
logger.info("GaussianBlur...")
smooth = image_file.filter(ImageFilter.GaussianBlur(sigma))
smooth = np.array(smooth).astype(int)# height x width x 3
logger.info("Creating graph...")
graph_edges = build_graph(smooth, size[1], size[0], diff, neighbor==8)
logger.info("Numbers of graph edges: {}".format(len(graph_edges)))
logger.info('Total running time: {:0.4}s'.format(time.time() - start_time))
if __name__ == '__main__':
# argument parser
parser = argparse.ArgumentParser(description='Img2Graph(Graph-based Segmentation)')
parser.add_argument('--sigma', type=float, default=0.5,
help='a float for the Gaussin Filter')
parser.add_argument('--neighbor', type=int, default=8, choices=[4, 8],
help='choose the neighborhood format, 4 or 8')
parser.add_argument('--input-file', type=str, default="./datas/BigTree.jpg",
help='the file path of the input image')
args = parser.parse_args()
# basic logging settings
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
datefmt='%m-%d %H:%M')
logger = logging.getLogger(__name__)
print("input:",'sigma=',args.sigma, 'neighbor=', args.neighbor,'input-file=',args.input_file)
GetGaussianBlurImage2Graph(args.sigma, args.neighbor, args.input_file)
个人学习记录

浙公网安备 33010602011771号