diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index cf23828c420b..2e3803446e42 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3190,6 +3190,7 @@ def write_excel( row_totals=row_totals, sparklines=sparklines, formulas=formulas, + autofilter=autofilter, ) # normalise cell refs (eg: "B3" => (2,1)) and establish table start/finish, diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index 6509eaa54390..c659329d36cc 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -330,6 +330,7 @@ def _xl_setup_table_columns( row_totals: RowTotalsDefinition | None = None, float_precision: int = 3, table_style: dict[str, Any] | str | None = None, + autofilter: bool = True, # noqa: FBT001 ) -> tuple[list[dict[str, Any]], dict[str | tuple[str, ...], str], DataFrame]: """Setup and unify all column-related formatting/defaults.""" @@ -450,14 +451,25 @@ def _map_str(s: Series) -> Series: dtype_formats[tp] = fmt # associate formats/functions with specific columns + header_dict = header_format or {} + add_align = "align" not in header_dict + header_alignment = {"align": "right"} + if autofilter: + header_alignment["indent"] = 2 # type: ignore[assignment] + col_header_format = {} for col, tp in df.schema.items(): base_type = tp.base_type() + header_fmt = header_format if base_type in dtype_formats: fmt = dtype_formats.get(tp, dtype_formats[base_type]) column_formats.setdefault(col, fmt) + if add_align and base_type in [*FLOAT_DTYPES, *INTEGER_DTYPES]: + header_fmt = {**header_dict, **header_alignment} if col not in column_formats: column_formats[col] = fmt_default + col_header_format[col] = format_cache.get(header_fmt) if header_fmt else None + # ensure externally supplied formats are made available for col, fmt in column_formats.items(): # type: ignore[assignment] if isinstance(fmt, str): @@ -473,9 +485,6 @@ def _map_str(s: Series) -> Series: fmt["valign"] = "vcenter" column_formats[col] = format_cache.get(fmt) - # optional custom header format - col_header_format = format_cache.get(header_format) if header_format else None - # assemble table columns table_columns = [ { @@ -483,7 +492,7 @@ def _map_str(s: Series) -> Series: for k, v in { "header": col, "format": column_formats[col], - "header_format": col_header_format, + "header_format": col_header_format.get(col), "total_function": column_total_funcs.get(col), "formula": ( row_total_funcs.get(col)