word2vec使用skip-gram实现

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm 
sentences = ["jack like dog", "jack like cat", "jack like animal",
             "dog cat animal", "banana apple cat dog like", "dog fish milk like",
             "dog cat animal like", "jack like apple", "apple like", "jack like banana",
             "apple banana jack movie book music like", "cat dog hate", "cat dog like"]

sentences_list = " ".join([i for i in sentences]).split(" ")
# 词到下标的映射
vocab = list(set(sentences_list))
word2idx = {j: i for i, j in enumerate(vocab)}
idx2word = {i: j for i, j in enumerate(vocab)}
vocab_size = len(vocab)
window_size = 2
embedding_size = 2

def make_data(seq_data):
    context_arr = []
    center = []
    context = []
    skip_gram = []
    seq_data = " ".join([i for i in seq_data]).split()
    for sen in seq_data:
        for step in range(window_size, len(sen) - window_size):
            # 中心词
            center = step
            # 上下文
            context_arr = list(range(step - window_size, step)) + list(range(step + 1, step + window_size))
            for context_i in context_arr:
                skip_gram.append([np.eye(vocab_size)[word2idx[seq_data[center]]], context_i])
    input_data = []
    target_data = []
    for a, b in skip_gram:
        input_data.append(a)
        target_data.append(b)
    return torch.FloatTensor(input_data), torch.LongTensor(target_data)

 class my_dataset(Dataset):
    def __init__(self, input_data, target_data):
        super(my_dataset, self).__init__()
        self.input_data = input_data
        self.target_data = target_data
 
    def __getitem__(self, index):
        return self.input_data[index], self.target_data[index]
 
    def __len__(self):
        return self.input_data.size(0)  # 返回张量的第一个维度
    
# 输入单词,输出上下文
class SkipGram(nn.Module):
    def __init__(self, embedding_size):
        super(SkipGram, self).__init__()
        self.embedding_size = embedding_size
        self.fc1 = torch.nn.Linear(vocab_size, self.embedding_size)
        self.fc2 = torch.nn.Linear(self.embedding_size, vocab_size)
        self.loss = nn.CrossEntropyLoss()
 
    def forward(self, center, context):
        """
        :param center: [Batch_size]
        :param context:[Batch_size, vocab_size]
        :return:
        """
        center = self.fc1(center)
        center = self.fc2(center)
        loss = self.loss(center, context)
        return loss
    
batch_size = 2
center_data, context_data = make_data(sentences)
train_data = my_dataset(center_data, context_data)
train_loader = DataLoader(train_data, batch_size, shuffle=True)
epochs = 5
model = SkipGram(embedding_size=embedding_size)
model.train()

optim = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total = len(train_loader))
    for index, (center, context) in loop:
        loss = model(center, context)
        loop.set_description(f'Epoch [{epoch}/{epochs}]')
        loop.set_postfix(loss = loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()

  

posted @ 2022-08-24 17:27  麦扣  阅读(37)  评论(0编辑  收藏  举报