Multilabel classification

Intro

First, we need to install blurr module for Transformers integration.

reticulate::py_install('https://github.com/ohmeow/blurr',pip = TRUE)

Multilabel

Grab data and take 1 % for fast training:

library(fastai)
library(magrittr)
library(zeallot)
df = HF_load_dataset('civil_comments', split='train[:1%]')

Preprocess

Select multiple outputs/columns:

df = data.table::as.data.table(df)

lbl_cols = c('severe_toxicity',
             'obscene',
             'threat',
             'insult',
             'identity_attack',
             'sexual_explicit')

df <- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols]
df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols]

Pretrained model

Load distill RoBERTa:

task = HF_TASKS_ALL()$SequenceClassification

pretrained_model_name = "distilroberta-base"
config = AutoConfig()$from_pretrained(pretrained_model_name)
config$num_labels = length(lbl_cols)

c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name,
                                                                               task=task,
                                                                               config=config)
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s]
Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s]

Datablock

Create data blocks:

blocks = list(
  HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer),
  MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols)
)

dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('text'), get_y=ColReader(lbl_cols),
                   splitter=RandomSplitter())

dls = dblock %>% dataloaders(df, bs=8)

dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[    0, 24268,  5257,  ...,     1,     1,     1],
        [    0,   287,  4505,  ...,     1,     1,     1],
        [    0,    38,   437,  ...,     1,     1,     1],
        ...,
        [    0,   152,  1129,  ...,     1,     1,     1],
        [    0,    85,    18,  ...,     1,     1,     1],
        [    0, 22014,    31,  ...,     1,     1,     1]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')


[[2]]
TensorMultiCategory([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], device='cuda:0')

Model

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls,
                model,
                opt_func=partial(Adam),
                loss_func=BCEWithLogitsLossFlat(),
                metrics=partial(accuracy_multi(), thresh=0.2),
                cbs=HF_BaseModelCallback(),
                splitter=hf_splitter())

learn$loss_func$thresh = 0.2
learn$create_opt()             # -> will create your layer groups based on your "splitter" function
learn$freeze()

learn %>% summary()

See summary:

epoch   train_loss   valid_loss   accuracy_multi   time  
------  -----------  -----------  ---------------  ------
HF_BaseModelWrapper (Input shape: 8 x 391)
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Embedding            8 x 391 x 768        38,603,520 False     
________________________________________________________________
Embedding            8 x 391 x 768        394,752    False     
________________________________________________________________
Embedding            8 x 391 x 768        768        False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 768              590,592    True      
________________________________________________________________
Dropout              8 x 768              0          False     
________________________________________________________________
Linear               8 x 6                4,614      True      
________________________________________________________________

Total params: 82,123,014
Total trainable params: 615,174
Total non-trainable params: 81,507,840

Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>)
Loss function: FlattenedLoss of BCEWithLogitsLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - HF_BaseModelCallback

Conclusion

Finally, fit the model:

lrs = learn %>% lr_find(suggestions=TRUE)

learn %>% fit_one_cycle(1, lr_max=1e-2)
epoch   train_loss   valid_loss   accuracy_multi   time  
------  -----------  -----------  ---------------  ------
0       0.040617     0.034286     0.993257         01:21 

Predict:

learn$loss_func$thresh = 0.02

learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes.
No enchiladas for them!")
$probabilities
  severe_toxicity     obscene       threat     insult identity_attack sexual_explicit
1    9.302437e-07 0.004268706 0.0007849637 0.02687055     0.003282947      0.00232468

$labels
[1] "insult"