1 import numpy as np
2
3
4 class matrix_factorization():
5 def __init__(self,data:np.ndarray,numOfFeatures=2) -> None:
6 self.data = data
7 self.numOfUser = data.shape[0]
8 self.numOfItem = data.shape[1]
9 self.numOfFeatures = numOfFeatures
10 #self.numOfFeatures隐向量的大小
11 self.userFeatures = np.ones((self.numOfUser,self.numOfFeatures))
12 self.itemFeatures = np.ones((self.numOfFeatures,self.numOfItem))
13
14 self.lr = 0.1
15
16 def prediction(self) -> np.ndarray:
17 return np.dot(self.userFeatures,self.itemFeatures)
18
19 def MSE(self,i:int,j:int) -> float:
20 user_row = self.userFeatures[i,:] #遍历每一行
21 item_col = self.itemFeatures[:,j] #遍历每一列
22 return self.data[i][j] - np.sum(user_row * item_col) #(target-prediction)² ?为什么这里没平方,只是算差,
23
24 def loss(self) -> float:
25 return sum([self.MSE(i,j) for i in range(self.numOfUser) for j in range(self.numOfItem)])
26
27 def updateUserFeature(self,i:int,k:int,j:int):
28 self.userFeatures[i][k] += 2 * self.lr * self.MSE(i,j) * self.itemFeatures[k,j]
29
30
31 def updateItemFeature(self,i:int,k:int,j:int):
32 self.itemFeatures[k][j] += 2 * self.lr * self.MSE(i,j) * self.userFeatures[i,k]
33
34 def updateUser(self,i:int,j:int):
35 for k in range(self.numOfFeatures):
36 self.updateUserFeature(i,k,j)
37
38 def updateItem(self,i:int,j:int):
39 for k in range(self.numOfFeatures):
40 self.updateItemFeature(i,k,j)
41
42 def update(self):
43 for i in range(self.numOfUser):
44 for j in range(self.numOfItem):
45 self.updateUser(i,j)
46 self.updateItem(i,j)
47
48 def train(self,lr=0.1, iteration=1000):
49 self.lr= lr
50 for i in range(iteration):
51 if(i % 100 == 0):
52 print(f"Echo {i},\tMSE={self.loss()}")
53 self.update()
54
55 if __name__ == '__main__':
56 data = np.array([[2,1],[4,2]])
57 mode1 = matrix_factorization(data, 1)
58 mode1.train(iteration=1000)
59 print(mode1.userFeatures,mode1.itemFeatures,sep='\n')
60 print(mode1.prediction())
61 print(mode1.data)
62
63 '''
64 输出
65 Echo 0, MSE=5.0
66 Echo 100, MSE=0.0
67 Echo 200, MSE=0.0
68 Echo 300, MSE=0.0
69 Echo 400, MSE=0.0
70 Echo 500, MSE=0.0
71 Echo 600, MSE=0.0
72 Echo 700, MSE=0.0
73 Echo 800, MSE=0.0
74 Echo 900, MSE=0.0
75 [[1.02102854]
76 [2.04205708]]
77
78 [[1.9588091 0.97940455]]
79
80 [[2. 1.]
81 [4. 2.]]
82
83 [[2 1]
84 [4 2]]
85 '''