Skip to main content

Pytorch version of MUSE

Project description

Example Package

MUSE-PyTorch

This repo is a pytorch-lightning implementation of the official MUSE: multi-modality structured embedding for spatially resolved transcriptomics analysis:

Bao, F., Deng, Y., Wan, S. et al. Integrative spatial analysis of cell morphologies and transcriptional states with MUSE. Nat Biotechnol (2022). https://doi.org/10.1038/s41587-022-01251-z

Requirements

  • numpy==1.22.3
  • online_triplet_loss==0.0.6
  • pandas==1.4.2
  • PhenoGraph==1.5.7
  • pytorch_lightning==1.6.3
  • scipy==1.8.0
  • torch==1.11.0

Installation

To install MUSE PyTorch package, use

pip install muse_pytorch

Usage

import muse_pytorch as muse

The library exposes the same fit_predict method as the orignial one.

z, x_hat, y_hat, latent_x, latent_y = muse.fit_predict(trans_features,
                                                       morph_features,
                                                       trans_labels,
                                                       morph_labels,
                                                       init_epochs=3, 
                                                       refine_epochs=3, 
                                                       cluster_epochs=6, 
                                                       cluster_update_epoch=2, 
                                                       joint_latent_dim=50, 
                                                       batch_size=512)

The method expects the same parameters, and more.

Parameters:

  trans_features:           input for transcript modality; matrix of  n * p, where n = number of cells, p = number of genes.
  morph_features:           input for morphological modality; matrix of n * q, where n = number of cells, q is the feature dimension.
  trans_labels:             initial reference cluster label for transcriptional modality.
  morph_labels:             inital reference cluster label for morphological modality.
  latent_dim:               size of the latent dimension for the single modalities
  joint_latent_dim:         size of the latent dimension of the joint representation
  lambda_reg:               factor for the regularisation term in the loss function
  lambda_sup:               factor for the self-supervised term in the loss function
  lr:                       learning rate for the optimizer
  init_epochs:              epochs for the initializing phase
  refine_epochs:            epochs for the refining phase
  cluster_epochs:           epochs for the clustering phase
  cluster_update_epoch:     interval after which the single modality clusters will be updated      
  batch_size:               batch size for the dataloaders

Outputs:

  z:            joint latent representation learned by MUSE.
  x_hat:        reconstructed feature matrix corresponding to input data_x.
  y_hat:        reconstructed feature matrix corresponding to input data_y.
  h_x:          modality-specific latent representation corresponding to data_x.
  h_y:          modality-specific latent representation corresponding to data_y.

On top of this, it is also possible to further personalize the training by importing PyTorch Lightning Module & Datamodule

from muse_pytorch import MUSE, MUSEDataModule

For a complete description of the project please head to the original repo

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

muse-pytorch-0.0.3.tar.gz (9.4 kB view hashes)

Uploaded Source

Built Distribution

muse_pytorch-0.0.3-py3-none-any.whl (9.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page