"""
__project_ = 'TF2learning'
__file_name__ = 'quantization'
__author__ = 'qilibin'
__time__ = '2021/3/17 9:18'
__product_name = PyCharm
"""
import h5py
import pandas as pd
import numpy as np
'''
读取原来的只包含权重的H5模型,按层遍历,对每层的每个权重进行16位或8位量化,将量化后的权重数值重新保存在H5文件中
'''
def quantization16bit(old_model_path,new_model_path,bit_num):
'''
:param old_model_path: 未量化的模型路径 模型是只保存了权重未保存网络结构
:param new_model_path: 量化过后的模型路径
:param bit_num: 量化位数
:return:
'''
f = h5py.File(old_model_path,'r')
f2 = h5py.File(new_model_path,'w')
for layer in f.keys():
print (layer)
length = len(list(f[layer].keys()))
if length > 0:
g1 = f2.create_group(layer)
g1.attrs["weight_names"] = layer
g2 = g1.create_group(layer)
for weight in f[layer][layer].keys():
print ("wieght name is :" + weight)
oldparam = f[layer][layer][weight][:]
print ('-----------------------------------------old-----------------------')
print (oldparam)
if type(oldparam) == np.ndarray:
if bit_num == 16:
newparam = np.float16(oldparam)
if bit_num == 8:
min_val = np.min(oldparam)
max_val = np.max(oldparam)
oldparam = np.round((oldparam - min_val) / (max_val - min_val) * 255)
newparam = np.uint8(oldparam)
else:
newparam = oldparam
print ('-----------------------------------------new-----------------------')
if bit_num == 16:
d = g2.create_dataset(weight, data=newparam,dtype=np.float16)
if bit_num == 8:
d = g2.create_dataset(weight, data=newparam, dtype=np.uint8)
else:
g1 = f2.create_group(layer)
g1.attrs["weight_names"] = layer
f.close()
f2.close()
old_model_path = './yolox_s.h5'
new_model_path = './yolox_sq.h5'
quantization16bit(old_model_path,new_model_path,8)