# """My demo train script."""
import argparse
import logging
import os
import random
import time
import numpy as np
import torch
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset
def parse_args() -> argparse.Namespace:
"""Parse arguments."""
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--seed", type=int, help="Fix random seed", default=123)
parser.add_argument(
"--log_file", type=str, help="Log file", default="test_train.log"
)
parser.add_argument(
"--log_path", type=str, help="Model path", default="./training_log/"
)
parser.add_argument(
"--train_epochs", type=int, help="Epochs of training", default=5
)
parser.add_argument("--batch_size", type=int, help="Batch size", default=32)
parser.add_argument(
"--learning_rate",
type=float,
help="Learning rate",
default=1e-3,
)
parser.add_argument("--device", type=str, help="Run on which device", default="cpu")
parser.add_argument(
"--cuda_visible_devices", type=str, help="Cuda visible devices", default="0"
)
return parser.parse_args()
def init_logging(log_file: str, level: str = "INFO") -> None:
"""Initialize logging."""
logging.basicConfig(
filename=log_file,
filemode="w",
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=level,
)
logging.getLogger().addHandler(logging.StreamHandler())
def set_seed(seed: int) -> None:
"""Set seed for reproducibility."""
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
def seed_worker(work_id: int) -> None:
"""Set seed for worker."""
np.random.seed(work_id)
random.seed(work_id)
class DatasetClass(Dataset):
"""My demo dataset class."""
def __init__(self):
self.input = np.random.rand(1000000, 2).astype(np.float32)
# self.input[:, 1] = 0.0
self.target = np.zeros([1000000, 1])
self.target[:, 0] = self.input[:, 0] + 1.0
def __len__(self):
return len(self.input)
def __getitem__(self, idx: int) -> tuple:
return self.input[idx], self.target[idx]
class ModelClass(torch.nn.Module):
"""My demo model class."""
def __init__(self):
super().__init__()
self.my_layer = nn.Linear(2, 1)
def forward(self, inputs: Tensor) -> Tensor:
"""My demo forward function."""
outputs = self.my_layer(inputs)
return outputs
def get_loss(model_output: Tensor, target: Tensor) -> Tensor:
"""My demo loss function."""
loss = torch.norm(model_output - target, dim=-1).sum()
return loss
def training() -> None:
"""My demo training function."""
train_set = DatasetClass()
g = torch.Generator()
g.manual_seed(args.seed)
train_loader = DataLoader(
dataset=train_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=os.cpu_count(),
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)
model = ModelClass()
if args.device == "cuda":
model = nn.DataParallel(model)
model.to(args.device)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
for epoch in range(args.train_epochs):
model.train()
for batch_index, (features, labels) in enumerate(train_loader):
features = features.to(args.device)
labels = labels.to(args.device)
model_outputs = model(features)
optimizer.zero_grad(set_to_none=True)
loss = get_loss(model_outputs, labels)
loss.backward()
optimizer.step()
if batch_index % 1000 == 0:
logging.info(
"Epoch: %s, Batch index: %s, Loss: %s",
epoch,
batch_index,
loss.item(),
)
torch.save(model.state_dict(), f"{args.log_path}/trained_model.pth")
def testing() -> None:
"""My demo testing function."""
test_set = DatasetClass()
g = torch.Generator()
g.manual_seed(args.seed)
test_loader = DataLoader(
dataset=test_set,
batch_size=args.batch_size,
shuffle=True,
num_workers=os.cpu_count(),
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)
model = ModelClass()
if args.device == "cuda":
model = nn.DataParallel(model)
model.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
model.to(args.device)
model.eval()
with torch.no_grad():
for batch_index, (features, labels) in enumerate(test_loader):
features = features.to(args.device)
labels = labels.to(args.device)
model_outputs = model(features)
loss = get_loss(model_outputs, labels)
if batch_index % 1000 == 0:
logging.info(
"Batch index: %s, Loss: %s",
batch_index,
loss.item() / args.batch_size,
)
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
init_logging(args.log_file)
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
main_start_time = time.time()
training()
main_end_time = time.time()
logging.info("Main time: %s", main_end_time - main_start_time)
testing()