diff --git a/docs/src/devdocs/dofhandler.md b/docs/src/devdocs/dofhandler.md index 0210759ed7..531def5284 100644 --- a/docs/src/devdocs/dofhandler.md +++ b/docs/src/devdocs/dofhandler.md @@ -19,6 +19,7 @@ Ferrite.SurfaceOrientationInfo The main entry point for dof distribution is [`__close!`](@ref). ```@docs +Ferrite.get_grid Ferrite.find_field(dh::DofHandler, field_name::Symbol) Ferrite._find_field(fh::FieldHandler, field_name::Symbol) Ferrite._close_fieldhandler! diff --git a/src/Dofs/ConstraintHandler.jl b/src/Dofs/ConstraintHandler.jl index 3ea8faafab..aeb51ed467 100644 --- a/src/Dofs/ConstraintHandler.jl +++ b/src/Dofs/ConstraintHandler.jl @@ -315,14 +315,14 @@ function _local_face_dofs_for_bc(interpolation, field_dim, components, offset, b return local_face_dofs, local_face_dofs_offset end -# Dirichlet on nodeset -function _add!(ch::ConstraintHandler, dbc::Dirichlet, bcnodes::Set{Int}, interpolation::Interpolation, field_dim::Int, offset::Int, bcvalue::BCValues, cellset::Set{Int}=Set{Int}(1:getncells(ch.dh.grid))) - if interpolation !== default_interpolation(typeof(ch.dh.grid.cells[first(cellset)])) +function _add!(ch::ConstraintHandler, dbc::Dirichlet, bcnodes::Set{Int}, interpolation::Interpolation, field_dim::Int, offset::Int, bcvalue::BCValues, cellset::Set{Int}=Set{Int}(1:getncells(get_grid(ch.dh)))) + grid = get_grid(ch.dh) + if interpolation !== default_interpolation(getcelltype(grid, first(cellset))) @warn("adding constraint to nodeset is not recommended for sub/super-parametric approximations.") end ncomps = length(dbc.components) - nnodes = getnnodes(ch.dh.grid) + nnodes = getnnodes(grid) interpol_points = getnbasefunctions(interpolation) node_dofs = zeros(Int, ncomps, nnodes) visited = falses(nnodes) @@ -449,7 +449,7 @@ function _update!(inhomogeneities::Vector{Float64}, f::Function, ::Set{Int}, fie dofmapping::Dict{Int,Int}, dofcoefficients::Vector{Union{Nothing,DofCoefficients{T}}}, time::Real) where T counter = 1 for nodenumber in nodeidxs - x = dh.grid.nodes[nodenumber].x + x = getcoordinates(getnodes(get_grid(dh), nodenumber)) bc_value = f(x, time) @assert length(bc_value) == length(components) for v in bc_value @@ -477,13 +477,13 @@ function WriteVTK.vtk_point_data(vtkfile, ch::ConstraintHandler) for field in unique_fields nd = getfielddim(ch.dh, field) - data = zeros(Float64, nd, getnnodes(ch.dh.grid)) + data = zeros(Float64, nd, getnnodes(get_grid(ch.dh))) for dbc in ch.dbcs dbc.field_name != field && continue if eltype(dbc.faces) <: BoundaryIndex functype = boundaryfunction(eltype(dbc.faces)) for (cellidx, faceidx) in dbc.faces - for facenode in functype(ch.dh.grid.cells[cellidx])[faceidx] + for facenode in functype(getcells(get_grid(ch.dh), cellidx))[faceidx] for component in dbc.components data[component, facenode] = 1 end @@ -823,7 +823,7 @@ function add!(ch::ConstraintHandler, dbc::Dirichlet) dbc.field_name in fh.field_names || continue # Compute the intersection between dbc.set and the cellset of this # FieldHandler and skip if the set is empty - filtered_set = filter_dbc_set(ch.dh.grid, fh.cellset, dbc.faces) + filtered_set = filter_dbc_set(get_grid(ch.dh), fh.cellset, dbc.faces) isempty(filtered_set) && continue # Fetch information about the field on this FieldHandler field_idx = find_field(fh, dbc.field_name) @@ -844,7 +844,7 @@ function add!(ch::ConstraintHandler, dbc::Dirichlet) # BCValues are just dummy for nodesets so set to FaceIndex EntityType = FaceIndex end - CT = getcelltype(ch.dh.grid, first(fh.cellset)) # Same celltype enforced in FieldHandler constructor + CT = getcelltype(get_grid(ch.dh), fh) # Same celltype enforced in FieldHandler constructor bcvalues = BCValues(interpolation, default_interpolation(CT), EntityType) # Recreate the Dirichlet(...) struct with the filtered set and call internal add! filtered_dbc = Dirichlet(dbc.field_name, filtered_set, dbc.f, components) @@ -867,6 +867,7 @@ function filter_dbc_set(::AbstractGrid, fhset::AbstractSet{Int}, dbcset::Abstrac end return ret end + function filter_dbc_set(grid::AbstractGrid, fhset::AbstractSet{Int}, dbcset::AbstractSet{Int}) ret = empty(dbcset) nodes_in_fhset = Set{Int}() @@ -939,7 +940,7 @@ function add!(ch::ConstraintHandler, pdbc::PeriodicDirichlet) is_legacy = !isempty(pdbc.face_pairs) && isempty(pdbc.face_map) if is_legacy for (mset, iset) in pdbc.face_pairs - collect_periodic_faces!(pdbc.face_map, ch.dh.grid, mset, iset, identity) # TODO: Better transform + collect_periodic_faces!(pdbc.face_map, get_grid(ch.dh), mset, iset, identity) # TODO: Better transform end end field_idx = find_field(ch.dh, pdbc.field_name) @@ -980,9 +981,8 @@ end function _add!(ch::ConstraintHandler, pdbc::PeriodicDirichlet, interpolation::Interpolation, field_dim::Int, offset::Int, is_legacy::Bool, rotation_matrix::Union{Matrix{T},Nothing}, ::Type{dof_map_t}, iterator_f::F) where {T, dof_map_t, F <: Function} - grid = ch.dh.grid + grid = get_grid(ch.dh) face_map = pdbc.face_map - Tx = typeof(first(ch.dh.grid.nodes).x) # Vec{D,T} # Indices of the local dofs for the faces local_face_dofs, local_face_dofs_offset = @@ -1054,6 +1054,7 @@ function _add!(ch::ConstraintHandler, pdbc::PeriodicDirichlet, interpolation::In "Dirichlet boundary condition on the relevant nodeset.", :PeriodicDirichlet) all_node_idxs = Set{Int}() + Tx = get_coordinate_type(grid) min_x = Tx(i -> typemax(eltype(Tx))) max_x = Tx(i -> typemin(eltype(Tx))) for facepair in face_map, faceidx in (facepair.mirror, facepair.image) @@ -1061,14 +1062,14 @@ function _add!(ch::ConstraintHandler, pdbc::PeriodicDirichlet, interpolation::In nodes = faces(grid.cells[cellidx])[faceidx] union!(all_node_idxs, nodes) for n in nodes - x = grid.nodes[n].x + x = getcoordinates(getnodes(grid, n)) min_x = Tx(i -> min(min_x[i], x[i])) max_x = Tx(i -> max(max_x[i], x[i])) end end all_node_idxs_v = collect(all_node_idxs) points = construct_cornerish(min_x, max_x) - tree = KDTree(Tx[grid.nodes[i].x for i in all_node_idxs_v]) + tree = KDTree(Tx[getcoordinates(getnodes(grid, i)) for i in all_node_idxs_v]) idxs, _ = NearestNeighbors.nn(tree, points) corner_set = Set{Int}(all_node_idxs_v[i] for i in idxs) @@ -1342,12 +1343,12 @@ function __collect_periodic_faces_tree!(face_map::Vector{PeriodicFacePair}, grid if length(mset) != length(mset) error("different number of faces in mirror and image set") end - Tx = typeof(first(grid.nodes).x) + Tx = get_coordinate_type(grid) mirror_mean_x = Tx[] for (c, f) in mset fn = faces(grid.cells[c])[f] - push!(mirror_mean_x, sum(grid.nodes[i].x for i in fn) / length(fn)) + push!(mirror_mean_x, sum(getcoordinates(getnodes(grid,i)) for i in fn) / length(fn)) end # Same dance for the image @@ -1355,7 +1356,7 @@ function __collect_periodic_faces_tree!(face_map::Vector{PeriodicFacePair}, grid for (c, f) in iset fn = faces(grid.cells[c])[f] # Apply transformation to all coordinates - push!(image_mean_x, sum(transformation(grid.nodes[i].x)::Tx for i in fn) / length(fn)) + push!(image_mean_x, sum(transformation(getcoordinates(getnodes(grid,i)))::Tx for i in fn) / length(fn)) end # Use KDTree to find closest face @@ -1432,16 +1433,16 @@ function __periodic_options(::T) where T <: Vec{3} end function __outward_normal(grid::Grid{2}, nodes, transformation::F=identity) where F <: Function - n1::Vec{2} = transformation(grid.nodes[nodes[1]].x) - n2::Vec{2} = transformation(grid.nodes[nodes[2]].x) + n1::Vec{2} = transformation(getcoordinates(getnodes(grid, nodes[1]))) + n2::Vec{2} = transformation(getcoordinates(getnodes(grid, nodes[2]))) n = Vec{2}((n2[2] - n1[2], - n2[1] + n1[1])) return n / norm(n) end function __outward_normal(grid::Grid{3}, nodes, transformation::F=identity) where F <: Function - n1::Vec{3} = transformation(grid.nodes[nodes[1]].x) - n2::Vec{3} = transformation(grid.nodes[nodes[2]].x) - n3::Vec{3} = transformation(grid.nodes[nodes[3]].x) + n1::Vec{3} = transformation(getcoordinates(getnodes(grid, nodes[1]))) + n2::Vec{3} = transformation(getcoordinates(getnodes(grid, nodes[2]))) + n3::Vec{3} = transformation(getcoordinates(getnodes(grid, nodes[3]))) n = (n3 - n2) × (n1 - n2) return n / norm(n) end @@ -1467,10 +1468,10 @@ function __check_periodic_faces(grid::Grid, fi::FaceIndex, fj::FaceIndex, known_ end # 2. Find the periodic direction using the vector between the midpoint of the faces - xmi = sum(grid.nodes[i].x for i in nodes_i) / length(nodes_i) - xmj = sum(grid.nodes[i].x for i in nodes_j) / length(nodes_j) + xmi = sum(getcoordinates(getnodes(grid, i)) for i in nodes_i) / length(nodes_i) + xmj = sum(getcoordinates(getnodes(grid, j)) for j in nodes_j) / length(nodes_j) xmij = xmj - xmi - h = 2 * norm(xmj - grid.nodes[nodes_j[1]].x) # Approximate element size + h = 2 * norm(xmj - getcoordinates(getnodes(grid, nodes_j[1]))) # Approximate element size TOLh = TOL * h found = false local len @@ -1486,11 +1487,11 @@ function __check_periodic_faces(grid::Grid, fi::FaceIndex, fj::FaceIndex, known_ # 3. Check that the first node of fj have a corresponding node in fi # In this method faces are mirrored (opposite normal vectors) so reverse the nodes nodes_i = circshift_tuple(reverse(nodes_i), 1) - xj = grid.nodes[nodes_j[1]].x + xj = getcoordinates(getnodes(grid, nodes_j[1])) node_rot = 0 found = false for i in eachindex(nodes_i) - xi = grid.nodes[nodes_i[i]].x + xi = getcoordinates(getnodes(grid, nodes_i[i])) xij = xj - xi if norm(xij - xmij) < TOLh found = true @@ -1502,8 +1503,8 @@ function __check_periodic_faces(grid::Grid, fi::FaceIndex, fj::FaceIndex, known_ # 4. Check the remaining nodes for the same criteria, now with known node_rot for j in 2:length(nodes_j) - xi = grid.nodes[nodes_i[mod1(j + node_rot, end)]].x - xj = grid.nodes[nodes_j[j]].x + xi = getcoordinates(getnodes(grid, nodes_i[mod1(j + node_rot, end)])) + xj = getcoordinates(getnodes(grid, nodes_j[j])) xij = xj - xi if norm(xij - xmij) >= TOLh return nothing @@ -1549,14 +1550,14 @@ function __check_periodic_faces_f(grid::Grid, fi::FaceIndex, fj::FaceIndex, xmi, # 2. Compute the relative rotation xmij = xmj - xmi - h = 2 * norm(xmj - grid.nodes[nodes_j[1]].x) # Approximate element size + h = 2 * norm(xmj - getcoordinates(getnodes(grid, nodes_j[1]))) # Approximate element size TOLh = TOL * h nodes_i = mirror ? circshift_tuple(reverse(nodes_i), 1) : nodes_i # reverse if necessary - xj = transformation(grid.nodes[nodes_j[1]].x) + xj = transformation(getcoordinates(getnodes(grid, nodes_j[1]))) node_rot = 0 found = false for i in eachindex(nodes_i) - xi = grid.nodes[nodes_i[i]].x + xi = getcoordinates(getnodes(grid, nodes_i[i])) xij = xj - xi if norm(xij - xmij) < TOLh found = true diff --git a/src/Dofs/DofHandler.jl b/src/Dofs/DofHandler.jl index f62f56ae73..16945ecd80 100644 --- a/src/Dofs/DofHandler.jl +++ b/src/Dofs/DofHandler.jl @@ -1,5 +1,17 @@ abstract type AbstractDofHandler end +""" + get_grid(dh::AbstractDofHandler) + +Access some grid representation for the dof handler. + +!!! note + This API function is currently not well-defined. It acts as the interface between + distributed assembly and assembly on a single process, because most parts of the + functionality can be handled by only acting on the locally owned cell set. +""" +get_grid(dh::AbstractDofHandler) + """ Field(name::Symbol, interpolation::Interpolation, dim::Int) @@ -51,6 +63,9 @@ mutable struct FieldHandler end end +# Shortcut +@inline getcelltype(grid::AbstractGrid, fh::FieldHandler) = getcelltype(grid, first(fh.cellset)) + """ DofHandler(grid::Grid) @@ -95,6 +110,7 @@ function Base.show(io::IO, ::MIME"text/plain", dh::DofHandler) end isclosed(dh::AbstractDofHandler) = dh.closed[] +get_grid(dh::DofHandler) = dh.grid """ ndofs(dh::AbstractDofHandler) @@ -111,11 +127,11 @@ Return the number of degrees of freedom for the cell with index `cell`. See also [`ndofs`](@ref). """ function ndofs_per_cell(dh::DofHandler, cell::Int=1) - @boundscheck 1 <= cell <= getncells(dh.grid) + @boundscheck 1 <= cell <= getncells(get_grid(dh)) return @inbounds ndofs_per_cell(dh.fieldhandlers[dh.cell_to_fieldhandler[cell]]) end ndofs_per_cell(fh::FieldHandler) = fh.ndofs_per_cell -nnodes_per_cell(dh::DofHandler, cell::Int=1) = nnodes_per_cell(dh.grid, cell) # TODO: deprecate, shouldn't belong to DofHandler any longer +nnodes_per_cell(dh::DofHandler, cell::Int=1) = nnodes_per_cell(get_grid(dh), cell) # TODO: deprecate, shouldn't belong to DofHandler any longer """ celldofs!(global_dofs::Vector{Int}, dh::AbstractDofHandler, i::Int) @@ -144,11 +160,11 @@ end #TODO: perspectively remove in favor of `getcoordinates!(global_coords, grid, i)`? function cellcoords!(global_coords::Vector{Vec{dim,T}}, dh::DofHandler, i::Union{Int, <:AbstractCell}) where {dim,T} - cellcoords!(global_coords, dh.grid, i) + cellcoords!(global_coords, get_grid(dh), i) end function cellnodes!(global_nodes::Vector{Int}, dh::DofHandler, i::Union{Int, <:AbstractCell}) - cellnodes!(global_nodes, dh.grid, i) + cellnodes!(global_nodes, get_grid(dh), i) end """ @@ -188,11 +204,11 @@ Add all fields of the [`FieldHandler`](@ref) `fh` to `dh`. function add!(dh::DofHandler, fh::FieldHandler) # TODO: perhaps check that a field with the same name is the same field? @assert !isclosed(dh) - _check_same_celltype(dh.grid, collect(fh.cellset)) + _check_same_celltype(get_grid(dh), collect(fh.cellset)) _check_cellset_intersections(dh, fh) # the field interpolations should have the same refshape as the cells they are applied to # extract the celltype from the first cell as the celltypes are all equal - cell_type = typeof(dh.grid.cells[first(fh.cellset)]) + cell_type = getcelltype(get_grid(dh), fh) refshape_cellset = getrefshape(default_interpolation(cell_type)) for interpolation in fh.field_interpolations refshape = getrefshape(interpolation) @@ -219,11 +235,11 @@ celltypes, [`add!(dh::DofHandler, fh::FieldHandler)`](@ref) must be used instead function add!(dh::DofHandler, name::Symbol, ip::Interpolation) @assert !isclosed(dh) - celltype = getcelltype(dh.grid) + celltype = getcelltype(get_grid(dh)) @assert isconcretetype(celltype) if length(dh.fieldhandlers) == 0 - cellset = Set(1:getncells(dh.grid)) + cellset = Set(1:getncells(get_grid(dh))) push!(dh.fieldhandlers, FieldHandler(Field[], cellset)) elseif length(dh.fieldhandlers) > 1 error("If you have more than one FieldHandler, you must specify field") @@ -279,7 +295,7 @@ function __close!(dh::DofHandler{dim}) where {dim} # `vertexdict` keeps track of the visited vertices. The first dof added to vertex v is # stored in vertexdict[v]. # TODO: No need to allocate this vector for fields that don't have vertex dofs - vertexdicts = [zeros(Int, getnnodes(dh.grid)) for _ in 1:numfields] + vertexdicts = [zeros(Int, getnnodes(get_grid(dh))) for _ in 1:numfields] # `edgedict` keeps track of the visited edges, this will only be used for a 3D problem. # An edge is uniquely determined by two global vertices, with global direction going @@ -377,7 +393,7 @@ function _close_fieldhandler!(dh::DofHandler{sdim}, fh::FieldHandler, fh_index:: @assert dh.cell_to_fieldhandler[ci] == 0 dh.cell_to_fieldhandler[ci] = fh_index - cell = getcells(dh.grid, ci) + cell = getcells(get_grid(dh), ci) len_cell_dofs_start = length(dh.cell_dofs) dh.cell_dofs_offset[ci] = len_cell_dofs_start + 1 @@ -819,10 +835,10 @@ function _evaluate_at_grid_nodes(dh::DofHandler, u::Vector{T}, fieldname::Symbol # VTK output of solution field (or L2 projected scalar data) n_c = n_components(ip) vtk_dim = n_c == 2 ? 3 : n_c # VTK wants vectors padded to 3D - data = fill(NaN * zero(T), vtk_dim, getnnodes(dh.grid)) + data = fill(NaN * zero(T), vtk_dim, getnnodes(get_grid(dh))) else # Just evalutation at grid nodes - data = fill(NaN * zero(RT), getnnodes(dh.grid)) + data = fill(NaN * zero(RT), getnnodes(get_grid(dh))) end # Loop over the fieldhandlers for fh in dh.fieldhandlers @@ -830,7 +846,7 @@ function _evaluate_at_grid_nodes(dh::DofHandler, u::Vector{T}, fieldname::Symbol field_idx = _find_field(fh, fieldname) field_idx === nothing && continue # Set up CellValues with the local node coords as quadrature points - CT = getcelltype(dh.grid, first(fh.cellset)) + CT = getcelltype(get_grid(dh), fh) ip_geo = default_interpolation(CT) local_node_coords = reference_coordinates(ip_geo) qr = QuadratureRule{getrefshape(ip)}(zeros(length(local_node_coords)), local_node_coords) diff --git a/src/Dofs/apply_analytical.jl b/src/Dofs/apply_analytical.jl index fc03b884bb..351b744cf6 100644 --- a/src/Dofs/apply_analytical.jl +++ b/src/Dofs/apply_analytical.jl @@ -1,13 +1,12 @@ function _default_interpolations(dh::DofHandler) fhs = dh.fieldhandlers - getcelltype(i) = typeof(getcells(dh.grid, first(fhs[i].cellset))) - ntuple(i -> default_interpolation(getcelltype(i)), length(fhs)) + ntuple(i -> default_interpolation(getcelltype(get_grid(dh), fhs[i])), length(fhs)) end """ apply_analytical!( a::AbstractVector, dh::AbstractDofHandler, fieldname::Symbol, - f::Function, cellset=1:getncells(dh.grid)) + f::Function, cellset=1:getncells(get_grid(dh))) Apply a solution `f(x)` by modifying the values in the degree of freedom vector `a` pertaining to the field `fieldname` for all cells in `cellset`. @@ -27,7 +26,7 @@ This function can be used to apply initial conditions for time dependent problem """ function apply_analytical!( a::AbstractVector, dh::DofHandler, fieldname::Symbol, f::Function, - cellset = 1:getncells(dh.grid)) + cellset = 1:getncells(get_grid(dh))) fieldname ∉ getfieldnames(dh) && error("The fieldname $fieldname was not found in the dof handler") ip_geos = _default_interpolations(dh) @@ -38,8 +37,8 @@ function apply_analytical!( ip_fun = getfieldinterpolation(fh, field_idx) field_dim = getfielddim(fh, field_idx) celldofinds = dof_range(fh, fieldname) - set_intersection = if length(cellset) == length(fh.cellset) == getncells(dh.grid) - BitSet(1:getncells(dh.grid)) + set_intersection = if length(cellset) == length(fh.cellset) == getncells(get_grid(dh)) + BitSet(1:getncells(get_grid(dh))) else intersect(BitSet(fh.cellset), BitSet(cellset)) end @@ -52,7 +51,7 @@ function _apply_analytical!( a::AbstractVector, dh::AbstractDofHandler, celldofinds, field_dim, ip_fun::Interpolation{RefShape}, ip_geo::Interpolation, f::Function, cellset) where {dim, RefShape<:AbstractRefShape{dim}} - coords = getcoordinates(dh.grid, first(cellset)) + coords = getcoordinates(get_grid(dh), first(cellset)) ref_points = reference_coordinates(ip_fun) dummy_weights = zeros(length(ref_points)) qr = QuadratureRule{RefShape}(dummy_weights, ref_points) @@ -65,7 +64,7 @@ function _apply_analytical!( length(f(first(coords))) == field_dim || error("length(f(x)) must be equal to dimension of the field ($field_dim)") for cellnr in cellset - getcoordinates!(coords, dh.grid, cellnr) + getcoordinates!(coords, get_grid(dh), cellnr) celldofs!(c_dofs, dh, cellnr) for (i, celldofind) in enumerate(celldofinds) f_dofs[i] = c_dofs[celldofind] diff --git a/src/Export/VTK.jl b/src/Export/VTK.jl index f077ea349b..8ddb5e4cef 100644 --- a/src/Export/VTK.jl +++ b/src/Export/VTK.jl @@ -27,7 +27,7 @@ The keyword arguments are forwarded to `WriteVTK.vtk_grid`, see """ function WriteVTK.vtk_grid(filename::AbstractString, grid::Grid{dim,C,T}; kwargs...) where {dim,C,T} cls = MeshCell[] - for cell in grid.cells + for cell in getcells(grid) celltype = Ferrite.cell_to_vtkcell(typeof(cell)) push!(cls, MeshCell(celltype, nodes_to_vtkorder(cell))) end @@ -35,7 +35,7 @@ function WriteVTK.vtk_grid(filename::AbstractString, grid::Grid{dim,C,T}; kwargs return vtk_grid(filename, coords, cls; kwargs...) end function WriteVTK.vtk_grid(filename::AbstractString, dh::AbstractDofHandler; kwargs...) - vtk_grid(filename, dh.grid; kwargs...) + vtk_grid(filename, get_grid(dh); kwargs...) end function toparaview!(v, x::Vec{D}) where D diff --git a/src/Grid/grid.jl b/src/Grid/grid.jl index db945da170..81d874553c 100644 --- a/src/Grid/grid.jl +++ b/src/Grid/grid.jl @@ -13,10 +13,23 @@ struct Node{dim,T} x::Vec{dim,T} end Node(x::NTuple{dim,T}) where {dim,T} = Node(Vec{dim,T}(x)) + +""" + getcoordinates(::Node) + +Get the value of the node coordinate. +""" getcoordinates(n::Node) = n.x """ - Ferrite.get_coordinate_eltype(::Node) + get_coordinate_type(::Node) + +Get the data type of the the node coordinate. +""" +get_coordinate_type(::Node{dim,T}) where {dim,T} = Vec{dim,T} + +""" + get_coordinate_eltype(::Node) Get the data type of the components of the nodes coordinate. """ @@ -540,6 +553,13 @@ end ########################## # Grid utility functions # ########################## +""" + get_coordinate_type(::AbstractGrid) + +Get the datatype for a single point in the grid. +""" +get_coordinate_type(grid::Grid{dim,C,T}) where {dim,C,T} = Vec{dim,T} # Node is baked into the mesh type. + """ getneighborhood(top::ExclusiveTopology, grid::AbstractGrid, cellidx::CellIndex, include_self=false) getneighborhood(top::ExclusiveTopology, grid::AbstractGrid, faceidx::FaceIndex, include_self=false) @@ -723,7 +743,7 @@ Returns all vertex sets of the grid. """ @inline getvertexsets(grid::AbstractGrid) = grid.vertexsets -n_faces_per_cell(grid::Grid) = nfaces(eltype(grid.cells)) +n_faces_per_cell(grid::Grid) = nfaces(getcelltype(grid)) # Transformations """ diff --git a/src/L2_projection.jl b/src/L2_projection.jl index a41371ef96..7cfd954c42 100644 --- a/src/L2_projection.jl +++ b/src/L2_projection.jl @@ -37,7 +37,7 @@ function L2Projector( grid::AbstractGrid; qr_lhs::QuadratureRule = _mass_qr(func_ip), set = 1:getncells(grid), - geom_ip::Interpolation = default_interpolation(typeof(grid.cells[first(set)])), + geom_ip::Interpolation = default_interpolation(getcelltype(grid, first(set))), ) # TODO: Maybe this should not be allowed? We always assume to project scalar entries. @@ -73,7 +73,7 @@ _mass_qr(ip::VectorizedInterpolation) = _mass_qr(ip.ip) function Base.show(io::IO, ::MIME"text/plain", proj::L2Projector) println(io, typeof(proj)) - println(io, " projection on: ", length(proj.set), "/", getncells(proj.dh.grid), " cells in grid") + println(io, " projection on: ", length(proj.set), "/", getncells(get_grid(proj.dh)), " cells in grid") println(io, " function interpolation: ", proj.func_ip) println(io, " geometric interpolation: ", proj.geom_ip) end @@ -100,7 +100,7 @@ function _assemble_L2_matrix(fe_values, set, dh) celldofs!(cell_dofs, dh, cellnum) fill!(Me, 0) - Xe = getcoordinates(dh.grid, cellnum) + Xe = getcoordinates(get_grid(dh), cellnum) reinit!(fe_values, Xe) ## ∭( v ⋅ u )dΩ @@ -195,7 +195,7 @@ function _project(vars, proj::L2Projector, fe_values::AbstractValues, M::Integer for (ic,cellnum) in enumerate(proj.set) celldofs!(cell_dofs, proj.dh, cellnum) fill!(fe, 0) - Xe = getcoordinates(proj.dh.grid, cellnum) + Xe = getcoordinates(get_grid(proj.dh), cellnum) cell_vars = vars[ic] reinit!(fe_values, Xe) @@ -226,7 +226,7 @@ end function WriteVTK.vtk_point_data(vtk::WriteVTK.DatasetFile, proj::L2Projector, vals::Vector{T}, name::AbstractString) where T data = _evaluate_at_grid_nodes(proj, vals, #=vtk=# Val(true))::Matrix - @assert size(data, 2) == getnnodes(proj.dh.grid) + @assert size(data, 2) == getnnodes(get_grid(proj.dh)) vtk_point_data(vtk, data, name; component_names=component_names(T)) return vtk end @@ -248,9 +248,9 @@ function _evaluate_at_grid_nodes( @assert ndofs(dh) == length(vals) if vtk nout = S <: Vec{2} ? 3 : M # Pad 2D Vec to 3D - data = fill(T(NaN), nout, getnnodes(dh.grid)) + data = fill(T(NaN), nout, getnnodes(get_grid(dh))) else - data = fill(NaN * zero(S), getnnodes(dh.grid)) + data = fill(NaN * zero(S), getnnodes(get_grid(dh))) end ip, gip = proj.func_ip, proj.geom_ip refdim, refshape = getdim(ip), getrefshape(ip) diff --git a/src/iterators.jl b/src/iterators.jl index 12ba37f6fc..ccc0c8933b 100644 --- a/src/iterators.jl +++ b/src/iterators.jl @@ -58,12 +58,12 @@ function CellCache(grid::Grid{dim,C,T}, flags::UpdateFlags=UpdateFlags()) where end function CellCache(dh::DofHandler{dim}, flags::UpdateFlags=UpdateFlags()) where {dim} - N = nnodes_per_cell(dh.grid) + N = nnodes_per_cell(get_grid(dh)) nodes = zeros(Int, N) - coords = zeros(Vec{dim, get_coordinate_eltype(dh.grid)}, N) + coords = zeros(Vec{dim, get_coordinate_eltype(get_grid(dh))}, N) n = ndofs_per_cell(dh) celldofs = zeros(Int, n) - return CellCache(flags, dh.grid, ScalarWrapper(-1), nodes, coords, dh, celldofs) + return CellCache(flags, get_grid(dh), ScalarWrapper(-1), nodes, coords, dh, celldofs) end function reinit!(cc::CellCache, i::Int) @@ -141,13 +141,13 @@ function CellIterator(gridordh::Union{Grid,AbstractDofHandler}, set::Union{IntegerCollection,Nothing}=nothing, flags::UpdateFlags=UpdateFlags()) if set === nothing - grid = gridordh isa AbstractDofHandler ? gridordh.grid : gridordh + grid = gridordh isa AbstractDofHandler ? get_grid(gridordh) : gridordh set = 1:getncells(grid) end - if gridordh isa DofHandler && !isconcretetype(getcelltype(gridordh.grid)) + if gridordh isa DofHandler && !isconcretetype(getcelltype(get_grid(gridordh))) # TODO: Since the CellCache is resizeable this is not really necessary to check # here, but might be useful to catch slow code paths? - _check_same_celltype(gridordh.grid, set) + _check_same_celltype(get_grid(gridordh), set) end return CellIterator(CellCache(gridordh, flags), set) end @@ -170,8 +170,8 @@ Base.length(ci::CellIterator) = length(ci.set) function _check_same_celltype(grid::AbstractGrid, cellset) - celltype = typeof(grid.cells[first(cellset)]) - if !all(typeof(grid.cells[i]) == celltype for i in cellset) + celltype = getcelltype(grid, first(cellset)) + if !all(getcelltype(grid, i) == celltype for i in cellset) error("The cells in the cellset are not all of the same celltype.") end end