Coverage for src/meshpy/core/vtk_writer.py: 95%
152 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-28 04:21 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-28 04:21 +0000
1# The MIT License (MIT)
2#
3# Copyright (c) 2018-2025 MeshPy Authors
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21# THE SOFTWARE.
22"""This module provides a class that is used to write VTK files."""
24import numbers as _numbers
25import os as _os
26import warnings as _warnings
28import numpy as _np
29import vtk as _vtk
31from meshpy.core.conf import mpy as _mpy
34def add_point_data_node_sets(point_data, nodes, *, extra_points=0):
35 """Add the information if a node is part of a set to the point_data vector
36 for all nodes in the list 'nodes'.
38 The extra_points argument specifies how many additional
39 visualization points there are, i.e., points that are not based on
40 nodes, but are only used for visualization purposes.
41 """
43 # Get list with node set indices of the given nodes
44 geometry_set_list = []
45 for node in nodes:
46 geometry_set_list.extend(node.node_sets_link)
48 # Remove double entries of list.
49 geometry_set_list = list(set(geometry_set_list))
51 # Loop over the geometry sets.
52 n_nodes = len(nodes)
53 for geometry_set in geometry_set_list:
54 # Check which nodes are connected to a geometry set.
55 data_vector = _np.zeros(n_nodes + extra_points)
56 for i, node in enumerate(nodes):
57 if geometry_set in node.node_sets_link:
58 data_vector[i] = 1
59 else:
60 data_vector[i] = _mpy.vtk_nan_int
61 for i in range(extra_points):
62 data_vector[n_nodes + i] = (
63 1 if geometry_set.geometry_type is _mpy.geo.line else _mpy.vtk_nan_int
64 )
66 # Get the name of the geometry type.
67 if geometry_set.geometry_type is _mpy.geo.point:
68 geometry_name = "geometry_point"
69 elif geometry_set.geometry_type is _mpy.geo.line:
70 geometry_name = "geometry_line"
71 elif geometry_set.geometry_type is _mpy.geo.surface:
72 geometry_name = "geometry_surface"
73 elif geometry_set.geometry_type is _mpy.geo.volume:
74 geometry_name = "geometry_volume"
75 else:
76 raise TypeError("The geometry type is wrong!")
78 # Add the data vector.
79 set_name = f"{geometry_name}_set_{_mpy.vtk_node_set_format.format(geometry_set.i_global)}"
80 point_data[set_name] = (data_vector, _mpy.vtk_type.int)
83def _get_data_value_and_type(data):
84 """Return the data and its type if one was given.
86 The default type, if none was given is float.
87 """
88 if isinstance(data, tuple):
89 return data[0], data[1]
90 else:
91 return data, _mpy.vtk_type.float
94def _get_vtk_array_type(data):
95 """Return the corresponding meshpy type."""
96 data_type = data.GetDataTypeAsString()
97 if data_type == "int":
98 return _mpy.vtk_type.int
99 elif data_type == "double":
100 return _mpy.vtk_type.float
101 raise ValueError(f'Got unexpected type "{data_type}"!')
104class VTKWriter:
105 """A class that manages VTK cells and data and can also create them."""
107 def __init__(self):
108 # Initialize VTK objects.
109 self.points = _vtk.vtkPoints()
110 self.points.SetDataTypeToDouble()
111 self.grid = _vtk.vtkUnstructuredGrid()
113 # Link points to grid.
114 self.grid.SetPoints(self.points)
116 # Container for output data.
117 self.data = {}
118 for key1 in _mpy.vtk_geo:
119 for key2 in _mpy.vtk_tensor:
120 self.data[key1, key2] = {}
122 def add_points(self, points, *, point_data=None):
123 """Add points to the data stored in this object.
125 Args
126 ----
127 points: [3d vector]
128 Coordinates of points for this cell.
129 point_data: dic
130 A dictionary containing data that will be added for the newly added points.
131 If a field exists in the global data but not in the one added here, that field
132 will be set to mpy.vtk_nan for the newly added points.
134 Return:
135 ----
136 indices: [int]
137 A list with the global indices of the added points.
138 """
140 n_points = len(points)
142 # Check if point data containers are of the correct size
143 if point_data is not None:
144 for key, item_value in point_data.items():
145 value, _data_type = _get_data_value_and_type(item_value)
146 if not len(value) == n_points:
147 raise IndexError(
148 f"The length of coordinates is {n_points},"
149 f"the length of {key} is {len(value)}, does not match!"
150 )
152 # Add point data
153 self._add_data(point_data, _mpy.vtk_geo.point, n_new_items=n_points)
155 # Add point coordinates
156 n_grid_points = self.points.GetNumberOfPoints()
157 for point in points:
158 # Add the coordinate to the global list of coordinates.
159 self.points.InsertNextPoint(*point)
161 return _np.array(
162 [n_grid_points + i_point for i_point in range(len(points))], dtype=int
163 )
165 def add_cell(self, cell_type, topology, *, cell_data=None):
166 """Create a cell and add it to the global array.
168 Args
169 ----
170 cell_type: VTK_type
171 Type of cell that will be created.
172 topology: [int]
173 The connectivity between the cell and the global points.
174 cell_data: dic
175 A dictionary containing data that will be added for the newly added cell.
176 If a field exists in the global data but not in the one added here, that field
177 will be set to mpy.vtk_nan for the newly added cell.
178 """
180 # Add the data entries.
181 self._add_data(cell_data, _mpy.vtk_geo.cell)
183 # Create the cell.
184 geometry_item = cell_type()
185 geometry_item.GetPointIds().SetNumberOfIds(len(topology))
187 # Set the connectivity
188 for i_local, i_global in enumerate(topology):
189 geometry_item.GetPointIds().SetId(i_local, i_global)
191 # Add to global cells
192 self.grid.InsertNextCell(
193 geometry_item.GetCellType(), geometry_item.GetPointIds()
194 )
196 def _add_data(self, data_container, vtk_geom_type, *, n_new_items=1):
197 """Add a data container to the existing global data container of this
198 object.
200 Args
201 ----
202 data_container: see self.add_cell
203 vtk_geom_type: mpy.vtk_geo
204 Type of data container that is added
205 n_new_items: int
206 Number of new items added. This is needed to fill up data fields that are in the
207 global data but not in the one that is added.
208 """
210 # Check if data container already exists. If not, add it and also add
211 # previous entries.
212 if data_container is not None:
213 if vtk_geom_type == _mpy.vtk_geo.cell:
214 n_items = self.grid.GetNumberOfCells()
215 else:
216 n_items = self.grid.GetNumberOfPoints()
218 for key, item_value in data_container.items():
219 # Get the data and the value type (int or float).
220 value, data_type = _get_data_value_and_type(item_value)
222 # Data type.
223 if vtk_geom_type == _mpy.vtk_geo.cell:
224 vtk_tensor_type = self._get_vtk_data_type(value)
225 else:
226 for item in value:
227 vtk_tensor_type = self._get_vtk_data_type(item)
229 # Check if key already exists.
230 if key not in self.data[vtk_geom_type, vtk_tensor_type].keys():
231 # Set up the VTK data array.
232 if data_type is _mpy.vtk_type.float:
233 data = _vtk.vtkDoubleArray()
234 else:
235 data = _vtk.vtkIntArray()
236 data.SetName(key)
237 if vtk_tensor_type == _mpy.vtk_tensor.scalar:
238 data.SetNumberOfComponents(1)
239 else:
240 data.SetNumberOfComponents(3)
242 # Add the empty values for all previous cells / points.
244 for i in range(n_items):
245 self._add_single_data_item(data, vtk_tensor_type)
246 self.data[vtk_geom_type, vtk_tensor_type][key] = data
248 else:
249 # In this case we just check that the already existing
250 # data has the same type.
251 data_array = self.data[vtk_geom_type, vtk_tensor_type][key]
252 if not _get_vtk_array_type(data_array) == data_type:
253 raise ValueError(
254 (
255 'The existing data with the key "{}"'
256 + ' is of type "{}", but the type you tried to add'
257 + ' is "{}"!'
258 ).format(key, data_array.GetDataTypeAsString(), data_type)
259 )
261 # Add to global data. Check if there is something to be added. If not an empty value
262 # is added.
263 for key_tensor in _mpy.vtk_tensor:
264 global_data = self.data[vtk_geom_type, key_tensor]
265 if data_container is None:
266 data_container = {}
268 for key, value in global_data.items():
269 # Check if an existing field is also given for this function.
270 if key in data_container.keys():
271 # Get the data and the value type (int or float).
272 data_values, _ = _get_data_value_and_type(data_container[key])
274 # Add the given data.
275 if vtk_geom_type == _mpy.vtk_geo.cell:
276 self._add_single_data_item(
277 value, key_tensor, non_zero_data=data_values
278 )
279 else:
280 for item in data_values:
281 self._add_single_data_item(
282 value, key_tensor, non_zero_data=item
283 )
284 else:
285 # Add empty data.
286 if vtk_geom_type == _mpy.vtk_geo.cell:
287 self._add_single_data_item(value, key_tensor)
288 else:
289 for item in range(n_new_items):
290 self._add_single_data_item(value, key_tensor)
292 @staticmethod
293 def _get_vtk_data_type(data):
294 """Return the type of data.
296 Check if data matches an expected case.
297 """
299 if isinstance(data, (list, _np.ndarray)):
300 if len(data) == 3:
301 return _mpy.vtk_tensor.vector
302 raise IndexError(
303 f"Only 3d vectors are implemented yet! Got len(data) = {len(data)}"
304 )
305 elif isinstance(data, _numbers.Number):
306 return _mpy.vtk_tensor.scalar
308 raise ValueError(f"Data {data} did not match any expected case!")
310 @staticmethod
311 def _add_single_data_item(data, vtk_tensor_type, non_zero_data=None):
312 """Add data to a VTK data array."""
314 if _get_vtk_array_type(data) == _mpy.vtk_type.int:
315 nan_value = _mpy.vtk_nan_int
316 elif _get_vtk_array_type(data) == _mpy.vtk_type.float:
317 nan_value = _mpy.vtk_nan_float
319 if vtk_tensor_type == _mpy.vtk_tensor.scalar:
320 if non_zero_data is None:
321 data.InsertNextTuple1(nan_value)
322 else:
323 data.InsertNextTuple1(non_zero_data)
324 else:
325 if non_zero_data is None:
326 data.InsertNextTuple3(nan_value, nan_value, nan_value)
327 else:
328 data.InsertNextTuple3(
329 non_zero_data[0], non_zero_data[1], non_zero_data[2]
330 )
332 def complete_data(self):
333 """Add the stored data to the vtk grid."""
334 for (key_geom, _key_data), value in self.data.items():
335 for vtk_data in value.values():
336 if key_geom == _mpy.vtk_geo.cell:
337 self.grid.GetCellData().AddArray(vtk_data)
338 else:
339 self.grid.GetPointData().AddArray(vtk_data)
341 def write_vtk(self, filepath, *, binary=True):
342 """Write the VTK geometry and data to a file.
344 Args
345 ----
346 filepath: str
347 Path to output file. The file extension should be vtu.
348 binary: bool
349 If the data should be written encoded in binary or in human readable text.
350 """
352 # Check if directory for file exits.
353 file_directory = _os.path.dirname(filepath)
354 if not _os.path.isdir(file_directory):
355 raise ValueError(f"Directory {file_directory} does not exist!".format())
357 # Initialize VTK writer.
358 writer = _vtk.vtkXMLUnstructuredGridWriter()
360 # Set the ascii flag.
361 if not binary:
362 writer.SetDataModeToAscii()
364 # Check the file extension.
365 _filename, file_extension = _os.path.splitext(filepath)
366 if not file_extension.lower() == ".vtu":
367 _warnings.warn(f'The extension should be "vtu", got {file_extension}!')
369 # Write geometry and data to file.
370 writer.SetFileName(filepath)
371 writer.SetInputData(self.grid)
372 writer.Write()