Music genre classifier¶

Neural networks as classifiers¶

This notebook should be see as the third step in a series of notebooks aimed to build an ML audio classifier.

We continue our journey of music classification by trainging more complex models, such as CNNs and RNNs. After this, we will see how our results compare against a pretrained model. You can find the other notebooks for this experiment here:

  • data exploration
  • preprocessing
  • traditional classifiers
  • simple nn classifiers

Goal¶

Train simple neural net classifiers to predict the genre of a song.

Dataset¶

The dataset contains 1000 audio tracks each 30 seconds long. It contains 10 genres, each represented by 100 tracks. The tracks were all 22050Hz Mono 16-bit audio files in .wav format. In preprocess.py, we convert the .wav fiels to MFCC features, and store them as PyTorch tensors (mfcc.pt). Labels and file paths are stored as numpy-arrays.

Source¶

https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification/ (accessed 2023-10-20)

In [1]:
import os 
os.chdir("/Users/per.morten.halvorsen@schibsted.com/personal/music-genre-classifiers/")
In [2]:
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score
from skorch import NeuralNetClassifier
from tqdm import tqdm

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

pio.renderers.default = "notebook"
torch.manual_seed(42)
Out[2]:
<torch._C.Generator at 0x1568ad950>

Load data¶

In [3]:
mfcc_tensor = torch.load("data/mfcc.pt")
covariance_tensor =  torch.load("data/covariance.pt")
file_paths = np.load("data/file_paths.npy")
labels = np.load("data/labels.npy")
In [4]:
mfcc_tensor.shape
Out[4]:
torch.Size([999, 2986, 13])
In [5]:
labels.shape
Out[5]:
(999,)
In [6]:
# for plotting
file_paths.shape
Out[6]:
(999,)
In [7]:
labels_to_idx = {label: idx for idx, label in enumerate(np.unique(labels))}
idx_to_labels = {idx: label for idx, label in enumerate(np.unique(labels))}
labels_to_idx
Out[7]:
{'blues': 0,
 'classical': 1,
 'country': 2,
 'disco': 3,
 'hiphop': 4,
 'jazz': 5,
 'metal': 6,
 'pop': 7,
 'reggae': 8,
 'rock': 9}

Build simple NN classifiers¶

We limit the scope of this notebook to the most basic NN architectures, namely MLPs, CNNs and RNNs.

MLP¶

A multi-layer perceptron (MLP) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. Multilayer perceptrons can be considered the most "vanilla" neural networks, as they resemble the simplest possible neural network architecture: a linear classifier.

Architecture¶

The MLP consists of an input layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The hidden layer can be of any size, but is usually smaller than the input and output layers. The hidden layer is where the "magic" happens, as it is where the network learns to classify the data.

Training¶

The MLP is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.

Activation functions¶

The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the MLP a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.

In [8]:
features = 13
measurements = 2986
input_size = measurements * features  # Number of MFCC coefficients
hidden_size = 128  # Number of neurons in the hidden layer
num_classes = 10  # Number of music genres
criterion = nn.CrossEntropyLoss()

# Hyperparameters
num_epochs = 10
batch_size = 100
learning_rate = 0.001
rnn_layers = 2
In [9]:
class FFN(nn.Module):
    def __init__(self, input_size=input_size, hidden_size=hidden_size, num_classes=num_classes, num_layers=2):
        super(FFN, self).__init__()
        self.fc_first = nn.Linear(input_size, hidden_size)
        self.fc_last = nn.Linear(hidden_size, num_classes)
        self.num_layers = num_layers

        if num_layers > 1:
            self.fc_hidden = nn.ModuleList()
            for i in range(num_layers - 1):
                self.fc_hidden.append(nn.Linear(hidden_size, hidden_size))
    
    def forward(self, x):
        x = F.relu(self.fc_first(x))
        
        # hidden layers
        if self.num_layers > 1:
            for i in range(self.num_layers - 1):
                x = F.relu(self.fc_hidden[i](x))

        # output layer
        x = self.fc_last(x)
        
        return x
    

CNN¶

A convolutional neural network (CNN) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. CNNs are a type of neural network that are particularly well suited for image classification, as they are able to learn spatial features. This means, they may perform better on spectrogram transformations of our data, rather than MFCC simplifications.

Architecture¶

The CNN consists of an input layer, a convolutional layer, a pooling layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The convolutional layer is where the "magic" happens, as it is where the network learns to classify the data. The pooling layer is used to reduce the dimensionality of the data, and is usually placed after the convolutional layer. The hidden layer can be of any size, but is usually smaller than the input and output layers.

Training¶

The CNN is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.

Activation functions¶

The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the CNN a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.

In [10]:
class CNN(nn.Module):
    def __init__(self, num_channels=13, num_classes=10, out_channels=32, measurements=2986, verbose=False):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(out_channels * measurements * 3, num_classes)  # Adjust the input size based on your data

        self.measurements = measurements
        
        self.verbose = verbose
    
    def forward(self, x):
        print("input", x.shape) if self.verbose else None
        x = self.conv1(x)
        print("conv1", x.shape) if self.verbose else None
        x = self.relu(x)
        print("relu", x.shape) if self.verbose else None
        x = self.maxpool(x)
        print("maxpool", x.shape) if self.verbose else None
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        print("fc", x.shape) if self.verbose else None
        print("-"*10) if self.verbose else None
        return x

RNN¶

A recurrent neural network (RNN) is a class of feedforward neural network (FNN) composed of multiple layers, each with their own activation function. RNNs are a type of neural network that are particularly well suited for time series classification, as they are able to learn temporal features. This means, they may perform better on raw audio data, rather than spectrogram transformations.

Architecture¶

The RNN consists of an input layer, a recurrent layer, a hidden layer and an output layer. The input layer is the same size as the input data, and the output layer is the same size as the number of classes. The recurrent layer is where the "magic" happens, as it is where the network learns to classify the data. The hidden layer can be of any size, but is usually smaller than the input and output layers.

Training¶

The RNN is trained using backpropagation, which is an algorithm for computing the gradient of the loss function with respect to the weights of the network. The loss function is usually the cross-entropy loss, which is a measure of the difference between the predicted and the actual class. The weights are updated using gradient descent, which is an algorithm for minimizing the loss function.

Activation functions¶

The activation function is a non-linear function that is applied to the output of each neuron in the network. The activation function is what makes the RNN a universal function approximator, as it allows the network to learn non-linear functions. The activation function is usually applied to the output of each neuron in the hidden layer, but can also be applied to the output layer. The activation function is usually the sigmoid function, but can also be the hyperbolic tangent function or the rectified linear unit (ReLU) function.

In [11]:
class RNN(nn.Module):
    def __init__(self, input_size=13, hidden_size=15, num_layers=2, num_classes=10, verbose=False):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

        self.verbose = verbose
    
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Forward pass through the GRU layer
        gru_out, _ = self.gru(x)

        # Use the final hidden state as the input for the fully connected layer
        output = self.fc(gru_out[:, -1, :])

        return output

Train test split¶

In [12]:
# Reshape the data into a 2D array (num_samples, num_features)
num_samples, num_frames, num_mfcc = mfcc_tensor.shape
mfcc_tensor_2d = np.reshape(mfcc_tensor, (num_samples, num_frames * num_mfcc))

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(mfcc_tensor_2d, labels, test_size=0.2, random_state=42)

# Get validation set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)
In [13]:
uniques, counts = np.unique(y_train, return_counts=True)
dict(zip(uniques, counts))
Out[13]:
{'blues': 73,
 'classical': 61,
 'country': 68,
 'disco': 71,
 'hiphop': 73,
 'jazz': 71,
 'metal': 80,
 'pop': 72,
 'reggae': 77,
 'rock': 73}

Train methods¶

In [14]:
X_train.reshape(*[-1, 1, measurements, features]).shape
Out[14]:
torch.Size([719, 1, 2986, 13])
In [15]:
def train_batch(batch, model, criterion, optimizer):
    # Get the batch of data
    batch_X, batch_y = batch
    # convert strings to ids
    batch_y = np.array([labels_to_idx[x] for x in batch_y])
    batch_X = batch_X.float()
    
    # Zero out the gradients
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(batch_X)
    loss = criterion(outputs, torch.tensor(batch_y))
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    return loss.item()  


def eval_batch(X_test, y_test, model, criterion):
     # Evaluate
    model.eval()
    with torch.no_grad():
        X_test = X_test.float()
        y_test = torch.tensor([labels_to_idx[x] for x in y_test])
        outputs = model(X_test)
        loss = criterion(outputs, y_test)
    
    return loss.item()


def train_epoch(X_train, y_train, model, criterion, optimizer, batch_size=100):
    # Shuffle the training data
    indices = np.arange(len(X_train))
    np.random.shuffle(indices)
    
    # Create batches
    num_batches = len(X_train) // batch_size
    batches = [(X_train[i*batch_size:(i+1)*batch_size], y_train[i*batch_size:(i+1)*batch_size]) for i in range(num_batches)]
    
    # Train each batch
    losses = []
    for batch in batches:
        loss = train_batch(batch, model, criterion, optimizer)
        losses.append(loss)
    
    return losses


def train_model(X_train, y_train, X_test, y_test, model, criterion, optimizer, reshape=None, num_epochs=10, batch_size=100, verbose=False, n=10):
    train_losses = []
    eval_losses = []
    average_loss = []

    if reshape:
        X_train = X_train.reshape(*reshape)
        X_test = X_test.reshape(*reshape)
    
    for epoch in tqdm(range(num_epochs)):
        
        # Train
        model.train()
        losses = train_epoch(X_train, y_train, model, criterion, optimizer, batch_size=batch_size)
        train_losses.extend(losses)
        average_loss.append(np.mean(losses))
        
        # Evaluate
        model.eval()
        with torch.no_grad():
            eval_loss = eval_batch(X_test, y_test, model, criterion)
            eval_losses.append(eval_loss)
        
        if verbose or epoch % n == 0:
            print(
                'Epoch: {}'.format(epoch),
                'Train loss: {:.4f}'.format(losses[-1]),
                'Test  loss: {:.4f}'.format(eval_losses[-1])
            )
        

    return train_losses, eval_losses, average_loss


X_train.numpy().shape
Out[15]:
(719, 38818)
In [16]:
def plot_losses(train_losses, val_losses, model=""):
    """Plot using Plotly Express"""
    import plotly.express as px
    import pandas as pd
    pd.options.plotting.backend = "plotly"
    
    df = pd.DataFrame({
        'epoch': np.arange(len(train_losses)),
        'train_loss': train_losses,
        'val_loss': val_losses
    })
    
    fig = px.line(df, x='epoch', y=['train_loss', 'val_loss'], title=f'Losses {model}')
    fig.show()

Training¶

Feed Forward¶

In [17]:
input_size = 38818  # fixed
num_classes = 10    # fixed

hidden_size = 128*7 # tunable
num_layers = 2      # tunable
lr=0.00001          # tunable
In [18]:
ffn_model = FFN(input_size, 128*5, num_classes)
optimizer_ffn = optim.Adam(ffn_model.parameters(), lr=0.00001)  # double check optimizer set-up

ffn_train_losses, ffn_val_losses, ffn_avg_losses = train_model(
    X_train, y_train, X_test, y_test, ffn_model, criterion, optimizer_ffn, 
    num_epochs=15, batch_size=batch_size, verbose=False
)

plot_losses(ffn_avg_losses, ffn_val_losses, "FFN")
  7%|▋         | 1/15 [00:00<00:05,  2.43it/s]
Epoch: 0 Train loss: 2.1800 Test  loss: 1.9666
 73%|███████▎  | 11/15 [00:03<00:01,  3.01it/s]
Epoch: 10 Train loss: 0.0359 Test  loss: 1.5808
100%|██████████| 15/15 [00:04<00:00,  3.01it/s]

CNN¶

In [19]:
# Initialize the CNN model
num_channels = 1  # Since each feature is treated as a channel
num_classes = 10  # Number of output classes
In [20]:
# Initialize the CNN model, loss function, and optimizer
cnn_model = CNN(num_channels, num_classes, out_channels=32, measurements=2986)
optimizer_cnn = optim.Adam(cnn_model.parameters(), lr=0.0001)

cnn_train_losses, cnn_val_losses, cnn_avg_losses = train_model(
    X_train, y_train, X_test, y_test, cnn_model, criterion, optimizer_cnn, 
    reshape=[-1, 1, measurements, features],
    num_epochs=15, batch_size=batch_size, verbose=False
)

plot_losses(cnn_avg_losses, cnn_val_losses, "CNN")
  0%|          | 0/15 [00:00<?, ?it/s]  7%|▋         | 1/15 [00:03<00:54,  3.89s/it]
Epoch: 0 Train loss: 682.3203 Test  loss: 656.0635
 73%|███████▎  | 11/15 [00:40<00:14,  3.71s/it]
Epoch: 10 Train loss: 9.3461 Test  loss: 20.5606
100%|██████████| 15/15 [00:55<00:00,  3.73s/it]

RNN¶

In [21]:
features = 13  # Number of features
hidden_size = 128  # Number of hidden units in the RNN layer
num_layers = 2  # Number of RNN layers
num_classes = 10  # Number of output classes (genres)
batch_size = 100  # Number of examples in a batch
In [22]:
# Initialize the RNN model, loss function, and optimizer
rnn_model = RNN(features, hidden_size, num_layers*2, num_classes, verbose=False)
optimizer_rnn = optim.Adam(rnn_model.parameters(), lr=.01, weight_decay=1e-7)

rnn_train_losses, rnn_val_losses, rnn_avg_losses = train_model(
    X_train, y_train, X_test, y_test, rnn_model, criterion, optimizer_rnn, 
    reshape=[-1, measurements, features],
    num_epochs=10, batch_size=64, verbose=False
)

plot_losses(rnn_avg_losses, rnn_val_losses, "RNN")
 10%|█         | 1/10 [00:17<02:34, 17.13s/it]
Epoch: 0 Train loss: 1.8639 Test  loss: 1.7750
100%|██████████| 10/10 [02:49<00:00, 16.92s/it]

Evaluate¶

In [23]:
def metrics(predictions, y_labels, verbose=False):
    """Calculate accuracy, F1 score, and confusion matrix"""
    accuracy = accuracy_score(y_labels, predictions)
    f1 = f1_score(y_labels, predictions, average='weighted', zero_division=0)
    confusion = confusion_matrix(y_labels, predictions, labels=list(labels_to_idx))
    report = classification_report(y_labels, predictions, zero_division=0)

    if verbose:
        print("Accuracy:", accuracy)
        print("F1 Score:", f1) 
        if verbose > 1:
            print("Classification Report:\n", report)
    return accuracy, f1, confusion, report
In [24]:
def plot_confusion_matrix(cm, classes=list((labels_to_idx)), name="", cmap=px.colors.sequential.Blues):
    """
    This function prints and plots the confusion matrix.
    """
    fig = px.imshow(cm, x=classes, y=classes, color_continuous_scale=cmap)
    fig.update_layout(title="Confusion matrix "+name, xaxis_title="Predicted", yaxis_title="Actual")
    fig.show()
In [25]:
def evaluate(model, X_val=X_val, y_val=y_val, name="", reshape=None, plot=True, verbose=True):
    """Evaluate the model on the validation set"""
    model.eval()
    with torch.no_grad():
        X_val = X_val.reshape(*reshape).float() if reshape else X_val.float()
        # y_val = torch.tensor([labels_to_idx[x] for x in y_val])

        outputs = model(X_val)

        # ffn_preds = np.argmax(ffn_outputs, axis=1)
        _, predictions = torch.max(outputs.data, 1)
        predictions = np.array([idx_to_labels[pred.item()] for pred in predictions])
        accuracy, f1, confusion, report = metrics(predictions, y_val, verbose=verbose)

        print("(Prediction, Label): ", list(zip(predictions, y_val))) if verbose>1 else None

        plot_confusion_matrix(confusion, name=name) if plot else None
    
    return accuracy, f1, confusion, report
In [26]:
ffn_eval = evaluate(ffn_model, name="FFN", verbose=2)
Accuracy: 0.55
F1 Score: 0.5165582516898306
Classification Report:
               precision    recall  f1-score   support

       blues       0.25      0.20      0.22         5
   classical       0.80      0.73      0.76        11
     country       0.50      0.10      0.17        10
       disco       0.00      0.00      0.00         5
      hiphop       1.00      0.57      0.73         7
        jazz       0.58      0.78      0.67         9
       metal       0.67      1.00      0.80         8
         pop       0.50      0.86      0.63         7
      reggae       0.50      0.62      0.56         8
        rock       0.36      0.40      0.38        10

    accuracy                           0.55        80
   macro avg       0.52      0.53      0.49        80
weighted avg       0.55      0.55      0.52        80

(Prediction, Label):  [('disco', 'hiphop'), ('metal', 'metal'), ('rock', 'blues'), ('jazz', 'jazz'), ('jazz', 'country'), ('classical', 'classical'), ('country', 'classical'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'jazz'), ('jazz', 'reggae'), ('metal', 'metal'), ('blues', 'rock'), ('rock', 'reggae'), ('pop', 'rock'), ('reggae', 'disco'), ('blues', 'country'), ('blues', 'blues'), ('metal', 'metal'), ('metal', 'metal'), ('reggae', 'reggae'), ('pop', 'disco'), ('rock', 'rock'), ('pop', 'pop'), ('jazz', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('metal', 'rock'), ('rock', 'classical'), ('pop', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('reggae', 'hiphop'), ('country', 'country'), ('jazz', 'classical'), ('hiphop', 'hiphop'), ('disco', 'country'), ('jazz', 'country'), ('pop', 'disco'), ('jazz', 'jazz'), ('classical', 'classical'), ('reggae', 'blues'), ('rock', 'rock'), ('reggae', 'reggae'), ('rock', 'country'), ('metal', 'metal'), ('pop', 'pop'), ('classical', 'country'), ('pop', 'disco'), ('hiphop', 'hiphop'), ('metal', 'reggae'), ('disco', 'rock'), ('reggae', 'reggae'), ('reggae', 'reggae'), ('pop', 'pop'), ('jazz', 'jazz'), ('reggae', 'country'), ('jazz', 'country'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('blues', 'country'), ('jazz', 'jazz'), ('jazz', 'jazz'), ('metal', 'metal'), ('rock', 'blues'), ('rock', 'blues'), ('reggae', 'pop'), ('reggae', 'reggae'), ('jazz', 'jazz'), ('classical', 'classical'), ('metal', 'metal'), ('pop', 'hiphop'), ('pop', 'pop'), ('hiphop', 'hiphop'), ('classical', 'jazz'), ('metal', 'rock'), ('rock', 'rock'), ('metal', 'disco')]
In [27]:
cnn_eval = evaluate(cnn_model, name="CNN", plot=True, reshape=[-1, 1, measurements, features])
Accuracy: 0.4625
F1 Score: 0.4765183394703768
In [28]:
rnn_eval = evaluate(rnn_model, X_val=X_val, name="RNN", plot=True, reshape=[-1, measurements, features])
Accuracy: 0.6
F1 Score: 0.5521900581788353

Hyperparameter tuning¶

In [29]:
def get_module_name(model):
    """Get the name of the model"""
    return model.__class__.__name__
In [30]:
def grid_search(
        classifier, params, cv=5, 
        X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, 
        reshape=None,
        return_full_metrics=False, verbose=False, plot=True, name=""
    ):
    # Create a GridSearchCV object with the specified parameter grid and classifier
    grid_search = GridSearchCV(estimator=classifier, param_grid=params, cv=cv, n_jobs=-1, scoring="f1_macro")  #scoring= f1 with average='weighted'?

    if reshape:
        X_train = X_train.reshape(*reshape)
        X_test = X_test.reshape(*reshape)

    # Perform grid search on your data
    label_idx_train = np.array([labels_to_idx[x] for x in y_train])
    grid_search.fit(X_train, label_idx_train)

    # Print the best parameters found by the grid search
    print("Best Parameters:", grid_search.best_params_)

    # Make predictions using the best estimator
    predictions = grid_search.predict(X_test)
    predictions = np.array([idx_to_labels[pred] for pred in predictions])
    print("(Prediction, Label): ", list(zip(predictions, y_test))) if verbose>1 else None

    accuracy, f1, cm, cr = metrics(predictions, y_test, verbose=verbose)
    
    # append classifier name and params for plotting
    grid_search.name = get_module_name(classifier)
    grid_search.params = params
    grid_search.f1 = f1
    grid_search.accuracy = accuracy
    grid_search.cm = cm
    grid_search.cr = cr

    if return_full_metrics:
        return accuracy, f1, cm, cr, grid_search
    
    plot_confusion_matrix(cm, name=name) if plot else None

    return grid_search


def plot_grid_seach(gs, score_col="mean_test_score", param_cols=None, verbose=0):
    """#d plot of grid search with params on x and y and score on z axis"""
    gs_df = pd.DataFrame(gs.cv_results_)

    # get score column
    if score_col is None:
        score_col = [x for x in gs_df.columns if "score" in x].pop()
        print("score_col", score_col) if verbose>1 else None

    # get param cols
    if param_cols is None:
        param_cols = [x for x in gs_df.columns if "param_" in x]
        print("param_cols", param_cols) if verbose >3 else None

    # get sizes
    x_size = len(gs_df[param_cols[0]].unique())
    y_size = len(gs_df[param_cols[1]].unique())

    # get x, y, z  # need smart way of finding df size..
    x = gs_df[param_cols[0]].values.reshape(x_size, y_size).T[0]
    y = gs_df[param_cols[1]].values.reshape(y_size, x_size)[0]
    z = gs_df[score_col].values.reshape(x_size, y_size).T

    fig = go.Figure(
        data=[go.Surface(
            x=x, y=y, z=z, 
            hovertemplate=f"{param_cols[0]}: {'%{x}'}<br>{param_cols[1]}: {'%{y}'}<br>{score_col}: {'%{z}'}<extra></extra>",
        )]
    )
    fig.update_layout(
        title=f"GridSearchCV Results for {gs.name} Classifier", 
        scene=dict(
            xaxis_title=param_cols[0], 
            yaxis_title=param_cols[1], 
            zaxis_title=score_col,
            # xaxis_type="log" if "x" in log else "linear",
            # yaxis_type="log" if "y" in log else "linear",
            # zaxis_type="log" if "z" in log else "linear",
        ),
        height=750,
    )

    if verbose > 2:
        print("x", x)
        print("y", y)
        print("z", z)

    return fig.show()

Feed Forward¶

In [31]:
# Initialize the FFN model, loss function, and optimizer
ffn_net = NeuralNetClassifier(
    FFN, 
    module__input_size=input_size,  # 38818
    module__num_classes=num_classes,  # 10
    criterion=nn.CrossEntropyLoss, 
    optimizer=optim.Adam, 
    lr=0.00001, 
    max_epochs=10, 
    batch_size=100, 
    # device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=0
)
In [32]:
ffn_gs = grid_search(
    ffn_net,
    params={
        "module__hidden_size": [128*5, 128*6, 128*7, 128*8, 128*9],
        "module__num_layers": [1, 2, 3, 4, 5],
    },
    name="FFN",
    verbose=2,
)

# so far best params: {'module__hidden_size': 640, 'module__num_layers': 3, 'optimizer': <class 'torch.optim.adam.Adam'>, 'optimizer__lr': 1e-05}
# Accuracy: 0.525
# F1 Score: 0.5222471631893988
# weird that we got all the way up to .60 earlier 
# though that one had 7 layers  
/opt/homebrew/Caskroom/miniconda/base/envs/music-genre/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py:752: UserWarning:

A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.

Best Parameters: {'module__hidden_size': 1024, 'module__num_layers': 3}
(Prediction, Label):  [('hiphop', 'reggae'), ('disco', 'hiphop'), ('disco', 'disco'), ('blues', 'blues'), ('disco', 'hiphop'), ('classical', 'classical'), ('rock', 'country'), ('classical', 'classical'), ('hiphop', 'reggae'), ('disco', 'rock'), ('jazz', 'jazz'), ('blues', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('rock', 'disco'), ('classical', 'classical'), ('metal', 'blues'), ('blues', 'country'), ('jazz', 'country'), ('classical', 'jazz'), ('pop', 'pop'), ('metal', 'metal'), ('country', 'disco'), ('rock', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('rock', 'blues'), ('blues', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('blues', 'rock'), ('metal', 'metal'), ('metal', 'metal'), ('disco', 'rock'), ('classical', 'classical'), ('blues', 'hiphop'), ('metal', 'blues'), ('jazz', 'jazz'), ('reggae', 'reggae'), ('metal', 'metal'), ('classical', 'classical'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('blues', 'country'), ('metal', 'rock'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'pop'), ('rock', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('classical', 'classical'), ('pop', 'hiphop'), ('country', 'pop'), ('disco', 'rock'), ('rock', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('jazz', 'country'), ('country', 'pop'), ('metal', 'metal'), ('blues', 'rock'), ('metal', 'metal'), ('reggae', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'hiphop'), ('classical', 'classical'), ('pop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('pop', 'pop'), ('country', 'country'), ('blues', 'blues'), ('rock', 'reggae'), ('country', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('rock', 'reggae'), ('jazz', 'country'), ('classical', 'classical'), ('jazz', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('hiphop', 'country'), ('classical', 'classical'), ('hiphop', 'reggae'), ('rock', 'disco'), ('disco', 'jazz'), ('jazz', 'classical'), ('hiphop', 'jazz'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('country', 'jazz'), ('rock', 'rock'), ('blues', 'jazz'), ('metal', 'metal'), ('pop', 'pop'), ('metal', 'disco'), ('blues', 'blues'), ('metal', 'blues'), ('classical', 'classical'), ('rock', 'blues'), ('metal', 'hiphop'), ('rock', 'disco'), ('disco', 'disco'), ('country', 'disco'), ('jazz', 'country'), ('country', 'country'), ('hiphop', 'hiphop'), ('pop', 'jazz'), ('hiphop', 'metal'), ('rock', 'disco'), ('country', 'blues'), ('rock', 'country'), ('jazz', 'reggae'), ('disco', 'jazz'), ('classical', 'classical'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('reggae', 'pop'), ('jazz', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('disco', 'disco'), ('rock', 'disco'), ('metal', 'metal'), ('classical', 'country'), ('blues', 'blues'), ('disco', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('metal', 'disco'), ('country', 'jazz'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('rock', 'disco'), ('pop', 'pop'), ('rock', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('rock', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'rock'), ('blues', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('jazz', 'blues'), ('disco', 'disco'), ('blues', 'metal'), ('reggae', 'reggae'), ('blues', 'country'), ('metal', 'rock'), ('jazz', 'jazz'), ('rock', 'rock'), ('disco', 'hiphop'), ('jazz', 'jazz'), ('rock', 'reggae'), ('reggae', 'hiphop'), ('blues', 'reggae'), ('metal', 'disco'), ('metal', 'metal'), ('classical', 'hiphop'), ('blues', 'blues'), ('country', 'disco'), ('country', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('reggae', 'country'), ('jazz', 'jazz'), ('blues', 'country'), ('blues', 'blues'), ('country', 'disco'), ('jazz', 'blues'), ('pop', 'pop'), ('pop', 'pop'), ('metal', 'disco'), ('disco', 'disco'), ('disco', 'reggae')]
Accuracy: 0.505
F1 Score: 0.4895867164687552
Classification Report:
               precision    recall  f1-score   support

       blues       0.42      0.45      0.43        22
   classical       0.87      0.93      0.90        28
     country       0.27      0.18      0.22        22
       disco       0.42      0.33      0.37        24
      hiphop       0.45      0.45      0.45        20
        jazz       0.47      0.47      0.47        19
       metal       0.48      0.83      0.61        12
         pop       0.78      0.86      0.82        21
      reggae       0.33      0.13      0.19        15
        rock       0.22      0.29      0.25        17

    accuracy                           0.51       200
   macro avg       0.47      0.49      0.47       200
weighted avg       0.49      0.51      0.49       200

In [33]:
plot_grid_seach(ffn_gs)

CNN¶

This one might be difficult to hyperparameter to tune.

In [34]:
# Initialize the FFN model, loss function, and optimizer
cnn_net = NeuralNetClassifier(
    CNN, 
    module__num_channels=1,  # Since each feature is treated as a channel
    criterion=nn.CrossEntropyLoss, 
    optimizer=optim.Adam, 
    lr=0.00001, 
    max_epochs=10, 
    batch_size=100, 
    # device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=0
)
In [35]:
cnn_gs = grid_search(
    cnn_net,
    params={
        "optimizer__lr": [0.0001, 0.00001],
        "optimizer": [optim.Adam, optim.SGD],
    },
    name="CNN",
    reshape=[-1, 1, measurements, features],
    verbose=2,
)
Best Parameters: {'optimizer': <class 'torch.optim.adam.Adam'>, 'optimizer__lr': 1e-05}
(Prediction, Label):  [('hiphop', 'reggae'), ('disco', 'hiphop'), ('disco', 'disco'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('country', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('country', 'rock'), ('hiphop', 'jazz'), ('blues', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('metal', 'hiphop'), ('pop', 'pop'), ('jazz', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('disco', 'disco'), ('classical', 'classical'), ('rock', 'blues'), ('jazz', 'country'), ('blues', 'country'), ('classical', 'jazz'), ('metal', 'pop'), ('metal', 'metal'), ('pop', 'disco'), ('reggae', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('classical', 'blues'), ('rock', 'rock'), ('classical', 'classical'), ('blues', 'rock'), ('rock', 'rock'), ('hiphop', 'metal'), ('metal', 'metal'), ('disco', 'rock'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('metal', 'blues'), ('classical', 'jazz'), ('hiphop', 'reggae'), ('metal', 'metal'), ('classical', 'classical'), ('pop', 'disco'), ('hiphop', 'hiphop'), ('disco', 'pop'), ('blues', 'country'), ('blues', 'rock'), ('classical', 'classical'), ('disco', 'jazz'), ('pop', 'pop'), ('country', 'pop'), ('jazz', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('blues', 'classical'), ('hiphop', 'hiphop'), ('country', 'pop'), ('disco', 'rock'), ('country', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('disco', 'hiphop'), ('metal', 'disco'), ('classical', 'country'), ('blues', 'pop'), ('metal', 'metal'), ('disco', 'rock'), ('metal', 'metal'), ('hiphop', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('disco', 'hiphop'), ('classical', 'classical'), ('hiphop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('metal', 'pop'), ('blues', 'country'), ('blues', 'blues'), ('metal', 'reggae'), ('jazz', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('disco', 'reggae'), ('jazz', 'country'), ('classical', 'classical'), ('jazz', 'hiphop'), ('pop', 'pop'), ('disco', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('pop', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('disco', 'disco'), ('disco', 'jazz'), ('jazz', 'classical'), ('hiphop', 'jazz'), ('reggae', 'hiphop'), ('classical', 'classical'), ('metal', 'jazz'), ('jazz', 'rock'), ('blues', 'jazz'), ('disco', 'metal'), ('pop', 'pop'), ('disco', 'disco'), ('hiphop', 'blues'), ('blues', 'blues'), ('classical', 'classical'), ('jazz', 'blues'), ('hiphop', 'hiphop'), ('country', 'disco'), ('disco', 'disco'), ('disco', 'disco'), ('country', 'country'), ('jazz', 'country'), ('hiphop', 'hiphop'), ('disco', 'jazz'), ('metal', 'metal'), ('country', 'disco'), ('blues', 'blues'), ('hiphop', 'country'), ('blues', 'reggae'), ('country', 'jazz'), ('classical', 'classical'), ('disco', 'pop'), ('metal', 'metal'), ('metal', 'hiphop'), ('pop', 'pop'), ('jazz', 'pop'), ('hiphop', 'pop'), ('pop', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('classical', 'disco'), ('disco', 'disco'), ('metal', 'metal'), ('disco', 'country'), ('blues', 'blues'), ('pop', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('blues', 'disco'), ('rock', 'jazz'), ('blues', 'blues'), ('metal', 'hiphop'), ('jazz', 'disco'), ('jazz', 'pop'), ('rock', 'rock'), ('classical', 'classical'), ('hiphop', 'reggae'), ('country', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('classical', 'rock'), ('blues', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('rock', 'blues'), ('disco', 'disco'), ('blues', 'metal'), ('blues', 'reggae'), ('rock', 'country'), ('metal', 'rock'), ('jazz', 'jazz'), ('jazz', 'rock'), ('disco', 'hiphop'), ('jazz', 'jazz'), ('rock', 'reggae'), ('reggae', 'hiphop'), ('country', 'reggae'), ('disco', 'disco'), ('metal', 'metal'), ('rock', 'hiphop'), ('blues', 'blues'), ('country', 'disco'), ('hiphop', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('pop', 'country'), ('jazz', 'jazz'), ('country', 'country'), ('blues', 'blues'), ('rock', 'disco'), ('classical', 'blues'), ('pop', 'pop'), ('disco', 'pop'), ('jazz', 'disco'), ('disco', 'disco'), ('disco', 'reggae')]
Accuracy: 0.44
F1 Score: 0.42098351495635133
Classification Report:
               precision    recall  f1-score   support

       blues       0.42      0.50      0.46        22
   classical       0.76      0.89      0.82        28
     country       0.29      0.18      0.22        22
       disco       0.39      0.50      0.44        24
      hiphop       0.35      0.40      0.37        20
        jazz       0.24      0.26      0.25        19
       metal       0.45      0.75      0.56        12
         pop       0.56      0.43      0.49        21
      reggae       0.40      0.13      0.20        15
        rock       0.27      0.18      0.21        17

    accuracy                           0.44       200
   macro avg       0.41      0.42      0.40       200
weighted avg       0.43      0.44      0.42       200

In [36]:
pd.DataFrame(cnn_gs.cv_results_).head()
Out[36]:
mean_fit_time std_fit_time mean_score_time std_score_time param_optimizer param_optimizer__lr params split0_test_score split1_test_score split2_test_score split3_test_score split4_test_score mean_test_score std_test_score rank_test_score
0 613.343941 83.638880 7.724842 1.238205 <class 'torch.optim.adam.Adam'> 0.0001 {'optimizer': <class 'torch.optim.adam.Adam'>,... 0.320108 0.318683 0.299336 0.320679 0.326986 0.317158 0.009356 2
1 589.521977 101.374310 4.592106 2.383811 <class 'torch.optim.adam.Adam'> 0.00001 {'optimizer': <class 'torch.optim.adam.Adam'>,... 0.384808 0.333293 0.304256 0.308969 0.391570 0.344579 0.037008 1
2 404.556484 198.926591 4.247250 3.705419 <class 'torch.optim.sgd.SGD'> 0.0001 {'optimizer': <class 'torch.optim.sgd.SGD'>, '... 0.206964 0.236803 0.176908 0.173035 0.172753 0.193293 0.025222 3
3 124.940959 32.440937 1.119628 0.078850 <class 'torch.optim.sgd.SGD'> 0.00001 {'optimizer': <class 'torch.optim.sgd.SGD'>, '... 0.105875 0.017722 0.020000 0.128026 0.163816 0.087088 0.058700 4

Trouble plotting optimizers as non-string objects. Just skipping since the score was quite low anyway.

RNN¶

In [37]:
rnn_net = NeuralNetClassifier(
    module=RNN,
    module__input_size=13,
    module__hidden_size=13*4,
    module__num_layers=2,
    module__num_classes=10,
    optimizer=optim.Adam,
    optimizer__lr=0.01,
    optimizer__weight_decay=1e-7,
    max_epochs=10,
    criterion=nn.CrossEntropyLoss,
)
In [38]:
rnn_gs = grid_search(
    rnn_net,
    params={
        # "module__hidden_size": [128, 128*2, 128*3],  # this was tested, but took very. had to rerun, so commented out
        "module__hidden_size": [128*2],
        "module__num_layers": [2, 4],
    },
    name="RNN",
    reshape=[-1, measurements, features],
    verbose=2,
)
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        2.5345       0.2414        2.2735  82.2726
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        2.5644       0.2000        2.1648  93.6890
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        2.4373       0.2783        2.1152  93.1534
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        2.5388       0.2957        2.2550  92.3371
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.5434       0.2957        2.0040  100.4151
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.4403       0.2500        2.0954  101.2073
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.4720       0.2696        2.0817  101.4083
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.5218       0.2783        2.2030  118.2089
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.5695       0.2870        2.0945  123.8702
  epoch    train_loss    valid_acc    valid_loss       dur
-------  ------------  -----------  ------------  --------
      1        2.5273       0.2783        2.2367  125.8745
      2        2.0470       0.3362        1.8447  81.0597
      2        1.9725       0.3478        1.7494  90.1354
      2        1.9374       0.3707        1.7632  90.7828
      2        1.9183       0.3478        1.7623  92.0393
      2        1.9541       0.3739        1.8473  96.2930
      2        1.9408       0.3739        1.8124  84.8313
      2        1.9354       0.3826        1.7385  105.2375
      2        1.8960       0.3217        1.8241  81.6591
      2        2.0058       0.3217        1.8419  114.9778
      2        1.8977       0.3565        1.8114  96.9768
      3        1.5350       0.4397        1.6256  109.3141
      3        1.4797       0.4348        1.6595  93.2377
      3        1.5328       0.3739        1.6606  73.9458
      3        1.5086       0.3652        1.6555  101.7210
      3        1.5599       0.4870        1.6202  95.6880
      3        1.4986       0.4000        1.6345  106.7779
      3        1.4662       0.4261        1.6565  101.8475
      3        1.4936       0.4000        1.6438  114.7037
      3        1.4820       0.4310        1.5537  116.2021
      3        1.5446       0.3913        1.7129  113.9270
      4        1.3178       0.4569        1.5212  92.5528
      4        1.2898       0.4696        1.5105  90.3154
      4        1.2860       0.4261        1.5334  88.4542
      4        1.2507       0.4087        1.5069  103.1782
      4        1.2203       0.4087        1.5482  115.7054
      4        1.2959       0.4261        1.5174  115.4407
      4        1.2438       0.4609        1.5431  104.0178
      4        1.2768       0.5043        1.4976  116.7335
      4        1.2592       0.4828        1.4517  102.8468
      4        1.3332       0.3913        1.6107  105.8596
      5        1.0918       0.4741        1.4477  94.6113
      5        1.0127       0.4828        1.3555  68.4263
      5        1.0493       0.4609        1.4419  68.9238
      5        1.0128       0.4609        1.4685  107.2677
      5        1.0664       0.5217        1.4471  107.7476
      5        0.9877       0.4261        1.4937  109.5112
      5        0.9937       0.4870        1.5122  109.5989
      5        1.0304       0.4609        1.4728  109.8205
      5        1.0464       0.4696        1.4510  109.7124
      5        1.1008       0.4000        1.5011  91.4517
      6        0.9084       0.4914        1.4073  78.4099
      6        0.8546       0.5043        1.3655  92.4999
      6        0.8164       0.5086        1.2987  94.9683
      6        0.8173       0.4609        1.3996  84.8938
      6        0.8835       0.4957        1.4116  97.9907
      6        0.8940       0.4174        1.4279  76.9430
      6        0.8731       0.4435        1.4275  95.1965
      6        0.8473       0.4870        1.4339  95.3319
      6        0.7842       0.4261        1.4608  95.5084
      7        0.7203       0.5086        1.3644  79.2748
      6        0.7951       0.4348        1.4595  102.1088
      7        0.6329       0.5345        1.2563  86.2738
      7        0.6250       0.4696        1.3759  79.2079
      7        0.6631       0.5130        1.3895  98.2096
      7        0.6621       0.5130        1.3904  84.4852
      7        0.6731       0.4870        1.3805  85.5710
      7        0.6827       0.5478        1.3506  104.9648
      7        0.7121       0.4261        1.4001  112.3767
      7        0.6014       0.4783        1.4349  99.7546
      8        0.5587       0.4914        1.3463  93.2165
      7        0.6065       0.5130        1.4446  97.0740
      8        0.4754       0.5862        1.2206  77.7356
      8        0.4974       0.5391        1.3242  68.2936
      8        0.4761       0.4609        1.3698  97.1594
      8        0.5294       0.4174        1.3981  83.6451
      8        0.5176       0.5391        1.3365  83.2688
      8        0.5031       0.5130        1.3587  84.7626
      8        0.4485       0.5043        1.4571  98.3014
      8        0.4418       0.4348        1.4648  107.7187
      9        0.4124       0.5172        1.3012  107.6907
      9        0.3452       0.5948        1.2003  93.0845
      8        0.5523       0.4609        1.3690  126.3014
      9        0.3673       0.5043        1.3688  102.1818
      9        0.3328       0.4609        1.3640  112.5498
     10        0.2980       0.5000        1.3242  66.6579
     10        0.2398       0.5862        1.1994  64.1645
      9        0.3749       0.5565        1.3284  113.5875
      9        0.3844       0.4522        1.3781  113.7713
      9        0.3671       0.5043        1.3616  114.6264
      9        0.4092       0.4783        1.3805  69.9123
      9        0.3151       0.5043        1.4697  93.9941
      9        0.3107       0.4870        1.4702  88.7709
     10        0.2241       0.4696        1.3761  52.8965
     10        0.2639       0.5304        1.3380  83.6427
     10        0.2765       0.4696        1.4220  43.4793
     10        0.2697       0.5391        1.3057  43.9570
     10        0.2596       0.5391        1.3461  42.8641
     10        0.2934       0.4870        1.3740  36.0700
     10        0.2132       0.5217        1.4834  35.9013
     10        0.2121       0.4696        1.5222  35.9666
  epoch    train_loss    valid_acc    valid_loss      dur
-------  ------------  -----------  ------------  -------
      1        2.6088       0.2778        2.0543  23.5194
      2        1.8811       0.3403        1.8143  22.2594
      3        1.5858       0.4167        1.6038  22.2372
      4        1.3012       0.4236        1.4980  22.5881
      5        1.0834       0.4792        1.4152  22.6296
      6        0.8792       0.4931        1.3751  22.5333
      7        0.6935       0.5069        1.2915  22.8427
      8        0.5321       0.5347        1.2698  22.3569
      9        0.3869       0.5417        1.2719  22.6662
     10        0.2824       0.5347        1.2778  22.8493
Best Parameters: {'module__hidden_size': 256, 'module__num_layers': 4}
(Prediction, Label):  [('rock', 'reggae'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('reggae', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('disco', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('metal', 'rock'), ('blues', 'jazz'), ('rock', 'rock'), ('rock', 'country'), ('disco', 'rock'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('blues', 'reggae'), ('jazz', 'blues'), ('disco', 'disco'), ('classical', 'classical'), ('metal', 'blues'), ('country', 'country'), ('country', 'country'), ('classical', 'jazz'), ('pop', 'pop'), ('metal', 'metal'), ('jazz', 'disco'), ('rock', 'country'), ('blues', 'blues'), ('classical', 'classical'), ('classical', 'blues'), ('blues', 'rock'), ('classical', 'classical'), ('rock', 'rock'), ('jazz', 'rock'), ('hiphop', 'metal'), ('metal', 'metal'), ('metal', 'rock'), ('classical', 'classical'), ('blues', 'hiphop'), ('metal', 'blues'), ('jazz', 'jazz'), ('reggae', 'reggae'), ('metal', 'metal'), ('pop', 'classical'), ('disco', 'disco'), ('hiphop', 'hiphop'), ('pop', 'pop'), ('country', 'country'), ('metal', 'rock'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'pop'), ('disco', 'blues'), ('pop', 'pop'), ('classical', 'classical'), ('blues', 'classical'), ('reggae', 'hiphop'), ('pop', 'pop'), ('pop', 'rock'), ('country', 'country'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('blues', 'blues'), ('classical', 'classical'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('country', 'country'), ('pop', 'pop'), ('metal', 'metal'), ('blues', 'rock'), ('metal', 'metal'), ('classical', 'classical'), ('jazz', 'jazz'), ('pop', 'pop'), ('pop', 'hiphop'), ('classical', 'classical'), ('hiphop', 'jazz'), ('pop', 'pop'), ('classical', 'classical'), ('pop', 'pop'), ('country', 'country'), ('blues', 'blues'), ('reggae', 'reggae'), ('country', 'rock'), ('classical', 'classical'), ('blues', 'reggae'), ('hiphop', 'reggae'), ('classical', 'blues'), ('hiphop', 'reggae'), ('country', 'country'), ('classical', 'classical'), ('metal', 'hiphop'), ('pop', 'pop'), ('pop', 'pop'), ('pop', 'country'), ('blues', 'rock'), ('country', 'country'), ('classical', 'classical'), ('reggae', 'reggae'), ('blues', 'disco'), ('disco', 'jazz'), ('country', 'classical'), ('reggae', 'jazz'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('jazz', 'jazz'), ('rock', 'rock'), ('rock', 'jazz'), ('metal', 'metal'), ('pop', 'pop'), ('hiphop', 'disco'), ('metal', 'blues'), ('metal', 'blues'), ('classical', 'classical'), ('metal', 'blues'), ('hiphop', 'hiphop'), ('rock', 'disco'), ('disco', 'disco'), ('pop', 'disco'), ('country', 'country'), ('country', 'country'), ('hiphop', 'hiphop'), ('pop', 'jazz'), ('disco', 'metal'), ('rock', 'disco'), ('blues', 'blues'), ('hiphop', 'country'), ('jazz', 'reggae'), ('blues', 'jazz'), ('classical', 'classical'), ('pop', 'pop'), ('metal', 'metal'), ('hiphop', 'hiphop'), ('hiphop', 'pop'), ('pop', 'pop'), ('hiphop', 'pop'), ('jazz', 'jazz'), ('classical', 'classical'), ('disco', 'disco'), ('classical', 'disco'), ('disco', 'disco'), ('metal', 'metal'), ('disco', 'country'), ('blues', 'blues'), ('rock', 'country'), ('hiphop', 'disco'), ('hiphop', 'hiphop'), ('disco', 'disco'), ('country', 'jazz'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('country', 'disco'), ('pop', 'pop'), ('jazz', 'rock'), ('jazz', 'classical'), ('hiphop', 'reggae'), ('blues', 'blues'), ('hiphop', 'hiphop'), ('classical', 'classical'), ('classical', 'classical'), ('rock', 'rock'), ('metal', 'blues'), ('jazz', 'jazz'), ('classical', 'classical'), ('blues', 'blues'), ('blues', 'disco'), ('metal', 'metal'), ('hiphop', 'reggae'), ('country', 'country'), ('metal', 'rock'), ('blues', 'jazz'), ('rock', 'rock'), ('disco', 'hiphop'), ('metal', 'jazz'), ('hiphop', 'reggae'), ('reggae', 'hiphop'), ('jazz', 'reggae'), ('rock', 'disco'), ('metal', 'metal'), ('rock', 'hiphop'), ('blues', 'blues'), ('disco', 'disco'), ('jazz', 'country'), ('disco', 'disco'), ('classical', 'classical'), ('country', 'jazz'), ('pop', 'country'), ('jazz', 'jazz'), ('country', 'country'), ('metal', 'blues'), ('pop', 'disco'), ('blues', 'blues'), ('pop', 'pop'), ('pop', 'pop'), ('rock', 'disco'), ('disco', 'disco'), ('hiphop', 'reggae')]
Accuracy: 0.58
F1 Score: 0.5700566998538545
Classification Report:
               precision    recall  f1-score   support

       blues       0.45      0.45      0.45        22
   classical       0.86      0.86      0.86        28
     country       0.72      0.59      0.65        22
       disco       0.61      0.46      0.52        24
      hiphop       0.50      0.65      0.57        20
        jazz       0.47      0.37      0.41        19
       metal       0.43      0.83      0.57        12
         pop       0.70      0.90      0.79        21
      reggae       0.50      0.27      0.35        15
        rock       0.33      0.29      0.31        17

    accuracy                           0.58       200
   macro avg       0.56      0.57      0.55       200
weighted avg       0.58      0.58      0.57       200

Validation¶

In [39]:
def evaluate(model, X_val=X_val, y_val=y_val, name="", reshape=None, plot=True, verbose=True):
    print("Evaluating model:", name) if verbose else None
    if reshape:
        X_val = X_val.reshape(*reshape)

    # Make predictions using the best estimator
    predictions = model.predict(X_val)
    predictions = np.array([idx_to_labels[pred] for pred in predictions])
    print("(Prediction, Label): ", list(zip(predictions, y_val))) if verbose>3 else None

    accuracy, f1, cm, cr = metrics(predictions, y_val, verbose=verbose)
    plot_confusion_matrix(cm, name=name) if plot else None

    return accuracy, f1, cm, cr
    
ffn_eval = evaluate(ffn_gs, name="FFN", verbose=2)
cnn_eval = evaluate(cnn_gs, name="CNN", reshape=[-1, 1, measurements, features], verbose=2)
rnn_eval = evaluate(rnn_gs, name="RNN", reshape=[-1, measurements, features], verbose=2)
Evaluating model: FFN
Accuracy: 0.5125
F1 Score: 0.4950345875529699
Classification Report:
               precision    recall  f1-score   support

       blues       0.38      0.60      0.46         5
   classical       0.82      0.82      0.82        11
     country       0.25      0.10      0.14        10
       disco       0.20      0.20      0.20         5
      hiphop       0.38      0.43      0.40         7
        jazz       0.55      0.67      0.60         9
       metal       0.67      0.75      0.71         8
         pop       0.56      0.71      0.63         7
      reggae       0.43      0.38      0.40         8
        rock       0.50      0.40      0.44        10

    accuracy                           0.51        80
   macro avg       0.47      0.51      0.48        80
weighted avg       0.49      0.51      0.50        80

Evaluating model: CNN
Accuracy: 0.3875
F1 Score: 0.38712636370337605
Classification Report:
               precision    recall  f1-score   support

       blues       0.22      0.40      0.29         5
   classical       0.88      0.64      0.74        11
     country       0.50      0.10      0.17        10
       disco       0.10      0.20      0.13         5
      hiphop       0.20      0.29      0.24         7
        jazz       0.33      0.56      0.42         9
       metal       0.56      0.62      0.59         8
         pop       0.57      0.57      0.57         7
      reggae       0.00      0.00      0.00         8
        rock       0.67      0.40      0.50        10

    accuracy                           0.39        80
   macro avg       0.40      0.38      0.36        80
weighted avg       0.45      0.39      0.39        80

Evaluating model: RNN
Accuracy: 0.6875
F1 Score: 0.6745319300489074
Classification Report:
               precision    recall  f1-score   support

       blues       0.60      0.60      0.60         5
   classical       0.92      1.00      0.96        11
     country       0.71      0.50      0.59        10
       disco       0.67      0.40      0.50         5
      hiphop       0.67      0.86      0.75         7
        jazz       0.78      0.78      0.78         9
       metal       0.64      0.88      0.74         8
         pop       0.60      0.86      0.71         7
      reggae       0.57      0.50      0.53         8
        rock       0.57      0.40      0.47        10

    accuracy                           0.69        80
   macro avg       0.67      0.68      0.66        80
weighted avg       0.69      0.69      0.67        80

Analysis¶

During hyper-parameter tuning, we again were able to bump performance. The model that performed best was the RNN, with an F1-score of 0.67. Our best traditional classifier was the SVM, with a similar, but slightly lower F1-score. The CNN model scored the worst here, which was interesting, as CNNs have lead the field of uadio classification for a while now, thanks to advances in image classifications with CNN based architectures. We should reiterate here, though, that these simple nets are not comparable to the state-of-the-art models, which are much larger and more complex. The CNNs that previously dominated the image and audio classification fields had many more layers with different sized convolution kernels.

The most noticable difference between these models and our traditional algorithms is the amount of time training and tuning takes. Here, we really see how the curse of dimensionality limits NNs. Even these "simple" neural nets require a lot of memory and time to train and fine-tune due to their increased parameter space. The RNN grid took almost 30 minutes to compare only 2 different configurations (with retraining), on my local machine. Since there are technically 3 models being fit in such a grid search, that equates to about 10 minutes to achieve roughly the same performance as the SVM, that took less than 20 seconds.

These models are very sensitive to hyperparameters, so there are likely even better configurations than the few we tested. Although, after the small amount of fine-tuning executed, we got an indication of the potential neural networks have when it comes to learning to classify music genres.

The confusion matrixes help see which genres are easier to predict than others, using the same code from our previous notebook. Across the board, the "classical" genre consistently acheives good results (low misclassifiucation rate), while "rock" is usually among the hardest. When considering our set-up of extracting the MFCCs from the audio files, this makes sense. Classical music is usually very structured, and has a limited set of instruments, and generally has a very clean sounding output, with very little "noise". Rock music on the other hand is usually more chaotic, and has a wider variety of instruments, and generally has a more "noisy" signal.

Conclusion¶

The results we were able to acheive here were very similar to our traditional classifiers. When considering the amount of training time and compute power required to train, tune, and host these models, one might argue that the traditional classifiers are best for this simplified task. Again, it is important to note that these simple nets are not completely representative of all neural architectures. Rather they serve as the baseline for what can be achieved with neural archiectures.

In our next note, we will see how our results compare against a pretrained, state-of-the-art model.