diff --git a/foyer/atomtyper.py b/foyer/atomtyper.py index d3c525ce..1117781b 100644 --- a/foyer/atomtyper.py +++ b/foyer/atomtyper.py @@ -77,14 +77,19 @@ def find_atomtypes(structure, forcefield, max_iter=10): elif isinstance(structure, Topology): topology_graph = TopologyGraph.from_gmso_topology(structure) - forcefield = AtomTypingRulesProvider.from_foyer_forcefield(forcefield) + if isinstance(forcefield, Forcefield): + atomtype_rules = AtomTypingRulesProvider.from_foyer_forcefield( + forcefield + ) + elif isinstance(forcefield, AtomTypingRulesProvider): + atomtype_rules = forcefield typemap = { atom_index: {"whitelist": set(), "blacklist": set(), "atomtype": None} for atom_index in topology_graph.atoms(data=False) } - rules = _load_rules(forcefield, typemap) + rules = _load_rules(atomtype_rules, typemap) # Only consider rules for elements found in topology subrules = dict() @@ -94,7 +99,7 @@ def find_atomtypes(structure, forcefield, max_iter=10): # First add non-element types, which are strings, then elements name = atom_data.name if name.startswith("_"): - if name in forcefield.non_element_types: + if name in atomtype_rules.non_element_types: system_elements.add(name) else: atomic_number = atom_data.atomic_number @@ -143,7 +148,7 @@ def find_atomtypes(structure, forcefield, max_iter=10): def _load_rules(rules_provider, typemap): - """Load atomtyping rules from a forcefield into SMARTSGraphs.""" + """Load atomtyping rules from a AtomTypingRulesProvider into SMARTSGraphs.""" rules = dict() # For every SMARTS string in the force field, # create a SMARTSGraph object