Note
Go to the end to download the full example code
Input-Output with brain tractograms
import argparse
import os.path
import sys
import pdb
import numpy as np
from dipy.tracking.streamline import set_number_of_points
import vtk
from vtk.util import numpy_support as ns
try:
# Python 2
from itertools import izip
except ImportError:
# Python 3
izip = zip
try:
# Python 2
xrange
except NameError:
# Python 3, xrange is now named range
xrange = range
def read_vtk(filename):
if filename.endswith("xml") or filename.endswith("vtp"):
polydata_reader = vtk.vtkXMLPolyDataReader()
else:
polydata_reader = vtk.vtkPolyDataReader()
polydata_reader.SetFileName(filename)
polydata_reader.Update()
polydata = polydata_reader.GetOutput()
return vtkPolyData_to_tracts(polydata)
def vtkPolyData_to_tracts(polydata):
result = {}
result["lines"] = ns.vtk_to_numpy(polydata.GetLines().GetData())
result["points"] = ns.vtk_to_numpy(polydata.GetPoints().GetData())
result["numberOfLines"] = polydata.GetNumberOfLines()
data = {}
if polydata.GetPointData().GetScalars():
data["ActiveScalars"] = polydata.GetPointData().GetScalars().GetName()
result["Scalars"] = polydata.GetPointData().GetScalars()
if polydata.GetPointData().GetVectors():
data["ActiveVectors"] = polydata.GetPointData().GetVectors().GetName()
if polydata.GetPointData().GetTensors():
data["ActiveTensors"] = polydata.GetPointData().GetTensors().GetName()
for i in xrange(polydata.GetPointData().GetNumberOfArrays()):
array = polydata.GetPointData().GetArray(i)
np_array = ns.vtk_to_numpy(array)
if np_array.ndim == 1:
np_array = np_array.reshape(len(np_array), 1)
data[polydata.GetPointData().GetArrayName(i)] = np_array
result["pointData"] = data
tracts, data = vtkPolyData_dictionary_to_tracts_and_data(result)
return tracts, data
def vtkPolyData_dictionary_to_tracts_and_data(dictionary):
dictionary_keys = set(("lines", "points", "numberOfLines"))
if not dictionary_keys.issubset(dictionary.keys()):
raise ValueError(
"Dictionary must have the keys lines and points" + repr(dictionary.keys())
)
tract_data = {}
tracts = []
lines = np.asarray(dictionary["lines"]).squeeze()
points = dictionary["points"]
actual_line_index = 0
number_of_tracts = dictionary["numberOfLines"]
original_lines = []
for l in xrange(number_of_tracts):
tracts.append(
points[
lines[
actual_line_index
+ 1 : actual_line_index
+ lines[actual_line_index]
+ 1
]
]
)
original_lines.append(
np.array(
lines[
actual_line_index
+ 1 : actual_line_index
+ lines[actual_line_index]
+ 1
],
copy=True,
)
)
actual_line_index += lines[actual_line_index] + 1
if "pointData" in dictionary:
point_data_keys = [
it[0]
for it in dictionary["pointData"].items()
if isinstance(it[1], np.ndarray)
]
for k in point_data_keys:
array_data = dictionary["pointData"][k]
if not k in tract_data:
tract_data[k] = [array_data[f] for f in original_lines]
else:
np.vstack(tract_data[k])
tract_data[k].extend(
[array_data[f] for f in original_lines[-number_of_tracts:]]
)
return tracts, tract_data
def save_vtk(filename, tracts, lines_indices=None, scalars=None):
lengths = [len(p) for p in tracts]
line_starts = ns.numpy.r_[0, ns.numpy.cumsum(lengths)]
if lines_indices is None:
lines_indices = [
ns.numpy.arange(length) + line_start
for length, line_start in izip(lengths, line_starts)
]
ids = ns.numpy.hstack(
[ns.numpy.r_[c[0], c[1]] for c in izip(lengths, lines_indices)]
)
vtk_ids = ns.numpy_to_vtkIdTypeArray(ids, deep=True)
cell_array = vtk.vtkCellArray()
cell_array.SetCells(len(tracts), vtk_ids)
points = ns.numpy.vstack(tracts).astype(
ns.get_vtk_to_numpy_typemap()[vtk.VTK_DOUBLE]
)
points_array = ns.numpy_to_vtk(points, deep=True)
poly_data = vtk.vtkPolyData()
vtk_points = vtk.vtkPoints()
vtk_points.SetData(points_array)
poly_data.SetPoints(vtk_points)
poly_data.SetLines(cell_array)
poly_data.BuildCells()
if filename.endswith(".xml") or filename.endswith(".vtp"):
writer = vtk.vtkXMLPolyDataWriter()
writer.SetDataModeToBinary()
else:
writer = vtk.vtkPolyDataWriter()
writer.SetFileTypeToBinary()
writer.SetFileName(filename)
if hasattr(vtk, "VTK_MAJOR_VERSION") and vtk.VTK_MAJOR_VERSION > 5:
writer.SetInputData(poly_data)
else:
writer.SetInput(poly_data)
writer.Write()
def save_vtk_labels(filename, tracts, scalars, lines_indices=None):
lengths = [len(p) for p in tracts]
line_starts = ns.numpy.r_[0, ns.numpy.cumsum(lengths)]
if lines_indices is None:
lines_indices = [
ns.numpy.arange(length) + line_start
for length, line_start in izip(lengths, line_starts)
]
ids = ns.numpy.hstack(
[ns.numpy.r_[c[0], c[1]] for c in izip(lengths, lines_indices)]
)
vtk_ids = ns.numpy_to_vtkIdTypeArray(ids, deep=True)
cell_array = vtk.vtkCellArray()
cell_array.SetCells(len(tracts), vtk_ids)
points = ns.numpy.vstack(tracts).astype(
ns.get_vtk_to_numpy_typemap()[vtk.VTK_DOUBLE]
)
points_array = ns.numpy_to_vtk(points, deep=True)
poly_data = vtk.vtkPolyData()
vtk_points = vtk.vtkPoints()
vtk_points.SetData(points_array)
poly_data.SetPoints(vtk_points)
poly_data.SetLines(cell_array)
poly_data.GetPointData().SetScalars(ns.numpy_to_vtk(scalars))
poly_data.BuildCells()
# poly_data.SetScalars(scalars)
if filename.endswith(".xml") or filename.endswith(".vtp"):
writer = vtk.vtkXMLPolyDataWriter()
writer.SetDataModeToBinary()
else:
writer = vtk.vtkPolyDataWriter()
writer.SetFileTypeToBinary()
writer.SetFileName(filename)
if hasattr(vtk, "VTK_MAJOR_VERSION") and vtk.VTK_MAJOR_VERSION > 5:
writer.SetInputData(poly_data)
else:
writer.SetInput(poly_data)
writer.Write()
def streamlines_resample(streamlines, perc=None, npoints=None):
if perc is not None:
resampled = [
set_number_of_points(s, int(len(s) * perc / 100.0)) for s in streamlines
]
else:
resampled = [set_number_of_points(s, npoints) for s in streamlines]
return resampled
def check_ext(value):
filename, file_extension = os.path.splitext(value)
if file_extension in (".vtk", ".xml", ".vtp"):
return value
else:
raise argparse.ArgumentTypeError(
"Invalid file extension (file format supported: vtk,xml,vtp): %r" % value
)
def check_resample(value):
try:
t = float(value)
if 0 <= t < 100:
return value
else:
raise argparse.ArgumentTypeError(
"Invalid resampling (must be between 0 and 100): %r" % value
)
except ValueError:
raise argparse.ArgumentTypeError("Invalid resampling value: %r" % value)
def setup():
parser = argparse.ArgumentParser()
parser.add_argument(
"Input_Tractogram", help="Name of the input tractography file", type=check_ext
)
parser.add_argument(
"Percentage", help="Resampling percentage value", type=check_resample
)
args = parser.parse_args()
return args.Input_Tractogram, float(args.Percentage)
if __name__ == "__main__":
in_file, perc = setup()
streamlines = read_vtk(in_file)[0]
resampled = streamlines_resample(streamlines, perc)
file_name, file_extension = os.path.splitext(in_file)
print(file_name + "_resampled" + file_extension)
save_vtk(file_name + "_resampled" + file_extension, np.array(resampled))
sys.exit()
def save_tract(x, fname, NPOINTS=20):
tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
save_vtk(fname, tract.detach().cpu().numpy())
def save_tract_numpy(x, fname, NPOINTS=20):
tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
np.save(
fname,
np.float16(tract.detach().cpu().numpy()),
allow_pickle=False,
fix_imports=False,
)
def save_tract_with_labels(fname, x, scalars, subsampling_fibers=None, NPOINTS=20):
# save tracts +label information on each fiber
tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
tract = tract[0::subsampling_fibers, :, :]
nf, ns, d = tract.shape
labels = scalars.view(-1, 1).repeat(1, ns)
save_vtk_labels(
fname,
tract.detach().cpu().numpy(),
labels.view(-1).detach().cpu().numpy().astype(int),
)
def save_tracts_labels_separate(fname, x, labels, start, end, NPOINTS=20):
# save tracts +label information on each fiber
tract = x.view(len(x), -1, 3) * np.sqrt(NPOINTS)
for l in range(start, end):
if (labels == l).nonzero().shape[0] != 0:
tract_l = tract[(labels == l).nonzero().view(-1), :, :]
save_vtk(fname + "_{:05d}.vtk".format(l), tract_l.detach().cpu().numpy())
else:
save_vtk(fname + "_{:05d}.vtk".format(l), np.array([[0, 0, 0]]))
Total running time of the script: ( 0 minutes 0.000 seconds)