clustra: clustering trajectories

George Ostrouchov, Hanna Gerlovin, and David Gagnon

2024-01-08

The clustra package was built to cluster longitudinal trajectories (time series) on a common time axis. For example, a number of individuals are started on a specific drug regimen and their blood pressure data is collected for a varying amount of time before and after the start of the medication. Observations can be unequally spaced, unequal length, and only partially overlapping.

Clustering proceeds by an EM algorithm that iterates switching between fitting a thin plate spline (TPS) to combined responses within each cluster (M-step) and reassigning cluster membership based on nearest fitted bspline (E-step). The fitting is done with the mgcv package function bam, which scales well to very large data sets.

For this vignette, we begin by generating a data set with the gen_traj_data() function. Given its parameters, the function generates groups of ids (their size given by the vector n_id) and for each id, a random number of observations based on the Poisson(\(\lambda =\) m_obs) distribution plus 3. The 3 additional observations are to guarantee one before intervention at time start, one at the intervention time 0, and one after the intervention at time end. The start time is Uniform(s_range) and the end time is Uniform(e_range). The remaining times are at times Uniform(start, end). The time units are arbitrary and depend on your application. Up to 3 groups are implemented so far, with Sin, Sigmoid, and constant forms.

Code below generates the data and looks at a few observations of the generated data. The mc variable sets core use and will be assigned to mccores parameter through the rest of the vignette. By default, 1 core is assigned. Parallel sections are implemented with parallel::mclappy(), so on unix and Mac platforms it is recommended to use the full number of cores available for faster performance. Default initialization of the clusters is set to "random" (see clustra help file for the other option distant). We also set seed for reproducibility.

library(clustra)
mc = 2 # If running on a unix or a Mac platform, increase up to # cores
if (.Platform$OS.type == "windows") mc = 1
init = "random"
set.seed(12345)
data = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), 
                     intercepts = c(70, 130, 120, 130), m_obs = 25, 
                     s_range = c(-365, -14), e_range = c(0.5*365, 2*365),
                     noise = c(0, 15))
head(data)
##       id time  response true_group
## 1: 12148 -112 108.09095          2
## 2: 12148  -78 108.09054          2
## 3: 12148  -75 106.52212          2
## 4: 12148  -54 103.07173          2
## 5: 12148  -48 138.76058          2
## 6: 12148    0  84.15391          2

The histogram shows the distribution of generated lengths. The short ones will be the most difficult to cluster correctly.

Select a few random ids and show their scatterplots.

plot_sample(data[id %in% sample(unique(data[, id]), 9)], group = "true_group")

Next, cluster the trajectories. Set k=4 (we will consider selection of k later), spline max degrees of freedom to 30, and set conv maximum iterations to 10 and convergence when 0 changes occur. mccores sets the number of cores to use in various components of the code. Note that this does not work on Windows operating systems, where it should be set to 1 (the default). In the code that follows, we use verbose output to get information from each iteration.

set.seed(12345)
cl4 = clustra(data, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
##   Start counts 1276 1240 1264 1220   Starts time: 0
##  1 (M-123)-0.12 (E-12345)-0.16 Changes: 3672 Counts: 1518 993 435 2054 Deviance: 253406224
##  2 (M-123)-0.12 (E-12345)-0.18 Changes: 1024 Counts: 981 1634 401 1984 Deviance: 89169996
##  3 (M-123)-0.13 (E-12345)-0.25 Changes: 381 Counts: 1003 1834 196 1967 Deviance: 53538461
##  4 (M-123)-0.17 (E-12345)-0.14 Changes: 264 Counts: 1003 1969 190 1838 Deviance: 51083879
##  5 (M-123)-0.11 (E-12345)-0.16 Changes: 308 Counts: 1003 1970 496 1531 Deviance: 39576314
##  6 (M-123)-0.12 (E-12345)-0.14 Changes: 0 Counts: 1003 1970 496 1531 Deviance: 31683916
##  AIC:289.44 BIC:599.07 edf:31.44  Total time: -1.88  converged

Each iteration displays components of the M-step and the E-step followed by its duration in seconds, the number of classification changes in the E-step, the current counts in each cluster, and the deviance.

Next, plot the raw data (sample if more than 10,000 points) with resulting spline fit, colored by the cluster value.

plot_smooths(data, group = NULL)

plot_smooths(data, cl4$tps)

The Rand index for comparing with true_groups is

MixSim::RandIndex(cl4$data_group, data[, true_group])
## $R
## [1] 0.9221675
## 
## $AR
## [1] 0.8274405
## 
## $F
## [1] 0.8910276
## 
## $M
## [1] 1522217190

The AR stands for Adjusted Rand index, which adjusts for random agreement. A .827 value comparing with true groups used to generate the data is quite good, considering that the short series are easily misclassified and that k-means often find a local minimum. Let’s double the error standard deviation in data generation and repeat…

set.seed(12345)
data2 = gen_traj_data(n_id = c(500, 1000, 1500, 2000), types = c(2, 1, 3, 2), 
                     intercepts = c(70, 130, 120, 130), m_obs = 25,
                     s_range = c(-365, -14), e_range = c(60, 2*365), 
                     noise = c(0, 30))
plot_sample(data2[id %in% sample(unique(data2[, id]), 9)], group = "true_group")

cl4a = clustra(data2, k = 4, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
##   Start counts 1237 1283 1270 1210   Starts time: 0
##  1 (M-123)-0.11 (E-12345)-0.15 Changes: 3676 Counts: 1785 397 1933 885 Deviance: 339661857
##  2 (M-123)-0.11 (E-12345)-0.19 Changes: 1008 Counts: 1434 873 1816 877 Deviance: 193933047
##  3 (M-123)-0.13 (E-12345)-0.17 Changes: 562 Counts: 1185 1297 1845 673 Deviance: 165887776
##  4 (M-123)-0.12 (E-12345)-0.25 Changes: 709 Counts: 751 1864 1894 491 Deviance: 157838726
##  5 (M-123)-0.17 (E-12345)-0.14 Changes: 346 Counts: 655 2029 1934 382 Deviance: 149479974
##  6 (M-123)-0.11 (E-12345)-0.15 Changes: 197 Counts: 777 2017 1957 249 Deviance: 145721621
##  7 (M-123)-0.11 (E-12345)-0.15 Changes: 94 Counts: 828 1992 1961 219 Deviance: 145037080
##  8 (M-123)-0.12 (E-12345)-0.15 Changes: 118 Counts: 892 1973 1939 196 Deviance: 144332295
##  9 (M-123)-0.12 (E-12345)-0.15 Changes: 208 Counts: 961 1969 1806 264 Deviance: 142293482
##  10 (M-123)-0.12 (E-12345)-0.15 Changes: 284 Counts: 998 1970 1561 471 Deviance: 133518715
##  AIC:977.32 BIC:1279.27 edf:30.66  Total time: -2.94  max-iter
MixSim::RandIndex(cl4a$data_group, data2[, true_group])
## $R
## [1] 0.9165866
## 
## $AR
## [1] 0.8152816
## 
## $F
## [1] 0.8827711
## 
## $M
## [1] 1631366350

This time the AR is 0.815 result is less but still respectable. It recovers the trajectory means quite well as we see the following plots. The first without cluster colors (obtained by setting group = NULL), showing the mass of points and the second with cluster means and cluster colors.

plot_smooths(data2, group = NULL)

plot_smooths(data2, cl4a$tps)

Average silhouette value is a way to select the number of clusters and a silhouette plot provides a way for a deeper evaluation (Rouseeuw 1986). As silhouette requires distances between individual subjects, this is not possible due to unequal subject sampling without fitting a separate trajectory model for each subject id. As a proxy, we use subject distances to cluster mean trajectories in the clustra_sil() function. The structure returned from the clustra() function contains the matrix loss, which has all the information needed to construct these proxy silhouette plots. The function clustra_sil() performs clustering for a number of k values and outputs information for the silhouette plot that is displayed next. We relax the convergence criterion in conv to 1 % of changes (instead of 0 used earlier) for faster processing. We use the first data set with noise = c(0, 15).

set.seed(12345)
sil = clustra_sil(data, kv = c(2, 3, 4, 5), mccores = mc, maxdf = 10,
                  conv = c(7, 1), verbose = TRUE)
##   Start counts 2540 2460   Starts time: 0
##  1 (M-123)-0.14 (E-12345)-0.1 Changes: 2500 Counts: 2350 2650 Deviance: 253543113
##  2 (M-123)-0.09 (E-12345)-0.1 Changes: 448 Counts: 2580 2420 Deviance: 165465516
##  3 (M-123)-0.08 (E-12345)-0.1 Changes: 179 Counts: 2715 2285 Deviance: 144541735
##  4 (M-123)-0.09 (E-12345)-0.2 Changes: 143 Counts: 2858 2142 Deviance: 141197700
##  5 (M-123)-0.14 (E-12345)-0.11 Changes: 89 Counts: 2947 2053 Deviance: 138184428
##  6 (M-123)-0.08 (E-12345)-0.11 Changes: 20 Counts: 2967 2033 Deviance: 136345523
##  AIC:1012.45 BIC:1204.1 edf:19.46  Total time: -1.38  converged 
##   Start counts 1686 1647 1667   Starts time: 0
##  1 (M-123)-0.1 (E-12345)-0.14 Changes: 3262 Counts: 1547 1649 1804 Deviance: 253485140
##  2 (M-123)-0.1 (E-12345)-0.17 Changes: 1617 Counts: 2029 1105 1866 Deviance: 153190295
##  3 (M-123)-0.1 (E-12345)-0.14 Changes: 235 Counts: 1971 1003 2026 Deviance: 68405265
##  4 (M-123)-0.1 (E-12345)-0.15 Changes: 2 Counts: 1973 1003 2024 Deviance: 53014331
##  AIC:422.2 BIC:634.63 edf:21.57  Total time: -1.06  converged 
##   Start counts 1276 1240 1264 1220   Starts time: 0
##  1 (M-123)-0.11 (E-12345)-0.15 Changes: 3672 Counts: 1518 993 435 2054 Deviance: 253406224
##  2 (M-123)-0.12 (E-12345)-0.15 Changes: 1024 Counts: 981 1634 401 1984 Deviance: 89169996
##  3 (M-123)-0.12 (E-12345)-0.16 Changes: 381 Counts: 1003 1834 196 1967 Deviance: 53538461
##  4 (M-123)-0.12 (E-12345)-0.16 Changes: 264 Counts: 1003 1969 190 1838 Deviance: 51083879
##  5 (M-123)-0.12 (E-12345)-0.15 Changes: 308 Counts: 1003 1970 496 1531 Deviance: 39576314
##  6 (M-123)-0.12 (E-12345)-0.15 Changes: 0 Counts: 1003 1970 496 1531 Deviance: 31683916
##  AIC:289.44 BIC:599.07 edf:31.44  Total time: -1.69  converged 
##   Start counts 1018 979 994 1007 1002   Starts time: 0
##  1 (M-123)-0.13 (E-12345)-0.19 Changes: 3948 Counts: 30 584 1568 1006 1812 Deviance: 253198999
##  2 (M-123)-0.12 (E-12345)-0.18 Changes: 1784 Counts: 480 1003 1003 983 1531 Deviance: 70403247
##  3 (M-123)-0.13 (E-12345)-0.43 Changes: 267 Counts: 496 1003 1096 874 1531 Deviance: 32639241
##  4 (M-123)-0.13 (E-12345)-0.18 Changes: 100 Counts: 496 1003 1060 910 1531 Deviance: 31422760
##  5 (M-123)-0.13 (E-12345)-0.28 Changes: 62 Counts: 496 1003 1036 934 1531 Deviance: 31406054
##  6 (M-123)-0.19 (E-12345)-0.19 Changes: 27 Counts: 496 1003 1019 951 1531 Deviance: 31401231
##  AIC:307.29 BIC:714.81 edf:41.38  Total time: -2.38  converged
lapply(sil, plot_silhouette)

## [[1]]
## [1] 0.79
## 
## [[2]]
## [1] 0.88
## 
## [[3]]
## [1] 0.88
## 
## [[4]]
## [1] 0.56

The plots for 3 or 4 clusters give the best Average Width. Usually we take the larger one, 4, which is supported here also by the minimum AIC and BIC scores. We also note that the final deviance drops substantially a 4 clusters and barely moves when 5 clusters are fit, further corroborating that k = 4.

If we don’t want to recluster the data again, we can directly reuse a previous clustra run and produce a silhouette plot for it, as we now do for the double variance error data clustra run above results in cl4.

sil = clustra_sil(cl4)
lapply(sil, plot_silhouette)

## [[1]]
## [1] 0.88

Another way to select the number of clusters is the Rand Index comparing different random starts and different numbers of clusters. When we replicate clustering with different random seeds, the “replicability” is an indicator of how stable the results are for a given k, the number of clusters. For this demonstration, we look at k = c(2, 3, 4), and 10 replicates for each k. To run this long-running chunk, set eval = TRUE.

set.seed(12345)
ran = clustra_rand(data, k = c(2, 3, 4, 5), mccores = mc,
                   replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE)
rand_plot(ran)

The plot shows AR similarity level between all pairs of 40 clusterings (10 random starts for each of 2, 3, 4, and 5 clusters). It is difficult to distinguish between the 3, 4, and 5 results but the 4 result has the largest block of complete agreement.

Here, we cat try running clustra with the “distant” starts option. The sequential selection of above-medial length series that are most distant from previous selections introduces less initial variability. To run this long-running chunk, set eval = TRUE.

set.seed(12345)
ran = clustra_rand(data, k = c(2, 3, 4, 5), starts = "distant", mccores = mc,
                   replicates = 10, maxdf = 10, conv = c(7, 1), verbose = TRUE)
rand_plot(ran)

In this case, k = 4 comes with complete agreement between the 10 starts.

Another possible evaluation of the number of clusters is to first ask clustra for a large number of clusters, evaluate the cluster centers on a common set of time points, and feed the resulting matrix to a hierarchical clustering function. Below, we ask for 40 clusters on the data2 data set but actually get back only 17 because several become empty or too small for maxdf. Below, the hclust() function clusters the 17 resulting cluster means, each evaluated on 100 time points.

set.seed(12345)
cl30 = clustra(data, k = 40, maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
##   Start counts 142 111 131 125 120 112 143 124 130 114 112 130 119 119 112 139 125 115 134 125 130 128 137 113 135 118 113 125 127 125 129 132 115 123 133 131 133 139 107 125   Starts time: 0
##  1 (M-123)-0.62 (E-12345)-1.03 Changes: 4814 Counts: 0 1410 0 0 0 439 996 0 0 0 26 0 0 18 0 49 20 0 0 0 19 2 0 137 37 0 0 1151 1 275 0 0 392 19 0 0 0 0 0 9 Deviance: 250429060
##  2 (M-123)-0.22 (E-12345)-0.41 Changes: 4956 Counts: 1824 38 3 102 226 144 109 435 263 87 235 204 8 91 621 458 152 Deviance: 55222343
##  3 (M-123)-0.27 (E-12345)-0.61 Changes: 728 Counts: 1701 85 6 102 172 264 109 428 330 138 181 216 51 155 506 411 145 Deviance: 30895215
##  4 (M-123)-0.35 (E-12345)-0.51 Changes: 558 Counts: 1540 127 8 112 163 423 117 386 361 129 168 212 125 167 447 369 146 Deviance: 30651083
##  5 (M-123)-0.35 (E-12345)-0.44 Changes: 440 Counts: 1381 154 9 121 159 581 119 348 374 135 168 227 182 158 400 342 142 Deviance: 30498659
##  6 (M-123)-0.3 (E-12345)-0.61 Changes: 346 Counts: 1240 173 10 124 156 721 119 316 371 135 168 259 218 158 367 323 142 Deviance: 30406821
##  7 (M-123)-0.3 (E-12345)-0.59 Changes: 282 Counts: 1139 191 10 128 154 822 121 305 366 133 167 275 233 156 352 305 143 Deviance: 30346071
##  8 (M-123)-0.3 (E-12345)-0.53 Changes: 179 Counts: 1073 200 10 133 151 888 123 304 361 134 160 284 244 156 338 296 145 Deviance: 30310869
##  9 (M-123)-0.3 (E-12345)-0.62 Changes: 131 Counts: 1026 209 10 136 149 935 125 303 358 137 154 276 258 156 336 287 145 Deviance: 30292690
##  10 (M-123)-0.37 (E-12345)-0.52 Changes: 104 Counts: 988 211 9 134 149 974 129 304 356 139 150 277 263 155 331 285 146 Deviance: 30282241
##  AIC:519.31 BIC:2010.35 edf:151.4  Total time: -9.43  zerocluster max-iter
gpred = function(tps, newdata) 
  as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
                               newdata.guaranteed = TRUE))
resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame(
  time = seq(min(data2$time), max(data2$time), length.out = 100))))
plot(hclust(dist(resp)))

The dendrogram clearly indicates 4 clusters.

When we use starts = "distant", the selected distant starts are more likely to persist into a nearby local minimum, retaining the full 40 specified clusters.

set.seed(12345)
cl30 = clustra(data, k = 40, starts = "distant", maxdf = 10, conv = c(10, 0), mccores = mc,
             verbose = TRUE)
## 
##  distant ids:  2037 1462 2040 2573 284 3704 4616 3243 3079 1109 1154 908 1590 2233 2266 4967 2348 1453 125 3714 2566 2964 4878 2270 1786 923 95 1985 1526 910 2948 4833 1059 2806 3782 616 3654 3801 1460 1392   Start counts 123 60 15 470 14 293 159 33 55 4 66 72 276 215 359 60 226 97 165 61 51 64 7 29 242 422 65 69 60 20 113 222 482 69 6 31 63 41 16 105   Starts time: -31.06
##  1 (M-123)-0.47 (E-12345)-1 Changes: 1355 Counts: 165 58 28 351 33 348 174 36 65 10 109 100 284 214 256 77 173 144 131 157 71 54 13 36 200 265 67 93 82 37 153 144 350 171 26 70 56 64 21 114 Deviance: 29973903
##  2 (M-123)-0.53 (E-12345)-1.13 Changes: 725 Counts: 161 60 47 312 38 319 178 35 71 15 130 105 273 188 215 71 127 160 117 182 94 51 19 34 179 257 66 122 98 49 179 147 281 201 41 85 61 80 30 122 Deviance: 29574078
##  3 (M-123)-0.54 (E-12345)-1.07 Changes: 525 Counts: 159 55 60 300 40 300 172 34 72 20 143 103 268 182 205 64 108 162 112 192 120 58 19 26 159 254 72 142 99 51 183 162 238 210 47 96 64 86 39 124 Deviance: 29408988
##  4 (M-123)-0.55 (E-12345)-1.01 Changes: 429 Counts: 154 54 71 287 39 290 170 33 71 22 153 111 255 174 195 61 101 179 112 189 136 54 25 27 148 238 81 157 98 53 178 174 208 216 52 115 63 90 44 122 Deviance: 29311193
##  5 (M-123)-0.56 (E-12345)-1.12 Changes: 436 Counts: 145 54 82 278 40 275 162 34 75 22 158 116 242 165 185 61 93 193 108 190 147 58 31 25 137 226 93 167 98 52 176 183 179 226 53 132 59 100 54 126 Deviance: 29233283
##  6 (M-123)-0.58 (E-12345)-1.1 Changes: 391 Counts: 136 56 92 273 48 264 158 30 78 23 159 120 225 157 187 63 84 203 105 194 150 62 28 25 127 231 99 167 96 55 170 187 175 222 58 146 52 109 62 124 Deviance: 29161061
##  7 (M-123)-0.58 (E-12345)-1.06 Changes: 304 Counts: 132 56 103 263 50 254 164 29 78 28 157 119 213 145 190 64 78 217 102 194 154 62 30 25 124 225 102 159 89 53 166 193 170 220 58 166 52 117 73 126 Deviance: 29102147
##  8 (M-123)-0.59 (E-12345)-1.14 Changes: 255 Counts: 128 55 106 251 50 248 168 27 78 32 156 118 209 139 179 63 74 235 101 195 160 62 31 27 122 215 101 156 90 54 163 194 163 214 56 182 55 128 88 127 Deviance: 29058466
##  9 (M-123)-0.57 (E-12345)-1.26 Changes: 212 Counts: 126 56 113 246 49 233 168 28 77 35 152 116 206 139 177 61 69 239 106 198 167 64 33 26 117 215 98 150 88 51 155 195 165 216 56 196 55 135 101 123 Deviance: 29028546
##  10 (M-123)-0.57 (E-12345)-0.91 Changes: 196 Counts: 121 56 113 241 49 224 163 28 78 34 149 116 204 142 174 62 67 241 109 206 162 66 33 25 117 219 98 145 92 53 157 202 158 209 56 199 54 143 113 122 Deviance: 29005589
##  AIC:890.82 BIC:4256.48 edf:341.75  Total time: -16.74  max-iter
gpred = function(tps, newdata) 
  as.numeric(mgcv::predict.bam(tps, newdata, type = "response",
                               newdata.guaranteed = TRUE))
resp = do.call(rbind, lapply(cl30$tps, gpred, newdata = data.frame(
  time = seq(min(data2$time), max(data2$time), length.out = 100))))
plot(hclust(dist(resp)))

Here again (if we consider 24 as an outlier) we get 4 clusters.

cat("clustra vignette run time:\n")
## clustra vignette run time:
print(proc.time() - start_knit)
##    user  system elapsed 
##  80.800  10.325  70.602