.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/plot_tree_growth.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here <sphx_glr_download__auto_examples_plot_tree_growth.py>` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_plot_tree_growth.py: Analyzing Differences Between Tree Images ========================================= Image registration with an implicit module of order 1. Segmentations given by the data are used to initialize its points. .. GENERATED FROM PYTHON SOURCE LINES 10-15 Initialization -------------- Import relevant Python modules. .. GENERATED FROM PYTHON SOURCE LINES 15-32 .. code-block:: default import time import pickle import sys sys.path.append("../") import matplotlib.pyplot as plt import torch import imodal device = 'cuda:1' torch.set_default_dtype(torch.float64) imodal.Utilities.set_compute_backend('keops') .. GENERATED FROM PYTHON SOURCE LINES 33-35 Load source and target images, along with the source curve. .. GENERATED FROM PYTHON SOURCE LINES 35-49 .. code-block:: default with open("../data/tree_growth.pickle", 'rb') as f: data = pickle.load(f) source_shape = data['source_shape'].to(torch.get_default_dtype()) source_image = data['source_image'].to(torch.get_default_dtype()) target_image = data['target_image'].to(torch.get_default_dtype()) # Segmentations as Axis Aligned Bounding Boxes (AABB) aabb_trunk = data['aabb_trunk'] aabb_crown = data['aabb_leaves'] extent = data['extent'] .. GENERATED FROM PYTHON SOURCE LINES 50-53 Display source and target images, along with the segmented source curve (orange for the trunk, green for the crown). .. GENERATED FROM PYTHON SOURCE LINES 53-71 .. code-block:: default shape_is_trunk = aabb_trunk.is_inside(source_shape) shape_is_crown = aabb_crown.is_inside(source_shape) plt.subplot(1, 2, 1) plt.title("Source") plt.imshow(source_image, cmap='gray', origin='lower', extent=extent.totuple()) plt.plot(source_shape[shape_is_trunk, 0].numpy(), source_shape[shape_is_trunk, 1].numpy(), lw=2., color='orange') plt.plot(source_shape[shape_is_crown, 0].numpy(), source_shape[shape_is_crown, 1].numpy(), lw=2., color='green') plt.axis('off') plt.subplot(1, 2, 2) plt.title("Target") plt.imshow(target_image, cmap='gray', origin='lower', extent=extent.totuple()) plt.axis('off') plt.show() .. GENERATED FROM PYTHON SOURCE LINES 72-74 Generating implicit modules of order 1 points and growth model tensor. .. GENERATED FROM PYTHON SOURCE LINES 74-110 .. code-block:: default implicit1_density = 500. # Lambda function defining the area in and around the tree shape area = lambda x, **kwargs: imodal.Utilities.area_shape(x, **kwargs) | imodal.Utilities.area_polyline_outline(x, **kwargs) polyline_width = 0.07 # Generation of the points of the initial geometrical descriptor implicit1_points = imodal.Utilities.fill_area_uniform_density(area, imodal.Utilities.AABB(xmin=0., xmax=1., ymin=0., ymax=1.), implicit1_density, shape=source_shape, polyline=source_shape, width=polyline_width) # Masks that flag points into either the trunk or the crown implicit1_trunk_points = aabb_trunk.is_inside(implicit1_points) implicit1_crown_points = aabb_crown.is_inside(implicit1_points) implicit1_points = implicit1_points[implicit1_trunk_points | implicit1_crown_points] implicit1_trunk_points = aabb_trunk.is_inside(implicit1_points) implicit1_crown_points = aabb_crown.is_inside(implicit1_points) assert implicit1_points[implicit1_trunk_points].shape[0] + implicit1_points[implicit1_crown_points].shape[0] == implicit1_points.shape[0] # Initial normal frames implicit1_r = torch.eye(2).repeat(implicit1_points.shape[0], 1, 1) # Growth model tensor implicit1_c = torch.zeros(implicit1_points.shape[0], 2, 4) # Horizontal stretching for the trunk implicit1_c[implicit1_trunk_points, 0, 0] = 1. # Vertical stretching for the trunk implicit1_c[implicit1_trunk_points, 1, 1] = 1. # Horizontal stretching for the crown implicit1_c[implicit1_crown_points, 0, 2] = 1. # Vertical stretching for the crown implicit1_c[implicit1_crown_points, 1, 3] = 1. .. GENERATED FROM PYTHON SOURCE LINES 111-113 Plot the 4 dimensional growth model tensor. .. GENERATED FROM PYTHON SOURCE LINES 113-126 .. code-block:: default plt.figure(figsize=[20., 5.]) for i in range(4): ax = plt.subplot(1, 4, i + 1) plt.imshow(source_image, origin='lower', extent=extent, cmap='gray') imodal.Utilities.plot_C_ellipses(ax, implicit1_points, implicit1_c, c_index=i, color='blue', scale=0.03) plt.xlim(0., 1.) plt.ylim(0., 1.) plt.axis('off') plt.show() .. GENERATED FROM PYTHON SOURCE LINES 127-129 Create the deformation model with a combination of 2 modules : a global translation and the implicit module of order 1. .. GENERATED FROM PYTHON SOURCE LINES 132-134 Create and initialize the global translation module **global_translation**. .. GENERATED FROM PYTHON SOURCE LINES 134-139 .. code-block:: default global_translation_coeff = 1. global_translation = imodal.DeformationModules.GlobalTranslation(2, coeff=global_translation_coeff) .. GENERATED FROM PYTHON SOURCE LINES 140-142 Create and initialize the implicit module of order 1 **implicit1**. .. GENERATED FROM PYTHON SOURCE LINES 142-150 .. code-block:: default sigma1 = 2./implicit1_density**(1/2) implicit1_coeff = 0.1 implicit1_nu = 100. implicit1 = imodal.DeformationModules.ImplicitModule1(2, implicit1_points.shape[0], sigma1, implicit1_c, nu=implicit1_nu, gd=(implicit1_points, implicit1_r), coeff=implicit1_coeff) implicit1.eps = 1e-2 .. GENERATED FROM PYTHON SOURCE LINES 151-153 Define deformables used by the registration model. .. GENERATED FROM PYTHON SOURCE LINES 153-160 .. code-block:: default source_image_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent=extent) target_image_deformable = imodal.Models.DeformableImage(target_image, output='bitmap', extent=extent) source_image_deformable.to_device(device) target_image_deformable.to_device(device) .. GENERATED FROM PYTHON SOURCE LINES 161-165 Registration ------------ Define the registration model. .. GENERATED FROM PYTHON SOURCE LINES 165-171 .. code-block:: default attachment_image = imodal.Attachment.L2NormAttachment(weight=1e0) model = imodal.Models.RegistrationModel([source_image_deformable], [implicit1, global_translation], [attachment_image], lam=1.) model.to_device(device) .. GENERATED FROM PYTHON SOURCE LINES 172-174 Fitting using Torch LBFGS optimizer. .. GENERATED FROM PYTHON SOURCE LINES 174-183 .. code-block:: default shoot_solver = 'rk4' shoot_it = 10 costs = {} fitter = imodal.Models.Fitter(model, optimizer='torch_lbfgs') fitter.fit([target_image_deformable], 500, costs=costs, options={'shoot_solver': shoot_solver, 'shoot_it': shoot_it, 'line_search_fn': 'strong_wolfe', 'history_size': 500}) .. GENERATED FROM PYTHON SOURCE LINES 184-188 Visualization ------------- Compute optimized deformation trajectory. .. GENERATED FROM PYTHON SOURCE LINES 188-196 .. code-block:: default deformed_intermediates = {} start = time.perf_counter() with torch.autograd.no_grad(): deformed_image = model.compute_deformed(shoot_solver, shoot_it, intermediates=deformed_intermediates)[0][0].detach().cpu() print("Elapsed={elapsed}".format(elapsed=time.perf_counter()-start)) .. GENERATED FROM PYTHON SOURCE LINES 197-199 Display deformed source image and target. .. GENERATED FROM PYTHON SOURCE LINES 199-219 .. code-block:: default plt.figure(figsize=[15., 5.]) plt.subplot(1, 3, 1) plt.title("Source") plt.imshow(source_image, extent=extent.totuple(), origin='lower') plt.axis('off') plt.subplot(1, 3, 2) plt.title("Deformed") plt.imshow(deformed_image, extent=extent.totuple(), origin='lower') plt.axis('off') plt.subplot(1, 3, 3) plt.title("Target") plt.imshow(target_image, extent=extent.totuple(), origin='lower') plt.axis('off') plt.show() .. GENERATED FROM PYTHON SOURCE LINES 220-222 We can follow the action of each part of the total deformation by setting all the controls components to zero but one. .. GENERATED FROM PYTHON SOURCE LINES 224-226 Functions generating controls to follow one part of the deformation. .. GENERATED FROM PYTHON SOURCE LINES 226-244 .. code-block:: default def generate_implicit1_controls(table): outcontrols = [] for control in deformed_intermediates['controls']: outcontrols.append(control[1]*torch.tensor(table, dtype=torch.get_default_dtype(), device=device)) return outcontrols def generate_controls(implicit1_table, trans): outcontrols = [] implicit1_controls = generate_implicit1_controls(implicit1_table) for control, implicit1_control in zip(deformed_intermediates['controls'], implicit1_controls): outcontrols.append([implicit1_control, control[2]*torch.tensor(trans, dtype=torch.get_default_dtype(), device=device)]) return outcontrols .. GENERATED FROM PYTHON SOURCE LINES 245-247 Function to compute a deformation given a set of controls up to some time point. .. GENERATED FROM PYTHON SOURCE LINES 247-279 .. code-block:: default grid_resolution = [16, 16] def compute_intermediate_deformed(it, controls, t1, intermediates=None): implicit1_points = deformed_intermediates['states'][0][1].gd[0] implicit1_r = deformed_intermediates['states'][0][1].gd[1] implicit1_cotan_points = deformed_intermediates['states'][0][1].cotan[0] implicit1_cotan_r = deformed_intermediates['states'][0][1].cotan[1] silent_cotan = deformed_intermediates['states'][0][0].cotan implicit1 = imodal.DeformationModules.ImplicitModule1(2, implicit1_points.shape[0], sigma1, implicit1_c.clone(), nu=implicit1_nu, gd=(implicit1_points.clone(), implicit1_r.clone()), cotan=(implicit1_cotan_points, implicit1_cotan_r), coeff=implicit1_coeff) global_translation = imodal.DeformationModules.GlobalTranslation(2, coeff=global_translation_coeff) implicit1.to_(device=device) global_translation.to_(device=device) source_deformable = imodal.Models.DeformableImage(source_image, output='bitmap', extent=extent) source_deformable.silent_module.manifold.cotan = silent_cotan grid_deformable = imodal.Models.DeformableGrid(extent, grid_resolution) source_deformable.to_device(device) grid_deformable.to_device(device) costs = {} with torch.autograd.no_grad(): deformed = imodal.Models.deformables_compute_deformed([source_deformable, grid_deformable], [implicit1, global_translation], shoot_solver, it, controls=controls, t1=t1, intermediates=intermediates, costs=costs) return deformed[0][0] .. GENERATED FROM PYTHON SOURCE LINES 280-282 Functions to generate the deformation trajectory given a set of controls. .. GENERATED FROM PYTHON SOURCE LINES 282-316 .. code-block:: default def generate_images(table, trans, outputfilename): incontrols = generate_controls(table, trans) intermediates_shape = {} deformed = compute_intermediate_deformed(10, incontrols, 1., intermediates=intermediates_shape) trajectory_grid = [imodal.Utilities.vec2grid(state[1].gd, grid_resolution[0], grid_resolution[1]) for state in intermediates_shape['states']] trajectory = [source_image] t = torch.linspace(0., 1., 11) indices = [0, 3, 7, 10] print("Computing trajectories...") for index in indices[1:]: print("{}, t={}".format(index, t[index])) deformed = compute_intermediate_deformed(index, incontrols[:index], t[index]) trajectory.append(deformed) print("Generating images...") plt.figure(figsize=[5.*len(indices), 5.]) for deformed, i in zip(trajectory, range(len(indices))): ax = plt.subplot(1, len(indices), i + 1) grid = trajectory_grid[indices[i]] plt.imshow(deformed.cpu(), origin='lower', extent=extent, cmap='gray') imodal.Utilities.plot_grid(ax, grid[0].cpu(), grid[1].cpu(), color='xkcd:light blue', lw=1) plt.xlim(0., 1.) plt.ylim(0., 1.) plt.axis('off') plt.tight_layout() plt.show() .. GENERATED FROM PYTHON SOURCE LINES 317-319 Generate trajectory of the total optimized deformation. .. GENERATED FROM PYTHON SOURCE LINES 319-322 .. code-block:: default generate_images([True, True, True, True], True, "deformed_all") .. GENERATED FROM PYTHON SOURCE LINES 323-325 Generate trajectory following vertical elongation of the trunk. .. GENERATED FROM PYTHON SOURCE LINES 325-328 .. code-block:: default generate_images([False, True, False, False], False, "deformed_trunk_vertical") .. GENERATED FROM PYTHON SOURCE LINES 329-331 Generate trajectory following horizontal elongation of the crown. .. GENERATED FROM PYTHON SOURCE LINES 331-335 .. code-block:: default generate_images([False, False, True, False], False, "deformed_crown_horizontal") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download__auto_examples_plot_tree_growth.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_tree_growth.py <plot_tree_growth.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_tree_growth.ipynb <plot_tree_growth.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_