Coverage for src/meshpy/core/vtk_writer.py: 95%

152 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-28 04:21 +0000

1# The MIT License (MIT) 

2# 

3# Copyright (c) 2018-2025 MeshPy Authors 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

21# THE SOFTWARE. 

22"""This module provides a class that is used to write VTK files.""" 

23 

24import numbers as _numbers 

25import os as _os 

26import warnings as _warnings 

27 

28import numpy as _np 

29import vtk as _vtk 

30 

31from meshpy.core.conf import mpy as _mpy 

32 

33 

34def add_point_data_node_sets(point_data, nodes, *, extra_points=0): 

35 """Add the information if a node is part of a set to the point_data vector 

36 for all nodes in the list 'nodes'. 

37 

38 The extra_points argument specifies how many additional 

39 visualization points there are, i.e., points that are not based on 

40 nodes, but are only used for visualization purposes. 

41 """ 

42 

43 # Get list with node set indices of the given nodes 

44 geometry_set_list = [] 

45 for node in nodes: 

46 geometry_set_list.extend(node.node_sets_link) 

47 

48 # Remove double entries of list. 

49 geometry_set_list = list(set(geometry_set_list)) 

50 

51 # Loop over the geometry sets. 

52 n_nodes = len(nodes) 

53 for geometry_set in geometry_set_list: 

54 # Check which nodes are connected to a geometry set. 

55 data_vector = _np.zeros(n_nodes + extra_points) 

56 for i, node in enumerate(nodes): 

57 if geometry_set in node.node_sets_link: 

58 data_vector[i] = 1 

59 else: 

60 data_vector[i] = _mpy.vtk_nan_int 

61 for i in range(extra_points): 

62 data_vector[n_nodes + i] = ( 

63 1 if geometry_set.geometry_type is _mpy.geo.line else _mpy.vtk_nan_int 

64 ) 

65 

66 # Get the name of the geometry type. 

67 if geometry_set.geometry_type is _mpy.geo.point: 

68 geometry_name = "geometry_point" 

69 elif geometry_set.geometry_type is _mpy.geo.line: 

70 geometry_name = "geometry_line" 

71 elif geometry_set.geometry_type is _mpy.geo.surface: 

72 geometry_name = "geometry_surface" 

73 elif geometry_set.geometry_type is _mpy.geo.volume: 

74 geometry_name = "geometry_volume" 

75 else: 

76 raise TypeError("The geometry type is wrong!") 

77 

78 # Add the data vector. 

79 set_name = f"{geometry_name}_set_{_mpy.vtk_node_set_format.format(geometry_set.i_global)}" 

80 point_data[set_name] = (data_vector, _mpy.vtk_type.int) 

81 

82 

83def _get_data_value_and_type(data): 

84 """Return the data and its type if one was given. 

85 

86 The default type, if none was given is float. 

87 """ 

88 if isinstance(data, tuple): 

89 return data[0], data[1] 

90 else: 

91 return data, _mpy.vtk_type.float 

92 

93 

94def _get_vtk_array_type(data): 

95 """Return the corresponding meshpy type.""" 

96 data_type = data.GetDataTypeAsString() 

97 if data_type == "int": 

98 return _mpy.vtk_type.int 

99 elif data_type == "double": 

100 return _mpy.vtk_type.float 

101 raise ValueError(f'Got unexpected type "{data_type}"!') 

102 

103 

104class VTKWriter: 

105 """A class that manages VTK cells and data and can also create them.""" 

106 

107 def __init__(self): 

108 # Initialize VTK objects. 

109 self.points = _vtk.vtkPoints() 

110 self.points.SetDataTypeToDouble() 

111 self.grid = _vtk.vtkUnstructuredGrid() 

112 

113 # Link points to grid. 

114 self.grid.SetPoints(self.points) 

115 

116 # Container for output data. 

117 self.data = {} 

118 for key1 in _mpy.vtk_geo: 

119 for key2 in _mpy.vtk_tensor: 

120 self.data[key1, key2] = {} 

121 

122 def add_points(self, points, *, point_data=None): 

123 """Add points to the data stored in this object. 

124 

125 Args 

126 ---- 

127 points: [3d vector] 

128 Coordinates of points for this cell. 

129 point_data: dic 

130 A dictionary containing data that will be added for the newly added points. 

131 If a field exists in the global data but not in the one added here, that field 

132 will be set to mpy.vtk_nan for the newly added points. 

133 

134 Return: 

135 ---- 

136 indices: [int] 

137 A list with the global indices of the added points. 

138 """ 

139 

140 n_points = len(points) 

141 

142 # Check if point data containers are of the correct size 

143 if point_data is not None: 

144 for key, item_value in point_data.items(): 

145 value, _data_type = _get_data_value_and_type(item_value) 

146 if not len(value) == n_points: 

147 raise IndexError( 

148 f"The length of coordinates is {n_points}," 

149 f"the length of {key} is {len(value)}, does not match!" 

150 ) 

151 

152 # Add point data 

153 self._add_data(point_data, _mpy.vtk_geo.point, n_new_items=n_points) 

154 

155 # Add point coordinates 

156 n_grid_points = self.points.GetNumberOfPoints() 

157 for point in points: 

158 # Add the coordinate to the global list of coordinates. 

159 self.points.InsertNextPoint(*point) 

160 

161 return _np.array( 

162 [n_grid_points + i_point for i_point in range(len(points))], dtype=int 

163 ) 

164 

165 def add_cell(self, cell_type, topology, *, cell_data=None): 

166 """Create a cell and add it to the global array. 

167 

168 Args 

169 ---- 

170 cell_type: VTK_type 

171 Type of cell that will be created. 

172 topology: [int] 

173 The connectivity between the cell and the global points. 

174 cell_data: dic 

175 A dictionary containing data that will be added for the newly added cell. 

176 If a field exists in the global data but not in the one added here, that field 

177 will be set to mpy.vtk_nan for the newly added cell. 

178 """ 

179 

180 # Add the data entries. 

181 self._add_data(cell_data, _mpy.vtk_geo.cell) 

182 

183 # Create the cell. 

184 geometry_item = cell_type() 

185 geometry_item.GetPointIds().SetNumberOfIds(len(topology)) 

186 

187 # Set the connectivity 

188 for i_local, i_global in enumerate(topology): 

189 geometry_item.GetPointIds().SetId(i_local, i_global) 

190 

191 # Add to global cells 

192 self.grid.InsertNextCell( 

193 geometry_item.GetCellType(), geometry_item.GetPointIds() 

194 ) 

195 

196 def _add_data(self, data_container, vtk_geom_type, *, n_new_items=1): 

197 """Add a data container to the existing global data container of this 

198 object. 

199 

200 Args 

201 ---- 

202 data_container: see self.add_cell 

203 vtk_geom_type: mpy.vtk_geo 

204 Type of data container that is added 

205 n_new_items: int 

206 Number of new items added. This is needed to fill up data fields that are in the 

207 global data but not in the one that is added. 

208 """ 

209 

210 # Check if data container already exists. If not, add it and also add 

211 # previous entries. 

212 if data_container is not None: 

213 if vtk_geom_type == _mpy.vtk_geo.cell: 

214 n_items = self.grid.GetNumberOfCells() 

215 else: 

216 n_items = self.grid.GetNumberOfPoints() 

217 

218 for key, item_value in data_container.items(): 

219 # Get the data and the value type (int or float). 

220 value, data_type = _get_data_value_and_type(item_value) 

221 

222 # Data type. 

223 if vtk_geom_type == _mpy.vtk_geo.cell: 

224 vtk_tensor_type = self._get_vtk_data_type(value) 

225 else: 

226 for item in value: 

227 vtk_tensor_type = self._get_vtk_data_type(item) 

228 

229 # Check if key already exists. 

230 if key not in self.data[vtk_geom_type, vtk_tensor_type].keys(): 

231 # Set up the VTK data array. 

232 if data_type is _mpy.vtk_type.float: 

233 data = _vtk.vtkDoubleArray() 

234 else: 

235 data = _vtk.vtkIntArray() 

236 data.SetName(key) 

237 if vtk_tensor_type == _mpy.vtk_tensor.scalar: 

238 data.SetNumberOfComponents(1) 

239 else: 

240 data.SetNumberOfComponents(3) 

241 

242 # Add the empty values for all previous cells / points. 

243 

244 for i in range(n_items): 

245 self._add_single_data_item(data, vtk_tensor_type) 

246 self.data[vtk_geom_type, vtk_tensor_type][key] = data 

247 

248 else: 

249 # In this case we just check that the already existing 

250 # data has the same type. 

251 data_array = self.data[vtk_geom_type, vtk_tensor_type][key] 

252 if not _get_vtk_array_type(data_array) == data_type: 

253 raise ValueError( 

254 ( 

255 'The existing data with the key "{}"' 

256 + ' is of type "{}", but the type you tried to add' 

257 + ' is "{}"!' 

258 ).format(key, data_array.GetDataTypeAsString(), data_type) 

259 ) 

260 

261 # Add to global data. Check if there is something to be added. If not an empty value 

262 # is added. 

263 for key_tensor in _mpy.vtk_tensor: 

264 global_data = self.data[vtk_geom_type, key_tensor] 

265 if data_container is None: 

266 data_container = {} 

267 

268 for key, value in global_data.items(): 

269 # Check if an existing field is also given for this function. 

270 if key in data_container.keys(): 

271 # Get the data and the value type (int or float). 

272 data_values, _ = _get_data_value_and_type(data_container[key]) 

273 

274 # Add the given data. 

275 if vtk_geom_type == _mpy.vtk_geo.cell: 

276 self._add_single_data_item( 

277 value, key_tensor, non_zero_data=data_values 

278 ) 

279 else: 

280 for item in data_values: 

281 self._add_single_data_item( 

282 value, key_tensor, non_zero_data=item 

283 ) 

284 else: 

285 # Add empty data. 

286 if vtk_geom_type == _mpy.vtk_geo.cell: 

287 self._add_single_data_item(value, key_tensor) 

288 else: 

289 for item in range(n_new_items): 

290 self._add_single_data_item(value, key_tensor) 

291 

292 @staticmethod 

293 def _get_vtk_data_type(data): 

294 """Return the type of data. 

295 

296 Check if data matches an expected case. 

297 """ 

298 

299 if isinstance(data, (list, _np.ndarray)): 

300 if len(data) == 3: 

301 return _mpy.vtk_tensor.vector 

302 raise IndexError( 

303 f"Only 3d vectors are implemented yet! Got len(data) = {len(data)}" 

304 ) 

305 elif isinstance(data, _numbers.Number): 

306 return _mpy.vtk_tensor.scalar 

307 

308 raise ValueError(f"Data {data} did not match any expected case!") 

309 

310 @staticmethod 

311 def _add_single_data_item(data, vtk_tensor_type, non_zero_data=None): 

312 """Add data to a VTK data array.""" 

313 

314 if _get_vtk_array_type(data) == _mpy.vtk_type.int: 

315 nan_value = _mpy.vtk_nan_int 

316 elif _get_vtk_array_type(data) == _mpy.vtk_type.float: 

317 nan_value = _mpy.vtk_nan_float 

318 

319 if vtk_tensor_type == _mpy.vtk_tensor.scalar: 

320 if non_zero_data is None: 

321 data.InsertNextTuple1(nan_value) 

322 else: 

323 data.InsertNextTuple1(non_zero_data) 

324 else: 

325 if non_zero_data is None: 

326 data.InsertNextTuple3(nan_value, nan_value, nan_value) 

327 else: 

328 data.InsertNextTuple3( 

329 non_zero_data[0], non_zero_data[1], non_zero_data[2] 

330 ) 

331 

332 def complete_data(self): 

333 """Add the stored data to the vtk grid.""" 

334 for (key_geom, _key_data), value in self.data.items(): 

335 for vtk_data in value.values(): 

336 if key_geom == _mpy.vtk_geo.cell: 

337 self.grid.GetCellData().AddArray(vtk_data) 

338 else: 

339 self.grid.GetPointData().AddArray(vtk_data) 

340 

341 def write_vtk(self, filepath, *, binary=True): 

342 """Write the VTK geometry and data to a file. 

343 

344 Args 

345 ---- 

346 filepath: str 

347 Path to output file. The file extension should be vtu. 

348 binary: bool 

349 If the data should be written encoded in binary or in human readable text. 

350 """ 

351 

352 # Check if directory for file exits. 

353 file_directory = _os.path.dirname(filepath) 

354 if not _os.path.isdir(file_directory): 

355 raise ValueError(f"Directory {file_directory} does not exist!".format()) 

356 

357 # Initialize VTK writer. 

358 writer = _vtk.vtkXMLUnstructuredGridWriter() 

359 

360 # Set the ascii flag. 

361 if not binary: 

362 writer.SetDataModeToAscii() 

363 

364 # Check the file extension. 

365 _filename, file_extension = _os.path.splitext(filepath) 

366 if not file_extension.lower() == ".vtu": 

367 _warnings.warn(f'The extension should be "vtu", got {file_extension}!') 

368 

369 # Write geometry and data to file. 

370 writer.SetFileName(filepath) 

371 writer.SetInputData(self.grid) 

372 writer.Write()