Skip to content

Model

The core classes for building and training GP-TEMPEST.

TEMPEST

gptempest.TEMPEST

Bases: Module

Gaussian Process Temporal Embedding for Protein Simulations and Transitions.

GP-TEMPEST is a GP-VAE that combines a variational autoencoder with a sparse Gaussian Process prior in the latent space. The GP prior enforces temporal smoothness and allows the model to recover kinetically relevant degrees of freedom from molecular dynamics trajectories.

The architecture consists of:

  • Encoder — inference network q(z|x) mapping frames to a Gaussian distribution in latent space.
  • Decoder — generative network p(x|z) reconstructing input features.
  • Sparse GP — Gaussian Process with inducing points acting as a temporal prior on the latent trajectories.

Parameters:

Name Type Description Default
cuda bool

If True, move the model to GPU.

required
kernel MaternKernel

A :class:MaternKernel instance defining the GP covariance.

required
dim_input int

Dimensionality of the input features.

required
dim_latent int

Dimensionality of the latent space (typically 2).

required
layers_hidden_encoder list[int]

Hidden layer sizes for the encoder, e.g. [32, 32, 32].

required
layers_hidden_decoder list[int]

Hidden layer sizes for the decoder, e.g. [32, 32, 32]. Usually the reverse of the encoder.

required
inducing_points ndarray

1-D array of inducing point timestamps. Should cover metastable states and transitions in the trajectory.

required
beta float

Weight of the GP regularisation term in the ELBO.

required
N_data int

Total number of frames in the dataset. Used to scale the mini-batch ELBO correctly.

required
dtype dtype

PyTorch dtype, e.g. torch.float64.

required
Example
kernel = MaternKernel(nu=1.5, scale=1e3, dtype=torch.float64)
model = TEMPEST(
    cuda=False,
    kernel=kernel,
    dim_input=10,
    dim_latent=2,
    layers_hidden_encoder=[32, 32],
    layers_hidden_decoder=[32, 32],
    inducing_points=np.linspace(0, 999, 50),
    beta=50.0,
    N_data=1000,
    dtype=torch.float64,
)
model.train_model(dataset, train_size=1, learning_rate=1e-4,
                  weight_decay=1e-6, batch_size=512, n_epochs=100)
embedding = model.extract_latent_space(dataset, batch_size=512)
Source code in src/gptempest/model.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
class TEMPEST(nn.Module):
    """Gaussian Process Temporal Embedding for Protein Simulations and
    Transitions.

    GP-TEMPEST is a GP-VAE that combines a variational autoencoder with a
    sparse Gaussian Process prior in the latent space. The GP prior enforces
    temporal smoothness and allows the model to recover kinetically relevant
    degrees of freedom from molecular dynamics trajectories.

    The architecture consists of:

    - **Encoder** — inference network q(z|x) mapping frames to a Gaussian
      distribution in latent space.
    - **Decoder** — generative network p(x|z) reconstructing input features.
    - **Sparse GP** — Gaussian Process with inducing points acting as a
      temporal prior on the latent trajectories.

    Args:
        cuda: If ``True``, move the model to GPU.
        kernel: A :class:`MaternKernel` instance defining the GP covariance.
        dim_input: Dimensionality of the input features.
        dim_latent: Dimensionality of the latent space (typically 2).
        layers_hidden_encoder: Hidden layer sizes for the encoder,
            e.g. ``[32, 32, 32]``.
        layers_hidden_decoder: Hidden layer sizes for the decoder,
            e.g. ``[32, 32, 32]``. Usually the reverse of the encoder.
        inducing_points: 1-D array of inducing point timestamps. Should cover
            metastable states and transitions in the trajectory.
        beta: Weight of the GP regularisation term in the ELBO.
        N_data: Total number of frames in the dataset. Used to scale the
            mini-batch ELBO correctly.
        dtype: PyTorch dtype, e.g. ``torch.float64``.

    Example:
        ```python
        kernel = MaternKernel(nu=1.5, scale=1e3, dtype=torch.float64)
        model = TEMPEST(
            cuda=False,
            kernel=kernel,
            dim_input=10,
            dim_latent=2,
            layers_hidden_encoder=[32, 32],
            layers_hidden_decoder=[32, 32],
            inducing_points=np.linspace(0, 999, 50),
            beta=50.0,
            N_data=1000,
            dtype=torch.float64,
        )
        model.train_model(dataset, train_size=1, learning_rate=1e-4,
                          weight_decay=1e-6, batch_size=512, n_epochs=100)
        embedding = model.extract_latent_space(dataset, batch_size=512)
        ```
    """
    def __init__(
        self,
        cuda: bool,
        kernel: MaternKernel,
        dim_input: int,
        dim_latent: int,
        layers_hidden_encoder: list[int],
        layers_hidden_decoder: list[int],
        inducing_points: np.ndarray,
        beta: float,
        N_data: int,
        dtype: torch.dtype,
    ):
        super().__init__()
        self.dtype = dtype
        self.device = torch.device('cuda' if cuda else 'cpu')
        self.kernel = kernel.to(self.device)
        self.inducing_points = torch.tensor(
            inducing_points, dtype=self.dtype
        ).to(self.device).unsqueeze(1)
        self.dim_latent = dim_latent
        self.layers_encoder = [dim_input, *layers_hidden_encoder]
        self.layers_decoder = [dim_latent, *layers_hidden_decoder, dim_input]
        self.encoder = InferenceNN(
            self.layers_encoder,
            layers_gaussian=[layers_hidden_encoder[-1], dim_latent],
        ).to(self.device).to(self.dtype)
        self.decoder = FeedForwardNN(
            self.layers_decoder, linear=False).to(self.device).to(self.dtype)
        self.decoder.add_layer('sigmoid', nn.Sigmoid())
        self.beta = beta
        self.N_data = N_data

    def compute_kernel_matrices(self, t):
        """Precompute all GP kernel matrices needed for the current batch.

        Stores the following attributes on ``self``:

        - ``kernel_mm``: GP kernel between inducing points, shape ``(M, M)``.
        - ``kernel_mm_inv``: Inverse of the stabilised ``kernel_mm``.
        - ``kernel_nn``: Diagonal kernel values for the batch, shape ``(N,)``.
        - ``kernel_nm``: Cross-kernel between batch and inducing points,
          shape ``(N, M)``.
        - ``kernel_mn``: Transpose of ``kernel_nm``, shape ``(M, N)``.

        Args:
            t: Batch timestamps of shape ``(N, 1)``.
        """
        self.kernel_mm = self.kernel.kernel_mat(
            self.inducing_points,
            self.inducing_points,
        )
        self.kernel_mm_inv = torch.linalg.inv(
            _num_stabilize_diag(self.kernel_mm),
        )
        self.kernel_nn = self.kernel.kernel_diag(t, t)
        self.kernel_nm = self.kernel.kernel_mat(t, self.inducing_points)
        self.kernel_mn = self.kernel_nm.transpose(0, 1)

    def _compute_diagonal_kernel(self, precision):
        """Compute the diagonal elements of the kernel matrix."""
        diagonal = torch.diagonal(torch.matmul(
            self.kernel_nm,
            torch.matmul(
                self.kernel_mm_inv,
                self.kernel_mn,
            ),
        ))
        return precision * (self.kernel_nn - diagonal)

    def _compute_Lambda(self):
        """Compute the Lambda matrix."""
        return torch.matmul(
            self.kernel_mm_inv,
            torch.matmul(
                torch.matmul(
                    self.kernel_nm.unsqueeze(2),
                    torch.transpose(self.kernel_nm.unsqueeze(2), 1, 2),
                ),
                self.kernel_mm_inv,
            ),
        )

    def compute_gp_params(self, t, qzx_mu, qzx_var):
        """Compute GP posterior parameters for one latent dimension.

        Updates the following attributes on ``self``:

        - ``mu_l``: GP posterior mean at inducing points, shape ``(M,)``.
        - ``A_l``: GP posterior covariance at inducing points, shape ``(M, M)``.
        - ``gp_mean_vector``: GP predictive mean at batch points, shape ``(N,)``.
        - ``gp_mean_sigma``: GP predictive variance at batch points, shape ``(N,)``.

        Follows Eq. (9) in Jazbec et al. (2021).

        Args:
            t: Batch timestamps of shape ``(N, 1)``.
            qzx_mu: Encoder posterior mean for one latent dim, shape ``(N,)``.
            qzx_var: Encoder posterior variance for one latent dim, shape ``(N,)``.
        """
        constant = self.N_data / t.shape[0]
        Sigma_l = self.kernel_mm + constant * torch.matmul(
            self.kernel_mn,
            self.kernel_nm / qzx_var.unsqueeze(1),
        )  # see Eq.(9) in Jazbec21
        Sigma_l_inv = torch.linalg.inv(_num_stabilize_diag(Sigma_l))
        self.mu_l = constant * torch.matmul(
            self.kernel_mm,
            torch.matmul(
                Sigma_l_inv,
                torch.matmul(
                    self.kernel_mn,
                    qzx_mu / qzx_var,
                )
            )
        )
        self.A_l = torch.matmul(
            self.kernel_mm,
            torch.matmul(
                Sigma_l_inv,
                self.kernel_mm,
            ),
        )
        self.gp_mean_vector = constant * torch.matmul(
            self.kernel_nm,
            torch.matmul(
                Sigma_l_inv,
                torch.matmul(
                    self.kernel_mn,
                    qzx_mu / qzx_var,
                )
            )
        )
        self.gp_mean_sigma = self.kernel_nn + torch.diagonal(
            -torch.matmul(
                self.kernel_nm,
                torch.matmul(
                    self.kernel_mm_inv,
                    self.kernel_mn,
                )
            ) + torch.matmul(
                self.kernel_nm,
                torch.matmul(
                    Sigma_l_inv,
                    self.kernel_mn,
                )
            )
        )

    def variational_loss(self, qzx_mu, qzx_var):
        """
        Compute the Hensman loss term for the current batch.
        Compare eg. Eq.(7) and (10) in Jazbec21.
        More details in Jazbec21 SI B, proposition B.1
        """
        # log_det_kmm = _cholesky_log_determinant(self.kernel_mm)
        # log_det_A = _cholesky_log_determinant(self.A_l)
        log_det_kmm = _robust_log_determinant(self.kernel_mm)
        log_det_A = _robust_log_determinant(self.A_l)
        KL_div = 0.5 * (-self.inducing_points.shape[0] + torch.trace(
            torch.matmul(
                self.kernel_mm_inv,
                self.A_l,
            )
        ) + torch.sum(self.mu_l * torch.matmul(
            self.kernel_mm_inv,
            self.mu_l,
        )) + log_det_kmm - log_det_A)
        # compute L3 sum term
        mean_vec = torch.matmul(
            self.kernel_nm,
            torch.matmul(
                self.kernel_mm_inv,
                self.mu_l,
            )
        )  # first term in Jazbec21 SI, B.1 first eq.
        precision = 1 / qzx_var
        k_iitilde = self._compute_diagonal_kernel(precision)
        Lambda = self._compute_Lambda()
        tr_ALambda = precision * torch.einsum(
            'bii->b',
            torch.matmul(
                self.A_l,
                Lambda,
            ),
        )
        loss_L3 = -0.5 * (
            torch.sum(k_iitilde) +
            torch.sum(tr_ALambda) +
            torch.sum(torch.log(qzx_var)) +
            qzx_mu.shape[0] * 1.8379 +
            torch.sum(precision * (qzx_mu - mean_vec)**2)
        )  # this is the L3 loss from Hensman
        return loss_L3, KL_div

    def gp_step(self, x, t):
        """Compute the full ELBO for a mini-batch.

        Runs the encoder, computes GP kernel matrices and posterior parameters
        for each latent dimension, evaluates the variational loss, decodes a
        latent sample, and accumulates the reconstruction loss. Results are
        stored as attributes:

        - ``self.elbo``: Total loss (reconstruction + beta * GP KL).
        - ``self.recon_loss``: MSE reconstruction term (scaled by 1e6).
        - ``self.gp_KL``: GP KL-divergence term.

        Args:
            x: Input feature batch of shape ``(N, dim_input)``.
            t: Timestamp batch of shape ``(N, 1)``.
        """
        qzx = self.encoder(x)
        qzx_mu = qzx['means']
        qzx_var = qzx['variances']
        self.compute_kernel_matrices(t)
        gp_mean, gp_var = [], []
        loss_L3, loss_KL = [], []
        for latent_dim in range(self.dim_latent):
            self.compute_gp_params(
                t,
                qzx_mu[:, latent_dim],
                qzx_var[:, latent_dim],
            )
            gp_mean.append(self.gp_mean_vector)
            gp_var.append(self.gp_mean_sigma)
            l_L3, l_KL = self.variational_loss(
                qzx_mu[:, latent_dim],
                qzx_var[:, latent_dim],
            )
            loss_L3.append(l_L3)
            loss_KL.append(l_KL)
        loss_L3 = torch.sum(torch.stack(loss_L3, dim=-1))
        loss_KL = torch.sum(torch.stack(loss_KL, dim=-1))
        elbo_gp = loss_L3 - (x.shape[0] / self.N_data) * loss_KL
        gp_mean = torch.stack(gp_mean, dim=1)
        gp_var = torch.stack(gp_var, dim=1)
        gp_cross_entropy = torch.sum(
            _gauss_cross_entropy(gp_mean, gp_var, qzx_mu, qzx_var)
        )
        self.gp_KL = gp_cross_entropy - elbo_gp
        latent_dist = Normal(qzx_mu, qzx_var)
        latent_samples = latent_dist.rsample()
        qxz = self.decoder(latent_samples)
        loss_L2 = nn.MSELoss(reduction='mean')
        self.recon_loss = loss_L2(qxz, x) * 1e6
        self.elbo = self.recon_loss + self.beta * self.gp_KL

    def train_model(
        self,
        dataset,
        train_size,
        learning_rate,
        weight_decay,
        batch_size,
        n_epochs,
    ):
        """Train the TEMPEST model.

        Args:
            dataset: A ``TensorDataset`` of ``(features, times)`` as returned
                by :func:`~gptempest.utils.load_prepare_data`.
            train_size: Fraction of the dataset to use for training. Pass
                ``1`` to use the full dataset.
            learning_rate: AdamW learning rate; typically between ``1e-3``
                and ``1e-5``.
            weight_decay: AdamW weight decay (L2 regularisation).
            batch_size: Number of frames per mini-batch. Larger batches give
                better GP estimator accuracy; values above ``512`` are
                recommended.
            n_epochs: Number of full passes over the training data.
        """
        train_dataset = dataset if train_size == 1 else random_split(
            dataset=dataset,
            lengths=[
                int(len(dataset) * train_size),
                len(dataset) - int(len(dataset) * train_size)
            ]
        )[0]
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True,
        )
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=learning_rate,
            weight_decay=weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=10,
        )
        epoch_pbar = trange(n_epochs, desc='Training')
        for nr_epoch in epoch_pbar:
            train_pbar = tqdm(
                train_loader,
                desc=f'Epoch {nr_epoch + 1} (train)',
                leave=False,
            )
            l_train_elbo, l_train_recon, l_train_gp = self.train_epoch(
                train_pbar, optimizer, is_training=True
            )
            scheduler.step(l_train_elbo)
            epoch_pbar.set_postfix({
                'Train ELBO': f'{l_train_elbo:.5f}',
                'LR': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })

    def train_epoch(
        self,
        pbar: tqdm,
        optimizer: torch.optim.Optimizer,
        is_training: bool = True,
    ) -> tuple[float, float, float]:
        """Train the model for one epoch.

        Args:
            pbar: Progress-bar-wrapped DataLoader for the current epoch.
            optimizer: AdamW optimizer for model parameters.
            is_training: If ``False``, run in eval mode without gradient
                updates. Defaults to ``True``.

        Returns:
            Tuple of average per-frame losses ``(ELBO, reconstruction, GP KL)``
            for the epoch.
        """
        nr_frames, loss_elbo, loss_recon, loss_gp = 0, 0, 0, 0
        if is_training:
            self.train()
        else:
            self.eval()

        for x_batch, t_batch in pbar:
            x_batch = x_batch.clone().detach().to(self.device)
            t_batch = t_batch.clone().detach().to(self.device)
            if is_training:
                optimizer.zero_grad()
            with torch.set_grad_enabled(is_training):
                self.gp_step(x_batch, t_batch)
                loss_elbo += self.elbo.item()
                loss_recon += self.recon_loss.item()
                loss_gp += self.gp_KL.item()
                if is_training:
                    self.elbo.backward()
                    optimizer.step()

            nr_frames += t_batch.shape[0]
            pbar.set_postfix({
                'ELBO': f'{loss_elbo/nr_frames:.5f}',
                'Recon': f'{loss_recon/nr_frames:.5f}',
                'GP': f'{loss_gp/nr_frames:.5f}'
            })
        return loss_elbo / nr_frames, loss_recon / nr_frames, \
            loss_gp / nr_frames

    def extract_latent_space(
        self,
        dataset: torch.utils.data.TensorDataset,
        batch_size: int,
    ) -> torch.Tensor:
        """Extract the GP-smoothed latent embedding for all frames.

        Runs the encoder and GP posterior in evaluation mode over the full
        dataset and returns the latent coordinates for each frame.

        Args:
            dataset: A ``TensorDataset`` of ``(features, times)`` as returned
                by :func:`~gptempest.utils.load_prepare_data`.
            batch_size: Number of frames per batch.

        Returns:
            Tensor of shape ``(N, dim_latent)`` containing the latent
            coordinates for all N frames.
        """
        self.eval()
        latent_samples = []

        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
        for x_batch, t_batch in dataloader:
            x_batch = x_batch.clone().detach().to(self.device)
            t_batch = t_batch.clone().detach().to(self.device)

            qzx = self.encoder(x_batch)
            self.compute_kernel_matrices(t_batch)
            qzx_mu = qzx['means']
            qzx_var = qzx['variances']
            gp_mean, gp_var = [], []
            for latent_dim in range(self.dim_latent):
                self.compute_gp_params(
                    t_batch,
                    qzx_mu[:, latent_dim],
                    qzx_var[:, latent_dim],
                )
                gp_mean.append(self.gp_mean_vector)
                gp_var.append(self.gp_mean_sigma)
            gp_mean = torch.stack(gp_mean, dim=1)
            gp_var = torch.stack(gp_var, dim=1)
            latent_samples_batch = _reparameterize(gp_mean, gp_var)
            latent_samples.append(latent_samples_batch.detach().cpu())
        return torch.cat(latent_samples, dim=0)

compute_kernel_matrices(t)

Precompute all GP kernel matrices needed for the current batch.

Stores the following attributes on self:

  • kernel_mm: GP kernel between inducing points, shape (M, M).
  • kernel_mm_inv: Inverse of the stabilised kernel_mm.
  • kernel_nn: Diagonal kernel values for the batch, shape (N,).
  • kernel_nm: Cross-kernel between batch and inducing points, shape (N, M).
  • kernel_mn: Transpose of kernel_nm, shape (M, N).

Parameters:

Name Type Description Default
t

Batch timestamps of shape (N, 1).

required
Source code in src/gptempest/model.py
def compute_kernel_matrices(self, t):
    """Precompute all GP kernel matrices needed for the current batch.

    Stores the following attributes on ``self``:

    - ``kernel_mm``: GP kernel between inducing points, shape ``(M, M)``.
    - ``kernel_mm_inv``: Inverse of the stabilised ``kernel_mm``.
    - ``kernel_nn``: Diagonal kernel values for the batch, shape ``(N,)``.
    - ``kernel_nm``: Cross-kernel between batch and inducing points,
      shape ``(N, M)``.
    - ``kernel_mn``: Transpose of ``kernel_nm``, shape ``(M, N)``.

    Args:
        t: Batch timestamps of shape ``(N, 1)``.
    """
    self.kernel_mm = self.kernel.kernel_mat(
        self.inducing_points,
        self.inducing_points,
    )
    self.kernel_mm_inv = torch.linalg.inv(
        _num_stabilize_diag(self.kernel_mm),
    )
    self.kernel_nn = self.kernel.kernel_diag(t, t)
    self.kernel_nm = self.kernel.kernel_mat(t, self.inducing_points)
    self.kernel_mn = self.kernel_nm.transpose(0, 1)

compute_gp_params(t, qzx_mu, qzx_var)

Compute GP posterior parameters for one latent dimension.

Updates the following attributes on self:

  • mu_l: GP posterior mean at inducing points, shape (M,).
  • A_l: GP posterior covariance at inducing points, shape (M, M).
  • gp_mean_vector: GP predictive mean at batch points, shape (N,).
  • gp_mean_sigma: GP predictive variance at batch points, shape (N,).

Follows Eq. (9) in Jazbec et al. (2021).

Parameters:

Name Type Description Default
t

Batch timestamps of shape (N, 1).

required
qzx_mu

Encoder posterior mean for one latent dim, shape (N,).

required
qzx_var

Encoder posterior variance for one latent dim, shape (N,).

required
Source code in src/gptempest/model.py
def compute_gp_params(self, t, qzx_mu, qzx_var):
    """Compute GP posterior parameters for one latent dimension.

    Updates the following attributes on ``self``:

    - ``mu_l``: GP posterior mean at inducing points, shape ``(M,)``.
    - ``A_l``: GP posterior covariance at inducing points, shape ``(M, M)``.
    - ``gp_mean_vector``: GP predictive mean at batch points, shape ``(N,)``.
    - ``gp_mean_sigma``: GP predictive variance at batch points, shape ``(N,)``.

    Follows Eq. (9) in Jazbec et al. (2021).

    Args:
        t: Batch timestamps of shape ``(N, 1)``.
        qzx_mu: Encoder posterior mean for one latent dim, shape ``(N,)``.
        qzx_var: Encoder posterior variance for one latent dim, shape ``(N,)``.
    """
    constant = self.N_data / t.shape[0]
    Sigma_l = self.kernel_mm + constant * torch.matmul(
        self.kernel_mn,
        self.kernel_nm / qzx_var.unsqueeze(1),
    )  # see Eq.(9) in Jazbec21
    Sigma_l_inv = torch.linalg.inv(_num_stabilize_diag(Sigma_l))
    self.mu_l = constant * torch.matmul(
        self.kernel_mm,
        torch.matmul(
            Sigma_l_inv,
            torch.matmul(
                self.kernel_mn,
                qzx_mu / qzx_var,
            )
        )
    )
    self.A_l = torch.matmul(
        self.kernel_mm,
        torch.matmul(
            Sigma_l_inv,
            self.kernel_mm,
        ),
    )
    self.gp_mean_vector = constant * torch.matmul(
        self.kernel_nm,
        torch.matmul(
            Sigma_l_inv,
            torch.matmul(
                self.kernel_mn,
                qzx_mu / qzx_var,
            )
        )
    )
    self.gp_mean_sigma = self.kernel_nn + torch.diagonal(
        -torch.matmul(
            self.kernel_nm,
            torch.matmul(
                self.kernel_mm_inv,
                self.kernel_mn,
            )
        ) + torch.matmul(
            self.kernel_nm,
            torch.matmul(
                Sigma_l_inv,
                self.kernel_mn,
            )
        )
    )

variational_loss(qzx_mu, qzx_var)

Compute the Hensman loss term for the current batch. Compare eg. Eq.(7) and (10) in Jazbec21. More details in Jazbec21 SI B, proposition B.1

Source code in src/gptempest/model.py
def variational_loss(self, qzx_mu, qzx_var):
    """
    Compute the Hensman loss term for the current batch.
    Compare eg. Eq.(7) and (10) in Jazbec21.
    More details in Jazbec21 SI B, proposition B.1
    """
    # log_det_kmm = _cholesky_log_determinant(self.kernel_mm)
    # log_det_A = _cholesky_log_determinant(self.A_l)
    log_det_kmm = _robust_log_determinant(self.kernel_mm)
    log_det_A = _robust_log_determinant(self.A_l)
    KL_div = 0.5 * (-self.inducing_points.shape[0] + torch.trace(
        torch.matmul(
            self.kernel_mm_inv,
            self.A_l,
        )
    ) + torch.sum(self.mu_l * torch.matmul(
        self.kernel_mm_inv,
        self.mu_l,
    )) + log_det_kmm - log_det_A)
    # compute L3 sum term
    mean_vec = torch.matmul(
        self.kernel_nm,
        torch.matmul(
            self.kernel_mm_inv,
            self.mu_l,
        )
    )  # first term in Jazbec21 SI, B.1 first eq.
    precision = 1 / qzx_var
    k_iitilde = self._compute_diagonal_kernel(precision)
    Lambda = self._compute_Lambda()
    tr_ALambda = precision * torch.einsum(
        'bii->b',
        torch.matmul(
            self.A_l,
            Lambda,
        ),
    )
    loss_L3 = -0.5 * (
        torch.sum(k_iitilde) +
        torch.sum(tr_ALambda) +
        torch.sum(torch.log(qzx_var)) +
        qzx_mu.shape[0] * 1.8379 +
        torch.sum(precision * (qzx_mu - mean_vec)**2)
    )  # this is the L3 loss from Hensman
    return loss_L3, KL_div

gp_step(x, t)

Compute the full ELBO for a mini-batch.

Runs the encoder, computes GP kernel matrices and posterior parameters for each latent dimension, evaluates the variational loss, decodes a latent sample, and accumulates the reconstruction loss. Results are stored as attributes:

  • self.elbo: Total loss (reconstruction + beta * GP KL).
  • self.recon_loss: MSE reconstruction term (scaled by 1e6).
  • self.gp_KL: GP KL-divergence term.

Parameters:

Name Type Description Default
x

Input feature batch of shape (N, dim_input).

required
t

Timestamp batch of shape (N, 1).

required
Source code in src/gptempest/model.py
def gp_step(self, x, t):
    """Compute the full ELBO for a mini-batch.

    Runs the encoder, computes GP kernel matrices and posterior parameters
    for each latent dimension, evaluates the variational loss, decodes a
    latent sample, and accumulates the reconstruction loss. Results are
    stored as attributes:

    - ``self.elbo``: Total loss (reconstruction + beta * GP KL).
    - ``self.recon_loss``: MSE reconstruction term (scaled by 1e6).
    - ``self.gp_KL``: GP KL-divergence term.

    Args:
        x: Input feature batch of shape ``(N, dim_input)``.
        t: Timestamp batch of shape ``(N, 1)``.
    """
    qzx = self.encoder(x)
    qzx_mu = qzx['means']
    qzx_var = qzx['variances']
    self.compute_kernel_matrices(t)
    gp_mean, gp_var = [], []
    loss_L3, loss_KL = [], []
    for latent_dim in range(self.dim_latent):
        self.compute_gp_params(
            t,
            qzx_mu[:, latent_dim],
            qzx_var[:, latent_dim],
        )
        gp_mean.append(self.gp_mean_vector)
        gp_var.append(self.gp_mean_sigma)
        l_L3, l_KL = self.variational_loss(
            qzx_mu[:, latent_dim],
            qzx_var[:, latent_dim],
        )
        loss_L3.append(l_L3)
        loss_KL.append(l_KL)
    loss_L3 = torch.sum(torch.stack(loss_L3, dim=-1))
    loss_KL = torch.sum(torch.stack(loss_KL, dim=-1))
    elbo_gp = loss_L3 - (x.shape[0] / self.N_data) * loss_KL
    gp_mean = torch.stack(gp_mean, dim=1)
    gp_var = torch.stack(gp_var, dim=1)
    gp_cross_entropy = torch.sum(
        _gauss_cross_entropy(gp_mean, gp_var, qzx_mu, qzx_var)
    )
    self.gp_KL = gp_cross_entropy - elbo_gp
    latent_dist = Normal(qzx_mu, qzx_var)
    latent_samples = latent_dist.rsample()
    qxz = self.decoder(latent_samples)
    loss_L2 = nn.MSELoss(reduction='mean')
    self.recon_loss = loss_L2(qxz, x) * 1e6
    self.elbo = self.recon_loss + self.beta * self.gp_KL

train_model(dataset, train_size, learning_rate, weight_decay, batch_size, n_epochs)

Train the TEMPEST model.

Parameters:

Name Type Description Default
dataset

A TensorDataset of (features, times) as returned by :func:~gptempest.utils.load_prepare_data.

required
train_size

Fraction of the dataset to use for training. Pass 1 to use the full dataset.

required
learning_rate

AdamW learning rate; typically between 1e-3 and 1e-5.

required
weight_decay

AdamW weight decay (L2 regularisation).

required
batch_size

Number of frames per mini-batch. Larger batches give better GP estimator accuracy; values above 512 are recommended.

required
n_epochs

Number of full passes over the training data.

required
Source code in src/gptempest/model.py
def train_model(
    self,
    dataset,
    train_size,
    learning_rate,
    weight_decay,
    batch_size,
    n_epochs,
):
    """Train the TEMPEST model.

    Args:
        dataset: A ``TensorDataset`` of ``(features, times)`` as returned
            by :func:`~gptempest.utils.load_prepare_data`.
        train_size: Fraction of the dataset to use for training. Pass
            ``1`` to use the full dataset.
        learning_rate: AdamW learning rate; typically between ``1e-3``
            and ``1e-5``.
        weight_decay: AdamW weight decay (L2 regularisation).
        batch_size: Number of frames per mini-batch. Larger batches give
            better GP estimator accuracy; values above ``512`` are
            recommended.
        n_epochs: Number of full passes over the training data.
    """
    train_dataset = dataset if train_size == 1 else random_split(
        dataset=dataset,
        lengths=[
            int(len(dataset) * train_size),
            len(dataset) - int(len(dataset) * train_size)
        ]
    )[0]
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
    )
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, self.parameters()),
        lr=learning_rate,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=10,
    )
    epoch_pbar = trange(n_epochs, desc='Training')
    for nr_epoch in epoch_pbar:
        train_pbar = tqdm(
            train_loader,
            desc=f'Epoch {nr_epoch + 1} (train)',
            leave=False,
        )
        l_train_elbo, l_train_recon, l_train_gp = self.train_epoch(
            train_pbar, optimizer, is_training=True
        )
        scheduler.step(l_train_elbo)
        epoch_pbar.set_postfix({
            'Train ELBO': f'{l_train_elbo:.5f}',
            'LR': f'{optimizer.param_groups[0]["lr"]:.2e}'
        })

train_epoch(pbar, optimizer, is_training=True)

Train the model for one epoch.

Parameters:

Name Type Description Default
pbar tqdm

Progress-bar-wrapped DataLoader for the current epoch.

required
optimizer Optimizer

AdamW optimizer for model parameters.

required
is_training bool

If False, run in eval mode without gradient updates. Defaults to True.

True

Returns:

Type Description
float

Tuple of average per-frame losses (ELBO, reconstruction, GP KL)

float

for the epoch.

Source code in src/gptempest/model.py
def train_epoch(
    self,
    pbar: tqdm,
    optimizer: torch.optim.Optimizer,
    is_training: bool = True,
) -> tuple[float, float, float]:
    """Train the model for one epoch.

    Args:
        pbar: Progress-bar-wrapped DataLoader for the current epoch.
        optimizer: AdamW optimizer for model parameters.
        is_training: If ``False``, run in eval mode without gradient
            updates. Defaults to ``True``.

    Returns:
        Tuple of average per-frame losses ``(ELBO, reconstruction, GP KL)``
        for the epoch.
    """
    nr_frames, loss_elbo, loss_recon, loss_gp = 0, 0, 0, 0
    if is_training:
        self.train()
    else:
        self.eval()

    for x_batch, t_batch in pbar:
        x_batch = x_batch.clone().detach().to(self.device)
        t_batch = t_batch.clone().detach().to(self.device)
        if is_training:
            optimizer.zero_grad()
        with torch.set_grad_enabled(is_training):
            self.gp_step(x_batch, t_batch)
            loss_elbo += self.elbo.item()
            loss_recon += self.recon_loss.item()
            loss_gp += self.gp_KL.item()
            if is_training:
                self.elbo.backward()
                optimizer.step()

        nr_frames += t_batch.shape[0]
        pbar.set_postfix({
            'ELBO': f'{loss_elbo/nr_frames:.5f}',
            'Recon': f'{loss_recon/nr_frames:.5f}',
            'GP': f'{loss_gp/nr_frames:.5f}'
        })
    return loss_elbo / nr_frames, loss_recon / nr_frames, \
        loss_gp / nr_frames

extract_latent_space(dataset, batch_size)

Extract the GP-smoothed latent embedding for all frames.

Runs the encoder and GP posterior in evaluation mode over the full dataset and returns the latent coordinates for each frame.

Parameters:

Name Type Description Default
dataset TensorDataset

A TensorDataset of (features, times) as returned by :func:~gptempest.utils.load_prepare_data.

required
batch_size int

Number of frames per batch.

required

Returns:

Type Description
Tensor

Tensor of shape (N, dim_latent) containing the latent

Tensor

coordinates for all N frames.

Source code in src/gptempest/model.py
def extract_latent_space(
    self,
    dataset: torch.utils.data.TensorDataset,
    batch_size: int,
) -> torch.Tensor:
    """Extract the GP-smoothed latent embedding for all frames.

    Runs the encoder and GP posterior in evaluation mode over the full
    dataset and returns the latent coordinates for each frame.

    Args:
        dataset: A ``TensorDataset`` of ``(features, times)`` as returned
            by :func:`~gptempest.utils.load_prepare_data`.
        batch_size: Number of frames per batch.

    Returns:
        Tensor of shape ``(N, dim_latent)`` containing the latent
        coordinates for all N frames.
    """
    self.eval()
    latent_samples = []

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
    )
    for x_batch, t_batch in dataloader:
        x_batch = x_batch.clone().detach().to(self.device)
        t_batch = t_batch.clone().detach().to(self.device)

        qzx = self.encoder(x_batch)
        self.compute_kernel_matrices(t_batch)
        qzx_mu = qzx['means']
        qzx_var = qzx['variances']
        gp_mean, gp_var = [], []
        for latent_dim in range(self.dim_latent):
            self.compute_gp_params(
                t_batch,
                qzx_mu[:, latent_dim],
                qzx_var[:, latent_dim],
            )
            gp_mean.append(self.gp_mean_vector)
            gp_var.append(self.gp_mean_sigma)
        gp_mean = torch.stack(gp_mean, dim=1)
        gp_var = torch.stack(gp_var, dim=1)
        latent_samples_batch = _reparameterize(gp_mean, gp_var)
        latent_samples.append(latent_samples_batch.detach().cpu())
    return torch.cat(latent_samples, dim=0)

FeedForwardNN

gptempest.FeedForwardNN

Bases: Module

Feed-forward neural network with optional LeakyReLU activations.

Builds a fully-connected network with batch normalisation between layers. Used as the encoder and decoder backbone inside TEMPEST. The list defines neuron counts at each node, so len(layer_sizes) - 1 linear layers are created.

Parameters:

Name Type Description Default
layer_sizes list[int]

List of integers specifying the neuron count at each node, e.g. [128, 64, 32] creates two linear layers (128→64 and 64→32).

required
linear bool

If True, omit all activation functions (pure linear network). Defaults to False.

False
Example
net = FeedForwardNN([128, 64, 32])
out = net(torch.randn(16, 128))  # shape (16, 32)
Source code in src/gptempest/model.py
class FeedForwardNN(nn.Module):
    """Feed-forward neural network with optional LeakyReLU activations.

    Builds a fully-connected network with batch normalisation between layers.
    Used as the encoder and decoder backbone inside TEMPEST.
    The list defines neuron counts at each node, so ``len(layer_sizes) - 1``
    linear layers are created.

    Args:
        layer_sizes: List of integers specifying the neuron count at each
            node, e.g. ``[128, 64, 32]`` creates two linear layers
            (128→64 and 64→32).
        linear: If ``True``, omit all activation functions (pure linear
            network). Defaults to ``False``.

    Example:
        ```python
        net = FeedForwardNN([128, 64, 32])
        out = net(torch.randn(16, 128))  # shape (16, 32)
        ```
    """
    def __init__(self, layer_sizes: list[int], linear: bool = False):
        super().__init__()
        # Check if layer_sizes is a list of integers
        assert isinstance(layer_sizes, list), (
            'layer_sizes must be a list'
        )
        assert all(isinstance(size, int) for size in layer_sizes), (
            'All elements in layer_sizes must be integers'
        )
        layers = []
        for layer_nr in range(len(layer_sizes) - 1):
            layers.append(
                nn.Linear(layer_sizes[layer_nr], layer_sizes[layer_nr + 1]),
            )
            if layer_nr < len(layer_sizes) - 2:
                layers.append(nn.BatchNorm1d(layer_sizes[layer_nr + 1]))
                if not linear:
                    layers.append(nn.LeakyReLU())
        self.model = nn.Sequential(*layers)

    def add_layer(self, name, layer):
        """Add a layer or some layers to the model."""
        assert hasattr(self, 'model'), ('The model is not yet initialized.')
        self.model.add_module(f'{name}', layer)

    def forward(self, x):
        """Forward pass through the network"""
        for layer in self.model:
            x = layer(x)
        return x

add_layer(name, layer)

Add a layer or some layers to the model.

Source code in src/gptempest/model.py
def add_layer(self, name, layer):
    """Add a layer or some layers to the model."""
    assert hasattr(self, 'model'), ('The model is not yet initialized.')
    self.model.add_module(f'{name}', layer)

forward(x)

Forward pass through the network

Source code in src/gptempest/model.py
def forward(self, x):
    """Forward pass through the network"""
    for layer in self.model:
        x = layer(x)
    return x

GaussianLayer

gptempest.GaussianLayer

Bases: Module

Class for Gaussian Sampling.

Parameters

dim_input: int
    Integer defining the input dimensionality before the last layer

dim_latent : int
    Integer defining the dimension of the latent space

Returns

mu : array
    learned means of the samples in the latent space

var : array
    the corresponding log variance of the Gaussian distributions

z : array
    sampled points from the learned distribution
Source code in src/gptempest/model.py
class GaussianLayer(nn.Module):
    """Class for Gaussian Sampling.

    Parameters
    ----------
        dim_input: int
            Integer defining the input dimensionality before the last layer

        dim_latent : int
            Integer defining the dimension of the latent space

    Returns
    -------
        mu : array
            learned means of the samples in the latent space

        var : array
            the corresponding log variance of the Gaussian distributions

        z : array
            sampled points from the learned distribution

    """

    def __init__(self, layers_gaussian):
        """Initialize Gaussian layer class."""
        super().__init__()
        self.mu = FeedForwardNN(layers_gaussian)
        self.var = FeedForwardNN(layers_gaussian)
        self.var.add_layer('softplus', nn.Softplus())

    def forward(self, x):
        """Learns latent space and samples from learned Gaussian."""
        mu = self.mu(x)
        variance = self.var(x)
        variance = torch.clamp(variance, min=1e-8)
        z = _reparameterize(mu, variance)
        return mu, variance, z

__init__(layers_gaussian)

Initialize Gaussian layer class.

Source code in src/gptempest/model.py
def __init__(self, layers_gaussian):
    """Initialize Gaussian layer class."""
    super().__init__()
    self.mu = FeedForwardNN(layers_gaussian)
    self.var = FeedForwardNN(layers_gaussian)
    self.var.add_layer('softplus', nn.Softplus())

forward(x)

Learns latent space and samples from learned Gaussian.

Source code in src/gptempest/model.py
def forward(self, x):
    """Learns latent space and samples from learned Gaussian."""
    mu = self.mu(x)
    variance = self.var(x)
    variance = torch.clamp(variance, min=1e-8)
    z = _reparameterize(mu, variance)
    return mu, variance, z