Pairwise Metrics for PyTorch
Project description
TorchPairwise
This package provides highly-efficient pairwise metrics for PyTorch.
News
- v0.1.1: Added SNR distance (
torchpairwise.snr_distances
) presented in https://arxiv.org/abs/1904.02616.
Highlights
torchpairwise
is a collection of general purpose pairwise metric functions that behave similar to
torch.cdist
(which only implements $L_p$ distance).
Instead, we offer a lot more metrics ported from other packages such as
scipy.spatial.distance
and sklearn.metrics.pairwise
.
For task-specific metrics (e.g. for evaluation of classification, regression, clustering, ...), you should be in the
wrong place, please head to the TorchMetrics repo.
Written in torch
's C++ API, the main differences are that our metrics:
- are all (except some boolean distances) differentiable with backward formulas manually derived, implemented,
and verified with
torch.autograd.gradcheck
. - are batched and can exploit GPU parallelization.
- can be integrated seamlessly within PyTorch-based projects, all functions are
torch.jit.script
-able.
List of pairwise distance metrics
torchpairwise ops |
Equivalences in other libraries | Differentiable |
---|---|---|
euclidean_distances |
sklearn.metrics.pairwise.euclidean_distances |
✔️ |
haversine_distances |
sklearn.metrics.pairwise.haversine_distances |
✔️ |
manhattan_distances |
sklearn.metrics.pairwise.manhattan_distances |
✔️ |
cosine_distances |
sklearn.metrics.pairwise.cosine_distances |
✔️ |
l1_distances |
(Alias of manhattan_distances ) |
✔️ |
l2_distances |
(Alias of euclidean_distances ) |
✔️ |
lp_distances |
(Alias of minkowski_distances ) |
✔️ |
linf_distances |
(Alias of chebyshev_distances ) |
✔️ |
directed_hausdorff_distances |
scipy.spatial.distance.directed_hausdorff [^1] |
✔️ |
minkowski_distances |
scipy.spatial.distance.minkowski [^1] |
✔️ |
wminkowski_distances |
scipy.spatial.distance.wminkowski [^1] |
✔️ |
sqeuclidean_distances |
scipy.spatial.distance.sqeuclidean_distances [^1] |
✔️ |
correlation_distances |
scipy.spatial.distance.correlation [^1] |
✔️ |
hamming_distances |
scipy.spatial.distance.hamming [^1] |
❌[^2] |
jaccard_distances |
scipy.spatial.distance.jaccard [^1] |
❌[^2] |
kulsinski_distances |
scipy.spatial.distance.kulsinski [^1] |
❌[^2] |
kulczynski1_distances |
scipy.spatial.distance.kulczynski1 [^1] |
❌[^2] |
seuclidean_distances |
scipy.spatial.distance.seuclidean [^1] |
✔️ |
cityblock_distances |
scipy.spatial.distance.cityblock [^1] (Alias of manhattan_distances ) |
✔️ |
mahalanobis_distances |
scipy.spatial.distance.mahalanobis [^1] |
✔️ |
chebyshev_distances |
scipy.spatial.distance.chebyshev [^1] |
✔️ |
braycurtis_distances |
scipy.spatial.distance.braycurtis [^1] |
✔️ |
canberra_distances |
scipy.spatial.distance.canberra [^1] |
✔️ |
jensenshannon_distances |
scipy.spatial.distance.jensenshannon [^1] |
✔️ |
yule_distances |
scipy.spatial.distance.yule [^1] |
❌[^2] |
dice_distances |
scipy.spatial.distance.dice [^1] |
❌[^2] |
rogerstanimoto_distances |
scipy.spatial.distance.rogerstanimoto [^1] |
❌[^2] |
russellrao_distances |
scipy.spatial.distance.russellrao [^1] |
❌[^2] |
sokalmichener_distances |
scipy.spatial.distance.sokalmichener [^1] |
❌[^2] |
sokalsneath_distances |
scipy.spatial.distance.sokalsneath [^1] |
❌[^2] |
snr_distances |
pytorch_metric_learning.distances.SNRDistance [^1] |
✔️ |
[^1]: These metrics are not pairwise but a pairwise form can be computed by
calling scipy.spatial.distance.cdist(x1, x2, metric="[metric_name_or_callable]")
.
[^2]: These are boolean distances. hamming_distances
can be applied for floating point inputs but involves
comparison.
Other pairwise metrics or kernel functions
These metrics are usually used to compute kernel for machine learning algorithms.
torchpairwise ops |
Equivalences in other libraries | Differentiable |
---|---|---|
linear_kernel |
sklearn.metrics.pairwise.linear_kernel |
✔️ |
polynomial_kernel |
sklearn.metrics.pairwise.polynomial_kernel |
✔️ |
sigmoid_kernel |
sklearn.metrics.pairwise.sigmoid_kernel |
✔️ |
rbf_kernel |
sklearn.metrics.pairwise.rbf_kernel |
✔️ |
laplacian_kernel |
sklearn.metrics.pairwise.laplacian_kernel |
✔️ |
cosine_similarity |
sklearn.metrics.pairwise.cosine_similarity |
✔️ |
additive_chi2_kernel |
sklearn.metrics.pairwise.additive_chi2_kernel |
✔️ |
chi2_kernel |
sklearn.metrics.pairwise.chi2_kernel |
✔️ |
Custom cdist
and pdist
Furthermore, we provide a convenient wrapper function analoguous to torch.cdist
excepts that it takes a string
metric: str = "minkowski"
indicating the desired metric to be used as the third argument,
and extra metric-specific arguments are passed as keywords.
import torch, torchpairwise
# directed_hausdorff_distances is a pairwise 2d metric
x1 = torch.rand(10, 6, 3)
x2 = torch.rand(8, 5, 3)
generator = torch.Generator().manual_seed(1)
output = torchpairwise.cdist(x1, x2,
metric="directed_hausdorff",
shuffle=True, # kwargs exclusive to directed_hausdorff
generator=generator)
Note that pairwise metrics on the second table are currently not allowed keys for cdist
because they are not dist.
We have a similar plan for pdist
(which is equivalent to calling cdist(x1, x1)
but avoid storing duplicated
positions).
However, that requires a total overhaul of existing C++/Cuda kernels and won't be available soon.
Future Improvements
- Add more metrics (contact me or create a feature request issue).
- Add memory-efficient
argkmin
for retrieving pairwise neighbors' distances and indices without storing the whole pairwise distance matrix. - Add an equivalence of
torch.pdist
withmetric: str = "minkowski"
argument. - (Unlikely) Support sparse layouts.
Requirements
torch>=2.1.0
(torch>=1.9.0
if compiled from source)
Installation
From PyPI:
To install prebuilt wheels from torchpairwise, simply run:
pip install torchpairwise
Note that the Linux and Windows wheels in PyPI are compiled with torch==2.1.0
and Cuda 12.1.
We only do a non-strict version checking and a warning will be raised if torch
's and torchpairwise
's
Cuda versions do not match.
From Source:
Make sure your machine has a C++17 and a Cuda compiler installed, then clone the repo and run:
pip install .
Usage
The basic usecase is very straight-forward if you are familiar with
sklearn.metrics.pairwise
and scipy.spatial.distance
:
scikit-learn / SciPy | TorchPairwise |
---|---|
import numpy as np
import sklearn.metrics.pairwise as sklearn_pairwise
x1 = np.random.rand(10, 5)
x2 = np.random.rand(12, 5)
output = sklearn_pairwise.cosine_similarity(x1, x2)
print(output)
|
import torch
import torchpairwise
x1 = torch.rand(10, 5, device='cuda')
x2 = torch.rand(12, 5, device='cuda')
output = torchpairwise.cosine_similarity(x1, x2)
print(output)
|
import numpy as np
import scipy.spatial.distance as distance
x1 = np.random.binomial(
1, p=0.6, size=(10, 5)).astype(np.bool_)
x2 = np.random.binomial(
1, p=0.7, size=(12, 5)).astype(np.bool_)
output = distance.cdist(x1, x2, metric='jaccard')
print(output)
|
import torch
import torchpairwise
x1 = torch.bernoulli(
torch.full((10, 5), fill_value=0.6, device='cuda')).to(torch.bool)
x2 = torch.bernoulli(
torch.full((12, 5), fill_value=0.7, device='cuda')).to(torch.bool)
output = torchpairwise.jaccard_distances(x1, x2)
print(output)
|
Please check the tests folder where we will add more examples.
License
The code is released under the MIT license. See LICENSE.txt
for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for torchpairwise-0.1.1-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26a555ac644f45db2928fc3e3c685cf3bc379910cf524276656b576877c98d08 |
|
MD5 | f216630e014f077ca307b290cb99cecb |
|
BLAKE2b-256 | 214afd642ef680974691b3dbe4d2a8f2a906b074fabe1fd16dc5f77e6219e785 |
Hashes for torchpairwise-0.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 62a45a155a56e2f46e89ac1870dae7ce87102f421c84d0ea0a2df2af6c2afd4f |
|
MD5 | edefe85482e8a6dc2bdc03880aa539cf |
|
BLAKE2b-256 | 24dfa68ee15e46d590872d3e7f2cb1cddcb04abdc79fc5438946a5cb352a11ed |
Hashes for torchpairwise-0.1.1-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7bad619bd7888e1a8561758f981af1fa9f447fdf453378dcc227f1e206567de5 |
|
MD5 | 4a75dce6adeb4c6f6fdba7f84e3ae437 |
|
BLAKE2b-256 | 95bbf3033c06c24a1158f04ed38b6900c354259ccc4c5a5b8dd89dd49365cad0 |
Hashes for torchpairwise-0.1.1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 29f6171a0650522864aae5701b9234156ad197938679c0ead54e07c1b2835c7b |
|
MD5 | 967d4b565c94edc9b32248c2e21ced63 |
|
BLAKE2b-256 | fc58fce93960091c4c77f4ea05e3b2b8201c70116d2edcaada5d104771db1027 |
Hashes for torchpairwise-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ecbbe463f184dee1fba97d5820cec6cf8397ea38420d7ab9328036e8c78f27a8 |
|
MD5 | c28887ba903236d37b5fe134e54624d4 |
|
BLAKE2b-256 | 5031f3777fec5a33481a487d11799370450766d9ba4d85f720e5aee3896dc9c6 |
Hashes for torchpairwise-0.1.1-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1d104d18e02c6a5fe11bdc6751cd0aac6c597287f5dbb3b4ab1a17f528dbbb2b |
|
MD5 | 430550290f199dc0bc19268458a9d677 |
|
BLAKE2b-256 | 1b0ca7a41fc8d421f13062a0b413ced31ed4ce320e237301c2a8e478615876e7 |
Hashes for torchpairwise-0.1.1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 905cc6671cdd3641316d00a7c40c25d22988e49ec57a96a1236d16903ecc2038 |
|
MD5 | d3a9637d3a800807643f019498f5ef7b |
|
BLAKE2b-256 | 7601136fbac0c6af1680c36d0fd8b3520d5de9c253cffff68a10cc962e107180 |
Hashes for torchpairwise-0.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd8c18c39b1d1301fd54fdb306b90e37dda6680441a8d5340b00501067762658 |
|
MD5 | f8408827fdc4f7b1f7f614a2c455bc01 |
|
BLAKE2b-256 | bf2f00b28d4fe8a6297aca6e14cc9a96a0a37c7b46bd9b5f0f2a3e3cb2f694ee |
Hashes for torchpairwise-0.1.1-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a568f0e7facb6c4034422f732d2a535563e78b43c4d292dbe5c964af6498b120 |
|
MD5 | b2575eee3330adfdf3e0aa84979ad90e |
|
BLAKE2b-256 | 5e1825ba537df84b37b068bcd0c17d8b8fdf6fdf4a3876fe6bcf43522cc66d25 |
Hashes for torchpairwise-0.1.1-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 42cac2c1b73496c7b7349197a68a52a3e36037f9ed20d4d0ab81ef8f19ef1e8a |
|
MD5 | 6a10cd26152598ee2ee3a922546dcdd0 |
|
BLAKE2b-256 | 834d1505accd1b931edac22b68047865b121135ecc83b341a06b73914be90504 |
Hashes for torchpairwise-0.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 32fcb746485e4d139bd2b92fd3184287db0cb8421109ffe9e9c7285abf35bf28 |
|
MD5 | c246a1c18eb3d049d0f7a10fb64f9cd5 |
|
BLAKE2b-256 | 8353f174de1762dfe70b3df61b5f3c08f2f62fcbee1ddb794cb5ae9b056729aa |
Hashes for torchpairwise-0.1.1-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 416d46a1b1e615d0fe31dfafadcec72cc8b4a2834ac9833288214870ba3ef1ba |
|
MD5 | c1caa28c8f0a7238f0dd22d29f395466 |
|
BLAKE2b-256 | 4ac6e604bfb200cf064eef4de6cf9872c3e7a1195e75829aa94ca7606cf93274 |