Music genre classifier¶
Neural networks as classifiers¶
This notebook should be see as the third step in a series of notebooks aimed to build an ML audio classifier.
We continue our journey of music classification by trainging more complex models, such as CNNs and RNNs. After this, we will see how our results compare against a pretrained model. You can find the other notebooks for this experiment here:
Goal¶
Train simple neural net classifiers to predict the genre of a song.
Dataset¶
The dataset contains 1000 audio tracks each 30 seconds long. It contains 10 genres, each represented by 100 tracks. The tracks were all 22050Hz Mono 16-bit audio files in .wav format.
In preprocess.py, we convert the .wav fiels to MFCC features, and store them as PyTorch tensors (mfcc.pt
). Labels and file paths are stored as numpy-arrays.
Source¶
https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification/ (accessed 2023-10-20)
import os
os.chdir("/Users/per.morten.halvorsen@schibsted.com/personal/music-genre-classifiers/")
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
from skorch import NeuralNetClassifier
from tqdm import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
pio.renderers.default = "notebook"
torch.manual_seed(42)
<torch._C.Generator at 0x1568ad950>
Load data¶
mfcc_tensor = torch.load("data/mfcc.pt")
covariance_tensor = torch.load("data/covariance.pt")
file_paths = np.load("data/file_paths.npy")
labels = np.load("data/labels.npy")
mfcc_tensor.shape
torch.Size([999, 2986, 13])
labels.shape
(999,)
# for plotting
file_paths.shape
(999,)
labels_to_idx = {label: idx for idx, label in enumerate(np.unique(labels))}
idx_to_labels = {idx: label for idx, label in enumerate(np.unique(labels))}
labels_to_idx
{'blues': 0, 'classical': 1, 'country': 2, 'disco': 3, 'hiphop': 4, 'jazz': 5, 'metal': 6, 'pop': 7, 'reggae': 8, 'rock': 9}
Build simple NN classifiers¶
We limit the scope of this notebook to the most basic NN architectures, namely MLPs, CNNs and RNNs.
MLP¶
A multi-layer perceptron (MLP) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. Multilayer perceptrons can be considered the most "vanilla" neural networks, as they resemble the simplest possible neural network architecture: a linear classifier.
Architecture¶
The MLP consists of an input layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The hidden layer can be of any size, but is usually smaller than the input and output layers. The hidden layer is where the "magic" happens, as it is where the network learns to classify the data.
Training¶
The MLP is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.
Activation functions¶
The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the MLP a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.
features = 13
measurements = 2986
input_size = measurements * features # Number of MFCC coefficients
hidden_size = 128 # Number of neurons in the hidden layer
num_classes = 10 # Number of music genres
criterion = nn.CrossEntropyLoss()
# Hyperparameters
num_epochs = 10
batch_size = 100
learning_rate = 0.001
rnn_layers = 2
class FFN(nn.Module):
def __init__(self, input_size=input_size, hidden_size=hidden_size, num_classes=num_classes, num_layers=2):
super(FFN, self).__init__()
self.fc_first = nn.Linear(input_size, hidden_size)
self.fc_last = nn.Linear(hidden_size, num_classes)
self.num_layers = num_layers
if num_layers > 1:
self.fc_hidden = nn.ModuleList()
for i in range(num_layers - 1):
self.fc_hidden.append(nn.Linear(hidden_size, hidden_size))
def forward(self, x):
x = F.relu(self.fc_first(x))
# hidden layers
if self.num_layers > 1:
for i in range(self.num_layers - 1):
x = F.relu(self.fc_hidden[i](x))
# output layer
x = self.fc_last(x)
return x
CNN¶
A convolutional neural network (CNN) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. CNNs are a type of neural network that are particularly well suited for image classification, as they are able to learn spatial features. This means, they may perform better on spectrogram transformations of our data, rather than MFCC simplifications.
Architecture¶
The CNN consists of an input layer, a convolutional layer, a pooling layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The convolutional layer is where the "magic" happens, as it is where the network learns to classify the data. The pooling layer is used to reduce the dimensionality of the data, and is usually placed after the convolutional layer. The hidden layer can be of any size, but is usually smaller than the input and output layers.
Training¶
The CNN is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.
Activation functions¶
The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the CNN a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.
class CNN(nn.Module):
def __init__(self, num_channels=13, num_classes=10, out_channels=32, measurements=2986, verbose=False):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(out_channels * measurements * 3, num_classes) # Adjust the input size based on your data
self.measurements = measurements
self.verbose = verbose
def forward(self, x):
print("input", x.shape) if self.verbose else None
x = self.conv1(x)
print("conv1", x.shape) if self.verbose else None
x = self.relu(x)
print("relu", x.shape) if self.verbose else None
x = self.maxpool(x)
print("maxpool", x.shape) if self.verbose else None
x = x.view(x.size(0), -1)
x = self.fc(x)
print("fc", x.shape) if self.verbose else None
print("-"*10) if self.verbose else None
return x
RNN¶
A recurrent neural network (RNN) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. RNNs are a type of neural network that are particularly well suited for time series classification, as they are able to learn temporal features. This means, they may perform better on raw audio data, rather than spectrogram transformations.
Architecture¶
The RNN consists of an input layer, a recurrent layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The recurrent layer is where the "magic" happens, as it is where the network learns to classify the data. The hidden layer can be of any size, but is usually smaller than the input and output layers.
Training¶
The RNN is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.
Activation functions¶
The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the RNN a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.
class RNN(nn.Module):
def __init__(self, input_size=13, hidden_size=15, num_layers=2, num_classes=10, verbose=False):
super(RNN, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
self.verbose = verbose
self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Forward pass through the GRU layer
gru_out, _ = self.gru(x)
# Use the final hidden state as the input for the fully connected layer
output = self.fc(gru_out[:, -1, :])
return output
Train test split¶
# Reshape the data into a 2D array (num_samples, num_features)
num_samples, num_frames, num_mfcc = mfcc_tensor.shape
mfcc_tensor_2d = np.reshape(mfcc_tensor, (num_samples, num_frames * num_mfcc))
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(mfcc_tensor_2d, labels, test_size=0.2, random_state=42)
# Get validation set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
uniques, counts = np.unique(y_train, return_counts=True)
dict(zip(uniques, counts))
{'blues': 73, 'classical': 61, 'country': 68, 'disco': 71, 'hiphop': 73, 'jazz': 71, 'metal': 80, 'pop': 72, 'reggae': 77, 'rock': 73}
Train methods¶
X_train.reshape(*[-1, 1, measurements, features]).shape
torch.Size([719, 1, 2986, 13])
def train_batch(batch, model, criterion, optimizer):
# Get the batch of data
batch_X, batch_y = batch
# convert strings to ids
batch_y = np.array([labels_to_idx[x] for x in batch_y])
batch_X = batch_X.float()
# Zero out the gradients
optimizer.zero_grad()
# Forward pass
outputs = model(batch_X)
loss = criterion(outputs, torch.tensor(batch_y))
# Backward pass
loss.backward()
optimizer.step()
return loss.item()
def eval_batch(X_test, y_test, model, criterion):
# Evaluate
model.eval()
with torch.no_grad():
X_test = X_test.float()
y_test = torch.tensor([labels_to_idx[x] for x in y_test])
outputs = model(X_test)
loss = criterion(outputs, y_test)
return loss.item()
def train_epoch(X_train, y_train, model, criterion, optimizer, batch_size=100):
# Shuffle the training data
indices = np.arange(len(X_train))
np.random.shuffle(indices)
# Create batches
num_batches = len(X_train) // batch_size
batches = [(X_train[i*batch_size:(i+1)*batch_size], y_train[i*batch_size:(i+1)*batch_size]) for i in range(num_batches)]
# Train each batch
losses = []
for batch in batches:
loss = train_batch(batch, model, criterion, optimizer)
losses.append(loss)
return losses
def train_model(X_train, y_train, X_test, y_test, model, criterion, optimizer, reshape=None, num_epochs=10, batch_size=100, verbose=False, n=10):
train_losses = []
eval_losses = []
average_loss = []
if reshape:
X_train = X_train.reshape(*reshape)
X_test = X_test.reshape(*reshape)
for epoch in tqdm(range(num_epochs)):
# Train
model.train()
losses = train_epoch(X_train, y_train, model, criterion, optimizer, batch_size=batch_size)
train_losses.extend(losses)
average_loss.append(np.mean(losses))
# Evaluate
model.eval()
with torch.no_grad():
eval_loss = eval_batch(X_test, y_test, model, criterion)
eval_losses.append(eval_loss)
if verbose or epoch % n == 0:
print(
'Epoch: {}'.format(epoch),
'Train loss: {:.4f}'.format(losses[-1]),
'Test loss: {:.4f}'.format(eval_losses[-1])
)
return train_losses, eval_losses, average_loss
X_train.numpy().shape
(719, 38818)
def plot_losses(train_losses, val_losses, model=""):
"""Plot using Plotly Express"""
import plotly.express as px
import pandas as pd
pd.options.plotting.backend = "plotly"
df = pd.DataFrame({
'epoch': np.arange(len(train_losses)),
'train_loss': train_losses,
'val_loss': val_losses
})
fig = px.line(df, x='epoch', y=['train_loss', 'val_loss'], title=f'Losses {model}')
fig.show()
Training¶
Feed Forward¶
input_size = 38818 # fixed
num_classes = 10 # fixed
hidden_size = 128*7 # tunable
num_layers = 2 # tunable
lr=0.00001 # tunable
ffn_model = FFN(input_size, 128*5, num_classes)
optimizer_ffn = optim.Adam(ffn_model.parameters(), lr=0.00001) # double check optimizer set-up
ffn_train_losses, ffn_val_losses, ffn_avg_losses = train_model(
X_train, y_train, X_test, y_test, ffn_model, criterion, optimizer_ffn,
num_epochs=15, batch_size=batch_size, verbose=False
)
plot_losses(ffn_avg_losses, ffn_val_losses, "FFN")
7%|▋ | 1/15 [00:00<00:05, 2.43it/s]
Epoch: 0 Train loss: 2.1800 Test loss: 1.9666
73%|███████▎ | 11/15 [00:03<00:01, 3.01it/s]
Epoch: 10 Train loss: 0.0359 Test loss: 1.5808
100%|██████████| 15/15 [00:04<00:00, 3.01it/s]
CNN¶
# Initialize the CNN model
num_channels = 1 # Since each feature is treated as a channel
num_classes = 10 # Number of output classes
# Initialize the CNN model, loss function, and optimizer
cnn_model = CNN(num_channels, num_classes, out_channels=32, measurements=2986)
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=0.0001)
cnn_train_losses, cnn_val_losses, cnn_avg_losses = train_model(
X_train, y_train, X_test, y_test, cnn_model, criterion, optimizer_cnn,
reshape=[-1, 1, measurements, features],
num_epochs=15, batch_size=batch_size, verbose=False
)
plot_losses(cnn_avg_losses, cnn_val_losses, "CNN")
0%| | 0/15 [00:00<?, ?it/s] 7%|▋ | 1/15 [00:03<00:54, 3.89s/it]
Epoch: 0 Train loss: 682.3203 Test loss: 656.0635
73%|███████▎ | 11/15 [00:40<00:14, 3.71s/it]
Epoch: 10 Train loss: 9.3461 Test loss: 20.5606
100%|██████████| 15/15 [00:55<00:00, 3.73s/it]
RNN¶
features = 13 # Number of features
hidden_size = 128 # Number of hidden units in the RNN layer
num_layers = 2 # Number of RNN layers
num_classes = 10 # Number of output classes (genres)
batch_size = 100 # Number of examples in a batch
# Initialize the RNN model, loss function, and optimizer
rnn_model = RNN(features, hidden_size, num_layers*2, num_classes, verbose=False)
optimizer_rnn = optim.Adam(rnn_model.parameters(), lr=.01, weight_decay=1e-7)
rnn_train_losses, rnn_val_losses, rnn_avg_losses = train_model(
X_train, y_train, X_test, y_test, rnn_model, criterion, optimizer_rnn,
reshape=[-1, measurements, features],
num_epochs=10, batch_size=64, verbose=False
)
plot_losses(rnn_avg_losses, rnn_val_losses, "RNN")
10%|█ | 1/10 [00:17<02:34, 17.13s/it]
Epoch: 0 Train loss: 1.8639 Test loss: 1.7750
100%|██████████| 10/10 [02:49<00:00, 16.92s/it]
Evaluate¶
def metrics(predictions, y_labels, verbose=False):
"""Calculate accuracy, F1 score, and confusion matrix"""
accuracy = accuracy_score(y_labels, predictions)
f1 = f1_score(y_labels, predictions, average='weighted', zero_division=0)
confusion = confusion_matrix(y_labels, predictions, labels=list(labels_to_idx))
report = classification_report(y_labels, predictions, zero_division=0)
if verbose:
print("Accuracy:", accuracy)
print("F1 Score:", f1)
if verbose > 1:
print("Classification Report:\n", report)
return accuracy, f1, confusion, report
def plot_confusion_matrix(cm, classes=list((labels_to_idx)), name="", cmap=px.colors.sequential.Blues):
"""
This function prints and plots the confusion matrix.
"""
fig = px.imshow(cm, x=classes, y=classes, color_continuous_scale=cmap)
fig.update_layout(title="Confusion matrix "+name, xaxis_title="Predicted", yaxis_title="Actual")
fig.show()
def evaluate(model, X_val=X_val, y_val=y_val, name="", reshape=None, plot=True, verbose=True):
"""Evaluate the model on the validation set"""
model.eval()
with torch.no_grad():
X_val = X_val.reshape(*reshape).float() if reshape else X_val.float()
# y_val = torch.tensor([labels_to_idx[x] for x in y_val])
outputs = model(X_val)
# ffn_preds = np.argmax(ffn_outputs, axis=1)
_, predictions = torch.max(outputs.data, 1)
predictions = np.array([idx_to_labels[pred.item()] for pred in predictions])
accuracy, f1, confusion, report = metrics(predictions, y_val, verbose=verbose)
print("(Prediction, Label): ", list(zip(predictions, y_val))) if verbose>1 else None
plot_confusion_matrix(confusion, name=name) if plot else None
return accuracy, f1, confusion, report
ffn_eval = evaluate(ffn_model, name="FFN", verbose=2)
Accuracy: 0.55 F1 Score: 0.5165582516898306 Classification Report: precision recall f1-score support blues 0.25 0.20 0.22 5 classical 0.80 0.73 0.76 11 country 0.50 0.10 0.17 10 disco 0.00 0.00 0.00 5 hiphop 1.00 0.57 0.73 7 jazz 0.58 0.78 0.67 9 metal 0.67 1.00 0.80 8 pop 0.50 0.86 0.63 7 reggae 0.50 0.62 0.56 8 rock 0.36 0.40 0.38 10 accuracy 0.55 80 macro avg 0.52 0.53 0.49 80 weighted avg 0.55 0.55 0.52 80 (Prediction, Label): [('disco', 'hiphop'), ('metal', 'metal'), ('rock', 'blues'), ('jazz', 'jazz'), ('jazz', 'country'), ('classical', 'classical'), ('country', 'classical'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'jazz'), ('jazz', 'reggae'), ('metal', 'metal'), ('blues', 'rock'), ('rock', 'reggae'), ('pop', 'rock'), ('reggae', 'disco'), ('blues', 'country'), ('blues', 'blues'), ('metal', 'metal'), ('metal', 'metal'), ('reggae', 'reggae'), ('pop', 'disco'), ('rock', 'rock'), ('pop', 'pop'), ('jazz', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('metal', 'rock'), ('rock', 'classical'), ('pop', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('reggae', 'hiphop'), ('country', 'country'), ('jazz', 'classical'), ('hiphop', 'hiphop'), ('disco', 'country'), ('jazz', 'country'), ('pop', 'disco'), ('jazz', 'jazz'), ('classical', 'classical'), ('reggae', 'blues'), ('rock', 'rock'), ('reggae', 'reggae'), ('rock', 'country'), ('metal', 'metal'), ('pop', 'pop'), ('classical', 'country'), ('pop', 'disco'), ('hiphop', 'hiphop'), ('metal', 'reggae'), ('disco', 'rock'), ('reggae', 'reggae'), ('reggae', 'reggae'), ('pop', 'pop'), ('jazz', 'jazz'), ('reggae', 'country'), ('jazz', 'country'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('blues', 'country'), ('jazz', 'jazz'), ('jazz', 'jazz'), ('metal', 'metal'), ('rock', 'blues'), ('rock', 'blues'), ('reggae', 'pop'), ('reggae', 'reggae'), ('jazz', 'jazz'), ('classical', 'classical'), ('metal', 'metal'), ('pop', 'hiphop'), ('pop', 'pop'), ('hiphop', 'hiphop'), ('classical', 'jazz'), ('metal', 'rock'), ('rock', 'rock'), ('metal', 'disco')]
cnn_eval = evaluate(cnn_model, name="CNN", plot=True, reshape=[-1, 1, measurements, features])
Accuracy: 0.4625 F1 Score: 0.4765183394703768
rnn_eval = evaluate(rnn_model, X_val=X_val, name="RNN", plot=True, reshape=[-1, measurements, features])
Accuracy: 0.6 F1 Score: 0.5521900581788353
Hyperparameter tuning¶
def get_module_name(model):
"""Get the name of the model"""
return model.__class__.__name__
def grid_search(
classifier, params, cv=5,
X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test,
reshape=None,
return_full_metrics=False, verbose=False, plot=True, name=""
):
# Create a GridSearchCV object with the specified parameter grid and classifier
grid_search = GridSearchCV(estimator=classifier, param_grid=params, cv=cv, n_jobs=-1, scoring="f1_macro") #scoring= f1 with average='weighted'?
if reshape:
X_train = X_train.reshape(*reshape)
X_test = X_test.reshape(*reshape)
# Perform grid search on your data
label_idx_train = np.array([labels_to_idx[x] for x in y_train])
grid_search.fit(X_train, label_idx_train)
# Print the best parameters found by the grid search
print("Best Parameters:", grid_search.best_params_)
# Make predictions using the best estimator
predictions = grid_search.predict(X_test)
predictions = np.array([idx_to_labels[pred] for pred in predictions])
print("(Prediction, Label): ", list(zip(predictions, y_test))) if verbose>1 else None
accuracy, f1, cm, cr = metrics(predictions, y_test, verbose=verbose)
# append classifier name and params for plotting
grid_search.name = get_module_name(classifier)
grid_search.params = params
grid_search.f1 = f1
grid_search.accuracy = accuracy
grid_search.cm = cm
grid_search.cr = cr
if return_full_metrics:
return accuracy, f1, cm, cr, grid_search
plot_confusion_matrix(cm, name=name) if plot else None
return grid_search
def plot_grid_seach(gs, score_col="mean_test_score", param_cols=None, verbose=0):
"""#d plot of grid search with params on x and y and score on z axis"""
gs_df = pd.DataFrame(gs.cv_results_)
# get score column
if score_col is None:
score_col = [x for x in gs_df.columns if "score" in x].pop()
print("score_col", score_col) if verbose>1 else None
# get param cols
if param_cols is None:
param_cols = [x for x in gs_df.columns if "param_" in x]
print("param_cols", param_cols) if verbose >3 else None
# get sizes
x_size = len(gs_df[param_cols[0]].unique())
y_size = len(gs_df[param_cols[1]].unique())
# get x, y, z # need smart way of finding df size..
x = gs_df[param_cols[0]].values.reshape(x_size, y_size).T[0]
y = gs_df[param_cols[1]].values.reshape(y_size, x_size)[0]
z = gs_df[score_col].values.reshape(x_size, y_size).T
fig = go.Figure(
data=[go.Surface(
x=x, y=y, z=z,
hovertemplate=f"{param_cols[0]}: {'%{x}'}<br>{param_cols[1]}: {'%{y}'}<br>{score_col}: {'%{z}'}<extra></extra>",
)]
)
fig.update_layout(
title=f"GridSearchCV Results for {gs.name} Classifier",
scene=dict(
xaxis_title=param_cols[0],
yaxis_title=param_cols[1],
zaxis_title=score_col,
# xaxis_type="log" if "x" in log else "linear",
# yaxis_type="log" if "y" in log else "linear",
# zaxis_type="log" if "z" in log else "linear",
),
height=750,
)
if verbose > 2:
print("x", x)
print("y", y)
print("z", z)
return fig.show()
Feed Forward¶
# Initialize the FFN model, loss function, and optimizer
ffn_net = NeuralNetClassifier(
FFN,
module__input_size=input_size, # 38818
module__num_classes=num_classes, # 10
criterion=nn.CrossEntropyLoss,
optimizer=optim.Adam,
lr=0.00001,
max_epochs=10,
batch_size=100,
# device="cuda" if torch.cuda.is_available() else "cpu",
verbose=0
)
ffn_gs = grid_search(
ffn_net,
params={
"module__hidden_size": [128*5, 128*6, 128*7, 128*8, 128*9],
"module__num_layers": [1, 2, 3, 4, 5],
},
name="FFN",
verbose=2,
)
# so far best params: {'module__hidden_size': 640, 'module__num_layers': 3, 'optimizer': <class 'torch.optim.adam.Adam'>, 'optimizer__lr': 1e-05}
# Accuracy: 0.525
# F1 Score: 0.5222471631893988
# weird that we got all the way up to .60 earlier
# though that one had 7 layers
/opt/homebrew/Caskroom/miniconda/base/envs/music-genre/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py:752: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.
Best Parameters: {'module__hidden_size': 1024, 'module__num_layers': 3} (Prediction, Label): [('hiphop', 'reggae'), ('disco', 'hiphop'), ('disco', 'disco'), ('blues', 'blues'), ('disco', 'hiphop'), ('classical', 'classical'), ('rock', 'country'), ('classical', 'classical'), ('hiphop', 'reggae'), ('disco', 'rock'), ('jazz', 'jazz'), ('blues', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('rock', 'disco'), ('classical', 'classical'), ('metal', 'blues'), ('blues', 'country'), ('jazz', 'country'), ('classical', 'jazz'), ('pop', 'pop'), ('metal', 'metal'), ('country', 'disco'), ('rock', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('rock', 'blues'), ('blues', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('blues', 'rock'), ('metal', 'metal'), ('metal', 'metal'), ('disco', 'rock'), ('classical', 'classical'), ('blues', 'hiphop'), ('metal', 'blues'), ('jazz', 'jazz'), ('reggae', 'reggae'), ('metal', 'metal'), ('classical', 'classical'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('blues', 'country'), ('metal', 'rock'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'pop'), ('rock', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('classical', 'classical'), ('pop', 'hiphop'), ('country', 'pop'), ('disco', 'rock'), ('rock', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('jazz', 'country'), ('country', 'pop'), ('metal', 'metal'), ('blues', 'rock'), ('metal', 'metal'), ('reggae', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'hiphop'), ('classical', 'classical'), ('pop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('pop', 'pop'), ('country', 'country'), ('blues', 'blues'), ('rock', 'reggae'), ('country', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('rock', 'reggae'), ('jazz', 'country'), ('classical', 'classical'), ('jazz', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('hiphop', 'country'), ('classical', 'classical'), ('hiphop', 'reggae'), ('rock', 'disco'), ('disco', 'jazz'), ('jazz', 'classical'), ('hiphop', 'jazz'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('country', 'jazz'), ('rock', 'rock'), ('blues', 'jazz'), ('metal', 'metal'), ('pop', 'pop'), ('metal', 'disco'), ('blues', 'blues'), ('metal', 'blues'), ('classical', 'classical'), ('rock', 'blues'), ('metal', 'hiphop'), ('rock', 'disco'), ('disco', 'disco'), ('country', 'disco'), ('jazz', 'country'), ('country', 'country'), ('hiphop', 'hiphop'), ('pop', 'jazz'), ('hiphop', 'metal'), ('rock', 'disco'), ('country', 'blues'), ('rock', 'country'), ('jazz', 'reggae'), ('disco', 'jazz'), ('classical', 'classical'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('reggae', 'pop'), ('jazz', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('disco', 'disco'), ('rock', 'disco'), ('metal', 'metal'), ('classical', 'country'), ('blues', 'blues'), ('disco', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('metal', 'disco'), ('country', 'jazz'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('rock', 'disco'), ('pop', 'pop'), ('rock', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('rock', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'rock'), ('blues', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('jazz', 'blues'), ('disco', 'disco'), ('blues', 'metal'), ('reggae', 'reggae'), ('blues', 'country'), ('metal', 'rock'), ('jazz', 'jazz'), ('rock', 'rock'), ('disco', 'hiphop'), ('jazz', 'jazz'), ('rock', 'reggae'), ('reggae', 'hiphop'), ('blues', 'reggae'), ('metal', 'disco'), ('metal', 'metal'), ('classical', 'hiphop'), ('blues', 'blues'), ('country', 'disco'), ('country', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('reggae', 'country'), ('jazz', 'jazz'), ('blues', 'country'), ('blues', 'blues'), ('country', 'disco'), ('jazz', 'blues'), ('pop', 'pop'), ('pop', 'pop'), ('metal', 'disco'), ('disco', 'disco'), ('disco', 'reggae')] Accuracy: 0.505 F1 Score: 0.4895867164687552 Classification Report: precision recall f1-score support blues 0.42 0.45 0.43 22 classical 0.87 0.93 0.90 28 country 0.27 0.18 0.22 22 disco 0.42 0.33 0.37 24 hiphop 0.45 0.45 0.45 20 jazz 0.47 0.47 0.47 19 metal 0.48 0.83 0.61 12 pop 0.78 0.86 0.82 21 reggae 0.33 0.13 0.19 15 rock 0.22 0.29 0.25 17 accuracy 0.51 200 macro avg 0.47 0.49 0.47 200 weighted avg 0.49 0.51 0.49 200
plot_grid_seach(ffn_gs)
CNN¶
This one might be difficult to hyperparameter to tune.
# Initialize the FFN model, loss function, and optimizer
cnn_net = NeuralNetClassifier(
CNN,
module__num_channels=1, # Since each feature is treated as a channel
criterion=nn.CrossEntropyLoss,
optimizer=optim.Adam,
lr=0.00001,
max_epochs=10,
batch_size=100,
# device="cuda" if torch.cuda.is_available() else "cpu",
verbose=0
)
cnn_gs = grid_search(
cnn_net,
params={
"optimizer__lr": [0.0001, 0.00001],
"optimizer": [optim.Adam, optim.SGD],
},
name="CNN",
reshape=[-1, 1, measurements, features],
verbose=2,
)
Best Parameters: {'optimizer': <class 'torch.optim.adam.Adam'>, 'optimizer__lr': 1e-05} (Prediction, Label): [('hiphop', 'reggae'), ('disco', 'hiphop'), ('disco', 'disco'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('country', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('country', 'rock'), ('hiphop', 'jazz'), ('blues', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('metal', 'hiphop'), ('pop', 'pop'), ('jazz', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('disco', 'disco'), ('classical', 'classical'), ('rock', 'blues'), ('jazz', 'country'), ('blues', 'country'), ('classical', 'jazz'), ('metal', 'pop'), ('metal', 'metal'), ('pop', 'disco'), ('reggae', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('classical', 'blues'), ('rock', 'rock'), ('classical', 'classical'), ('blues', 'rock'), ('rock', 'rock'), ('hiphop', 'metal'), ('metal', 'metal'), ('disco', 'rock'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('metal', 'blues'), ('classical', 'jazz'), ('hiphop', 'reggae'), ('metal', 'metal'), ('classical', 'classical'), ('pop', 'disco'), ('hiphop', 'hiphop'), ('disco', 'pop'), ('blues', 'country'), ('blues', 'rock'), ('classical', 'classical'), ('disco', 'jazz'), ('pop', 'pop'), ('country', 'pop'), ('jazz', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('blues', 'classical'), ('hiphop', 'hiphop'), ('country', 'pop'), ('disco', 'rock'), ('country', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('disco', 'hiphop'), ('metal', 'disco'), ('classical', 'country'), ('blues', 'pop'), ('metal', 'metal'), ('disco', 'rock'), ('metal', 'metal'), ('hiphop', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('disco', 'hiphop'), ('classical', 'classical'), ('hiphop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('metal', 'pop'), ('blues', 'country'), ('blues', 'blues'), ('metal', 'reggae'), ('jazz', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('disco', 'reggae'), ('jazz', 'country'), ('classical', 'classical'), ('jazz', 'hiphop'), ('pop', 'pop'), ('disco', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('pop', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('disco', 'disco'), ('disco', 'jazz'), ('jazz', 'classical'), ('hiphop', 'jazz'), ('reggae', 'hiphop'), ('classical', 'classical'), ('metal', 'jazz'), ('jazz', 'rock'), ('blues', 'jazz'), ('disco', 'metal'), ('pop', 'pop'), ('disco', 'disco'), ('hiphop', 'blues'), ('blues', 'blues'), ('classical', 'classical'), ('jazz', 'blues'), ('hiphop', 'hiphop'), ('country', 'disco'), ('disco', 'disco'), ('disco', 'disco'), ('country', 'country'), ('jazz', 'country'), ('hiphop', 'hiphop'), ('disco', 'jazz'), ('metal', 'metal'), ('country', 'disco'), ('blues', 'blues'), ('hiphop', 'country'), ('blues', 'reggae'), ('country', 'jazz'), ('classical', 'classical'), ('disco', 'pop'), ('metal', 'metal'), ('metal', 'hiphop'), ('pop', 'pop'), ('jazz', 'pop'), ('hiphop', 'pop'), ('pop', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('classical', 'disco'), ('disco', 'disco'), ('metal', 'metal'), ('disco', 'country'), ('blues', 'blues'), ('pop', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('blues', 'disco'), ('rock', 'jazz'), ('blues', 'blues'), ('metal', 'hiphop'), ('jazz', 'disco'), ('jazz', 'pop'), ('rock', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('country', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('classical', 'rock'), ('blues', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('rock', 'blues'), ('disco', 'disco'), ('blues', 'metal'), ('blues', 'reggae'), ('rock', 'country'), ('metal', 'rock'), ('jazz', 'jazz'), ('jazz', 'rock'), ('disco', 'hiphop'), ('jazz', 'jazz'), ('rock', 'reggae'), ('reggae', 'hiphop'), ('country', 'reggae'), ('disco', 'disco'), ('metal', 'metal'), ('rock', 'hiphop'), ('blues', 'blues'), ('country', 'disco'), ('hiphop', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('pop', 'country'), ('jazz', 'jazz'), ('country', 'country'), ('blues', 'blues'), ('rock', 'disco'), ('classical', 'blues'), ('pop', 'pop'), ('disco', 'pop'), ('jazz', 'disco'), ('disco', 'disco'), ('disco', 'reggae')] Accuracy: 0.44 F1 Score: 0.42098351495635133 Classification Report: precision recall f1-score support blues 0.42 0.50 0.46 22 classical 0.76 0.89 0.82 28 country 0.29 0.18 0.22 22 disco 0.39 0.50 0.44 24 hiphop 0.35 0.40 0.37 20 jazz 0.24 0.26 0.25 19 metal 0.45 0.75 0.56 12 pop 0.56 0.43 0.49 21 reggae 0.40 0.13 0.20 15 rock 0.27 0.18 0.21 17 accuracy 0.44 200 macro avg 0.41 0.42 0.40 200 weighted avg 0.43 0.44 0.42 200
pd.DataFrame(cnn_gs.cv_results_).head()
mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_optimizer | param_optimizer__lr | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | mean_test_score | std_test_score | rank_test_score | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 613.343941 | 83.638880 | 7.724842 | 1.238205 | <class 'torch.optim.adam.Adam'> | 0.0001 | {'optimizer': <class 'torch.optim.adam.Adam'>,... | 0.320108 | 0.318683 | 0.299336 | 0.320679 | 0.326986 | 0.317158 | 0.009356 | 2 |
1 | 589.521977 | 101.374310 | 4.592106 | 2.383811 | <class 'torch.optim.adam.Adam'> | 0.00001 | {'optimizer': <class 'torch.optim.adam.Adam'>,... | 0.384808 | 0.333293 | 0.304256 | 0.308969 | 0.391570 | 0.344579 | 0.037008 | 1 |
2 | 404.556484 | 198.926591 | 4.247250 | 3.705419 | <class 'torch.optim.sgd.SGD'> | 0.0001 | {'optimizer': <class 'torch.optim.sgd.SGD'>, '... | 0.206964 | 0.236803 | 0.176908 | 0.173035 | 0.172753 | 0.193293 | 0.025222 | 3 |
3 | 124.940959 | 32.440937 | 1.119628 | 0.078850 | <class 'torch.optim.sgd.SGD'> | 0.00001 | {'optimizer': <class 'torch.optim.sgd.SGD'>, '... | 0.105875 | 0.017722 | 0.020000 | 0.128026 | 0.163816 | 0.087088 | 0.058700 | 4 |
Trouble plotting optimizers as non-string objects. Just skipping since the score was quite low anyway.
RNN¶
rnn_net = NeuralNetClassifier(
module=RNN,
module__input_size=13,
module__hidden_size=13*4,
module__num_layers=2,
module__num_classes=10,
optimizer=optim.Adam,
optimizer__lr=0.01,
optimizer__weight_decay=1e-7,
max_epochs=10,
criterion=nn.CrossEntropyLoss,
)
rnn_gs = grid_search(
rnn_net,
params={
# "module__hidden_size": [128, 128*2, 128*3], # this was tested, but took very. had to rerun, so commented out
"module__hidden_size": [128*2],
"module__num_layers": [2, 4],
},
name="RNN",
reshape=[-1, measurements, features],
verbose=2,
)
epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------- 1 2.5345 0.2414 2.2735 82.2726 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------- 1 2.5644 0.2000 2.1648 93.6890 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------- 1 2.4373 0.2783 2.1152 93.1534 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------- 1 2.5388 0.2957 2.2550 92.3371 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.5434 0.2957 2.0040 100.4151 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.4403 0.2500 2.0954 101.2073 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.4720 0.2696 2.0817 101.4083 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.5218 0.2783 2.2030 118.2089 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.5695 0.2870 2.0945 123.8702 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ -------- 1 2.5273 0.2783 2.2367 125.8745 2 2.0470 0.3362 1.8447 81.0597 2 1.9725 0.3478 1.7494 90.1354 2 1.9374 0.3707 1.7632 90.7828 2 1.9183 0.3478 1.7623 92.0393 2 1.9541 0.3739 1.8473 96.2930 2 1.9408 0.3739 1.8124 84.8313 2 1.9354 0.3826 1.7385 105.2375 2 1.8960 0.3217 1.8241 81.6591 2 2.0058 0.3217 1.8419 114.9778 2 1.8977 0.3565 1.8114 96.9768 3 1.5350 0.4397 1.6256 109.3141 3 1.4797 0.4348 1.6595 93.2377 3 1.5328 0.3739 1.6606 73.9458 3 1.5086 0.3652 1.6555 101.7210 3 1.5599 0.4870 1.6202 95.6880 3 1.4986 0.4000 1.6345 106.7779 3 1.4662 0.4261 1.6565 101.8475 3 1.4936 0.4000 1.6438 114.7037 3 1.4820 0.4310 1.5537 116.2021 3 1.5446 0.3913 1.7129 113.9270 4 1.3178 0.4569 1.5212 92.5528 4 1.2898 0.4696 1.5105 90.3154 4 1.2860 0.4261 1.5334 88.4542 4 1.2507 0.4087 1.5069 103.1782 4 1.2203 0.4087 1.5482 115.7054 4 1.2959 0.4261 1.5174 115.4407 4 1.2438 0.4609 1.5431 104.0178 4 1.2768 0.5043 1.4976 116.7335 4 1.2592 0.4828 1.4517 102.8468 4 1.3332 0.3913 1.6107 105.8596 5 1.0918 0.4741 1.4477 94.6113 5 1.0127 0.4828 1.3555 68.4263 5 1.0493 0.4609 1.4419 68.9238 5 1.0128 0.4609 1.4685 107.2677 5 1.0664 0.5217 1.4471 107.7476 5 0.9877 0.4261 1.4937 109.5112 5 0.9937 0.4870 1.5122 109.5989 5 1.0304 0.4609 1.4728 109.8205 5 1.0464 0.4696 1.4510 109.7124 5 1.1008 0.4000 1.5011 91.4517 6 0.9084 0.4914 1.4073 78.4099 6 0.8546 0.5043 1.3655 92.4999 6 0.8164 0.5086 1.2987 94.9683 6 0.8173 0.4609 1.3996 84.8938 6 0.8835 0.4957 1.4116 97.9907 6 0.8940 0.4174 1.4279 76.9430 6 0.8731 0.4435 1.4275 95.1965 6 0.8473 0.4870 1.4339 95.3319 6 0.7842 0.4261 1.4608 95.5084 7 0.7203 0.5086 1.3644 79.2748 6 0.7951 0.4348 1.4595 102.1088 7 0.6329 0.5345 1.2563 86.2738 7 0.6250 0.4696 1.3759 79.2079 7 0.6631 0.5130 1.3895 98.2096 7 0.6621 0.5130 1.3904 84.4852 7 0.6731 0.4870 1.3805 85.5710 7 0.6827 0.5478 1.3506 104.9648 7 0.7121 0.4261 1.4001 112.3767 7 0.6014 0.4783 1.4349 99.7546 8 0.5587 0.4914 1.3463 93.2165 7 0.6065 0.5130 1.4446 97.0740 8 0.4754 0.5862 1.2206 77.7356 8 0.4974 0.5391 1.3242 68.2936 8 0.4761 0.4609 1.3698 97.1594 8 0.5294 0.4174 1.3981 83.6451 8 0.5176 0.5391 1.3365 83.2688 8 0.5031 0.5130 1.3587 84.7626 8 0.4485 0.5043 1.4571 98.3014 8 0.4418 0.4348 1.4648 107.7187 9 0.4124 0.5172 1.3012 107.6907 9 0.3452 0.5948 1.2003 93.0845 8 0.5523 0.4609 1.3690 126.3014 9 0.3673 0.5043 1.3688 102.1818 9 0.3328 0.4609 1.3640 112.5498 10 0.2980 0.5000 1.3242 66.6579 10 0.2398 0.5862 1.1994 64.1645 9 0.3749 0.5565 1.3284 113.5875 9 0.3844 0.4522 1.3781 113.7713 9 0.3671 0.5043 1.3616 114.6264 9 0.4092 0.4783 1.3805 69.9123 9 0.3151 0.5043 1.4697 93.9941 9 0.3107 0.4870 1.4702 88.7709 10 0.2241 0.4696 1.3761 52.8965 10 0.2639 0.5304 1.3380 83.6427 10 0.2765 0.4696 1.4220 43.4793 10 0.2697 0.5391 1.3057 43.9570 10 0.2596 0.5391 1.3461 42.8641 10 0.2934 0.4870 1.3740 36.0700 10 0.2132 0.5217 1.4834 35.9013 10 0.2121 0.4696 1.5222 35.9666 epoch train_loss valid_acc valid_loss dur ------- ------------ ----------- ------------ ------- 1 2.6088 0.2778 2.0543 23.5194 2 1.8811 0.3403 1.8143 22.2594 3 1.5858 0.4167 1.6038 22.2372 4 1.3012 0.4236 1.4980 22.5881 5 1.0834 0.4792 1.4152 22.6296 6 0.8792 0.4931 1.3751 22.5333 7 0.6935 0.5069 1.2915 22.8427 8 0.5321 0.5347 1.2698 22.3569 9 0.3869 0.5417 1.2719 22.6662 10 0.2824 0.5347 1.2778 22.8493 Best Parameters: {'module__hidden_size': 256, 'module__num_layers': 4} (Prediction, Label): [('rock', 'reggae'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('reggae', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('disco', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('metal', 'rock'), ('blues', 'jazz'), ('rock', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('disco', 'disco'), ('classical', 'classical'), ('metal', 'blues'), ('country', 'country'), ('country', 'country'), ('classical', 'jazz'), ('pop', 'pop'), ('metal', 'metal'), ('jazz', 'disco'), ('rock', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('classical', 'blues'), ('blues', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('jazz', 'rock'), ('hiphop', 'metal'), ('metal', 'metal'), ('metal', 'rock'), ('classical', 'classical'), ('blues', 'hiphop'), ('metal', 'blues'), ('jazz', 'jazz'), ('reggae', 'reggae'), ('metal', 'metal'), ('pop', 'classical'), ('disco', 'disco'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('metal', 'rock'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'pop'), ('disco', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('blues', 'classical'), ('reggae', 'hiphop'), ('pop', 'pop'), ('pop', 'rock'), ('country', 'country'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('country', 'country'), ('pop', 'pop'), ('metal', 'metal'), ('blues', 'rock'), ('metal', 'metal'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'hiphop'), ('classical', 'classical'), ('hiphop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('pop', 'pop'), ('country', 'country'), ('blues', 'blues'), ('reggae', 'reggae'), ('country', 'rock'), ('classical', 'classical'), ('blues', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('hiphop', 'reggae'), ('country', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('country', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('blues', 'disco'), ('disco', 'jazz'), ('country', 'classical'), ('reggae', 'jazz'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('jazz', 'jazz'), ('rock', 'rock'), ('rock', 'jazz'), ('metal', 'metal'), ('pop', 'pop'), ('hiphop', 'disco'), ('metal', 'blues'), ('metal', 'blues'), ('classical', 'classical'), ('metal', 'blues'), ('hiphop', 'hiphop'), ('rock', 'disco'), ('disco', 'disco'), ('pop', 'disco'), ('country', 'country'), ('country', 'country'), ('hiphop', 'hiphop'), ('pop', 'jazz'), ('disco', 'metal'), ('rock', 'disco'), ('blues', 'blues'), ('hiphop', 'country'), ('jazz', 'reggae'), ('blues', 'jazz'), ('classical', 'classical'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('hiphop', 'pop'), ('pop', 'pop'), ('hiphop', 'pop'), ('jazz', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('classical', 'disco'), ('disco', 'disco'), ('metal', 'metal'), ('disco', 'country'), ('blues', 'blues'), ('rock', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('country', 'jazz'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('country', 'disco'), ('pop', 'pop'), ('jazz', 'rock'), ('jazz', 'classical'), ('hiphop', 'reggae'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'rock'), ('metal', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('blues', 'blues'), ('blues', 'disco'), ('metal', 'metal'), ('hiphop', 'reggae'), ('country', 'country'), ('metal', 'rock'), ('blues', 'jazz'), ('rock', 'rock'), ('disco', 'hiphop'), ('metal', 'jazz'), ('hiphop', 'reggae'), ('reggae', 'hiphop'), ('jazz', 'reggae'), ('rock', 'disco'), ('metal', 'metal'), ('rock', 'hiphop'), ('blues', 'blues'), ('disco', 'disco'), ('jazz', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('pop', 'country'), ('jazz', 'jazz'), ('country', 'country'), ('metal', 'blues'), ('pop', 'disco'), ('blues', 'blues'), ('pop', 'pop'), ('pop', 'pop'), ('rock', 'disco'), ('disco', 'disco'), ('hiphop', 'reggae')] Accuracy: 0.58 F1 Score: 0.5700566998538545 Classification Report: precision recall f1-score support blues 0.45 0.45 0.45 22 classical 0.86 0.86 0.86 28 country 0.72 0.59 0.65 22 disco 0.61 0.46 0.52 24 hiphop 0.50 0.65 0.57 20 jazz 0.47 0.37 0.41 19 metal 0.43 0.83 0.57 12 pop 0.70 0.90 0.79 21 reggae 0.50 0.27 0.35 15 rock 0.33 0.29 0.31 17 accuracy 0.58 200 macro avg 0.56 0.57 0.55 200 weighted avg 0.58 0.58 0.57 200
Validation¶
def evaluate(model, X_val=X_val, y_val=y_val, name="", reshape=None, plot=True, verbose=True):
print("Evaluating model:", name) if verbose else None
if reshape:
X_val = X_val.reshape(*reshape)
# Make predictions using the best estimator
predictions = model.predict(X_val)
predictions = np.array([idx_to_labels[pred] for pred in predictions])
print("(Prediction, Label): ", list(zip(predictions, y_val))) if verbose>3 else None
accuracy, f1, cm, cr = metrics(predictions, y_val, verbose=verbose)
plot_confusion_matrix(cm, name=name) if plot else None
return accuracy, f1, cm, cr
ffn_eval = evaluate(ffn_gs, name="FFN", verbose=2)
cnn_eval = evaluate(cnn_gs, name="CNN", reshape=[-1, 1, measurements, features], verbose=2)
rnn_eval = evaluate(rnn_gs, name="RNN", reshape=[-1, measurements, features], verbose=2)
Evaluating model: FFN Accuracy: 0.5125 F1 Score: 0.4950345875529699 Classification Report: precision recall f1-score support blues 0.38 0.60 0.46 5 classical 0.82 0.82 0.82 11 country 0.25 0.10 0.14 10 disco 0.20 0.20 0.20 5 hiphop 0.38 0.43 0.40 7 jazz 0.55 0.67 0.60 9 metal 0.67 0.75 0.71 8 pop 0.56 0.71 0.63 7 reggae 0.43 0.38 0.40 8 rock 0.50 0.40 0.44 10 accuracy 0.51 80 macro avg 0.47 0.51 0.48 80 weighted avg 0.49 0.51 0.50 80
Evaluating model: CNN Accuracy: 0.3875 F1 Score: 0.38712636370337605 Classification Report: precision recall f1-score support blues 0.22 0.40 0.29 5 classical 0.88 0.64 0.74 11 country 0.50 0.10 0.17 10 disco 0.10 0.20 0.13 5 hiphop 0.20 0.29 0.24 7 jazz 0.33 0.56 0.42 9 metal 0.56 0.62 0.59 8 pop 0.57 0.57 0.57 7 reggae 0.00 0.00 0.00 8 rock 0.67 0.40 0.50 10 accuracy 0.39 80 macro avg 0.40 0.38 0.36 80 weighted avg 0.45 0.39 0.39 80
Evaluating model: RNN Accuracy: 0.6875 F1 Score: 0.6745319300489074 Classification Report: precision recall f1-score support blues 0.60 0.60 0.60 5 classical 0.92 1.00 0.96 11 country 0.71 0.50 0.59 10 disco 0.67 0.40 0.50 5 hiphop 0.67 0.86 0.75 7 jazz 0.78 0.78 0.78 9 metal 0.64 0.88 0.74 8 pop 0.60 0.86 0.71 7 reggae 0.57 0.50 0.53 8 rock 0.57 0.40 0.47 10 accuracy 0.69 80 macro avg 0.67 0.68 0.66 80 weighted avg 0.69 0.69 0.67 80
Analysis¶
During hyper-parameter tuning, we again were able to bump performance. The model that performed best was the RNN, with an F1-score of 0.67. Our best traditional classifier was the SVM, with a similar, but slightly lower F1-score. The CNN model scored the worst here, which was interesting, as CNNs have lead the field of uadio classification for a while now, thanks to advances in image classifications with CNN based architectures. We should reiterate here, though, that these simple nets are not comparable to the state-of-the-art models, which are much larger and more complex. The CNNs that previously dominated the image and audio classification fields had many more layers with different sized convolution kernels.
The most noticable difference between these models and our traditional algorithms is the amount of time training and tuning takes. Here, we really see how the curse of dimensionality limits NNs. Even these "simple" neural nets require a lot of memory and time to train and fine-tune due to their increased parameter space. The RNN grid took almost 30 minutes to compare only 2 different configurations (with retraining), on my local machine. Since there are technically 3 models being fit in such a grid search, that equates to about 10 minutes to achieve roughly the same performance as the SVM, that took less than 20 seconds.
These models are very sensitive to hyperparameters, so there are likely even better configurations than the few we tested. Although, after the small amount of fine-tuning executed, we got an indication of the potential neural networks have when it comes to learning to classify music genres.
The confusion matrixes help see which genres are easier to predict than others, using the same code from our previous notebook. Across the board, the "classical" genre consistently acheives good results (low misclassifiucation rate), while "rock" is usually among the hardest. When considering our set-up of extracting the MFCCs from the audio files, this makes sense. Classical music is usually very structured, and has a limited set of instruments, and generally has a very clean sounding output, with very little "noise". Rock music on the other hand is usually more chaotic, and has a wider variety of instruments, and generally has a more "noisy" signal.
Conclusion¶
The results we were able to acheive here were very similar to our traditional classifiers. When considering the amount of training time and compute power required to train, tune, and host these models, one might argue that the traditional classifiers are best for this simplified task. Again, it is important to note that these simple nets are not completely representative of all neural architectures. Rather they serve as the baseline for what can be achieved with neural archiectures.
In our next note, we will see how our results compare against a pretrained, state-of-the-art model.