TensorDot

This is a test script to showcase the tensordot syntax.

import numpy as np
import torch

from pykeops.torch import LazyTensor

M, N = 2, 10

Matrix multiplication as a special case of Tensordot

a = torch.randn(4 * 7, requires_grad=True, dtype=torch.float64)
b = torch.randn(7, requires_grad=True, dtype=torch.float64)
c = a.reshape(4, 7) @ b

A single matrix multiplication

In this case no need to use KeOps: this is a sanity check.

A = LazyTensor(a[None, None, :])
B = LazyTensor(b[None, None, :])
C = A.keops_tensordot(B, (4, 7), (7,), (1,), (0,)).sum_reduction(dim=1)

#  print(C, c)
print("Compare the two MatVecMul implementations. All good?", torch.allclose(c.flatten(), C.flatten()))

xi = torch.randn(4, dtype=torch.float64)
dC = torch.autograd.grad(C, a, xi.reshape(1, 4), retain_graph=True)[0].view(-1)
dc = torch.autograd.grad(c, a, xi, retain_graph=True)[0].view(-1)

#  print(dC, dc)
print("Compare the two MatVecMul gradient wrt a implementations. All good?", torch.allclose(dc.flatten(), dC.flatten()))

dC = torch.autograd.grad(C, b, xi.reshape(1, 4))[0].view(-1)
dc = torch.autograd.grad(c, b, xi)[0].view(-1)

#  print(dC, dc)
print("Compare the two MatVecMul gradient wrt b implementations. All good?", torch.allclose(dc.flatten(), dC.flatten()))

print('-------------------------------')

Out:

Compare the two MatVecMul implementations. All good? True
Compare the two MatVecMul gradient wrt a implementations. All good? True
Compare the two MatVecMul gradient wrt b implementations. All good? True
-------------------------------

Matrix multiplication with a sum reduction

That is where KeOps come into play.

a = torch.randn(M, 4 * 7, requires_grad=True, dtype=torch.float64)
b = torch.randn(N, 7, requires_grad=True, dtype=torch.float64)
c = torch.tensordot(a.reshape(M, 4, 7), b.reshape(N, 7), dims=([2], [1])).sum(2)

A = LazyTensor(a[:, None, :])
B = LazyTensor(b[None, :, :])
C = A.keops_tensordot(B, (4, 7), (7,), (1,), (0,)).sum_reduction(dim=1)

# print(C, c)
print("Compare the two MatVecMul with sum implementations. All good ?", torch.allclose(c.flatten(), C.flatten()))

xi = torch.randn(M, 4, dtype=torch.float64)
dCa = torch.autograd.grad(C, a, xi, retain_graph=True)[0].view(-1)
dca = torch.autograd.grad(c, a, xi, retain_graph=True)[0].view(-1)

# print(dC, dc)
print("Compare the two MatVecMul with sum gradient wrt a implementations. All good ?",
      torch.allclose(dca.flatten(), dCa.flatten()))

dCb = torch.autograd.grad(C, b, xi)[0].view(-1)
dcb = torch.autograd.grad(c, b, xi)[0].view(-1)

#  print(dC, dc)
print("Compare the two MatVecMul with sum gradient wrt b implementations. All good ?",
      torch.allclose(dcb.flatten(), dCb.flatten()))

print('-------------------------------')

Out:

Compare the two MatVecMul with sum implementations. All good ? True
Compare the two MatVecMul with sum gradient wrt a implementations. All good ? True
Compare the two MatVecMul with sum gradient wrt b implementations. All good ? True
-------------------------------

Matrix-Matrix multiplication as a special case of Tensordot

a = torch.randn(4 * 7, requires_grad=True, dtype=torch.float64)
b = torch.randn(7 * 2, requires_grad=True, dtype=torch.float64)
c = a.reshape(4, 7) @ b.reshape(7, 2)

A = LazyTensor(a[None, None, :])
B = LazyTensor(b[None, None, :])
C = A.keops_tensordot(B, (4, 7), (7, 2), (1,), (0,)).sum_reduction(dim=1)

#  print(C, c)
print("Compare the two MatMul implementations. All good?", torch.allclose(c.flatten(), C.flatten()))

xi = torch.randn(4 * 2, dtype=torch.float64)
dC = torch.autograd.grad(C, a, xi.reshape(1, 4 * 2), retain_graph=True)[0].view(-1)
dc = torch.autograd.grad(c, a, xi.reshape(4, 2), retain_graph=True)[0].view(-1)

#  print(dC, dc)
print("Compare the two MatMul gradient wrt a implementations. All good?", torch.allclose(dc.flatten(), dC.flatten()))

dCb = torch.autograd.grad(C, b, xi.reshape(1, 4 * 2))[0].view(-1)
dcb = torch.autograd.grad(c, b, xi.reshape(4, 2))[0].view(-1)

# print(dCb, dcb)
print("Compare the two MatMul gradient wrt b implementations. All good?", torch.allclose(dcb.flatten(), dCb.flatten()))

print('-------------------------------')

Out:

Compare the two MatMul implementations. All good? True
Compare the two MatMul gradient wrt a implementations. All good? True
Compare the two MatMul gradient wrt b implementations. All good? True
-------------------------------

Tensordot in keops (generic case)

A fisrt example

First, let us start with a standard torch implementation. We contract two tensor along a common axis of size 7. Then, a reduction is performed alog the dimension of size N.

x = torch.randn(M, 4, 7, 3, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 7, 2, requires_grad=True, dtype=torch.float64)

f_torch = torch.tensordot(x, y, dims=([2], [1]))  # now is shape (M, 4, 3, N, 2)
sum_f_torch2 = f_torch.sum(3)  # ... yielding a result of dimension (M,4*3*2)

# In KeOps, we forgot the first reduction axis (size M and N respectively). We then need to tell the compiler not only
# the contration axis (1 and 0 respectively both of dimension 7) but the shapes (4,7,3) and (7,2) as well,
# keeping in mind that the 2 actual first axis of x and y (reduction axis) are ignored so the result has
# shape (M,4*3*2) or (N, 4*3*2) depending on the chosen reduction axis.

f_keops = LazyTensor(x.reshape(M, 1, 4 * 7 * 3)).keops_tensordot(LazyTensor(y.reshape(1, N, 7 * 2)), (4, 7, 3), (7, 2),
                                                                 (1,), (0,))
sum_f_keops = f_keops.sum_reduction(dim=1)  # reduction is perform along second axis
# print(sum_f_keops.flatten())                        # ... yielding a result of dimension (M,4*3*2)

print("Compare the two tensordot implementation. All good ?",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten(), rtol=1e-4))

Out:

Compare the two tensordot implementation. All good ? True

As before, let us check the gradients

e = torch.randn(M, 4 * 3 * 2, dtype=torch.float64)
Ee = e.reshape(M, 4, 3, 2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e, retain_graph=True)[0].squeeze().numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, Ee, retain_graph=True)[0].squeeze().numpy()

# print(grad_keops[0,:,:,:])
# print(grad_torch[0,:,:,:])
print("Check gradient wrt x. All good ?", np.allclose(grad_keops.flatten(), grad_torch.flatten()))

# tmp = torch.tensordot(Ee,y, dims=([3], [2])).sum(3).detach().numpy()
# print("grad_keops and tmp are the same? ", np.allclose(tmp.flatten(), grad_keops.flatten()))

# print("grad_torch and tmp are the same? ",  np.allclose(grad_torch , np.moveaxis(tmp, [0,1,2,3], [0,1,3,2])))
grad_keops = torch.autograd.grad(sum_f_keops, y, e)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, Ee)[0].numpy()

#  print(grad_keops[:1])
#  print(grad_torch[:1])
print("Check gradient wrt y. All good ?", np.allclose(grad_keops.flatten(), grad_torch.flatten()))

print('-------------------------------')

Out:

Check gradient wrt x. All good ? True
Check gradient wrt y. All good ? True
-------------------------------

A Second example

Torch version

x = torch.randn(M, 4, 3, 7, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 7, 2, requires_grad=True, dtype=torch.float64)

f_torch = torch.tensordot(x, y, dims=([3], [1]))  # now is shape (M, 4, 3, N, 2)
sum_f_torch2 = f_torch.sum(3)  # ... yielding a result of dimension (M,4,3,2)

And corresponding KeOps version

f_keops = LazyTensor(x.reshape(M, 1, 4 * 3 * 7)).keops_tensordot(LazyTensor(y.reshape(1, N, 7 * 2)), (4, 3, 7), (7, 2),
                                                                 (2,), (0,))
sum_f_keops = f_keops.sum_reduction(dim=1)  # reduction is perform along second axis
# print(sum_f_keops.shape)                        # ... yielding a result of dimension (M,4*3*2)

print("Compare the two tensordot implementation. All good ?",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten(), rtol=1e-4))

# checking gradients
e = torch.randn(M, 4 * 3 * 2, dtype=torch.float64)
Ee = e.reshape(M, 4, 3, 2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e, retain_graph=True)[0].squeeze().numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, Ee, retain_graph=True)[0].squeeze().numpy()

#  print(grad_keops[0,:,:,:])
#  print(grad_torch[0,:,:,:])
print("Compare the two gradient x tensordot implementation. All good ?",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e, retain_graph=True)[0].squeeze().numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, Ee, retain_graph=True)[0].squeeze().numpy()

#  print(grad_keops[0,:,:,:])
#  print(grad_torch[0,:,:,:])
print("Compare the two gradient y tensordot implementation. All good ?",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ? True
Compare the two gradient x tensordot implementation. All good ? True
Compare the two gradient y tensordot implementation. All good ? True
------------------------------------------

A Third example

x = torch.randn(M, 4, 3, 2, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 4, 2, requires_grad=True, dtype=torch.float64)

xshape, yshape = x.shape[1:], y.shape[1:]
f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
    xshape,
    yshape,
    (0, 2),
    (0, 1)
)
sum_f_keops = f_keops.sum_reduction(dim=1)
sum_f_torch2 = torch.tensordot(x, y, dims=([1, 3], [1, 2])).sum(2)
# sum_f_torch2 = torch.tensordot(x, y, dims=([3], [1])).sum(3)

print("Compare the two tensordot implementation. All good ????",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()
print("Compare the two gradient x tensordot implementation. is All good ????",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()
print("Compare the two gradient y tensordot implementation. is All good ????",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ???? True
Compare the two gradient x tensordot implementation. is All good ???? True
Compare the two gradient y tensordot implementation. is All good ???? True
------------------------------------------

A Fourth example

x = torch.randn(M, 2, 3, 4, 2, 2, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, 5, 3, 2, requires_grad=True, dtype=torch.float64)

xshape, yshape = x.shape[1:], y.shape[1:]
f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
    xshape,
    yshape,
    (0, 1, 4),
    (0, 3, 4)
)
sum_f_keops = f_keops.sum_reduction(dim=1)
sum_f_torch2 = torch.tensordot(x, y, dims=([1, 2, 5], [1, 4, 5])).sum(3)
# sum_f_torch2 = torch.tensordot(x, y, dims=([3], [1])).sum(3)

print("Compare the two tensordot implementation. All good ????!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()
print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ????! True
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------

A Fifth example

x = torch.randn(M, 2, 3, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, 5, requires_grad=True, dtype=torch.float64)

xshape, yshape = x.shape[1:], y.shape[1:]
f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
    xshape,
    yshape,
    (2, 0),
    (1, 0)
)
sum_f_keops = f_keops.sum_reduction(dim=1)
sum_f_torch2 = torch.tensordot(x, y, dims=([3, 1], [2, 1])).sum(2)
# sum_f_torch2 = torch.tensordot(x, y, dims=([3], [1])).sum(3)

print("Compare the two tensordot implementation. All good ????!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()
print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ????! True
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------

A Sixth example

x = torch.randn(M, 2, 3, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 4, 2, requires_grad=True, dtype=torch.float64)

xshape, yshape = x.shape[1:], y.shape[1:]
f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
    xshape,
    yshape,
    (2, 0),
    (0, 1)
)
sum_f_keops = f_keops.sum_reduction(dim=1)
sum_f_torch2 = torch.tensordot(x, y, dims=([3, 1], [1, 2])).sum(2)
# sum_f_torch2 = torch.tensordot(x, y, dims=([3], [1])).sum(3)

print("Compare the two tensordot implementation. All good ????",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1))[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e)[0].numpy()
# print(grad_keops)
#  print(grad_torch)
print("Compare the two gradient y tensordot implementation. All good ????",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ???? True
Compare the two gradient x tensordot implementation. All good ???? True
Compare the two gradient y tensordot implementation. All good ???? True
------------------------------------------

A Seventh example

x = torch.randn(M, 2, 3, 2, 2, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, 2, 3, 2, 3, requires_grad=True, dtype=torch.float64)

xshape, yshape = x.shape[1:], y.shape[1:]
f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
    xshape,
    yshape,
    (4, 0, 2),
    (1, 4, 2)
)
sum_f_keops = f_keops.sum_reduction(dim=1)
sum_f_torch2 = torch.tensordot(x, y, dims=([5, 1, 3], [2, 5, 3])).sum(3)
# sum_f_torch2 = torch.tensordot(x, y, dims=([3], [1])).sum(3)

print("Compare the two tensordot implementation. All good ????!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()
print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ????! True
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------
def my_tensordort_perm(a, b, dims=None, perm=None):
    # print(torch.tensordot(a, b, dims=dims).sum(3).shape)
    return torch.tensordot(a, b, dims=dims).sum(3).permute(perm)


def invert_permutation_numpy(permutation):
    return np.arange(len(permutation))[np.argsort(permutation)]


x = torch.randn(M, 2, 3, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, requires_grad=True, dtype=torch.float64)

dimfa, dimfb = x.shape[1:], y.shape[1:]
contfa, contfb = [3], [2]
keepfa, keepfb = [item - 1 for item in [1, 2, 3] if item not in contfa], [item for item in [1, 2] if item not in contfb]
#  contfa, contfb = [2, 3], [1, 2]
n = len(dimfa) + len(dimfb) - 2 * len(contfa)
# perm = [int(i) for i in torch.randperm(n)]
perm = [2, 0, 1]
# perm = [2, 1, 3, 0]
#  perm = [1, 0]


perm_torch = (0,) + tuple([(i + 1) for i in invert_permutation_numpy(perm)])
sum_f_torch2 = my_tensordort_perm(x, y, dims=(contfa, contfb), perm=perm_torch)  # 1, 2,3,5,4 -> 1, 5,3,4,2

f_keops = LazyTensor(x.reshape(M, 1, int(np.array((dimfa)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(dimfb).prod()))),
    dimfa,
    dimfb,
    tuple(np.array(contfa) - 1),
    tuple(np.array(contfb) - 1),
    tuple(perm)
)
sum_f_keops = f_keops.sum_reduction(dim=1)

print("Compare the two tensordot implementation. All good ????!!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
# grad_torch2 = my_tensordort_perm(e, y, dims=([4,2], keepfb), perm=[0,1,2])

grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))
# print("Compare the two gradient x tensordot implementation. All good ????!",
#        np.allclose(grad_torch2.detach().numpy(), grad_torch, rtol=1e-4))
grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()

#  grad_torch2 = my_tensordort_perm(e, x, dims=([1,3], [1,2]), perm=[0,1,2,3]).permute(perm)
# grad_torch2 = my_tensordort_perm(e, x, dims=([1,3], [1,2]), perm=perm)

print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))
#  print("Compare the two gradient y tensordot implementation. All good ????!",
#        np.allclose(grad_torch2.detach().numpy(), grad_torch, rtol=1e-4))
print('------------------------------------------')

x = torch.randn(M, 2, 3, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, 5, requires_grad=True, dtype=torch.float64)

dimfa, dimfb = x.shape[1:], y.shape[1:]
contfa, contfb = [3], [2]
keepfa, keepfb = [item - 1 for item in [1, 2, 3] if item not in contfa], [item for item in [1, 2, 3] if
                                                                          item not in contfb]
#  contfa, contfb = [2, 3], [1, 2]
n = len(dimfa) + len(dimfb) - 2 * len(contfa)
# perm = [int(i) for i in torch.randperm(n)]
perm = [0, 2, 3, 1]
# perm = [2, 1, 3, 0]
#  perm = [1, 0]


perm_torch = (0,) + tuple([(i + 1) for i in invert_permutation_numpy(perm)])
sum_f_torch2 = my_tensordort_perm(x, y, dims=(contfa, contfb), perm=perm_torch)  # 1, 2,3,5,4 -> 1, 5,3,4,2

f_keops = LazyTensor(x.reshape(M, 1, int(np.array((dimfa)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(dimfb).prod()))),
    dimfa,
    dimfb,
    tuple(np.array(contfa) - 1),
    tuple(np.array(contfb) - 1),
    tuple(perm)
)
sum_f_keops = f_keops.sum_reduction(dim=1)

print("Compare the two tensordot implementation. All good ????!!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
#  grad_torch2 = my_tensordort_perm(e, y, dims=([4,2], keepfb), perm=[0,1,2,3])

grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))
#  print("Compare the two gradient x tensordot implementation. All good ????!",
#        np.allclose(grad_torch2.detach().numpy(), grad_torch, rtol=1e-4))


grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0]
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0]
# grad_torch2 = my_tensordort_perm(e, x, dims=([1,3], [1,2]), perm=perm)

print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.numpy().flatten(), grad_torch.numpy().flatten(), rtol=1e-4))
# print("Compare the two gradient y tensordot implementation. All good ????!",
#      np.allclose(grad_torch2.detach().numpy(), grad_torch, rtol=1e-4))
print('------------------------------------------')

x = torch.randn(M, 2, 3, 2, 2, 4, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 4, 2, 3, 2, 3, requires_grad=True, dtype=torch.float64)

dimfa, dimfb = x.shape[1:], y.shape[1:]
contfa, contfb = [5, 1, 3], [2, 5, 3]
n = len(dimfa) + len(dimfb) - 2 * len(contfa)
# perm_id = [int(i) for i in range(n+1)]

# perm = [int(i) for i in torch.randperm(n)]
# perm = [0,2,1,4,3]
perm = [4, 3, 2, 0, 1]
perm_torch = (0,) + tuple([(i + 1) for i in invert_permutation_numpy(perm)])
sum_f_torch2 = my_tensordort_perm(x, y, dims=(contfa, contfb), perm=perm_torch)

# print(sum_f_torch2.shape)

f_keops = LazyTensor(x.reshape(M, 1, int(np.array((dimfa)).prod()))).keops_tensordot(
    LazyTensor(y.reshape(1, N, int(np.array(dimfb).prod()))),
    dimfa,
    dimfb,
    tuple(np.array(contfa) - 1),
    tuple(np.array(contfb) - 1),
    tuple(perm)
)
sum_f_keops = f_keops.sum_reduction(dim=1)

print("Compare the two tensordot implementation. All good ????!!",
      torch.allclose(sum_f_keops.flatten(), sum_f_torch2.flatten()))

# checking gradients
e = torch.randn_like(sum_f_torch2)
grad_keops = torch.autograd.grad(sum_f_keops, x, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, x, e, retain_graph=True)[0].numpy()
# grad_torch2 = my_tensordort_perm(e, y, dims=([1,2,3], [1,4,6]), perm=[0,1,2,3,4,5]).permute(3)

print("Compare the two gradient x tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

grad_keops = torch.autograd.grad(sum_f_keops, y, e.reshape(M, -1), retain_graph=True)[0].numpy()
grad_torch = torch.autograd.grad(sum_f_torch2, y, e, retain_graph=True)[0].numpy()
print("Compare the two gradient y tensordot implementation. All good ????!",
      np.allclose(grad_keops.flatten(), grad_torch.flatten(), rtol=1e-4))

print('------------------------------------------')

Out:

Compare the two tensordot implementation. All good ????!! True
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------
Compare the two tensordot implementation. All good ????!! True
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------
Compare the two tensordot implementation. All good ????!! True
Compiling libKeOpstorch376f25a6a0 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch376f25a6a0:
       formula: Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,96,0), Var(1,288,1), Ind(2,3,2,2,4), Ind(2,4,2,3,2,3), Ind(4,0,2), Ind(1,4,2), Ind(4,3,2,0,1)),0), Var(1,288,1), Var(2,108,0), Var(3,108,0))
       aliases: Var(0,96,0); Var(1,288,1); Var(2,108,0); Var(3,108,0);
       dtype  : float64
... Done.
Compare the two gradient x tensordot implementation. All good ????! True
Compare the two gradient y tensordot implementation. All good ????! True
------------------------------------------

Using gradcheck

# def my_tensordot(x,y):
# f_keops = LazyTensor(x.reshape(M, 1, 4 * 3 * 7)).keops_tensordot(LazyTensor(y.reshape(1, N, 7 * 2)), (4, 3, 7),
# (7, 2), (2,), (0,))
# return f_keops.sum_reduction(dim=1)

# print(torch.autograd.gradcheck(my_tensordot, [x,y]))

def my_tensordot2(x, y):
    xshape, yshape = x.shape[1:], y.shape[1:]
    f_keops = LazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod()))
                         ).keops_tensordot(LazyTensor(y.reshape(1, N, int(np.array(yshape).prod()))),
                                           xshape,
                                           yshape,
                                           (2, 0),  # (2,0,1),
                                           (0, 1)  #  (0,3,2)
                                           )
    return f_keops.sum_reduction(dim=1)


x = torch.randn(M, 2, 2, 2, requires_grad=True, dtype=torch.float64)
y = torch.randn(N, 2, 2, requires_grad=True, dtype=torch.float64)
print(torch.autograd.gradcheck(my_tensordot2, [x, y], atol=1e-5, rtol=1e-5))
print(torch.autograd.gradgradcheck(my_tensordot2, [x, y], atol=1e-5, rtol=1e-5))

Out:

Compiling libKeOpstorch997ac076d4 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch997ac076d4:
       formula: Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0)
       aliases: Var(0,8,0); Var(1,4,1);
       dtype  : float64
... Done.
Compiling libKeOpstorch3fe4aed59f in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch3fe4aed59f:
       formula: Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(0,8,0), Var(2,2,0), Var(3,2,0))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0);
       dtype  : float64
... Done.
Compiling libKeOpstorch065b387b68 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch065b387b68:
       formula: Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(1,4,1), Var(2,2,0), Var(3,2,0))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0);
       dtype  : float64
... Done.
True
Compiling libKeOpstorch5dcdf4e702 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch5dcdf4e702:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(0,8,0), Var(2,2,0), Var(3,2,0)), Var(0,8,0), Var(4,8,0), Var(5,8,0))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,8,0); Var(5,8,0);
       dtype  : float64
... Done.
Compiling libKeOpstorch83c543dde5 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch83c543dde5:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(0,8,0), Var(2,2,0), Var(3,2,0)), Var(1,4,1), Var(4,8,0), Var(5,8,0))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,8,0); Var(5,8,0);
       dtype  : float64
... Done.
Compiling libKeOpstorch5423f96494 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch5423f96494:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(0,8,0), Var(2,2,0), Var(3,2,0)), Var(2,2,0), Var(4,8,0), Var(5,8,0))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,8,0); Var(5,8,0);
       dtype  : float64
... Done.
Compiling libKeOpstorchd122daadf9 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorchd122daadf9:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(1,4,1), Var(2,2,0), Var(3,2,0)), Var(0,8,0), Var(4,4,1), Var(5,4,1))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,4,1); Var(5,4,1);
       dtype  : float64
... Done.
Compiling libKeOpstorch08e0a85e43 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch08e0a85e43:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(1,4,1), Var(2,2,0), Var(3,2,0)), Var(1,4,1), Var(4,4,1), Var(5,4,1))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,4,1); Var(5,4,1);
       dtype  : float64
... Done.
Compiling libKeOpstorch339cc6b358 in /home/bcharlier/tmp/libkeops/pykeops/common/../build//build-libKeOpstorch339cc6b358:
       formula: Grad_WithSavedForward(Grad_WithSavedForward(Sum_Reduction(TensorDot(Var(0,8,0), Var(1,4,1), Ind(2,2,2), Ind(2,2), Ind(2,0), Ind(0,1)),0), Var(1,4,1), Var(2,2,0), Var(3,2,0)), Var(2,2,0), Var(4,4,1), Var(5,4,1))
       aliases: Var(0,8,0); Var(1,4,1); Var(2,2,0); Var(3,2,0); Var(4,4,1); Var(5,4,1);
       dtype  : float64
... Done.
True

Total running time of the script: ( 5 minutes 18.071 seconds)

Gallery generated by Sphinx-Gallery