Goa or psytrance music has several subgenres and some of them are difficult to tell apart. I want to build a classifier to predict genres based on music features. First, I need an expert opinion to use for supervised classification. For that, I’ll use the absolutely fantastic http://psytranceguide.com/ (Made by @DanielLesden) that lists 20 psytrance genres and fortunately, also has example playlists that can be scraped and analyzed.
Highlights
Webscraping using rvest
Using the spotifyr package to get music features
Exploring the music features
Building a random forest classifier on bootstrap resamples using tidymodels
Estimating the importance of each feature using the vip package
Creating a network of subgenres from the confusion matrix
Misc: using tidyverse for data wrangling, rectangling, and visualization.
# Load packages
library(tidyverse)
library(rvest)
library(spotifyr)
library(tidytext)
library(tidymodels)
library(vip)
library(igraph)
library(ggraph)
# Default ggplot theme
theme_set(theme_light())
# Setup parallel procesing
doParallel::registerDoParallel()
# Authenticate the spotify API through spotifyr.
spotify_id <- read_csv(here::here("spotify_client_id.txt"))
Sys.setenv(SPOTIFY_CLIENT_ID = spotify_id$client_id)
Sys.setenv(SPOTIFY_CLIENT_SECRET = spotify_id$secret)
access_token <- get_spotify_access_token()
Create the dataset
The first task is to scrape the data that we can use for machine learning. Get data from http://psytranceguide.com/ . There are genres here with examples, and what is more, links lead to spotify playlists. The spotifyr package makes it possible to retrieve track features for each playlist that we can use for classification.
Create a dataset that contains genre data and track features.
# Scrape the homepage.
psytrance_link <- "http://psytranceguide.com/"
psytrance_page <- read_html(psytrance_link)
# Get the urls and other data from the psytrance guide
# I used the selectorGadget to find the proper css tags
goa_data <-
psytrance_page %>%
html_nodes("h2 a , .jouele-info-control-text")
# Create a table for all data (genres + playlist links)
goa_genres <-
tibble(genre = map_chr(goa_data, html_text),
playlist_url = html_attr(goa_data, "href")) %>%
extract(playlist_url,
into = "playlist_id",
regex = ".*playlist/(.*)\\?si.*")
# Get all music features for all tracks for each playlist
features <- map_dfr(goa_genres$playlist_id,
~get_playlist_audio_features(playlist_uris = .x))
# Put playlist information and features together
goa <-
left_join(goa_genres, features, by = "playlist_id") %>%
hoist(track.album.artists, "name", .remove = TRUE) %>%
select(genre:playlist_name,
track_id = track.id, track_name = track.name,
artist_names = name,
danceability:tempo) %>%
rowwise() %>%
mutate(artist_names = str_c(artist_names, collapse = ", ")) %>%
ungroup()
# Write data
write_csv(goa, here::here("data/goa_tracks.csv"))
Explore goa data
First let’s see if there is enough variability in the features. I’m using a trick to calculate confidence intervals for the means by using a one-sample t.test with 0 as reference (could be any number as we are not interested in the p value).
# Use the saved dataset so we don't have to reassemle the data.
goa <-
read_csv(here::here("data/goa_tracks.csv")) %>%
select(genre, track_name, danceability:tempo) %>%
mutate(genre = factor(genre))
# Balance the data by sampling exactly 10 tracks from each genre
set.seed(1)
goa <-
goa %>%
group_by(genre) %>%
sample_n(10) %>%
ungroup()
# Put data into long format, calculate confidence intervals
goa_long <-
goa %>%
pivot_longer(danceability:tempo,
names_to = "feature") %>%
group_nest(genre, feature) %>%
mutate(t_test = map(data,
~t.test(.x$value) %>%
tidy())) %>%
unnest(c(t_test, data))
# Distributions
goa_long %>%
ggplot() +
aes(x = value, fill = feature) +
geom_histogram(show.legend = FALSE) +
facet_wrap(~feature, scales = "free")
Let’s order genres by features. Some features seem legit on face validity, although I have some difficulty in telling apart psyctrance genres. Also, I’m not fully aware of the meaning of all features.
goa_long %>%
mutate(genre = reorder_within(genre, estimate, feature)) %>%
ggplot() +
aes(x = estimate, y = genre,
xmin = conf.low, xmax = conf.high,
color = feature) %>%
geom_pointrange(size = 0.5, show.legend = FALSE) +
scale_y_reordered() +
facet_wrap(~feature, scales = "free") +
labs(title = "Average feature values by genre (95%CI)",
subtitle = "Several features have high variability across genres that make features plausible to classify genres.",
x = NULL, y = NULL)
Classify goa subgenres using random forest
Create a random forest model for finding the defining features for the different subgenres. As we have a very little dataset, we are not splitting it into training and test set, but we use bootstraps to make the model more robust.
Bootstrap data, specify model and workflow.
# Create bootstraps
set.seed(123)
goa_boot <- bootstraps(goa)
# Specify random forest model
rf_spec <-
rand_forest(trees = 1000) %>%
set_mode("classification") %>%
set_engine("ranger", importance = "permutation")
# Create workflow
goa_wf <-
workflow() %>%
add_formula(genre ~ . -track_name) %>%
add_model(rf_spec)
Train model and evaluate. I seems like accuracy is poor, and sensitivity is also not great, that is a characteristic of multiclass classifications (too many categories to choose from). In overall, the model seems to have a 90% ROC AUC.
goa_rs <-
goa_wf %>%
fit_resamples(resamples = goa_boot,
metrics = metric_set(roc_auc, accuracy,
specificity, sensitivity),
control = control_resamples(verbose = TRUE,
save_pred = TRUE,
save_workflow = TRUE))
goa_rs %>%
collect_metrics()
## # A tibble: 4 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy multiclass 0.317 25 0.0102
## 2 roc_auc hand_till 0.900 25 0.00326
## 3 sens macro 0.372 25 0.0117
## 4 spec macro 0.964 25 0.000517
Which goa genres are the most difficult to classify?
It seems like some styles are more difficult than others to classify. As the individual ROC curve show Goa Trance, Pssybreaks, Psytechno can be easily confused with other styles. In contrast, Chillout, Hi-Tech, and Offbeat are easier to identify.
goa_rs %>%
collect_predictions() %>%
roc_curve(genre, .pred_Chillout:.pred_Tribal) %>%
ggplot() +
aes(x = 1 - specificity, y = sensitivity, color = .level) +
geom_abline(lty = 2, color = "gray80", size = 1) +
geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
facet_wrap(~.level, ncol = 5) +
scale_color_viridis_d(option = "plasma") +
scale_x_continuous(labels = scales::percent_format()) +
scale_y_continuous(labels = scales::percent_format()) +
labs(title = "Individual ROC curves for each goa genre.",
subtitle = "Some genres are easier to classify than others.")
Importance of the features
We can evaluate the importance of the features. It seems like tempo is by far the most important feature, followed by danceability and energy. Mode is not especially helpful (I’m not even sure what it means).
rf_spec %>%
set_engine("ranger", importance = "permutation") %>%
fit(genre ~ ., data = goa) %>%
vip(geom = "col") +
aes(fill = Variable) +
scale_fill_viridis_d(option = "viridis") +
labs(title = "Variable importance for classification")
Explore which genres are easier to confuse with each other
Because of the several classes and the small dataset, the confusion matrix shows that there are only a few genres that are correctly classified at least half of the time. It also shows what genres are the easiest to confuse with each other.
goa_confusion <-
goa_rs %>%
collect_predictions() %>%
conf_mat(genre, .pred_class)
goa_conf_perc <-
goa_confusion$table %>%
as_tibble() %>%
group_by(Truth) %>%
mutate(all = sum(n),
perc = n/all)
goa_conf_perc %>%
ggplot() +
aes(x = Truth, y = Prediction, fill = perc,
label = scales::percent(perc, accuracy = 1)) +
geom_tile(show.legend = FALSE) +
geom_text(size = 3) +
scale_fill_viridis_c(option = "plasma") +
labs(title = "Confusion matrix for the random forest model",
subtitle = "Percent of correct classifications in the genre",
x = "Truth") +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))
Let’s visualize how easy it is to confuse certain genres with others by building a network graph!
goa_graph <-
goa_conf_perc %>%
filter(Prediction != Truth & n > 0) %>%
select(Prediction, Truth, perc) %>%
graph_from_data_frame(directed = FALSE)
goa_graph %>%
ggraph() +
geom_edge_link(aes(edge_width = perc),
alpha = .4) +
scale_edge_width(range = c(.01, 3)) +
geom_node_text(aes(label = name,
size = betweenness(goa_graph),
color = name),
show.legend = FALSE) +
scale_size(range = c(3,6)) +
scale_color_viridis_d(option = "plasma") +
guides(fill = FALSE, alpha = FALSE, edge_alpha = FALSE, edge_width = FALSE) +
theme_graph() +
labs(title = "Confusion map of psytrance genres",
subtitle = "The wider the connection between genre names, the easier to confuse.\nThe size of the node shows the centrality (betweenness)")
That’s all!
