Summary
The following tutorial introduces the use of an adversarial autoencoder model for generating synthetic dance poses. This model can be trained on motion capture data. The model uses conventional neural network layers (ANN) to create poses. Apart from explaining how to create and train the corresponding models, this article also provides some examples how the latent space of pose encodings can be navigated to discover potentially interesting synthetic poses.
This tutorial forms part of a series of tutorials on using PyTorch to create and train generative deep learning models. The code for these tutorials is available here.
After 100 epochs of training, the poses reconstructed by the autoencoder look like this when rendered as skeleton.
Create Models
The code parts for importing poses from a motion capture recording and creating Datasets and Dataloaders for training and testing are identical to those in a previous article on Pose Generation with a GAN. These code parts are skipped here.
An adversarial autoencoder (AAE) consists of three models. A model named Encoder that compresses input data into a latent vector, a model named Decoder that recreates the input data from the latent vector, and a model named Discriminator, that distinguishes between random variables from a true normal distribution and latent vectors. A brief introduction into autoencoders is available here.
In this example, the models operate on dance poses obtained from motion capture recordings. All models employ conventional artificial neural networks (ANN).
The code for creating the Discriminator model is identical to that in the previous article on Pose Generation with a GAN. This code and its explanation are skipped here.
Create Encoder Model
For compressing poses into a latent vectors, the Encoder model passes an input pose as one dimensional feature vector through several ANN layers. These layers successively reduce the dimension of the feature vector down to the latent dimension. Each ANN layer is followed by a leaky ReLU activation function.
The class definition of the Encoder model is as follows:
class Encoder(nn.Module):
def __init__(self, pose_dim, latent_dim, dense_layer_sizes):
super(Encoder, self).__init__()
self.pose_dim = pose_dim
self.latent_dim = latent_dim
self.dense_layer_sizes = dense_layer_sizes
# create dense layers
dense_layers = []
dense_layers.append(("encoder_dense_0", nn.Linear(self.pose_dim, self.dense_layer_sizes[0])))
dense_layers.append(("encoder_dense_relu_0", nn.ReLU()))
dense_layer_count = len(dense_layer_sizes)
for layer_index in range(1, dense_layer_count):
dense_layers.append(("encoder_dense_{}".format(layer_index), nn.Linear(self.dense_layer_sizes[layer_index-1], self.dense_layer_sizes[layer_index])))
dense_layers.append(("encoder_dense_relu_{}".format(layer_index), nn.ReLU()))
dense_layers.append( ( "encoder_dense_{}".format( len( self.dense_layer_sizes ) ), nn.Linear( self.dense_layer_sizes[-1], self.latent_dim ) ))
dense_layers.append( ( "encoder_dense_relu_{}".format( len( self.dense_layer_sizes ) ), nn.ReLU() ) )
self.dense_layers = nn.Sequential(OrderedDict(dense_layers))
def forward(self, x):
#print("x 1 ", x.shape)
yhat = self.dense_layers(x)
#print("yhat ", yhat.shape)
return yhat
The constructor of the Encoder model class takes three arguments: the pose dimension, the latent dimension, and a sequence of unit counts for the ANN layers (with the unit count equal to the latent dimensions for the last layer missing, since this layer is added anyway). The Encoder model class can be instantiated as follows:
latent_dim = 8
ae_dense_layer_sizes = [ 64, 16 ]
encoder = Encoder(pose_dim, latent_dim, ae_dense_layer_sizes).to(device)
The shapes of the input and output tensors for this model are as follows:
- input tensor: batch_size x pose_dim
- output tensor: batch_size x latent_dim
Create Decoder Model
The Decoder model mirrors the task and network structure of the Encoder model. For decompressing a one dimensional latent vector, the Decoder passes the latent vector though several ANN layers. The ANN layers successively increase the dimension of the latent vector until it reaches the dimension of a pose. Each ANN layer is followed by a ReLU activation function.
The class definition of the Decoder model is as follows:
class Decoder(nn.Module):
def __init__(self, pose_dim, latent_dim, dense_layer_sizes):
super(Decoder, self).__init__()
self.pose_dim = pose_dim
self.latent_dim = latent_dim
self.dense_layer_sizes = dense_layer_sizes
# create dense layers
dense_layers = []
dense_layers.append(("decoder_dense_0", nn.Linear(latent_dim, self.dense_layer_sizes[0])))
dense_layers.append(("decoder_relu_0", nn.ReLU()))
dense_layer_count = len(self.dense_layer_sizes)
for layer_index in range(1, dense_layer_count):
dense_layers.append(("decoder_dense_{}".format(layer_index), nn.Linear(self.dense_layer_sizes[layer_index-1], self.dense_layer_sizes[layer_index])))
dense_layers.append( ( "decoder_dense_relu_{}".format( layer_index ), nn.ReLU() ) )
dense_layers.append( ( "encoder_dense_{}".format( len( self.dense_layer_sizes ) ), nn.Linear( self.dense_layer_sizes[-1], self.pose_dim ) ))
self.dense_layers = nn.Sequential(OrderedDict(dense_layers))
def forward(self, x):
#print("x 1 ", x.size())
# dense layers
yhat = self.dense_layers(x)
#print("yhat ", yhat.size())
return yhat
The constructor of the Decoder model class takes three arguments: the pose dimension, the latent dimension, and a sequence of unit counts for the ANN layers (with the unit count equal to the latent dimensions for the first layer missing, since this layer is added anyway). The unit counts for the ANN layers are obtained by reversing the corresponding list that was used for creating the Encoder model. The Decoder model class can be instantiated as follows:
ae_dense_layer_sizes_reversed = ae_dense_layer_sizes.copy()
ae_dense_layer_sizes_reversed.reverse()
decoder = Decoder(pose_dim, latent_dim, ae_dense_layer_sizes_reversed).to(device)
The shapes of the input and output tensors for this model are as follows:
- input tensor: batch_size x latent_dim
- output tensor: batch_size x pose_dim
Optimisers and Loss Functions
To update the weights of the three models during training, two individual Adam optimisers are used. One optimiser changes the combined weights of the Encoder and Decoder models. The other optimiser changes only the weights of the Discriminator model. In the example, the optimiser for the Discriminator uses a five times higher learning rate that that for the Encoder/Decoder. These optimisers are instantiated as follows:
dp_learning_rate = 5e-4
ae_learning_rate = 1e-4
disc_optimizer = torch.optim.Adam(disc_prior.parameters(), lr=dp_learning_rate)
ae_optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=ae_learning_rate)
Two different loss function are used, one for the Discriminator model and one for the combined Encoder and Decoder models.
The loss function for the Discriminator is based on binary cross-entropy to quantify the classification error. This loss function is identical with the one used for the Critique model in the articles on generative adversarial networks. This loss function is defined as follows:
cross_entropy = nn.BCELoss()
def disc_prior_loss(disc_real_output, disc_fake_output):
ones = torch.ones_like(disc_real_output).to(device)
zeros = torch.zeros_like(disc_fake_output).to(device)
real_loss = cross_entropy(disc_real_output, ones)
fake_loss = cross_entropy(disc_fake_output, zeros)
total_loss = (real_loss + fake_loss) * 0.5
return total_loss
The loss function for the Encoder/Decoder is more involved and employs a weighted combination of individual losses. These individual losses are: a loss based on the deviation of joint rotations from unit quaternions, a loss based on the difference between the positions of an original pose and those of a reconstructed pose, a loss based on the different between the rotations of an original pose and those of a reconstructed pose, and a loss based on the mistakes the Discriminator makes when classifying latent vectors generated by the Encoder.
The loss function based on the deviation of joint rotations from unit quaternions is defined as follows:
def ae_norm_loss(yhat):
_yhat = yhat.view(-1, 4)
_norm = torch.norm(_yhat, dim=1)
_diff = (_norm - 1.0) ** 2
_loss = torch.mean(_diff)
return _loss
The loss function based on the difference between the positions of an original pose and those of a reconstructed pose is based on code that has been published by Pavllo et. al. in 2018. What is unique about this loss function is that it is used for a model that operates on joint rotations rather than joint positions. For this reason, joint positions have to be derived from joint rotations using forward kinematics. In order for back propagation to be able to calculate gradients, forward kinematics has to be conducted using tensor math only. The reason why the model developed by Pavllo et. al. and also the model in this article combine joints rotations with a position-based loss function is to be able to benefit from both joint representations while avoiding their individual drawbacks. A representation of a motion capture pose by joint positions in world coordinates has the benefit that all errors in these coordinates are equivalent regardless of whether the joint is close or far from the root joint. This is not the case when representing motion capture poses as joint rotations. Here, a rotation error in a joint close to the root joint has a much bigger effect on the overall pose than a rotation error in a joint that is far from the root joint (such as a hand joint). Joint positions on the other hand have the drawback that their errors lead to changes in the lengths of the edges connecting the joints. Joint rotations don’t cause this problem. Therefore, the combination of a joint representation based on rotations with an error that is based on positions in world coordinates is more effective in detecting and correcting wrong poses and at the same avoids changes in edge length.
The code for the loss function used here employs the forward kinematics function of the Skeleton class to derive joint positions from joint rotations. The Skeleton class is part of the “common” module and briefly introduced here. In the dataset used here, the root position of each motion capture pose is always set to zero. For this reason, the trajectory that needs to be passed as tensor to the forward kinematics function contains only zeros.
def ae_pos_loss(y, yhat):
# y and yhat shapes: batch_size, seq_length, pose_dim
# normalize tensors
_yhat = yhat.view(-1, 4)
_yhat_norm = nn.functional.normalize(_yhat, p=2, dim=1)
_y_rot = y.view((y.shape[0], 1, -1, 4))
_yhat_rot = _yhat.view((y.shape[0], 1, -1, 4))
zero_trajectory = torch.zeros((y.shape[0], 1, 3), dtype=torch.float32, requires_grad=True).to(device)
_y_pos = skeleton.forward_kinematics(_y_rot, zero_trajectory)
_yhat_pos = skeleton.forward_kinematics(_yhat_rot, zero_trajectory)
_pos_diff = torch.norm((_y_pos - _yhat_pos), dim=3)
_loss = torch.mean(_pos_diff)
return _loss
The loss function based on the difference between the rotations of an original pose and those of a reconstructed pose is defined as follows:
def ae_quat_loss(y, yhat):
# y and yhat shapes: batch_size, seq_length, pose_dim
# normalize quaternion
_y = y.view((-1, 4))
_yhat = yhat.view((-1, 4))
_yhat_norm = nn.functional.normalize(_yhat, p=2, dim=1)
# inverse of quaternion: https://www.mathworks.com/help/aeroblks/quaternioninverse.html
_yhat_inv = _yhat_norm * torch.tensor([[1.0, -1.0, -1.0, -1.0]], dtype=torch.float32).to(device)
# calculate difference quaternion
_diff = qmul(_yhat_inv, _y)
# length of complex part
_len = torch.norm(_diff[:, 1:], dim=1)
# atan2
_atan = torch.atan2(_len, _diff[:, 0])
# abs
_abs = torch.abs(_atan)
_loss = torch.mean(_abs)
return _loss
All these loss functions are called by a loss function named “ae_loss”. This loss function calculates a single loss value from a weighted sum of the four individual loss values. The function is defined as follows:
ae_norm_loss_scale = 0.1
ae_pos_loss_scale = 0.1
ae_quat_loss_scale = 1.0
ae_prior_loss_scale = 0.01
def ae_loss(y, yhat, disc_fake_output):
# function parameters
# y: encoder input
# yhat: decoder output (i.e. reconstructed encoder input)
# disc_fake_output: discriminator output for encoder generated prior
_norm_loss = ae_norm_loss(yhat)
_pos_loss = ae_pos_loss(y, yhat)
_quat_loss = ae_quat_loss(y, yhat)
# discrimination loss
_fake_loss = cross_entropy(torch.zeros_like(disc_fake_output), disc_fake_output)
_total_loss = 0.0
_total_loss += _norm_loss * ae_norm_loss_scale
_total_loss += _pos_loss * ae_pos_loss_scale
_total_loss += _quat_loss * ae_quat_loss_scale
_total_loss += _fake_loss * ae_prior_loss_scale
return _total_loss, _norm_loss, _pos_loss, _quat_loss, _fake_loss
Finally, a simple function is defined for sampling from a normal distribution. This function provides the Discriminator with true random variables. The sampling function is defined as follows:
def sample_normal(shape):
return torch.tensor(np.random.normal(size=shape), dtype=torch.float32).to(device)
Training and Testing Functions
A total of three different functions are defined for conducting training and testing steps. A training step function for the Discriminator model and a training and testing step function for the combined Encoder and Decoder models. These functions are largely identical to those used in the article on image generation with an adversarial autoencoder. The definition of these functions is as follows:
def disc_prior_train_step(target_poses):
# have normal distribution and encoder produce real and fake outputs, respectively
with torch.no_grad():
encoder_output = encoder(target_poses)
real_output = sample_normal(encoder_output.shape)
# let discriminator distinguish between real and fake outputs
disc_real_output = disc_prior(real_output)
disc_fake_output = disc_prior(encoder_output)
_disc_loss = disc_prior_loss(disc_real_output, disc_fake_output)
# Backpropagation
disc_optimizer.zero_grad()
_disc_loss.backward()
disc_optimizer.step()
return _disc_loss
def ae_train_step(target_poses):
#print("train step target_poses ", target_poses.shape)
# let autoencoder preproduce target_poses (decoder output) and also return encoder output
encoder_output = encoder(target_poses)
pred_poses = decoder(encoder_output)
# let discriminator output its fake assessment of the encoder ouput
with torch.no_grad():
disc_fake_output = disc_prior(encoder_output)
_ae_loss, _ae_norm_loss, _ae_pos_loss, _ae_quat_loss, _ae_prior_loss = ae_loss(target_poses, pred_poses, disc_fake_output)
#print("_ae_pos_loss ", _ae_pos_loss)
# Backpropagation
ae_optimizer.zero_grad()
_ae_loss.backward()
ae_optimizer.step()
return _ae_loss, _ae_norm_loss, _ae_pos_loss, _ae_quat_loss, _ae_prior_loss
def ae_test_step(target_poses):
with torch.no_grad():
# let autoencoder preproduce target_poses (decoder output) and also return encoder output
encoder_output = encoder(target_poses)
pred_poses = decoder(encoder_output)
# let discriminator output its fake assessment of the encoder ouput
disc_fake_output = disc_prior(encoder_output)
_ae_loss, _ae_norm_loss, _ae_pos_loss, _ae_quat_loss, _ae_prior_loss = ae_loss(target_poses, pred_poses, disc_fake_output)
return _ae_loss, _ae_norm_loss, _ae_pos_loss, _ae_quat_loss, _ae_prior_loss
The function named “train” performs the actual training of the three models by calling the train and test step functions repeatedly. This function is also almost identical to the one used for image generation. The function is defined as follows:
def train(train_dataloader, test_dataloader, epochs):
loss_history = {}
loss_history["ae train"] = []
loss_history["ae test"] = []
loss_history["ae norm"] = []
loss_history["ae pos"] = []
loss_history["ae quat"] = []
loss_history["ae prior"] = []
loss_history["disc prior"] = []
for epoch in range(epochs):
start = time.time()
ae_train_loss_per_epoch = []
ae_norm_loss_per_epoch = []
ae_pos_loss_per_epoch = []
ae_quat_loss_per_epoch = []
ae_prior_loss_per_epoch = []
disc_prior_loss_per_epoch = []
for train_batch in train_dataloader:
train_batch = train_batch.to(device)
# start with discriminator training
_disc_prior_train_loss = disc_prior_train_step(train_batch)
_disc_prior_train_loss = _disc_prior_train_loss.detach().cpu().numpy()
#print("_disc_prior_train_loss ", _disc_prior_train_loss)
disc_prior_loss_per_epoch.append(_disc_prior_train_loss)
# now train the autoencoder
_ae_loss, _ae_norm_loss, _ae_pos_loss, _ae_quat_loss, _ae_prior_loss = ae_train_step(train_batch)
_ae_loss = _ae_loss.detach().cpu().numpy()
_ae_norm_loss = _ae_norm_loss.detach().cpu().numpy()
_ae_pos_loss = _ae_pos_loss.detach().cpu().numpy()
_ae_quat_loss = _ae_quat_loss.detach().cpu().numpy()
_ae_prior_loss = _ae_prior_loss.detach().cpu().numpy()
#print("_ae_prior_loss ", _ae_prior_loss)
ae_train_loss_per_epoch.append(_ae_loss)
ae_norm_loss_per_epoch.append(_ae_norm_loss)
ae_pos_loss_per_epoch.append(_ae_pos_loss)
ae_quat_loss_per_epoch.append(_ae_quat_loss)
ae_prior_loss_per_epoch.append(_ae_prior_loss)
ae_train_loss_per_epoch = np.mean(np.array(ae_train_loss_per_epoch))
ae_norm_loss_per_epoch = np.mean(np.array(ae_norm_loss_per_epoch))
ae_pos_loss_per_epoch = np.mean(np.array(ae_pos_loss_per_epoch))
ae_quat_loss_per_epoch = np.mean(np.array(ae_quat_loss_per_epoch))
ae_prior_loss_per_epoch = np.mean(np.array(ae_prior_loss_per_epoch))
disc_prior_loss_per_epoch = np.mean(np.array(disc_prior_loss_per_epoch))
ae_test_loss_per_epoch = []
for test_batch in test_dataloader:
test_batch = test_batch.to(device)
_ae_loss, _, _, _, _ = ae_test_step(train_batch)
_ae_loss = _ae_loss.detach().cpu().numpy()
ae_test_loss_per_epoch.append(_ae_loss)
ae_test_loss_per_epoch = np.mean(np.array(ae_test_loss_per_epoch))
if epoch % model_save_interval == 0 and save_weights == True:
disc_prior.save_weights("disc_prior_weights epoch_{}".format(epoch))
encoder.save_weights("ae_encoder_weights epoch_{}".format(epoch))
decoder.save_weights("ae_decoder_weights epoch_{}".format(epoch))
"""
if epoch % vis_save_interval == 0 and save_vis == True:
create_epoch_visualisations(epoch)
"""
loss_history["ae train"].append(ae_train_loss_per_epoch)
loss_history["ae test"].append(ae_test_loss_per_epoch)
loss_history["ae norm"].append(ae_norm_loss_per_epoch)
loss_history["ae pos"].append(ae_pos_loss_per_epoch)
loss_history["ae quat"].append(ae_quat_loss_per_epoch)
loss_history["ae prior"].append(ae_prior_loss_per_epoch)
loss_history["disc prior"].append(disc_prior_loss_per_epoch)
print ('epoch {} : ae train: {:01.4f} ae test: {:01.4f} disc prior {:01.4f} norm {:01.4f} pos {:01.4f} quat {:01.4f} prior {:01.4f} time {:01.2f}'.format(epoch + 1, ae_train_loss_per_epoch, ae_test_loss_per_epoch, disc_prior_loss_per_epoch, ae_norm_loss_per_epoch, ae_pos_loss_per_epoch, ae_quat_loss_per_epoch, ae_prior_loss_per_epoch, time.time()-start))
return loss_history
The “train” function can be called as follows:
epochs = 100
loss_history = train(train_dataloader, test_dataloader, epochs)
Generate and Visualise Poses
Once the autoencoder has been trained, it can be used to experiment with pose reconstruction. To visualise the reconstructed (or original) poses, an instance of the PoseRenderer class is used. This class forms part of the “common” module which is explained here. The PoseRenderer class is instantiated as follows:
skel_edge_list = utils.get_skeleton_edge_list(skeleton)
poseRenderer = PoseRenderer(skel_edge_list)
Several convenience functions are provided for reconstructing and visualising poses.
A function entitled “create_ref_pose_image” generates an image of an original pose. This function takes as arguments an index in the pose sequence of the original motion capture recording and the file name under which the visualised pose is saved. The function is defined as follows:
def create_ref_pose_image(pose_index, file_name):
pose = poses[pose_index]
pose = torch.tensor(np.reshape(pose, (1, 1, joint_count, joint_dim))).to(device)
zero_trajectory = torch.tensor(np.zeros((1, 1, 3), dtype=np.float32)).to(device)
skel_pose = skeleton.forward_kinematics(pose, zero_trajectory)
skel_pose = skel_pose.detach().cpu().numpy()
skel_pose = np.reshape(skel_pose, (joint_count, 3))
view_min, view_max = utils.get_equal_mix_max_positions(skel_pose)
pose_image = poseRenderer.create_pose_image(skel_pose, view_min, view_max, view_ele, view_azi, view_line_width, view_size, view_size)
pose_image.save(file_name, optimize=False)
A function entitled “create_rec_pose_image” generates an image of a reconstructed pose. This function also takes as arguments an index in the pose sequence of the original motion capture recording and the file name under which the visualised pose is saved. The function is defined as follows:
def create_rec_pose_image(pose_index, file_name):
encoder.eval()
decoder.eval()
pose = poses[pose_index]
pose = torch.tensor(np.expand_dims(pose, axis=0)).to(device)
with torch.no_grad():
pose_enc = encoder(pose)
rec_pose = decoder(pose_enc)
rec_pose = torch.squeeze(rec_pose)
rec_pose = rec_pose.view((-1, 4))
rec_pose = nn.functional.normalize(rec_pose, p=2, dim=1)
rec_pose = rec_pose.view((1, 1, joint_count, joint_dim))
zero_trajectory = torch.tensor(np.zeros((1, 1, 3), dtype=np.float32))
zero_trajectory = zero_trajectory.to(device)
skel_pose = skeleton.forward_kinematics(rec_pose, zero_trajectory)
skel_pose = skel_pose.detach().cpu().numpy()
skel_pose = np.squeeze(skel_pose)
view_min, view_max = utils.get_equal_mix_max_positions(skel_pose)
pose_image = poseRenderer.create_pose_image(skel_pose, view_min, view_max, view_ele, view_azi, view_line_width, view_size, view_size)
pose_image.save(file_name, optimize=False)
encoder.train()
decoder.train()
Another convenience function named “encode_poses” can be used to obtain a list of encodings of poses. This function takes as argument a list of pose indices. The function is defined as follows:
def encode_poses(pose_indices):
encoder.eval()
pose_encodings = []
for pose_index in pose_indices:
pose = poses[pose_index]
pose = np.expand_dims(pose, axis=0)
pose = torch.from_numpy(pose).to(device)
with torch.no_grad():
pose_enc = encoder(pose)
pose_enc = torch.squeeze(pose_enc)
pose_enc = pose_enc.detach().cpu().numpy()
pose_encodings.append(pose_enc)
encoder.train()
return pose_encodings
The counterpart of the previously described function is a convenience function with the name “decode_pose_encodings”. This function decodes pose encodings into poses, visualises these poses and then exports the visualisations as animation. The function takes as arguments a list of pose encodings and a file name under which the animation is saved. The function is defined as follows:
def decode_pose_encodings(pose_encodings, file_name):
decoder.eval()
rec_poses = []
for pose_encoding in pose_encodings:
pose_encoding = np.expand_dims(pose_encoding, axis=0)
pose_encoding = torch.from_numpy(pose_encoding).to(device)
with torch.no_grad():
rec_pose = decoder(pose_encoding)
rec_pose = torch.squeeze(rec_pose)
rec_pose = rec_pose.view((-1, 4))
rec_pose = nn.functional.normalize(rec_pose, p=2, dim=1)
rec_pose = rec_pose.view((1, joint_count, joint_dim))
rec_poses.append(rec_pose)
rec_poses = torch.cat(rec_poses, dim=0)
rec_poses = torch.unsqueeze(rec_poses, dim=0)
zero_trajectory = torch.tensor(np.zeros((1, len(pose_encodings), 3), dtype=np.float32))
zero_trajectory = zero_trajectory.to(device)
skel_poses = skeleton.forward_kinematics(rec_poses, zero_trajectory)
skel_poses = skel_poses.detach().cpu().numpy()
skel_poses = np.squeeze(skel_poses)
view_min, view_max = utils.get_equal_mix_max_positions(skel_poses)
pose_images = poseRenderer.create_pose_images(skel_poses, view_min, view_max, view_ele, view_azi, view_line_width, view_size, view_size)
pose_images[0].save(file_name, save_all=True, append_images=pose_images[1:], optimize=False, duration=33.0, loop=0)
decoder.train()
In the following, some examples of using the convenience functions to experiment with poses and pose encodings are presented.
Create an Image for a Single Original Pose
A single original pose can be obtained and saved as image as follows:
pose_index = 100
create_ref_pose_image(pose_index, "results/images/orig_pose_{}.gif".format(pose_index))
Create an Image for a Single Reconstructed Pose
A single pose can be reconstructed and saved as image as follows:
pose_index = 100
create_rec_pose_image(pose_index, "results/images/rec_pose_{}.gif".format(pose_index))
Create an Animation from Several Reconstructed Poses
A list of poses can be reconstructed and saved as animation as follows:
start_pose_index = 100
end_pose_index = 500
pose_indices = [ pose_index for pose_index in range(start_pose_index, end_pose_index)]
pose_encodings = encode_poses(pose_indices)
decode_pose_encodings(pose_encodings, "results/images/rec_pose_sequence_{}-{}.gif".format(start_pose_index, end_pose_index))
Create an Animation from a Random Walk in Latent Space
In a more interesting example, a single pose is encoded and the encoding is used as starting point for a random walk within latent space. The random walk generates a list of increasingly randomised pose encodings which are then decoded into poses and saved as animation.
start_pose_index = 100
pose_count = 500
pose_indices = [start_pose_index]
pose_encodings = encode_poses(pose_indices)
for index in range(0, pose_count - 1):
random_step = np.random.random((latent_dim)).astype(np.float32) * 2.0
pose_encodings.append(pose_encodings[index] + random_step)
decode_pose_encodings(pose_encodings, "results/images/rec_poses_randwalk_{}_{}.gif".format(start_pose_index, pose_count))
Create an Animation by Following a Trajectory in Latent Space with an Offset
In this example, a list of poses that follow each other in the original motion capture recording is encoded into a list of latent vectors. These latent vectors represent a trajectory in latent space. This trajectory is followed at an offset by adding a vector to each pose encoding. The resulting encodings are then decoded into poses and saved as animation.
start_pose_index = 100
end_pose_index = 500
pose_indices = [ pose_index for pose_index in range(start_pose_index, end_pose_index)]
pose_encodings = encode_poses(pose_indices)
offset_pose_encodings = []
for index in range(len(pose_encodings)):
sin_value = np.sin(index / (len(pose_encodings) - 1) * np.pi * 4.0)
offset = np.ones(shape=(latent_dim), dtype=np.float32) * sin_value * 4.0
offset_pose_encoding = pose_encodings[index] + offset
offset_pose_encodings.append(offset_pose_encoding)
decode_pose_encodings(offset_pose_encodings, "results/images/rec_pose_sequence_offset_{}-{}.gif".format(start_pose_index, end_pose_index))
Create an Animation by Interpolating Between Pose Encodings
Two poses are encoded and new encodings are created by gradually interpolating between the initial pose encodings. Each interpolated encoding is decoded in a pose to obtain a sequence of poses. This sequence of poses is saved as animation.
start_pose1_index = 100
end_pose1_index = 500
start_pose2_index = 1100
end_pose2_index = 1500
pose1_indices = [ pose_index for pose_index in range(start_pose1_index, end_pose1_index)]
pose2_indices = [ pose_index for pose_index in range(start_pose2_index, end_pose2_index)]
pose1_encodings = encode_poses(pose1_indices)
pose2_encodings = encode_poses(pose2_indices)
mixed_pose_encodings = []
for index in range(len(pose1_indices)):
mix_factor = index / (len(pose1_indices) - 1)
mixed_pose_encoding = pose1_encodings[index] * (1.0 - mix_factor) + pose2_encodings[index] * mix_factor
mixed_pose_encodings.append(mixed_pose_encoding)
decode_pose_encodings(mixed_pose_encodings, "results/images/rec_pose_sequence_mix_{}-{}_{}-{}.gif".format(start_pose1_index, end_pose2_index, start_pose2_index, end_pose2_index))