Jigsaw Unintended Bias in Toxicity Classification 完整代码

from collections import Counter
from contextlib import contextmanager
import copy
from functools import partial
from itertools import chain
from multiprocessing import Pool
import os
import random
import re
import string
import time
import warnings

import joblib
import numpy as np
import pandas as pd

from nltk.stem import PorterStemmer, SnowballStemmer
from nltk.stem.lancaster import LancasterStemmer

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, KFold
from sklearn.utils import shuffle

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import torch
import torch.nn as nn
from torch.utils.data import Dataset, Sampler, DataLoader
from torch.optim.optimizer import Optimizer

EMBEDDING_FASTTEXT = '../input/fasttext-crawl-300d-2m/crawl-300d-2M.vec'
TRAIN_DATA = '../input/jigsaw-unintended-bias-in-toxicity-classification/train.csv'
TEST_DATA = '../input/jigsaw-unintended-bias-in-toxicity-classification/test.csv'
SAMPLE_SUBMISSION = '../input/jigsaw-unintended-bias-in-toxicity-classification/sample_submission.csv'

embed_size = 300
max_features = 100000
max_len = 220

batch_size = 512
train_epochs = 6
n_splits = 5

mu = 0.9
updates_per_epoch = 10

seed = 1029
device = torch.device('cuda:0')

ps = PorterStemmer()
lc = LancasterStemmer()
sb = SnowballStemmer('english')

@contextmanager
def timer(msg):
    t0 = time.time()
    print(f'[{msg}] start.')
    yield
    elapsed_time = time.time() - t0
    print(f'[{msg}] done in {elapsed_time / 60:.2f} min.')


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

misspell_dict = {"aren't": "are not", "can't": "cannot", "couldn't": "could not",
                 "didn't": "did not", "doesn't": "does not", "don't": "do not",
                 "hadn't": "had not", "hasn't": "has not", "haven't": "have not",
                 "he'd": "he would", "he'll": "he will", "he's": "he is",
                 "i'd": "I had", "i'll": "I will", "i'm": "I am", "isn't": "is not",
                 "it's": "it is", "it'll": "it will", "i've": "I have", "let's": "let us",
                 "mightn't": "might not", "mustn't": "must not", "shan't": "shall not",
                 "she'd": "she would", "she'll": "she will", "she's": "she is",
                 "shouldn't": "should not", "that's": "that is", "there's": "there is",
                 "they'd": "they would", "they'll": "they will", "they're": "they are",
                 "they've": "they have", "we'd": "we would", "we're": "we are",
                 "weren't": "were not", "we've": "we have", "what'll": "what will",
                 "what're": "what are", "what's": "what is", "what've": "what have",
                 "where's": "where is", "who'd": "who would", "who'll": "who will",
                 "who're": "who are", "who's": "who is", "who've": "who have",
                 "won't": "will not", "wouldn't": "would not", "you'd": "you would",
                 "you'll": "you will", "you're": "you are", "you've": "you have",
                 "'re": " are", "wasn't": "was not", "we'll": " will", "tryin'": "trying"}


def _get_misspell(misspell_dict):
    misspell_re = re.compile('(%s)' % '|'.join(misspell_dict.keys()))
    return misspell_dict, misspell_re


def replace_typical_misspell(text):
    misspellings, misspellings_re = _get_misspell(misspell_dict)

    def replace(match):
        return misspellings[match.group(0)]

    return misspellings_re.sub(replace, text)
    

puncts = [',', '.', '"', ':', ')', '(', '-', '!', '?', '|', ';', "'", '$', '&', '/', '[', ']',
          '>', '%', '=', '#', '*', '+', '\\', '•', '~', '@', '£', '·', '_', '{', '}', '©', '^',
          '®', '`', '<', '→', '°', '€', '™', '›', '♥', '←', '×', '§', '″', '′', 'Â', '█',
          '½', 'à', '…', '“', '★', '”', '–', '●', 'â', '►', '−', '¢', '²', '¬', '░', '¶',
          '↑', '±', '¿', '▾', '═', '¦', '║', '―', '¥', '▓', '—', '‹', '─', '▒', ':', '¼',
          '⊕', '▼', '▪', '†', '■', '’', '▀', '¨', '▄', '♫', '☆', 'é', '¯', '♦', '¤', '▲',
          'è', '¸', '¾', 'Ã', '⋅', '‘', '∞', '∙', ')', '↓', '、', '│', '(', '»', ',', '♪',
          '╩', '╚', '³', '・', '╦', '╣', '╔', '╗', '▬', '❤', 'ï', 'Ø', '¹', '≤', '‡', '√']


def clean_text(x):
    x = str(x)
    for punct in puncts + list(string.punctuation):
        if punct in x:
            x = x.replace(punct, f' {punct} ')
    return x


def clean_numbers(x):
    return re.sub(r'\d+', ' ', x)

def load_embedding(embedding_path, word_index):

    def get_coefs(word, *arr):
        return word, np.asarray(arr, dtype='float32')

    embeddings_index = dict(get_coefs(*o.strip().split(' ')) for o in open(embedding_path))
    
    # word_index = tokenizer.word_index
    nb_words = min(max_features + 2, len(word_index))
    embedding_matrix = np.zeros((nb_words, embed_size))

    for key, i in word_index.items():
        word = key
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = key.lower()
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = key.upper()
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = key.capitalize()
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = ps.stem(key)
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = lc.stem(key)
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue
        word = sb.stem(key)
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            embedding_matrix[i] = embedding_vector
            continue

    return embedding_matrix

def load_and_prec():
    train = pd.read_csv(TRAIN_DATA, index_col='id')
    test = pd.read_csv(TEST_DATA, index_col='id')
    
    # lower
    train['comment_text'] = train['comment_text'].str.lower()
    test['comment_text'] = test['comment_text'].str.lower()

    # clean misspellings
    train['comment_text'] = train['comment_text'].apply(replace_typical_misspell)
    test['comment_text'] = test['comment_text'].apply(replace_typical_misspell)

    # clean the text
    train['comment_text'] = train['comment_text'].apply(clean_text)
    test['comment_text'] = test['comment_text'].apply(clean_text)

    # clean numbers
    train['comment_text'] = train['comment_text'].apply(clean_numbers)
    test['comment_text'] = test['comment_text'].apply(clean_numbers)
    
    # strip
    train['comment_text'] = train['comment_text'].str.strip()
    test['comment_text'] = test['comment_text'].str.strip()
    
    # replace blank with nan
    train['comment_text'].replace('', np.nan, inplace=True)
    test['comment_text'].replace('', np.nan, inplace=True)

    # nan prediction
    nan_pred = train['target'][train['comment_text'].isna()].mean()
    
    # fill up the missing values
    train_x = train['comment_text'].fillna('_##_').values
    test_x = test['comment_text'].fillna('_##_').values
    
    # get the target values
    identity_columns = [
        'male', 'female', 'homosexual_gay_or_lesbian', 'christian', 'jewish',
        'muslim', 'black', 'white', 'psychiatric_or_mental_illness']

    weights = np.ones((len(train),))
    weights += train[identity_columns].fillna(0).values.sum(axis=1) * 3
    weights += train['target'].values * 8
    weights /= weights.max()
    train_y = np.vstack([train['target'].values, weights]).T
    
    train_y_identity = train[identity_columns].values

    # shuffling the data
    np.random.seed(seed)
    train_idx = np.random.permutation(len(train_x))

    train_x = train_x[train_idx]
    train_y = train_y[train_idx]
    train_y_identity = train_y_identity[train_idx]

    return train_x, train_y, train_y_identity, test_x, nan_pred

def build_vocab(texts, max_features):
    counter = Counter()
    for text in texts:
        counter.update(text.split())

    vocab = {
        'token2id': {'<PAD>': 0, '<UNK>': max_features + 1},
        'id2token': {}
    }
    vocab['token2id'].update(
        {token: _id + 1 for _id, (token, count) in
         enumerate(counter.most_common(max_features))})
    vocab['id2token'] = {v: k for k, v in vocab['token2id'].items()}
    return vocab


def tokenize(texts, vocab):
    
    def text2ids(text, token2id):
        return [
            token2id.get(token, len(token2id) - 1)
            for token in text.split()[:max_len]]
    
    return [
        text2ids(text, vocab['token2id'])
        for text in texts]

class TextDataset(Dataset):

    def __init__(self, seqs, targets=None, maxlen=200):
        if targets is not None:
            self.targets = targets
        else:
            self.targets = np.random.randint(2, size=(len(seqs),))
        
        self.seqs = seqs
        self.maxlen = maxlen
        
    def __len__(self):
        return len(self.seqs)
        
    def get_keys(self):
        lens = np.fromiter(
            ((min(self.maxlen, len(seq))) for seq in self.seqs),
            dtype=np.int32)
        return lens
        
    def __getitem__(self, index):
        return index, self.seqs[index], self.targets[index]


def collate_fn(data):

    def _pad_sequences(seqs):
        lens = [len(seq) for seq in seqs]
        max_len = max(lens)

        padded_seqs = torch.zeros(len(seqs), max_len).long()
        for i, seq in enumerate(seqs):
            start = max_len - lens[i]
            padded_seqs[i, start:] = torch.LongTensor(seq)
        return padded_seqs

    index, seqs, targets = zip(*data)
    seqs = _pad_sequences(seqs)
    return index, seqs, torch.FloatTensor(targets)


class BucketSampler(Sampler):

    def __init__(self, data_source, sort_keys, bucket_size=None, batch_size=1048, shuffle_data=True):
        super().__init__(data_source)
        self.shuffle = shuffle_data
        self.batch_size = batch_size
        self.sort_keys = sort_keys
        self.bucket_size = bucket_size if bucket_size is not None else len(sort_keys)
        self.weights = None

        if not shuffle_data:
            self.index = self.prepare_buckets()
        else:
            self.index = None

    def set_weights(self, weights):
        assert weights >= 0
        total = np.sum(weights)
        if total != 1:
            weights = weights / total
        self.weights = weights

    def __iter__(self):
        indices = None
        if self.weights is not None:
            total = len(self.sort_keys)
            indices = np.random.choice(total, (total,), p=self.weights)
        if self.shuffle:
            self.index = self.prepare_buckets(indices)
        return iter(self.index)

    def get_reverse_indexes(self):
        indexes = np.zeros((len(self.index),), dtype=np.int32)
        for i, j in enumerate(self.index):
            indexes[j] = i
        return indexes

    def __len__(self):
        return len(self.sort_keys)
        
    def prepare_buckets(self, indices=None):
        lens = - self.sort_keys
        assert self.bucket_size % self.batch_size == 0 or self.bucket_size == len(lens)

        if indices is None:
            if self.shuffle:
                indices = shuffle(np.arange(len(lens), dtype=np.int32))
                lens = lens[indices]
            else:
                indices = np.arange(len(lens), dtype=np.int32)

        #  bucket iterator
        def divide_chunks(l, n):
            if n == len(l):
                yield np.arange(len(l), dtype=np.int32), l
            else:
                # looping till length l
                for i in range(0, len(l), n):
                    data = l[i:i + n]
                    yield np.arange(i, i + len(data), dtype=np.int32), data
    
        new_indices = []
        extra_batch = None
        for chunk_index, chunk in divide_chunks(lens, self.bucket_size):
            # sort indices in bucket by descending order of length
            indices_sorted = chunk_index[np.argsort(chunk, axis=-1)]
            batches = []
            for _, batch in divide_chunks(indices_sorted, self.batch_size):
                if len(batch) == self.batch_size:
                    batches.append(batch.tolist())
                else:
                    assert extra_batch is None
                    assert batch is not None
                    extra_batch = batch
    
            # shuffling batches within buckets
            if self.shuffle:
                batches = shuffle(batches)
            for batch in batches:
                new_indices.extend(batch)
    
        if extra_batch is not None:
            new_indices.extend(extra_batch)
        return indices[new_indices]

class NeuralNet(nn.Module):

    def __init__(self, embedding_matrix):
        super(NeuralNet, self).__init__()

        lstm_hidden_size = 120
        gru_hidden_size = 60
        self.gru_hidden_size = gru_hidden_size

        self.embedding = nn.Embedding(*embedding_matrix.shape)
        self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
        self.embedding.weight.requires_grad = False
        self.embedding_dropout = nn.Dropout2d(0.2)

        self.lstm = nn.LSTM(embedding_matrix.shape[1], lstm_hidden_size, bidirectional=True, batch_first=True)
        self.gru = nn.GRU(lstm_hidden_size * 2, gru_hidden_size, bidirectional=True, batch_first=True)

        self.linear = nn.Linear(gru_hidden_size * 6, 20)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.out = nn.Linear(20, 1)
        
    def apply_spatial_dropout(self, h_embedding):
        h_embedding = h_embedding.transpose(1, 2).unsqueeze(2)
        h_embedding = self.embedding_dropout(h_embedding).squeeze(2).transpose(1, 2)
        return h_embedding

    def forward(self, x):
        h_embedding = self.embedding(x)
        h_embedding = self.apply_spatial_dropout(h_embedding)

        h_lstm, _ = self.lstm(h_embedding)
        h_gru, hh_gru = self.gru(h_lstm)

        hh_gru = hh_gru.view(-1, self.gru_hidden_size * 2)

        avg_pool = torch.mean(h_gru, 1)
        max_pool, _ = torch.max(h_gru, 1)

        conc = torch.cat((hh_gru, avg_pool, max_pool), 1)
        conc = self.relu(self.linear(conc))
        conc = self.dropout(conc)
        out = self.out(conc)

        return out

class EMA:

    def __init__(self, model, mu, level='batch', n=1):
        # self.ema_model = copy.deepcopy(model)
        self.mu = mu
        self.level = level
        self.n = n
        self.cnt = self.n
        self.shadow = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data

    def _update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                new_average = (1 - self.mu) * param.data + self.mu * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def set_weights(self, ema_model):
        for name, param in ema_model.named_parameters():
            if param.requires_grad:
                param.data = self.shadow[name]

    def on_batch_end(self, model):
        if self.level is 'batch':
            self.cnt -= 1
            if self.cnt == 0:
                self._update(model)
                self.cnt = self.n
                
    def on_epoch_end(self, model):
        if self.level is 'epoch':
            self._update(model)

class ParamScheduler:
    
    def __init__(self, optimizer, scale_fn, step_size):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        
        self.optimizer = optimizer
        self.scale_fn = scale_fn
        self.step_size = step_size
        self.last_batch_iteration = 0
        
    def batch_step(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.scale_fn(self.last_batch_iteration / self.step_size)
        
        self.last_batch_iteration += 1


def combine_scale_functions(scale_fns, phases=None):
    if phases is None:
        phases = [1. / len(scale_fns)] * len(scale_fns)
    phases = [phase / sum(phases) for phase in phases]
    phases = torch.tensor([0] + phases)
    phases = torch.cumsum(phases, 0)
    
    def _inner(x):
        idx = (x >= phases).nonzero().max()
        actual_x = (x - phases[idx]) / (phases[idx + 1] - phases[idx])
        return scale_fns[idx](actual_x)
        
    return _inner


def scale_cos(start, end, x):
    return start + (1 + np.cos(np.pi * (1 - x))) * (end - start) / 2

class JigsawEvaluator:
    
    def __init__(self, y_binary, y_identity_binary, power=-5, overall_model_weight=0.25):
        self.y = y_binary
        self.y_i = y_identity_binary
        self.n_subgroups = self.y_i.shape[1]
        self.power = power
        self.overall_model_weight = overall_model_weight
        
    @staticmethod
    def _compute_auc(y_true, y_pred):
        try:
            return roc_auc_score(y_true, y_pred)
        except ValueError:
            return np.nan
        
    def _compute_subgroup_auc(self, i, y_pred):
        mask = self.y_i[:, i] == 1
        return self._compute_auc(self.y[mask], y_pred[mask])
        
    def _compute_bpsn_auc(self, i, y_pred):
        mask = self.y_i[:, i] + self.y == 1
        return self._compute_auc(self.y[mask], y_pred[mask])
        
    def _compute_bnsp_auc(self, i, y_pred):
        mask = self.y_i[:, i] + self.y != 1
        return self._compute_auc(self.y[mask], y_pred[mask])
        
    def compute_bias_metrics_for_model(self, y_pred):
        records = np.zeros((3, self.n_subgroups))
        for i in range(self.n_subgroups):
            records[0, i] = self._compute_subgroup_auc(i, y_pred)
            records[1, i] = self._compute_bpsn_auc(i, y_pred)
            records[2, i] = self._compute_bnsp_auc(i, y_pred)
        return records
        
    def _calculate_overall_auc(self, y_pred):
        return roc_auc_score(self.y, y_pred)
        
    def _power_mean(self, array):
        total = sum(np.power(array, self.power))
        return np.power(total / len(array), 1 / self.power)
        
    def get_final_metric(self, y_pred):
        bias_metrics = self.compute_bias_metrics_for_model(y_pred)
        bias_score = np.average([
            self._power_mean(bias_metrics[0]),
            self._power_mean(bias_metrics[1]),
            self._power_mean(bias_metrics[2])
        ])
        overall_score = self.overall_model_weight * self._calculate_overall_auc(y_pred)
        bias_score = (1 - self.overall_model_weight) * bias_score
        return overall_score + bias_score


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def eval_model(model, data_loader):
    model.eval()
    preds_fold = np.zeros(len(data_loader.dataset))

    with torch.no_grad():
        for index, x_batch, y_batch in data_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            y_pred = model(x_batch).detach()
            preds_fold[list(index)] = sigmoid(y_pred.cpu().numpy())[:, 0]

    return preds_fold

warnings.filterwarnings('ignore')
seed_torch(seed)

with timer('load data'):
    train_x, train_y, train_y_identity, test_x, nan_pred = load_and_prec()
    train_nan_mask = train_x == '_##_'
    test_nan_mask = test_x == '_##_'
    y_binary = (train_y[:, 0] >= 0.5).astype(int)
    y_identity_binary = (train_y_identity >= 0.5).astype(int)
    vocab = build_vocab(chain(train_x, test_x), max_features)
    embedding_matrix = load_embedding(EMBEDDING_FASTTEXT, vocab['token2id'])

    train_x = np.array(tokenize(train_x, vocab))
    test_x = np.array(tokenize(test_x, vocab))

with timer('pseudo label'):
    train_preds = np.zeros((len(train_x)))
    test_preds = np.zeros((len(test_x)))

    ema_train_preds = np.zeros((len(train_x)))
    ema_test_preds = np.zeros((len(test_x)))

    train_dataset = TextDataset(train_x, targets=train_y, maxlen=max_len)
    test_dataset = TextDataset(test_x, maxlen=max_len)

    train_sampler = BucketSampler(train_dataset, train_dataset.get_keys(),
                                  bucket_size=batch_size * 20, batch_size=batch_size)
    test_sampler = BucketSampler(test_dataset, test_dataset.get_keys(),
                                 batch_size=batch_size, shuffle_data=False)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                              sampler=train_sampler, num_workers=0, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler,
                             shuffle=False, num_workers=0, collate_fn=collate_fn)

    models = {}
    model = NeuralNet(embedding_matrix).to(device)

    ema_model = copy.deepcopy(model)
    ema_model.eval()

    ema_n = int(len(train_loader.dataset) / (updates_per_epoch * batch_size))
    ema = EMA(model, mu, n=ema_n)

    scale_fn = combine_scale_functions(
        [partial(scale_cos, 1e-4, 5e-3), partial(scale_cos, 5e-3, 1e-3)], [0.2, 0.8])

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    scheduler = ParamScheduler(optimizer, scale_fn, train_epochs * len(train_loader))

    all_test_preds = []

    for epoch in range(train_epochs):
        start_time = time.time()
        model.train()

        for _, x_batch, y_batch in train_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            scheduler.batch_step()
            y_pred = model(x_batch)

            loss = nn.BCEWithLogitsLoss(weight=y_batch[:, 1])(y_pred[:, 0], y_batch[:, 0])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ema.on_batch_end(model)

        elapsed_time = time.time() - start_time
        print('Epoch {}/{} \t time={:.2f}s'.format(
            epoch + 1, train_epochs, elapsed_time))

        test_preds = eval_model(model, test_loader)
        all_test_preds.append(test_preds)

        ema.on_epoch_end(model)

    ema.set_weights(ema_model)
    ema_model.lstm.flatten_parameters()
    ema_model.gru.flatten_parameters()

    checkpoint_weights = np.array([2 ** epoch for epoch in range(train_epochs)])
    checkpoint_weights = checkpoint_weights / checkpoint_weights.sum()

    ema_test_y = eval_model(ema_model, test_loader)
    test_y = np.average(all_test_preds, weights=checkpoint_weights, axis=0)
    test_y = np.mean([test_y, ema_test_y], axis=0)
    test_y[test_nan_mask] = nan_pred
    weight = np.ones((len(test_y)))
    test_y = np.vstack((test_y, weight)).T

    models['model'] = model.state_dict()
    models['ema_model'] = ema_model.state_dict()

with timer('train'):
    splits = list(
        StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed).split(train_x, y_binary))
    splits_test = list(KFold(n_splits=n_splits, shuffle=True, random_state=seed).split(test_x))

    for fold, ((train_idx, valid_idx), (train_idx_test, _)) in enumerate(zip(splits, splits_test)):
        print(f'Fold {fold + 1}')

        x_train_fold = np.concatenate((train_x[train_idx], test_x[train_idx_test]), axis=0)
        y_train_fold = np.concatenate((train_y[train_idx], test_y[train_idx_test]), axis=0)

        x_valid_fold = train_x[valid_idx]
        y_valid_fold = train_y[valid_idx]

        valid_nan_mask = train_nan_mask[valid_idx]

        y_valid_fold_binary = y_binary[valid_idx]
        y_valid_fold_identity_binary = y_identity_binary[valid_idx]
        evaluator = JigsawEvaluator(y_valid_fold_binary, y_valid_fold_identity_binary)

        train_dataset = TextDataset(x_train_fold, targets=y_train_fold, maxlen=max_len)
        valid_dataset = TextDataset(x_valid_fold, targets=y_valid_fold, maxlen=max_len)

        train_sampler = BucketSampler(train_dataset, train_dataset.get_keys(),
                                      bucket_size=batch_size * 20, batch_size=batch_size)
        valid_sampler = BucketSampler(valid_dataset, valid_dataset.get_keys(),
                                      batch_size=batch_size, shuffle_data=False)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
                                  sampler=train_sampler, num_workers=0, collate_fn=collate_fn)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
                                  sampler=valid_sampler, collate_fn=collate_fn)

        model = NeuralNet(embedding_matrix).to(device)

        ema_model = copy.deepcopy(model)
        ema_model.eval()

        ema_n = int(len(train_loader.dataset) / (updates_per_epoch * batch_size))
        ema = EMA(model, mu, n=ema_n)

        scale_fn = combine_scale_functions(
            [partial(scale_cos, 1e-4, 5e-3), partial(scale_cos, 5e-3, 1e-3)], [0.2, 0.8])

        optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
        scheduler = ParamScheduler(optimizer, scale_fn, train_epochs * len(train_loader))

        all_valid_preds = []
        all_test_preds = []

        for epoch in range(train_epochs):
            start_time = time.time()
            model.train()

            for _, x_batch, y_batch in train_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)

                scheduler.batch_step()
                y_pred = model(x_batch)

                loss = nn.BCEWithLogitsLoss(weight=y_batch[:, 1])(y_pred[:, 0], y_batch[:, 0])
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ema.on_batch_end(model)

            valid_preds = eval_model(model, valid_loader)
            valid_preds[valid_nan_mask] = nan_pred
            all_valid_preds.append(valid_preds)

            auc_score = evaluator.get_final_metric(valid_preds)
            elapsed_time = time.time() - start_time
            print('Epoch {}/{} \t auc={:.5f} \t time={:.2f}s'.format(
                epoch + 1, train_epochs, auc_score, elapsed_time))

            test_preds = eval_model(model, test_loader)
            all_test_preds.append(test_preds)

            models[f'model_{fold}{epoch}'] = model.state_dict()

            ema.on_epoch_end(model)

        ema.set_weights(ema_model)
        ema_model.lstm.flatten_parameters()
        ema_model.gru.flatten_parameters()

        models[f'ema_model_{fold}'] = ema_model.state_dict()

        checkpoint_weights = np.array([2 ** epoch for epoch in range(train_epochs)])
        checkpoint_weights = checkpoint_weights / checkpoint_weights.sum()

        valid_preds_fold = np.average(all_valid_preds, weights=checkpoint_weights, axis=0)
        valid_preds_fold[valid_nan_mask] = nan_pred
        auc_score = evaluator.get_final_metric(valid_preds)
        print(f'cv model \t auc={auc_score:.5f}')

        ema_valid_preds_fold = eval_model(ema_model, valid_loader)
        ema_valid_preds_fold[valid_nan_mask] = nan_pred
        auc_score = evaluator.get_final_metric(ema_valid_preds_fold)
        print(f'EMA model \t auc={auc_score:.5f}')

        train_preds[valid_idx] = valid_preds_fold
        ema_train_preds[valid_idx] = ema_valid_preds_fold

        test_preds_fold = np.average(all_test_preds, weights=checkpoint_weights, axis=0)
        ema_test_preds_fold = eval_model(ema_model, test_loader)

        test_preds += test_preds_fold / n_splits
        ema_test_preds += ema_test_preds_fold / n_splits

torch.save(models, 'model.pt')
test_preds[test_nan_mask] = nan_pred
ema_test_preds[test_nan_mask] = nan_pred
evaluator = JigsawEvaluator(y_binary, y_identity_binary)
auc_score = evaluator.get_final_metric(train_preds)
ema_auc_score = evaluator.get_final_metric(ema_train_preds)
print(f'cv score: {auc_score:<8.5f}')
print(f'EMA cv score: {ema_auc_score:<8.5f}')

train_preds = np.mean([train_preds, ema_train_preds], axis=0)
test_preds = np.mean([test_preds, ema_test_preds], axis=0)
auc_score = evaluator.get_final_metric(train_preds)
print(f'final prediction score: {auc_score:<8.5f}')

submission = pd.read_csv(SAMPLE_SUBMISSION, index_col='id')
submission['prediction'] = test_preds * 0.9 + test_y[:, 0] * 0.1
submission.reset_index(drop=False, inplace=True)
submission.to_csv('submission.csv', index=False)
submission.head()

posted @ 2021-08-17 20:26  Simbanana  阅读(80)  评论(0)    收藏  举报