Summary

The following tutorial introduces the use of a simple autoregressive model for generating sequences of dance poses. This model can be trained on motion capture data. The model uses a long short term memory (LSTM) network to predict the continuation of a sequence of poses.

This tutorial forms part of a series of tutorials on using PyTorch to create and train generative deep learning models. The code for these tutorials is available here.

After 1000 epochs of training, the model generates pose sequences that look like this when rendered as skeleton animations.

Original Versus Predicted Dance Pose Sequence. The predicted dance sequence has been created by the autoregressive model described in this article. The model has been trained for 1000 epochs on the following motion capture recording: MUR_Fluidity_Body_Take1_mb_proc_rh.p

Imports

The following modules need to be available and imported for this example. The “common” module with all its submodules is included when downloading the tutorial files.

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict

import os, sys, time, subprocess
import numpy as np
import math
sys.path.append("../..")

from common import utils
from common.skeleton import Skeleton
from common.mocap_dataset import MocapDataset
from common.quaternion import qmul, qnormalize_np, slerp
from common.pose_renderer import PoseRenderer

Compute Device

If a GPU is available for running the model, this device can be selected as follows.

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Read Motion Capture Data

The file from which data is loaded is a “pickled” dictionary of motion capture data. More information about this format is available here. The MocapDataset class is used to import such a file. Information about the MocapDataset class is available here. In the following code excerpt, a motion capture file is loaded and information about the motion capture data is collected.

mocap_data_path = "../../data/Mocap/MUR_Nov_2021/MUR_Fluidity_Body_Take1_mb_proc_rh.p"
mocap_fps = 50

# load mocap data
mocap_data = MocapDataset(mocap_data_path, fps=mocap_fps)
if device == 'cuda':
    mocap_data.cuda()
mocap_data.compute_positions()

# gather skeleton info
skeleton = mocap_data.skeleton()
skeleton_joint_count = skeleton.num_joints()
skel_edge_list = utils.get_skeleton_edge_list(skeleton)

# obtain pose sequence
subject = "S1"
action = "A1"
pose_sequence = mocap_data[subject][action]["rotations"]

pose_sequence_length = pose_sequence.shape[0]
joint_count = pose_sequence.shape[1]
joint_dim = pose_sequence.shape[2]
pose_dim = joint_count * joint_dim
pose_sequence = np.reshape(pose_sequence, (-1, pose_dim))

Create Dataset

To create a dataset from motion capture data that can be used for training, several steps are undertaken: remove sequence excerpts in which poses are invalid or otherwise unsuitable for training, split pose sequences into an input pose sequence and an output pose, declare and define a Dataset class to hold the data, split the data into a training and test set, and instantiate DataLoaders from the training and test set.

The removal of sequence excerpts and the splitting into input sequences and output poses is done in parallel. The input sequences are the sequences which are fed into the model and for which the model is supposed to predict the next pose. The output poses are the next poses that the model needs to learn to predict.

sequence_length = 128
sequence_offset = 2
mocap_valid_frame_ranges = [ [ 500, 6500 ] ]

# prepare training data
# split data into input sequence(s) and output pose(s)
input_pose_sequences = []
output_poses = []

for valid_frame_range in mocap_valid_frame_ranges:
    frame_range_start = valid_frame_range[0]
    frame_range_end = valid_frame_range[1]
    
    for seq_excerpt_start in np.arange(frame_range_start, frame_range_end - sequence_length - 1, sequence_offset):
        #print("valid: start ", frame_range_start, " end ", frame_range_end, " exc: start ", seq_excerpt_start, " end ", (seq_excerpt_start + sequence_length) )
        input_pose_sequences.append( pose_sequence[ seq_excerpt_start : seq_excerpt_start + sequence_length ] )
        output_poses.append( pose_sequence[ seq_excerpt_start + sequence_length : seq_excerpt_start + sequence_length + 1 ] )

input_pose_sequences = np.array(input_pose_sequences)
output_poses = np.array(output_poses)

A custom dataset class for the input pose sequences and output poses is created by subclassing the Dataset class.

class SequencePoseDataset(Dataset):
    def __init__(self, input_poses_sequences, output_poses):
        self.input_poses_sequences = input_poses_sequences
        self.output_poses = output_poses
    
    def __len__(self):
        return self.input_poses_sequences.shape[0]
    
    def __getitem__(self, idx):
        return self.input_poses_sequences[idx, ...], self.output_poses[idx, ...]

This custom dataset class is instantiated as follows:

full_dataset = SequencePoseDataset(input_pose_sequences, output_poses)

This dataset contains all data. The dataset can be split into two datasets, one for training and one for testing, as follows:

test_percentage = 0.2

dataset_size = len(full_dataset)

test_size = int(test_percentage * dataset_size)
train_size = dataset_size - test_size

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

DataLoaders are created from these two datasets as follows:

batch_size = 16

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Create Model

An autoregressive model is created that takes as input a sequence of poses and outputs a single pose which represents the continuation of the input sequence. The model employs an LSTM network for autoregression. The ReLU function is used as activation function. The output of the last LSTM layer is passed through a conventional artificial neural network. The last layer of this network outputs the predicted poses. The model is implemented by subclassing the nn.Module class. The class definition is as follows:

class AutoRegressor(nn.Module):
    def __init__(self, pose_dim, rnn_layer_count, rnn_layer_size, dense_layer_sizes):
        super(AutoRegressor, self).__init__()
        
        self.pose_dim = pose_dim
        self.rnn_layer_count = rnn_layer_count
        self.rnn_layer_size = rnn_layer_size
        self.dense_layer_sizes = dense_layer_sizes
        
        # create recurrent layers
        rnn_layers = []
        
        rnn_layers.append(("autoreg_rnn_0", nn.LSTM(self.pose_dim, self.rnn_layer_size, self.rnn_layer_count, batch_first=True)))
        self.rnn_layers = nn.Sequential(OrderedDict(rnn_layers))
        
        # create dense layers
        dense_layers = []
        dense_layer_count = len(self.dense_layer_sizes)
        
        if dense_layer_count > 0:
            dense_layers.append(("autoreg_dense_0", nn.Linear(self.rnn_layer_size, self.dense_layer_sizes[0])))
            dense_layers.append(("autoregr_dense_relu_0", nn.ReLU()))

            for layer_index in range(1, dense_layer_count):
                dense_layers.append(("autoreg_dense_{}".format(layer_index), nn.Linear(self.dense_layer_sizes[layer_index-1], self.dense_layer_sizes[layer_index])))
                dense_layers.append(("autoregr_dense_relu_{}".format(layer_index), nn.ReLU()))
        
            dense_layers.append(("autoregr_dense_{}".format(len(self.dense_layer_sizes)), nn.Linear(self.dense_layer_sizes[-1], self.pose_dim)))
        else:
            dense_layers.append(("autoreg_dense_0", nn.Linear(self.rnn_layer_size, self.pose_dim)))
        
        self.dense_layers = nn.Sequential(OrderedDict(dense_layers))
    
    def forward(self, x):
        #print("x 1 ", x.shape)
        x, (_, _) = self.rnn_layers(x)
        #print("x 2 ", x.shape)
        x = x[:, -1, :] # only last time step 
        #print("x 3 ", x.shape)
        yhat = self.dense_layers(x)
        #print("yhat ", yhat.shape)
        return yhat

The constructor of the model class takes the following arguments: the dimension of a single pose, the number of recurrent layers to create, the number of units in each LSTM layer, and a list of units per layer in the artificial neural network that follows the LSTM network. The model class can be instantiated as follows:

ar_rnn_layer_count = 2
ar_rnn_layer_size = 512
ar_dense_layer_sizes = [ ]

autoreg = AutoRegressor(pose_dim, ar_rnn_layer_count, ar_rnn_layer_size, ar_dense_layer_sizes).to(device)

The shapes of the input and output tensors for this model are as follows:

  • input tensor: batch_size x sequence_length x pose_dim
  • output tensor: batch_size x pose_dim

Optimiser and Loss Functions

To update the model weights during training, the Adam optimiser is used. This optimiser is instantiated as follows:

ar_learning_rate = 1e-4

ar_optimizer = torch.optim.Adam(autoreg.parameters(), lr=ar_learning_rate)

The overall loss of the model is obtained from a weighted sum of two individual losses. The two losses quantify the following: the deviation of the generated joint rotations from unit quaternions, and the deviation of the joint rotations of the predicted poses from the target poses. The two loss functions are defined as follows:

def ar_norm_loss(yhat):
    _yhat = yhat.view(-1, 4)
    _norm = torch.norm(_yhat, dim=1)
    _diff = (_norm - 1.0) ** 2
    _loss = torch.mean(_diff)
    return _loss

def ar_quat_loss(y, yhat):
    # y and yhat shapes: batch_size, seq_length, pose_dim
    
    # normalize quaternion

    _y = y.view((-1, 4))
    _yhat = yhat.view((-1, 4))
    _yhat_norm = nn.functional.normalize(_yhat, p=2, dim=1)
    
    # inverse of quaternion: 
    _yhat_inv = _yhat_norm * torch.tensor([[1.0, -1.0, -1.0, -1.0]], dtype=torch.float32).to(device)
    # calculate difference quaternion
    _diff = qmul(_yhat_inv, _y)
    # length of complex part
    _len = torch.norm(_diff[:, 1:], dim=1)
    # atan2
    _atan = torch.atan2(_len, _diff[:, 0])
    # abs
    _abs = torch.abs(_atan)
    _loss = torch.mean(_abs)   
    return _loss

The overall loss is calculated by function named “ar_loss”. This function is defined as follows:

ar_norm_loss_scale = 0.1
ar_quat_loss_scale = 0.9

def ar_loss(y, yhat):
    _norm_loss = ar_norm_loss(yhat)
    _quat_loss = ar_quat_loss(y, yhat)
    
    _total_loss = 0.0
    _total_loss += _norm_loss * ar_norm_loss_scale
    _total_loss += _quat_loss * ar_quat_loss_scale
    
    return _total_loss, _norm_loss, _quat_loss

Training and Testing Functions

For conducting a single training and testing step, two separate functions are declared: “ar_train_step” and “ar_test_step”. These functions take as input two tensors, one representing a batch of input pose sequences, the other a batch of output poses. Both functions compute and return loss values. The “ar_train_step” function also calculates gradients from the loss functions and updates the trainable parameters of the model. The “ar_test_step” suppressed gradient calculation and leaves the trainable model parameters unchanged.

def ar_train_step(pose_sequences, target_poses):

    pred_poses = autoreg(pose_sequences)

    _ar_loss, _ar_norm_loss, _ar_quat_loss = ar_loss(target_poses, pred_poses) 

    #print("_ae_pos_loss ", _ae_pos_loss)
    
    # Backpropagation
    ar_optimizer.zero_grad()
    _ar_loss.backward()

    ar_optimizer.step()
    
    return _ar_loss, _ar_norm_loss, _ar_quat_loss

def ar_test_step(pose_sequences, target_poses):
    
    autoreg.eval()
 
    with torch.no_grad():
        pred_poses = autoreg(pose_sequences)
        _ar_loss, _ar_norm_loss, _ar_quat_loss = ar_loss(target_poses, pred_poses) 
    
    autoreg.train()
    
    return _ar_loss, _ar_norm_loss, _ar_quat_loss

The function named “train” performs the actual training by calling the train and test step functions repeatedly. This function takes as arguments the train and test Dataloaders and the number of epochs. It then runs through two loops. The outer loop iterates over all epochs. The inner loop iterates over all batches. The inner loop exists in two versions, one iterates over the batches provided by the train Dataloader, and the other iterates over the batches provided by the test Dataloader. In each of these inner loops, the loss returned by the loss functions are added to a dictionary. This dictionary contains the history of the training process. The definition of the function is as follows:

def train(train_dataloader, test_dataloader, epochs):
    
    loss_history = {}
    loss_history["ar train"] = []
    loss_history["ar test"] = []
    loss_history["ar norm"] = []
    loss_history["ar quat"] = []

    for epoch in range(epochs):
        start = time.time()
        
        ar_train_loss_per_epoch = []
        ar_norm_loss_per_epoch = []
        ar_quat_loss_per_epoch = []

        for train_batch in train_dataloader:
            input_pose_sequences = train_batch[0].to(device)
            target_poses = train_batch[1].to(device)
            
            _ar_loss, _ar_norm_loss, _ar_quat_loss = ar_train_step(input_pose_sequences, target_poses)
            
            _ar_loss = _ar_loss.detach().cpu().numpy()
            _ar_norm_loss = _ar_norm_loss.detach().cpu().numpy()
            _ar_quat_loss = _ar_quat_loss.detach().cpu().numpy()
            
            ar_train_loss_per_epoch.append(_ar_loss)
            ar_norm_loss_per_epoch.append(_ar_norm_loss)
            ar_quat_loss_per_epoch.append(_ar_quat_loss)

        ar_train_loss_per_epoch = np.mean(np.array(ar_train_loss_per_epoch))
        ar_norm_loss_per_epoch = np.mean(np.array(ar_norm_loss_per_epoch))
        ar_quat_loss_per_epoch = np.mean(np.array(ar_quat_loss_per_epoch))

        ar_test_loss_per_epoch = []
        
        for test_batch in test_dataloader:
            input_pose_sequences = train_batch[0].to(device)
            target_poses = train_batch[1].to(device)
            
            _ar_loss, _, _ = ar_train_step(input_pose_sequences, target_poses)
            
            _ar_loss = _ar_loss.detach().cpu().numpy()
            
            ar_test_loss_per_epoch.append(_ar_loss)
        
        ar_test_loss_per_epoch = np.mean(np.array(ar_test_loss_per_epoch))
        
        if epoch % model_save_interval == 0 and save_weights == True:
            autoreg.save_weights("results/weights/autoreg_weights_epoch_{}".format(epoch))
        
        loss_history["ar train"].append(ar_train_loss_per_epoch)
        loss_history["ar test"].append(ar_test_loss_per_epoch)
        loss_history["ar norm"].append(ar_norm_loss_per_epoch)
        loss_history["ar quat"].append(ar_quat_loss_per_epoch)
        
        print ('epoch {} : ar train: {:01.4f} ar test: {:01.4f} norm {:01.4f} quat {:01.4f} time {:01.2f}'.format(epoch + 1, ar_train_loss_per_epoch, ar_test_loss_per_epoch, ar_norm_loss_per_epoch, ar_quat_loss_per_epoch, time.time()-start))
    
    return loss_history

The “train” function can be called as follows:

epochs = 100
model_save_interval = 100
save_weights = False

# fit model
loss_history = train(train_dataloader, test_dataloader, epochs)

Save Training History and Model Parameters

The loss history can be exported as CSV file and a graphic plot by calling the corresponding functions of the common.utils module.

# save history
utils.save_loss_as_csv(loss_history, "results/histories/history_{}.csv".format(epochs))
utils.save_loss_as_image(loss_history, "results/histories/history_{}.png".format(epochs))

The parameters of the trained model can be saved as follows:

# save model weights
torch.save(autoreg.state_dict(), "results/weights/autoreg_weights_epoch_{}".format(epochs))

Generate and Visualise Predicted Poses Sequences

Once the model has been trained, it can be used to generate new pose sequences. These pose sequences are created by starting with an existing pose sequence and then extending this sequence by predicting one pose at a time. Once a predicted pose sequence has been obtained, it can be used to generate a skeleton animation. The PoseRenderer class can be used for visualisation purposes. More information about the PoseRenderer class is available here.

An instance of the PoseRenderer class can be created as follows:

skel_edge_list = utils.get_skeleton_edge_list(skeleton)
poseRenderer = PoseRenderer(skel_edge_list)

Two convenience functions are declared for rendering a sequence of poses as skeleton animation. The animation is exported in “.gif” format.

The first function named “create_ref_sequence_anim” creates animations from excerpts of the original motion capture data. This function takes as arguments the index of the first frame in a pose sequence, the number of poses following this first frame, and the name of the file the animation is exported as.

def create_ref_sequence_anim(start_pose_index, pose_count, file_name):
    
    start_pose_index = max(start_pose_index, sequence_length)
    pose_count = min(pose_count, pose_sequence_length - start_pose_index)
    
    sequence_excerpt = pose_sequence[start_pose_index:start_pose_index + pose_count, :]
    sequence_excerpt = np.reshape(sequence_excerpt, (pose_count, joint_count, joint_dim))

    sequence_excerpt = torch.tensor(np.expand_dims(sequence_excerpt, axis=0)).to(device)
    zero_trajectory = torch.tensor(np.zeros((1, pose_count, 3), dtype=np.float32)).to(device)
    
    skel_sequence = skeleton.forward_kinematics(sequence_excerpt, zero_trajectory)

    skel_sequence = np.squeeze(skel_sequence.cpu().numpy())
    view_min, view_max = utils.get_equal_mix_max_positions(skel_sequence)
    skel_images = poseRenderer.create_pose_images(skel_sequence, view_min, view_max, view_ele, view_azi, view_line_width, view_size, view_size)
    skel_images[0].save(file_name, save_all=True, append_images=skel_images[1:], optimize=False, duration=33.0, loop=0)

The second function named “create_pred_sequence_anim” creates animations from predicted pose sequences. This function takes the same arguments as the previous function. The start_pose_index argument marks the beginning of the predicted pose sequence. To create a first predicted pose, an excerpt of the original motion capture sequence that immediately precedes this first pose is used as input for the model. Accordingly, the start_pose_index must be higher than the length of the sequence that is input into the model. All successive poses are predicted by extending the input sequence by one predicted pose at the time. The animation is created from predicted poses only.

def create_pred_sequence_anim(start_pose_index, pose_count, file_name):
    autoreg.eval()
    
    start_pose_index = max(start_pose_index, sequence_length)
    pose_count = min(pose_count, pose_sequence_length - start_pose_index)
    
    start_seq = pose_sequence[start_pose_index - sequence_length:start_pose_index, :]
    start_seq = torch.from_numpy(start_seq).to(device)
    
    next_seq = start_seq
    
    pred_poses = []
    
    for i in range(pose_count):
        with torch.no_grad():
            pred_pose = autoreg(torch.unsqueeze(next_seq, axis=0))
    
        # normalize pred pose
        pred_pose = torch.squeeze(pred_pose)
        pred_pose = pred_pose.view((-1, 4))
        pred_pose = nn.functional.normalize(pred_pose, p=2, dim=1)
        pred_pose = pred_pose.view((1, pose_dim))

        pred_poses.append(pred_pose)
    
        #print("next_seq s ", next_seq.shape)
        #print("pred_pose s ", pred_pose.shape)

        next_seq = torch.cat([next_seq[1:,:], pred_pose], axis=0)
    
        print("predict time step ", i)

    pred_poses = torch.cat(pred_poses, dim=0)
    pred_poses = pred_poses.view((1, pose_count, joint_count, joint_dim))


    zero_trajectory = torch.tensor(np.zeros((1, pose_count, 3), dtype=np.float32))
    zero_trajectory = zero_trajectory.to(device)
    
    skel_poses = skeleton.forward_kinematics(pred_poses, zero_trajectory)
    
    skel_poses = skel_poses.detach().cpu().numpy()
    skel_poses = np.squeeze(skel_poses)
    
    view_min, view_max = utils.get_equal_mix_max_positions(skel_poses)
    pose_images = poseRenderer.create_pose_images(skel_poses, view_min, view_max, view_ele, view_azi, view_line_width, view_size, view_size)

    pose_images[0].save(file_name, save_all=True, append_images=pose_images[1:], optimize=False, duration=33.0, loop=0) 

    autoreg.train()

These two functions can be called as follows:

seq_start_pose_index = 1000
seq_pose_count = 200

create_ref_sequence_anim(seq_start_pose_index, seq_pose_count, "ref_{}_{}.gif".format(seq_start_pose_index, seq_pose_count))
create_pred_sequence_anim(seq_start_pose_index, seq_pose_count, "pred_{}_{}.gif".format(seq_start_pose_index, seq_pose_count))