点此进入CSDN

点此添加QQ好友 加载失败时会显示




我用numpy实现了VIT,手写vision transformer, 可在树莓派上运行,在hugging face上训练模型保存参数成numpy格式,纯numpy实现

先复制一点知乎上的内容

 

按照上面的流程图,一个ViT block可以分为以下几个步骤

(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768

(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768

(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768

一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 Z0 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),如下图公式(4),后面接一个MLP进行图片分类

vit 的 numpy 实现代码,可以直接看懂各个部分的细节实现 ,和bert有一些不一样,除了embedding层不一样之外,还有模型结构有有些不同,主要是layer_normalization放在了attention层和feed_forword层之前,bert都是放在之后

import numpy as np
import os
from PIL import Image

# 加载保存的模型数据
model_data = np.load('vit_model_params.npz')
for i in model_data:
    # print(i)
    print(i,model_data[i].shape)

patch_embedding_weight = model_data["vit.embeddings.patch_embeddings.projection.weight"]
patch_embedding_bias = model_data["vit.embeddings.patch_embeddings.projection.bias"]
position_embeddings = model_data["vit.embeddings.position_embeddings"]
cls_token_embeddings = model_data["vit.embeddings.cls_token"]

def patch_embedding(images):
    # 卷积核大小
    kernel_size = 16    
    return conv2d(images, patch_embedding_weight,patch_embedding_bias,stride=kernel_size)

def position_embedding():
    return position_embeddings

def model_input(images):
    
    patch_embedded = np.transpose(patch_embedding(images).reshape([1,768,-1]), (0, 2, 1))

    patch_embedded = np.concatenate([cls_token_embeddings,patch_embedded],axis=1)

    # position_ids = np.array(range(patch_embedded.shape[1]))  # 位置id
    # 位置嵌入矩阵,形状为 (max_position, embedding_size)
    position_embedded = position_embedding()

    embedding_output = patch_embedded + position_embedded

    return embedding_output

def softmax(x, axis=None):
    # e_x = np.exp(x).astype(np.float32) #  
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    sum_ex = np.sum(e_x, axis=axis,keepdims=True).astype(np.float32)
    return e_x / sum_ex

def conv2d(images,weight,bias,stride=1,padding=0):
    # 卷积操作
    N, C, H, W = images.shape
    F, _, HH, WW = weight.shape
    # 计算卷积后的输出尺寸
    H_out = (H - HH + 2 * padding) // stride + 1
    W_out = (W - WW + 2 * padding) // stride + 1
    # 初始化卷积层输出
    out = np.zeros((N, F, H_out, W_out))
    # 执行卷积运算
    for i in range(H_out):
        for j in range(W_out):
            # 提取当前卷积窗口
            window = images[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
            # 执行卷积运算
            out[:, :, i, j] = np.sum(window * weight, axis=(1, 2, 3)) + bias
    # 输出结果
    # print("卷积层输出尺寸:", out.shape)
    return out

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, scores, np.full_like(scores, -np.inf))
    attention_weights = softmax(scores, axis=-1)
    # print(attention_weights)
    # print(np.sum(attention_weights,axis=-1))
    output = np.matmul(attention_weights, V)
    return output, attention_weights

def multihead_attention(input, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O):

    q = np.matmul(input, W_Q.T)+B_Q
    k = np.matmul(input, W_K.T)+B_K
    v = np.matmul(input, W_V.T)+B_V

    # 分割输入为多个头
    q = np.split(q, num_heads, axis=-1)
    k = np.split(k, num_heads, axis=-1)
    v = np.split(v, num_heads, axis=-1)

    outputs = []
    for q_,k_,v_ in zip(q,k,v):
        output, attention_weights = scaled_dot_product_attention(q_, k_, v_)
        outputs.append(output)
    outputs = np.concatenate(outputs, axis=-1)
    outputs = np.matmul(outputs, W_O.T)+B_O
    return outputs

def layer_normalization(x, weight, bias, eps=1e-12):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    std = np.sqrt(variance + eps)
    normalized_x = (x - mean) / std
    output = weight * normalized_x + bias
    return output

def feed_forward_layer(inputs, weight, bias, activation='relu'):
    linear_output = np.matmul(inputs,weight) + bias
    
    if activation == 'relu':
        activated_output = np.maximum(0, linear_output)  # ReLU激活函数
    elif activation == 'gelu':
        activated_output = 0.5 * linear_output * (1 + np.tanh(np.sqrt(2 / np.pi) * (linear_output + 0.044715 * np.power(linear_output, 3))))  # GELU激活函数
    
    elif activation == "tanh" :
        activated_output = np.tanh(linear_output)
    else:
        activated_output = linear_output  # 无激活函数
    
    return activated_output


def residual_connection(inputs, residual):
    # 残差连接
    residual_output = inputs + residual
    return residual_output

def vit(input,num_heads=12):

    for i in range(12):
        # 调用多头自注意力函数
        W_Q = model_data['vit.encoder.layer.{}.attention.attention.query.weight'.format(i)]
        B_Q = model_data['vit.encoder.layer.{}.attention.attention.query.bias'.format(i)]
        W_K = model_data['vit.encoder.layer.{}.attention.attention.key.weight'.format(i)]
        B_K = model_data['vit.encoder.layer.{}.attention.attention.key.bias'.format(i)]
        W_V = model_data['vit.encoder.layer.{}.attention.attention.value.weight'.format(i)]
        B_V = model_data['vit.encoder.layer.{}.attention.attention.value.bias'.format(i)]
        W_O = model_data['vit.encoder.layer.{}.attention.output.dense.weight'.format(i)]
        B_O = model_data['vit.encoder.layer.{}.attention.output.dense.bias'.format(i)]
        intermediate_weight = model_data['vit.encoder.layer.{}.intermediate.dense.weight'.format(i)]
        intermediate_bias = model_data['vit.encoder.layer.{}.intermediate.dense.bias'.format(i)]
        dense_weight = model_data['vit.encoder.layer.{}.output.dense.weight'.format(i)]
        dense_bias = model_data['vit.encoder.layer.{}.output.dense.bias'.format(i)]
        LayerNorm_before_weight = model_data['vit.encoder.layer.{}.layernorm_before.weight'.format(i)]
        LayerNorm_before_bias = model_data['vit.encoder.layer.{}.layernorm_before.bias'.format(i)]
        LayerNorm_after_weight = model_data['vit.encoder.layer.{}.layernorm_after.weight'.format(i)]
        LayerNorm_after_bias = model_data['vit.encoder.layer.{}.layernorm_after.bias'.format(i)]

        output = layer_normalization(input,LayerNorm_before_weight,LayerNorm_before_bias) 
        output = multihead_attention(output, num_heads,W_Q,B_Q,W_K,B_K,W_V,B_V,W_O,B_O)
        output1 = residual_connection(input,output)
           #这里和模型输出一致

        output = layer_normalization(output1,LayerNorm_after_weight,LayerNorm_after_bias)   #一致
        output = feed_forward_layer(output, intermediate_weight.T, intermediate_bias, activation='gelu')
        output = feed_forward_layer(output, dense_weight.T, dense_bias, activation='')
        output2 = residual_connection(output1,output)
    
        input = output2

    bert_pooler_dense_weight = model_data['vit.layernorm.weight']
    bert_pooler_dense_bias = model_data['vit.layernorm.bias']
    output = layer_normalization(output2[:,0],bert_pooler_dense_weight,bert_pooler_dense_bias )    #一致
    classifier_weight = model_data['classifier.weight']
    classifier_bias = model_data['classifier.bias']
    output = feed_forward_layer(output,classifier_weight.T,classifier_bias,activation="" )    #一致
    output = np.argmax(output,axis=-1)
    return output

folder_path = './cifar10'  # 替换为图片所在的文件夹路径
def infer_images_in_folder(folder_path):
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)
        if os.path.isfile(file_path) and file_name.endswith(('.jpg', '.jpeg', '.png')):
            image = Image.open(file_path)
            image = image.resize((224, 224))
            label = file_name.split(".")[0].split("_")[1]
            image = np.array(image)/255.0
            image = np.transpose(image, (2, 0, 1))
            image = np.expand_dims(image,axis=0)
            print("file_path:",file_path,"img size:",image.shape,"label:",label)
            input = model_input(image)
            predicted_class = vit(input)
            print('Predicted class:', predicted_class)


if __name__ == "__main__":

    infer_images_in_folder(folder_path)

结果:

file_path: ./cifar10/8619_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/6042_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/6801_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7946_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/6925_2.jpg img size: (1, 3, 224, 224) label: 2
Predicted class: [2]
file_path: ./cifar10/6007_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/7903_1.jpg img size: (1, 3, 224, 224) label: 1
Predicted class: [1]
file_path: ./cifar10/7064_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/2713_8.jpg img size: (1, 3, 224, 224) label: 8
Predicted class: [8]
file_path: ./cifar10/8575_9.jpg img size: (1, 3, 224, 224) label: 9
Predicted class: [9]
file_path: ./cifar10/1985_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/5312_5.jpg img size: (1, 3, 224, 224) label: 5
Predicted class: [5]
file_path: ./cifar10/593_6.jpg img size: (1, 3, 224, 224) label: 6
Predicted class: [6]
file_path: ./cifar10/8093_7.jpg img size: (1, 3, 224, 224) label: 7
Predicted class: [7]
file_path: ./cifar10/6862_5.jpg img size: (1, 3, 224, 224) label: 5

 

模型参数:

vit.embeddings.cls_token (1, 1, 768)
vit.embeddings.position_embeddings (1, 197, 768)
vit.embeddings.patch_embeddings.projection.weight (768, 3, 16, 16)
vit.embeddings.patch_embeddings.projection.bias (768,)
vit.encoder.layer.0.attention.attention.query.weight (768, 768)
vit.encoder.layer.0.attention.attention.query.bias (768,)
vit.encoder.layer.0.attention.attention.key.weight (768, 768)
vit.encoder.layer.0.attention.attention.key.bias (768,)
vit.encoder.layer.0.attention.attention.value.weight (768, 768)
vit.encoder.layer.0.attention.attention.value.bias (768,)
vit.encoder.layer.0.attention.output.dense.weight (768, 768)
vit.encoder.layer.0.attention.output.dense.bias (768,)
vit.encoder.layer.0.intermediate.dense.weight (3072, 768)
vit.encoder.layer.0.intermediate.dense.bias (3072,)
vit.encoder.layer.0.output.dense.weight (768, 3072)
vit.encoder.layer.0.output.dense.bias (768,)
vit.encoder.layer.0.layernorm_before.weight (768,)
vit.encoder.layer.0.layernorm_before.bias (768,)
vit.encoder.layer.0.layernorm_after.weight (768,)
vit.encoder.layer.0.layernorm_after.bias (768,)
vit.encoder.layer.1.attention.attention.query.weight (768, 768)
vit.encoder.layer.1.attention.attention.query.bias (768,)
vit.encoder.layer.1.attention.attention.key.weight (768, 768)
vit.encoder.layer.1.attention.attention.key.bias (768,)
vit.encoder.layer.1.attention.attention.value.weight (768, 768)
vit.encoder.layer.1.attention.attention.value.bias (768,)
vit.encoder.layer.1.attention.output.dense.weight (768, 768)
vit.encoder.layer.1.attention.output.dense.bias (768,)
vit.encoder.layer.1.intermediate.dense.weight (3072, 768)
vit.encoder.layer.1.intermediate.dense.bias (3072,)
vit.encoder.layer.1.output.dense.weight (768, 3072)
vit.encoder.layer.1.output.dense.bias (768,)
vit.encoder.layer.1.layernorm_before.weight (768,)
vit.encoder.layer.1.layernorm_before.bias (768,)
vit.encoder.layer.1.layernorm_after.weight (768,)
vit.encoder.layer.1.layernorm_after.bias (768,)
vit.encoder.layer.2.attention.attention.query.weight (768, 768)
vit.encoder.layer.2.attention.attention.query.bias (768,)
vit.encoder.layer.2.attention.attention.key.weight (768, 768)
vit.encoder.layer.2.attention.attention.key.bias (768,)
vit.encoder.layer.2.attention.attention.value.weight (768, 768)
vit.encoder.layer.2.attention.attention.value.bias (768,)
vit.encoder.layer.2.attention.output.dense.weight (768, 768)
vit.encoder.layer.2.attention.output.dense.bias (768,)
vit.encoder.layer.2.intermediate.dense.weight (3072, 768)
vit.encoder.layer.2.intermediate.dense.bias (3072,)
vit.encoder.layer.2.output.dense.weight (768, 3072)
vit.encoder.layer.2.output.dense.bias (768,)
vit.encoder.layer.2.layernorm_before.weight (768,)
vit.encoder.layer.2.layernorm_before.bias (768,)
vit.encoder.layer.2.layernorm_after.weight (768,)
vit.encoder.layer.2.layernorm_after.bias (768,)
vit.encoder.layer.3.attention.attention.query.weight (768, 768)
vit.encoder.layer.3.attention.attention.query.bias (768,)
vit.encoder.layer.3.attention.attention.key.weight (768, 768)
vit.encoder.layer.3.attention.attention.key.bias (768,)
vit.encoder.layer.3.attention.attention.value.weight (768, 768)
vit.encoder.layer.3.attention.attention.value.bias (768,)
vit.encoder.layer.3.attention.output.dense.weight (768, 768)
vit.encoder.layer.3.attention.output.dense.bias (768,)
vit.encoder.layer.3.intermediate.dense.weight (3072, 768)
vit.encoder.layer.3.intermediate.dense.bias (3072,)
vit.encoder.layer.3.output.dense.weight (768, 3072)
vit.encoder.layer.3.output.dense.bias (768,)
vit.encoder.layer.3.layernorm_before.weight (768,)
vit.encoder.layer.3.layernorm_before.bias (768,)
vit.encoder.layer.3.layernorm_after.weight (768,)
vit.encoder.layer.3.layernorm_after.bias (768,)
vit.encoder.layer.4.attention.attention.query.weight (768, 768)
vit.encoder.layer.4.attention.attention.query.bias (768,)
vit.encoder.layer.4.attention.attention.key.weight (768, 768)
vit.encoder.layer.4.attention.attention.key.bias (768,)
vit.encoder.layer.4.attention.attention.value.weight (768, 768)
vit.encoder.layer.4.attention.attention.value.bias (768,)
vit.encoder.layer.4.attention.output.dense.weight (768, 768)
vit.encoder.layer.4.attention.output.dense.bias (768,)
vit.encoder.layer.4.intermediate.dense.weight (3072, 768)
vit.encoder.layer.4.intermediate.dense.bias (3072,)
vit.encoder.layer.4.output.dense.weight (768, 3072)
vit.encoder.layer.4.output.dense.bias (768,)
vit.encoder.layer.4.layernorm_before.weight (768,)
vit.encoder.layer.4.layernorm_before.bias (768,)
vit.encoder.layer.4.layernorm_after.weight (768,)
vit.encoder.layer.4.layernorm_after.bias (768,)
vit.encoder.layer.5.attention.attention.query.weight (768, 768)
vit.encoder.layer.5.attention.attention.query.bias (768,)
vit.encoder.layer.5.attention.attention.key.weight (768, 768)
vit.encoder.layer.5.attention.attention.key.bias (768,)
vit.encoder.layer.5.attention.attention.value.weight (768, 768)
vit.encoder.layer.5.attention.attention.value.bias (768,)
vit.encoder.layer.5.attention.output.dense.weight (768, 768)
vit.encoder.layer.5.attention.output.dense.bias (768,)
vit.encoder.layer.5.intermediate.dense.weight (3072, 768)
vit.encoder.layer.5.intermediate.dense.bias (3072,)
vit.encoder.layer.5.output.dense.weight (768, 3072)
vit.encoder.layer.5.output.dense.bias (768,)
vit.encoder.layer.5.layernorm_before.weight (768,)
vit.encoder.layer.5.layernorm_before.bias (768,)
vit.encoder.layer.5.layernorm_after.weight (768,)
vit.encoder.layer.5.layernorm_after.bias (768,)
vit.encoder.layer.6.attention.attention.query.weight (768, 768)
vit.encoder.layer.6.attention.attention.query.bias (768,)
vit.encoder.layer.6.attention.attention.key.weight (768, 768)
vit.encoder.layer.6.attention.attention.key.bias (768,)
vit.encoder.layer.6.attention.attention.value.weight (768, 768)
vit.encoder.layer.6.attention.attention.value.bias (768,)
vit.encoder.layer.6.attention.output.dense.weight (768, 768)
vit.encoder.layer.6.attention.output.dense.bias (768,)
vit.encoder.layer.6.intermediate.dense.weight (3072, 768)
vit.encoder.layer.6.intermediate.dense.bias (3072,)
vit.encoder.layer.6.output.dense.weight (768, 3072)
vit.encoder.layer.6.output.dense.bias (768,)
vit.encoder.layer.6.layernorm_before.weight (768,)
vit.encoder.layer.6.layernorm_before.bias (768,)
vit.encoder.layer.6.layernorm_after.weight (768,)
vit.encoder.layer.6.layernorm_after.bias (768,)
vit.encoder.layer.7.attention.attention.query.weight (768, 768)
vit.encoder.layer.7.attention.attention.query.bias (768,)
vit.encoder.layer.7.attention.attention.key.weight (768, 768)
vit.encoder.layer.7.attention.attention.key.bias (768,)
vit.encoder.layer.7.attention.attention.value.weight (768, 768)
vit.encoder.layer.7.attention.attention.value.bias (768,)
vit.encoder.layer.7.attention.output.dense.weight (768, 768)
vit.encoder.layer.7.attention.output.dense.bias (768,)
vit.encoder.layer.7.intermediate.dense.weight (3072, 768)
vit.encoder.layer.7.intermediate.dense.bias (3072,)
vit.encoder.layer.7.output.dense.weight (768, 3072)
vit.encoder.layer.7.output.dense.bias (768,)
vit.encoder.layer.7.layernorm_before.weight (768,)
vit.encoder.layer.7.layernorm_before.bias (768,)
vit.encoder.layer.7.layernorm_after.weight (768,)
vit.encoder.layer.7.layernorm_after.bias (768,)
vit.encoder.layer.8.attention.attention.query.weight (768, 768)
vit.encoder.layer.8.attention.attention.query.bias (768,)
vit.encoder.layer.8.attention.attention.key.weight (768, 768)
vit.encoder.layer.8.attention.attention.key.bias (768,)
vit.encoder.layer.8.attention.attention.value.weight (768, 768)
vit.encoder.layer.8.attention.attention.value.bias (768,)
vit.encoder.layer.8.attention.output.dense.weight (768, 768)
vit.encoder.layer.8.attention.output.dense.bias (768,)
vit.encoder.layer.8.intermediate.dense.weight (3072, 768)
vit.encoder.layer.8.intermediate.dense.bias (3072,)
vit.encoder.layer.8.output.dense.weight (768, 3072)
vit.encoder.layer.8.output.dense.bias (768,)
vit.encoder.layer.8.layernorm_before.weight (768,)
vit.encoder.layer.8.layernorm_before.bias (768,)
vit.encoder.layer.8.layernorm_after.weight (768,)
vit.encoder.layer.8.layernorm_after.bias (768,)
vit.encoder.layer.9.attention.attention.query.weight (768, 768)
vit.encoder.layer.9.attention.attention.query.bias (768,)
vit.encoder.layer.9.attention.attention.key.weight (768, 768)
vit.encoder.layer.9.attention.attention.key.bias (768,)
vit.encoder.layer.9.attention.attention.value.weight (768, 768)
vit.encoder.layer.9.attention.attention.value.bias (768,)
vit.encoder.layer.9.attention.output.dense.weight (768, 768)
vit.encoder.layer.9.attention.output.dense.bias (768,)
vit.encoder.layer.9.intermediate.dense.weight (3072, 768)
vit.encoder.layer.9.intermediate.dense.bias (3072,)
vit.encoder.layer.9.output.dense.weight (768, 3072)
vit.encoder.layer.9.output.dense.bias (768,)
vit.encoder.layer.9.layernorm_before.weight (768,)
vit.encoder.layer.9.layernorm_before.bias (768,)
vit.encoder.layer.9.layernorm_after.weight (768,)
vit.encoder.layer.9.layernorm_after.bias (768,)
vit.encoder.layer.10.attention.attention.query.weight (768, 768)
vit.encoder.layer.10.attention.attention.query.bias (768,)
vit.encoder.layer.10.attention.attention.key.weight (768, 768)
vit.encoder.layer.10.attention.attention.key.bias (768,)
vit.encoder.layer.10.attention.attention.value.weight (768, 768)
vit.encoder.layer.10.attention.attention.value.bias (768,)
vit.encoder.layer.10.attention.output.dense.weight (768, 768)
vit.encoder.layer.10.attention.output.dense.bias (768,)
vit.encoder.layer.10.intermediate.dense.weight (3072, 768)
vit.encoder.layer.10.intermediate.dense.bias (3072,)
vit.encoder.layer.10.output.dense.weight (768, 3072)
vit.encoder.layer.10.output.dense.bias (768,)
vit.encoder.layer.10.layernorm_before.weight (768,)
vit.encoder.layer.10.layernorm_before.bias (768,)
vit.encoder.layer.10.layernorm_after.weight (768,)
vit.encoder.layer.10.layernorm_after.bias (768,)
vit.encoder.layer.11.attention.attention.query.weight (768, 768)
vit.encoder.layer.11.attention.attention.query.bias (768,)
vit.encoder.layer.11.attention.attention.key.weight (768, 768)
vit.encoder.layer.11.attention.attention.key.bias (768,)
vit.encoder.layer.11.attention.attention.value.weight (768, 768)
vit.encoder.layer.11.attention.attention.value.bias (768,)
vit.encoder.layer.11.attention.output.dense.weight (768, 768)
vit.encoder.layer.11.attention.output.dense.bias (768,)
vit.encoder.layer.11.intermediate.dense.weight (3072, 768)
vit.encoder.layer.11.intermediate.dense.bias (3072,)
vit.encoder.layer.11.output.dense.weight (768, 3072)
vit.encoder.layer.11.output.dense.bias (768,)
vit.encoder.layer.11.layernorm_before.weight (768,)
vit.encoder.layer.11.layernorm_before.bias (768,)
vit.encoder.layer.11.layernorm_after.weight (768,)
vit.encoder.layer.11.layernorm_after.bias (768,)
vit.layernorm.weight (768,)
vit.layernorm.bias (768,)
classifier.weight (10, 768)
classifier.bias (10,)

 

hungging face模型训练代码 对cifar10训练,保存模型参数为numpy格式,方便numpy的模型加载:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from transformers import ViTModel, ViTForImageClassification
from tqdm import tqdm
import numpy as np

# 设置随机种子
torch.manual_seed(42)

# 定义超参数
batch_size = 64
num_epochs = 1
learning_rate = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载CIFAR-10数据集
train_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='/data/xinyuuliu/datas', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 加载预训练的ViT模型
vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)

# 替换分类头
num_classes = 10
# vit_model.config.classifier = 'mlp'
# vit_model.config.num_labels = num_classes
vit_model.classifier = nn.Linear(vit_model.config.hidden_size, num_classes).to(device)


# parameters = list(vit_model.parameters())
# for x in parameters[:-1]:
#     x.requires_grad = False


# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

# 微调ViT模型
for epoch in range(num_epochs):
    print("epoch:",epoch)
    vit_model.train()
    train_loss = 0.0
    train_correct = 0

    bar = tqdm(train_loader,total=len(train_loader))
    for images, labels in bar:
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = vit_model(images)
        loss = criterion(outputs.logits, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        train_correct += (predicted == labels).sum().item()

    # 在训练集上计算准确率
    train_accuracy = 100.0 * train_correct / len(train_dataset)

    # 在测试集上进行评估
    vit_model.eval()
    test_loss = 0.0
    test_correct = 0

    with torch.no_grad():
        bar = tqdm(test_loader,total=len(test_loader))
        for images, labels in bar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = vit_model(images)
            loss = criterion(outputs.logits, labels)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            test_correct += (predicted == labels).sum().item()

    # 在测试集上计算准确率
    test_accuracy = 100.0 * test_correct / len(test_dataset)

    # 打印每个epoch的训练损失、训练准确率和测试准确率
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%')


torch.save(vit_model.state_dict(), 'vit_model_parameters.pth')

# 打印BERT模型的权重维度
for name, param in vit_model.named_parameters():
    print(name, param.data.shape)

# # # 保存模型参数为NumPy格式
model_params = {name: param.data.cpu().numpy() for name, param in vit_model.named_parameters()}
np.savez('vit_model_params.npz', **model_params)
# model_params

 

Epoch [1/1], Train Loss: 97.7498, Train Accuracy: 96.21%, Test Accuracy: 96.86%

posted @ 2023-07-11 15:42  高颜值的殺生丸  阅读(130)  评论(0编辑  收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928