import math
import matplotlib.image
import torch
from imodal.Utilities.usefulfunctions import points2pixels, points2voxels_affine
from imodal.Utilities.aabb import AABB
[docs]def load_greyscale_image(filename, origin='lower', dtype=None, device=None):
"""Load grescale image from disk as an array of normalised float values.
Parameters
----------
filename : str
Filename of the image to load.
dtype : torch.dtype
dtype of the returned image tensor.
device : torch.device
Device on which the image tensor will be loaded.
Returns
-------
torch.Tensor
[width, height] shaped tensor representing the loaded image.
"""
if origin != 'upper' and origin != 'lower':
raise RuntimeError("Origin type {origin} not implemented!".format(origin=origin))
def _set_origin(bitmap, origin):
if origin == 'upper':
return bitmap
else:
return bitmap.flip(0)
image = matplotlib.image.imread(filename)
if(image.ndim == 2):
return _set_origin(torch.tensor(1. - image, dtype=dtype, device=device), origin)
elif(image.ndim == 3):
return _set_origin(torch.tensor(1. - image[:, :, 0], dtype=dtype, device=device), origin)
else:
raise NotImplementedError
[docs]def sample_from_greyscale(image, threshold, centered=False, normalise_weights=False, normalise_position=True):
"""Sample points from a greyscale image.
Points are defined as a (position, weight) tuple.
Parameters
----------
image : torch.Tensor
Tensor of shape [width, height] representing the image from which we will sample the points.
threshold : float
Minimum pixel value (i.e. point weight).
centered : bool, default=False
If true, center the sampled points.
normalise_weights : bool, default=False
If true, normalise weight values, such that :math:'\alpha_i = \frac{\alpha_i}{\sum_k \alpha_k}'
normalise_position : bool, default=True
If true, normalise point position such that all points live in the unit square.
Returns
-------
torch.Tensor, torch.Tensor
Two tensors representing point position (of shape [N, dim]) and weight (of shape [N]), in this order, with :math:'N' the number of points.
"""
# Compute number of output points
length = torch.sum(image >= threshold)
pos = torch.zeros([length, 2])
alpha = torch.zeros([length])
width_weight = 1.
height_weight = 1.
if(normalise_position):
width_weight = 1./image.shape[0]
height_weight = 1./image.shape[1]
count = 0
pixels = AABB(0., image.shape[0], 0, image.shape[1]).fill_count(image.shape)
for pixel in pixels:
pixel_value = image[math.floor(pixel[0]), math.floor(pixel[1])]
if pixel_value < threshold:
continue
pos[count] = pixel
alpha[count] = pixel_value
count = count + 1
if(centered):
pos = pos - torch.mean(pos, dim=0)
if(normalise_weights):
alpha = alpha/torch.sum(alpha)
return pos, alpha
[docs]def load_and_sample_greyscale(filename, threshold=0., centered=False, normalise_weights=True):
"""Load a greyscale and sample points from it."""
image = load_greyscale_image(filename)
return sample_from_greyscale(image, threshold, centered, normalise_weights)
[docs]def mask_to_indices(mask):
indices = torch.meshgrid([torch.arange(size) for size in mask.shape])
return torch.stack([indice[mask] for indice in indices]).T
[docs]def interpolate_image(image, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
"""
Simple wrapper around torch.nn.functional.interpolate() for 2D images.
"""
interpolated = torch.nn.functional.interpolate(image.view((1, 1) + image.shape), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
return interpolated.view(interpolated.shape[2:])