#!/usr/bin/env python
# coding: utf-8
# In[3]:
from scipy import fft
from scipy.io import wavfile
from matplotlib.pyplot import specgram
import matplotlib.pyplot as plt
import numpy as np
(sample_rate, X) = wavfile.read("D://genres//blues//converted//blues.00000.au.wav")
print(sample_rate, X.shape)
# In[5]:
plt.figure(figsize=(10, 4), dpi=80)
plt.xlabel("time")
plt.ylabel("frequency")
plt.grid(True, linestyle='-', color='0.75')
specgram(X, Fs=sample_rate, xextent=(0, 30))
plt.show()
# In[10]:
def plotSpec(g, n):
    (sample_rate, X) = wavfile.read(f"D://genres//{g}//converted//{g}.{n}.au.wav")
    specgram(X, Fs=sample_rate, xextent=(0, 30))
    plt.title(g+"-"+n[-1])
plt.figure(num=None, figsize=(18, 9), dpi=80, facecolor='w', edgecolor='k')
plt.subplot(6, 3, 1); plotSpec("classical", '00001')
plt.subplot(6, 3, 2); plotSpec("classical", '00002')
plt.subplot(6, 3, 3); plotSpec("classical", '00003')
plt.subplot(6, 3, 4); plotSpec("jazz", '00001')
plt.subplot(6, 3, 5); plotSpec("jazz", '00002')
plt.subplot(6, 3, 6); plotSpec("jazz", '00003')
plt.subplot(6, 3, 7); plotSpec("country", '00001')
plt.subplot(6, 3, 8); plotSpec("country", '00002')
plt.subplot(6, 3, 9); plotSpec("country", '00003')
plt.subplot(6, 3, 10); plotSpec("pop", '00001')
plt.subplot(6, 3, 11); plotSpec("pop", '00002')
plt.subplot(6, 3, 12); plotSpec("pop", '00003')
plt.subplot(6, 3, 13); plotSpec("rock", '00001')
plt.subplot(6, 3, 14); plotSpec("rock", '00002')
plt.subplot(6, 3, 15); plotSpec("rock", '00003')
plt.subplot(6, 3, 16); plotSpec("metal", '00001')
plt.subplot(6, 3, 17); plotSpec("metal", '00002')
plt.subplot(6, 3, 18); plotSpec("metal", '00003')
plt.tight_layout(pad=0.4, w_pad=0, h_pad=1.0)
plt.show()
# In[13]:
sample_rate, X = wavfile.read("D://genres//metal//converted//metal.00000.au.wav")
plt.figure(num=None, figsize=(9,6), dpi=60, facecolor="w", edgecolor='k')
plt.subplot(2, 1, 1)
plt.xlabel("time")
plt.ylabel("frequency")
specgram(X, Fs=sample_rate, xextent=(0, 30)) #30秒
plt.subplot(2,1,2)
plt.xlabel("frequency")
plt.xlim((0, 3000))  
plt.ylabel("amplitude")
plt.plot(fft(X, sample_rate))  # fft 傅里叶变化将时域 变换到频域 得到各个频率的振幅
plt.show()
# # 将每个音乐文件 使用傅里叶变换 变换之后的数据落盘保存 提取特征
# In[22]:
def create_file(g, b):
    rad = "D:/genres/"+g+"/converted/"+g+'.'+str(n).zfill(5)+'.au.wav'
    sample_rate, X = wavfile.read(rad)
    fft_features = abs(fft(X)[:1000])  # 只取前1000的数据 高于1000 人也听不到
    sad = "D:/genres/"+g+"/converted/"+g+'.'+str(n).zfill(5)+'.fft'
    np.save(sad, fft_features)
# In[23]:
genre_list = ["classical", "jazz", "country", "pop", "rock", "metal"]
for g in genre_list:
    for n in range(100):
        create_file(g, n)
# # 模型训练
# In[24]:
import pickle
# In[28]:
# 读取数据 将傅里叶变换之后的数据 处理成机器学习所需的X,y
X = []
y = []
for g in genre_list:
    for i in range(100):
        rad = "D:/genres/"+g+"/converted/"+g+'.'+str(n).zfill(5)+'.fft.npy'
        fft_features = np.load(rad)
        X.append(fft_features)
        y.append(genre_list.index(g))
X = np.array(X)
y = np.array(y)
# In[32]:
from sklearn.linear_model import LogisticRegression
from pprint import pprint
# In[30]:
model = LogisticRegression()
model.fit(X, y)
out_put = open("model.pkl", "wb")
pickle.dump(model, out_put)
out_put.close()
# In[33]:
pkl_file = open("model.pkl", "rb")
model_loaded = pickle.load(pkl_file)
pprint(model_loaded)
pkl_file.close()
# In[50]:
music_path = r"D:\genres\country\converted\country.00002.au.wav"
sample_rate, X = wavfile.read(music_path)
test_fft_features = abs(fft(X)[:1000])
predict = model_loaded.predict([test_fft_features])[0]
predict, genre_list[predict]
# In[41]: