使用t-SNE可视化CIFAR-10的表征

t-SNE理论相关理论可参见t-SNE 算法。本文通过PyTorch提供的预训练Resnet50提取CIFAR-10表征,并使用t-SNE进行可视化。

加载预训练Resnet50

import torch
from torchvision.models import resnet50, ResNet50_Weights

# 加载ResNet模型
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)

# 移除最后一层全连接层
resnet_fe = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet_fe.cuda()
resnet_fe.eval()

加载CIFAR-10数据集

from torchvision.datasets import CIFAR10
from torchvision import transforms

transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

dataset = CIFAR10(root='./data', train=True, download=True, transform=transformer)

提取CIFAR-10表征

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
features = []
labels = []
for i, (x, y) in enumerate(dataloader):
    x = x.cuda()
    with torch.no_grad():
        feature = resnet_fe(x)  # feature shape: (batch_size, 512, 1, 1)
    feature = feature.view(feature.size(0), -1).cpu()  # feature shape: (batch_size, 512)
    for f,l in zip(feature,y):
        features.append(f.numpy())
        labels.append(l.numpy())

训练t-SNE

from sklearn.manifold import TSNE
import numpy as np

features = np.array(features)
labels = np.array(labels)
tsne = TSNE(n_components=2, random_state=0).fit_transform(X=features)

可视化

import altair as alt
import pandas as pd

# 提取 x 和 y 坐标
k = 5000 # 否则会报错:MaxRowsError: The number of rows in your dataset is greater than the maximum allowed (5000).
label = labels[:k]
x = tsne[:k, 0]
y = tsne[:k:, 1]

# 创建 DataFrame
df = pd.DataFrame({'x': x, 'y': y, 'label': label})

# 创建散点图
chart = alt.Chart(df).mark_point(filled=True).encode(x="x", y="y", color="label:N").properties(width=400, height=400)
chart = chart.configure_axis(
    disable=True, # 禁用坐标轴
)
chart

参考文献

  1. Representation Learning — CIFAR-10
  2. t-SNE 算法
  3. sklearn.manifold.TSNE

运行环境

# Name                    Version                   Build  Channel
altair                    5.0.1           py312haa95532_0 
jupyter                   1.0.0           py312haa95532_9
pandas                    2.2.1           py312h0158946_0
pytorch                   2.2.2           py3.12_cuda12.1_cudnn8_0    pytorch
scikit-learn              1.3.0           py312hc7c4135_2

有时候,PyCharm 2024.1 (Professional Edition)运行的Jupyter,altair的图显示不出来,可以用浏览器打开.ipynb文件查看。

posted @ 2024-04-16 19:35  zh-jp  阅读(23)  评论(0编辑  收藏  举报