Python,Battleship field validator

# Battleship field validator
# https://www.codewars.com/kata/52bb6539a4cf1b12d90005b7/train/python

def validate_battlefield(field):
    # 4*1 + 3*2 + 2*3 + 1*4 = 20
    if sum([sum(row) for row in field]) != 20:
        return False
    # cells_status=[[1]*10]*10
    # [1]*10创建的列表内的元素是独立的,列表中的每个元素都是独立的整数,而不是对同一个对象的10个引用
    # 但是[[1]*10]*10创建的二维列表内的每一个子列表实际上是对同一个列表对象的引用,
    # 也就是说cells_status实际上是一个引用列表,内部保存了10个对同一个列表对象的引用
    # 上面这种创建方式会导致所有子列表引用同一个列表对象,修改其中一个子列表会导致其他子列表也被修改
    # cells_status[0]和cells_status[1]是同一个对象,都是同一个[1]*10,修改cells_status[0]会导致cells_status[1]也被修改
    # 可以使用列表推导式创建独立的子列表
    cells_status=[[True]*10 for _ in range(10)]
    # False -- forbidden,True -- allowed
    ships = [0,4,3,2,1]
    def make_forbidden(i,j):
        cells_status[i][j]=0
        for y in range(i-1,i+2):
            for x in range(j-1,j+2):
                if 0<=y<10 and 0<=x<10:
                    cells_status[y][x]=False

    for i in range(10):
        for j in range(10):
            if(not cells_status[i][j]):
                continue
            if field[i][j]==1:
                tempI = i
                tempJ = j
                make_forbidden(i,j)
                while field[tempI][j]==1:
                    #向下
                    make_forbidden(tempI,j)
                    tempI+=1
                    if tempI>=10:
                        break
                while field[i][tempJ]==1:
                    #向右
                    make_forbidden(i,tempJ)
                    tempJ+=1
                    if tempJ>=10:
                        break
                if max(tempI-i,tempJ-j)>4:
                    return False
                ships[max(tempI-i,tempJ-j)]-=1
    for ship in ships:
        if ship!=0:
            return False
    return True


# best practice
#scipy.ndimage.measurements

# 是 SciPy 库中的一个子模块,专门用于图像处理和分析。
# ndimage 代表 N-dimensional image(N 维图像),该模块提供了许多函数,用于处理和分析多维图像数据。
# label是一个用于标记连通区域的函数,它会返回一个标记数组和标记的数量。
# find_objects是一个用于获取标记数组中每个标记的位置的函数。
# np是numpy的别名
# np.ones((3,3))生成一个3*3的全1矩阵
from scipy.ndimage.measurements import label, find_objects
import numpy as np
def validate_battlefield(field):
    # numpy.array()可表示一维,二维,三维等多维数组
    # numpy.array()可传入列表,元组等可迭代对象,生成一个数组
    field = np.array(field)
    # field是一个10*10的数组,每个元素是0或1,表示战舰的位置

    #np.ones()函数用于生成一个全1数组,传入一个元组表示数组的形状
    labeled_array, num_features = label(field,np.ones((3,3)))
    #labeled_array是标记数组, num_features是标记出的联通区域的数量

    positions = find_objects(labeled_array)
    #positions是标记对象的位置,是一个切片对象的列表,每个切片对象包括行和列的范围,表示一个矩形区域

    result =  []
    for pos in positions:
        # pos 是一个切片对象,包括行和列的范围, 形如(slice(0, 2, None), slice(1, 4, None))
        ship = field[pos]
        # field[pos]是一个子数组,表示一个战舰的位置,是一个矩形区域,形如
        # [[0 1 1]
        #  [1 1 0]]

        if min(ship.shape)==1:
            #shape是一个数组的属性,表示数组的维度,
            #shape属性返回一个元组,其中包含数组每个维度的大小
            #min(ship.shape)==1说明ship的行数或列数为1,即ship占据一行多列或者一列多行或者只有一个元素
            #也就是说ship的形状是1*n或者n*1,符合要求

            # ship.size返回数组中元素的个数,即行数*列数
            # 由于ship的形状是1*n或者n*1,所以该联通区域中只有元素1,没有元素0
            # 即ship.size的值就是1的个数
            result.append(ship.size)
    #sorted()默认是升序排序
    return sorted(result) == [1,1,1,1,2,2,2,3,3,4]

    # original code
    # return sorted(
    #     ship.size if min(ship.shape) == 1 else 0
    #     for ship in (field[pos] for pos in find_objects(label(field, np.ones((3,3)))[0]))
    # ) == [1,1,1,1,2,2,2,3,3,4]


# 关于label() ============================================================================
# label(input,structure)函数用于标记数组中的联通区域,
# 连通区域是指在数组中相邻且具有相同值的元素组成的区域。相邻的定义可以通过 structure 参数来指定。
# 只有非零元素组成的联通区域才会被标记
# 默认情况下,structure 参数是一个与输入数组维度相同且所有元素为 1 的数组,
# 这意味着所有相邻的元素(包括对角线方向)都被视为连通的。
# label()函数会为每一个联通区域分配一个唯一的整数标记,整数标记从1开始, 并返回一个标记数组和标记的数量
# 标记数组的形状和输入数组相同, 其中每个元素的值是该元素所属的联通区域的整数标记
# 创建一个二维数组
import numpy as np
from scipy.ndimage import measurements
data = np.array([[0, 0, 1, 1, 0],
                 [0, 1, 1, 0, 0],
                 [0, 0, 0, 0, 1],
                 [1, 1, 0, 0, 0],
                 [0, 0, 0, 1, 1]])

# 标记连通区域
labeled_array, num_features = measurements.label(data)
print("Labeled Array:\n", labeled_array)
print("Number of Features:", num_features)
# Labeled Array:
#  [[0 0 1 1 0]
#  [0 1 1 0 0]
#  [0 0 0 0 2]
#  [3 3 0 0 0]
#  [0 0 0 4 4]]
# Number of Features: 4
#  标记数组值为1表示该元素属于第一个联通区域,值为2表示该元素属于第二个联通区域,以此类推
#  0表示该元素不属于任何联通区域
#  number_features表示联通区域的数量,这里有4个联通区域

# 关于find_objects() ========================================================================
# 用于查找标记数组中每个标记对象的位置。它返回一个切片对象的列表,每个切片对象表示一个标记对象的位置。
# 输入参数
# input:标记数组,通常是由 scipy.ndimage.measurements.label 函数生成的数组。
# max_label(可选):要查找的最大标签。如果未提供,默认查找所有标签。
# 返回返回一个切片对象的列表,每个切片对象表示一个标记对象的位置。切片对象可以用于索引原始数组,以提取标记对象。

# 使用上面的标记数组labeled_array
# 查找标记对象的切片
slices = measurements.find_objects(labeled_array)
print("Slices:", slices)
# Slices: [(slice(0, 2, None), slice(1, 4, None)), (slice(2, 3, None), slice(4, 5, None)), (slice(3, 4, None), slice(0, 2, None)), (slice(4, 5, None), slice(3, 5, None))]
print(data[slices[0]])
# [[0 1 1]
#  [1 1 0]]

# slice(start, stop, step)表示一个切片对象,表示[start,stop)范围内的元素,step表示步长
# (slice(0, 2, None),slice(1, 4, None))表示第一个联通区域的行索引范围是[0,2),列索引范围是[1,4)




# another solution
def validateBattlefield(field):  
    
    #print('\n'.join([''.join(['{:4}'.format(item) for item in row]) for row in field]))
    
    ships = []
    
    #this algorithm uses the field 2-dimensional array it self to store infomration about the size of ships found      
    for i in range(0, 10):            
        for j in range(0, 10):  
            #if not at end of any edge in 2d-array, check that sum of two cross diagonal elements is not more than max 
            #if it is then two ships are two close
            if j < 9 and i < 9: 
                if field[i][j] + field[i+1][j+1] > max(field[i][j], field[i+1][j+1]): 
                    return False 
                if field[i+1][j] + field[i][j+1] > max(field[i+1][j], field[i][j+1]):
                    return False
            #if the element at position (i, j) is occupied then add the current value of position to next
            if j < 9 and field[i][j] > 0 and field[i][j+1] > 0:
                field[i][j+1] += field[i][j]
            elif i < 9 and field[i][j] > 0 and field[i+1][j] > 0:
                field[i+1][j] += field[i][j]
            elif field[i][j] > 0:
                ships.append(field[i][j]) #since we add numbers
                
    ships.sort()

    return ships == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] #if the ships we have found are of correct configuration then it will equal this array



# 尝试重现别人的代码

# try 1
from scipy.ndimage.measurements import label, find_objects
import numpy as np
def validate_battlefield_s(field):
    field = np.array(field)
    labeled_array = label(field,np.ones((3,3)))[0]
    positions = find_objects(labeled_array)
    result = []
    for pos in positions:
        ship = field[pos]
        if min(ship.shape)==1:
            result.append(ship.size)
    return sorted(result) == [1,1,1,1,2,2,2,3,3,4]

# try 2
def validate_battlefield(field):
    #创建field的副本
    data = field[:]#使用切片操作创建field的副本
    ships = []
    for i in range(10):
        for j in range(10):
            if data[i][j]>0:
                #左下角
                if i+1<9 and j-1>=0 and data[i+1][j-1]>0:
                    return False
                #右下角
                if i+1<9 and j+1<9 and data[i+1][j+1]>0:
                    return False
                if j<9 and data[i][j+1]>0:
                    data[i][j+1]+=data[i][j]
                elif i<9 and data[i+1][j]>0:
                    data[i+1][j]+=data[i][j]
                else:
                    ships.append(data[i][j])
    ships.sort()
    return ships==[1,1,1,1,2,2,2,3,3,4]
                
    
posted @ 2025-03-25 15:15  Kazuma_124  阅读(19)  评论(0)    收藏  举报