深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记

Training Transformer

TA’s Slide

Slide

Description

In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.

Please feel free to mail us if you have any questions.

ntu-ml-2025-spring-ta@googlegroups.com

Utilities

Download packages

!pip install datasets==3.3.2
Collecting datasets==3.3.2
  Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)
  Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.9.0
    Uninstalling fsspec-2025.9.0:
      Successfully uninstalled fsspec-2025.9.0
   ---------------------------------------- 0/3 [fsspec]
   ---------------------------------------- 0/3 [fsspec]
   ---------------------------------------- 0/3 [fsspec]
   ---------------------------------------- 0/3 [fsspec]
  Attempting uninstall: dill
   ---------------------------------------- 0/3 [fsspec]
    Found existing installation: dill 0.4.0
   ---------------------------------------- 0/3 [fsspec]
    Uninstalling dill-0.4.0:
   ---------------------------------------- 0/3 [fsspec]
      Successfully uninstalled dill-0.4.0
   ---------------------------------------- 0/3 [fsspec]
   ------------- -------------------------- 1/3 [dill]
   ------------- -------------------------- 1/3 [dill]
   ------------- -------------------------- 1/3 [dill]
  Attempting uninstall: datasets
   ------------- -------------------------- 1/3 [dill]
    Found existing installation: datasets 4.1.1
   ------------- -------------------------- 1/3 [dill]
    Uninstalling datasets-4.1.1:
   ------------- -------------------------- 1/3 [dill]
      Successfully uninstalled datasets-4.1.1
   ------------- -------------------------- 1/3 [dill]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   -------------------------- ------------- 2/3 [datasets]
   ---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0

Import Packages

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional

Check Devices

!nvidia-smi
Wed Oct  8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97                 Driver Version: 580.97         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti   WDDM  |   00000000:07:00.0  On |                  Off |
| 47%   42C    P8             25W /  450W |   12684MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2292    C+G   C:\Windows\System32\dwm.exe           N/A      |
|    0   N/A  N/A            5552    C+G   ...8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A            9928    C+G   C:\Windows\explorer.exe               N/A      |
|    0   N/A  N/A           10036    C+G   ..._cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A           10264    C+G   ...y\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A           10632    C+G   ...ogram Files\ToDesk\ToDesk.exe      N/A      |
|    0   N/A  N/A           14304    C+G   ...xyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A           15600    C+G   ...5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A           15812    C+G   ...ouryDevice\asus_framework.exe      N/A      |
|    0   N/A  N/A           18660    C+G   ...crosoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A           18668    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           21724    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           22748    C+G   ...s\TencentDocs\TencentDocs.exe      N/A      |
|    0   N/A  N/A           25412    C+G   ...ram Files\Tencent\QQNT\QQ.exe      N/A      |
|    0   N/A  N/A           25872    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           26600    C+G   ...ocal\Programs\Quark\quark.exe      N/A      |
|    0   N/A  N/A           28688    C+G   ...ntrolPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A           30104    C+G   ...de\Microsoft VS Code\Code.exe      N/A      |
|    0   N/A  N/A           31500    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           39276    C+G   ...t\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A           41696    C+G   ...PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A           44176    C+G   ...ffice6\promecefpluginhost.exe      N/A      |
|    0   N/A  N/A           72652      C   ...2025-Spring-Hw1\python.exe.c~      N/A      |
|    0   N/A  N/A          115660    C+G   ...ef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A          124396    C+G   ...yb3d8bbwe\WindowsTerminal.exe      N/A      |
+-----------------------------------------------------------------------------------------+

Set Random Seed

set_seed(0)

Prepare Data

Define Dataset

from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")

Download Dataset & Prepare Dataloader

# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][
posted @ 2025-11-09 15:01  yxysuanfa  阅读(2)  评论(0)    收藏  举报