7.11学习记录 - 字典调用 数组切片 双三次插值降采样等

  刚刚入门python,在做超分辨有关的项目,解决了以下问题:

    1、从字典中读取key,把其中一个或几个key的内容摘出来;

    2、对多维数组进行切片;

    3、进行h5和mat文件的读取与保存;

    4、对数组进行双三次插值降采样。

  所选软件:PyCharm,图标如图所示:

  新建一个Python文件,根据今天的目标内容需要调用以下几个库:

import scipy.io as sio # 读取mat文件时需要
import h5py # 读取保存h5文件时需要
import numpy as np # 双三次插值降采样时建立新数组需要
from scipy import misc # 双三次插值降采样的imresize由此得到
  • 读取mat文件并得到保存所需数据的数组
    • 首先我们的目的是从数据库文件(即mat文件)中得到用于实验的数据,但直接将mat文件赋值给数组进行运算会报错(因为mat文件包含内容不只是一个数组,后面会解释),因此要进行debug:
    • 点击窗口左侧图标里最右边的“Evaluate Expression”,弹出窗口,输入“data”发现mat文件是一个字典而非数组,由四个key(键值)和它们相关联的元素组成。

    • 由上图可以看出‘salinas_corrected’键值对应内容为所需数组,将此键值关联的数组提取出来进行下一步的实验:

      salina = data['salinas_corrected']  # 从字典data中提取特定key赋给salina数组
    • 由此得到salina数组,且上图可得shape = (512, 217, 204),是三维数组。

 

  • 数组切片
    • 此时得到了整个数据集,因为实验需要划分训练集测试集,所以要对整个数组进行切片。
    • 在这里将数组第三维度(c即channels)保持不变,前两个维度(h, w即长宽)划分为45×45的小块,
    • 将数组以步长为10分割成(45, 45, 204)的小片(为了代码复用,数字都使用参数形式):
      stride = 10 # 步长为10
      piece = 45 # 分成45×45小块
      data_split_get = []
      for i in range(0, salina.shape[0]-piece, stride): # 为了代码复用推荐使用参数形式,salina.shape[0]意为salina数组第一个维度的长度,下同
          for j in range(0, salina.shape[1]-piece, stride):
              data_split = salina[i:i+piece, j:j+piece, :] # 第三个维度不变,前两个维度切片
              data_split_get.append(data_split)# 将切片所得小块分块连接存储在data_split_get中

       

  • 双三次插值降采样
    • 经过数组分片,得到(846, 45, 45, 204)即(n, h, w, c)的四维数组,为了获取超分辨需要的groundtruth,进行超分辨训练,需要对数组进行降采样处理,得到(846, 15, 15, 204)的数组(长宽变为原来的1/scale,在这里scale取3,也可以取其他数值)。
      n, h, w, c = dataset.shape # dataset是切片后数组,n h w c为它四个维度的长度
      scale = 3.0 # 定义scale,可根据实验效果改变数值
      h2, w2 = int(h/scale), int(w/scale) # 对降采样后数组进行定义时要注意二三维度发生了改变,且必须是int形式,不能为float形式
      data_get = np.zeros([n, h2, w2, c])
      for i in range(0, n):
          for j in range(0, c):
              data_get[i, :, :, j] = misc.imresize(dataset[i, :, :, j], 1.0 / scale, 'bicubic') # 将二三维变为原来的1/scale,进行降采样

       

  • 读取、保存h5文件
    • 保存为h5文件:
      file = h5py.File("data_split.h5", 'w')
      file.create_dataset('split', data=data_split_get) # h5文件也是字典,split为数组所对应的键值
      file.close()
    • 读取h5文件赋给data_new:
      data_new = h5py.File('./data_split.h5') # 括号内为文件地址
      dataset = data_new['split'].value # 读取文件特定键值对应数组值,得到所需数组

       

  • 全部代码如下:
     1 import numpy as np
     2 from scipy import misc
     3 import scipy.io as sio
     4 import h5py
     5 
     6 # get 45*45 piece
     7 data = sio.loadmat('./Salinas_corrected.mat')
     8 salina = data['salinas_corrected']
     9 stride = 10
    10 piece = 4511 data_split_get = []12 for i in range(0, salina.shape[0]-piece, stride):
    13     for j in range(0, salina.shape[1]-piece, stride):
    14         data_split = salina[i:i+piece, j:j+piece, :]
    15         data_split_get.append(data_split)
    16 # save file
    17 file = h5py.File("data_split.h5", 'w')
    18 file.create_dataset('split', data=data_split_get)
    19 file.close()
    20 
    21 # get 15*15 Bicubic interpolation downsample
    22 data_new = h5py.File('./data_split.h5')
    23 dataset = data_new['split'].value
    24 n, h, w, c = dataset.shape
    25 scale = 3.0
    26 h2, w2 = int(h/scale), int(w/scale)
    27 data_get = np.zeros([n, h2, w2, c])
    28 for i in range(0, n):
    29     for j in range(0, c):
    30         data_get[i, :, :, j] = misc.imresize(dataset[i, :, :, j], 1.0 / scale, 'bicubic')
    31 # save file
    32 file = h5py.File("data_split_get.h5", 'w')
    33 file.create_dataset('get', data=data_get)
    34 file.close()
    35 print("hello")
    • see you~

Coconut.

2017-07-11

posted @ 2017-07-11 22:19  Coconut-x  阅读(1367)  评论(1)    收藏  举报