diff --git a/dtreeviz/utils.py b/dtreeviz/utils.py index 73ceb74a..5d340f75 100644 --- a/dtreeviz/utils.py +++ b/dtreeviz/utils.py @@ -89,7 +89,7 @@ def scale_SVG(svg:str, scale:float) -> str: Convert: - + To: @@ -109,7 +109,7 @@ def scale_SVG(svg:str, scale:float) -> str: ns = {"svg": "http://www.w3.org/2000/svg"} graph = root.find(".//svg:g", ns) # get first node, which is graph transform = graph.attrib['transform'] - transform = transform.replace('scale(1.0 1.0)', f'scale({scale} {scale})') + transform = transform.replace('scale(1 1)', f'scale({scale} {scale})') graph.set("transform", transform) ET.register_namespace('', "http://www.w3.org/2000/svg")