Skip to content

Commit

Permalink
feat(python): Enable collection with gpu engine (#17550)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored Jul 25, 2024
1 parent 8373cdb commit 4739460
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 5 deletions.
4 changes: 3 additions & 1 deletion py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
scan_parquet,
scan_pyarrow_dataset,
)
from polars.lazyframe import LazyFrame
from polars.lazyframe import GPUEngine, LazyFrame
from polars.meta import (
build_info,
get_index_type,
Expand Down Expand Up @@ -206,6 +206,8 @@
"Expr",
"LazyFrame",
"Series",
# Engine configuration
"GPUEngine",
# schema
"Schema",
# datatypes
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.lazyframe.engine_config import GPUEngine
from polars.selectors import _selector_proxy_

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -293,3 +294,6 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
]
SingleColSelector: TypeAlias = Union[SingleIndexSelector, SingleNameSelector]
MultiColSelector: TypeAlias = Union[MultiIndexSelector, MultiNameSelector, BooleanMask]

# LazyFrame engine selection
EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"]
2 changes: 2 additions & 0 deletions py-polars/polars/lazyframe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from polars.lazyframe.engine_config import GPUEngine
from polars.lazyframe.frame import LazyFrame

__all__ = [
"GPUEngine",
"LazyFrame",
]
40 changes: 40 additions & 0 deletions py-polars/polars/lazyframe/engine_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Mapping

from rmm.mr import DeviceMemoryResource # type: ignore[import-not-found]


class GPUEngine:
"""
Configuration options for the GPU execution engine.
Use this if you want control over details of the execution.
Supported options
- `device`: Select the device to run the query on.
- `memory_resource`: Set an RMM memory resource for
device-side allocations.
"""

device: int | None
"""Device on which to run query."""
memory_resource: DeviceMemoryResource | None
"""Memory resource to use for device allocations."""
config: Mapping[str, Any]
"""Additional configuration options for the engine."""

def __init__(
self,
*,
device: int | None = None,
memory_resource: Any | None = None,
**kwargs: Any,
) -> None:
self.device = device
self.memory_resource = memory_resource
self.config = kwargs
92 changes: 88 additions & 4 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import os
from datetime import date, datetime, time, timedelta
from functools import lru_cache, reduce
from functools import lru_cache, partial, reduce
from io import BytesIO, StringIO
from operator import and_
from pathlib import Path
Expand Down Expand Up @@ -78,6 +78,7 @@
from polars.datatypes.group import DataTypeGroup
from polars.dependencies import import_optional, subprocess
from polars.exceptions import PerformanceWarning
from polars.lazyframe.engine_config import GPUEngine
from polars.lazyframe.group_by import LazyGroupBy
from polars.lazyframe.in_process import InProcessQuery
from polars.schema import Schema
Expand All @@ -99,6 +100,7 @@
ClosedInterval,
ColumnNameOrSelector,
CsvQuoteStyle,
EngineType,
ExplainFormat,
FillNullStrategy,
FrameInitTypes,
Expand Down Expand Up @@ -1771,6 +1773,7 @@ def collect(
cluster_with_columns: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: EngineType = "cpu",
background: Literal[True],
_eager: bool = False,
) -> InProcessQuery: ...
Expand All @@ -1789,6 +1792,7 @@ def collect(
cluster_with_columns: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: EngineType = "cpu",
background: Literal[False] = False,
_eager: bool = False,
) -> DataFrame: ...
Expand All @@ -1806,6 +1810,7 @@ def collect(
cluster_with_columns: bool = True,
no_optimization: bool = False,
streaming: bool = False,
engine: EngineType = "cpu",
background: bool = False,
_eager: bool = False,
**_kwargs: Any,
Expand Down Expand Up @@ -1848,6 +1853,27 @@ def collect(
.. note::
Use :func:`explain` to see if Polars can process the query in streaming
mode.
engine
Select the engine used to process the query, optional.
If set to `"cpu"` (default), the query is run using the
polars CPU engine. If set to `"gpu"`, the GPU engine is
used. Fine-grained control over the GPU engine, for
example which device to use on a system with multiple
devices, is possible by providing a :class:`GPUEngine` object
with configuration options.
.. note::
GPU mode is considered **unstable**. Not all queries will run
successfully on the GPU, however, they should fall back transparently
to the default engine if execution is not supported.
Running with `POLARS_VERBOSE=1` will provide information if a query
falls back (and why).
.. note::
The GPU engine does not support streaming, or running in the
background. If either are enabled, then GPU execution is switched off.
background
Run the query in the background and get a handle to the query.
This handle can be used to fetch the result or cancel the query.
Expand Down Expand Up @@ -1904,6 +1930,36 @@ def collect(
│ b ┆ 11 ┆ 10 │
│ c ┆ 6 ┆ 1 │
└─────┴─────┴─────┘
Collect in GPU mode
>>> lf.group_by("a").agg(pl.all().sum()).collect(engine="gpu") # doctest: +SKIP
shape: (3, 3)
┌─────┬─────┬─────┐
│ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ b ┆ 11 ┆ 10 │
│ a ┆ 4 ┆ 10 │
│ c ┆ 6 ┆ 1 │
└─────┴─────┴─────┘
With control over the device used
>>> lf.group_by("a").agg(pl.all().sum()).collect(
... engine=pl.GPUEngine(device=1)
... ) # doctest: +SKIP
shape: (3, 3)
┌─────┬─────┬─────┐
│ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ b ┆ 11 ┆ 10 │
│ a ┆ 4 ┆ 10 │
│ c ┆ 6 ┆ 1 │
└─────┴─────┴─────┘
"""
new_streaming = _kwargs.get("new_streaming", False)

Expand All @@ -1918,6 +1974,21 @@ def collect(
if streaming:
issue_unstable_warning("Streaming mode is considered unstable.")

is_gpu = (is_config_obj := isinstance(engine, GPUEngine)) or engine == "gpu"
if not (is_config_obj or engine in ("cpu", "gpu")):
msg = f"Invalid engine argument {engine=}"
raise ValueError(msg)
if (streaming or background or new_streaming) and is_gpu:
issue_warning(
"GPU engine does not support streaming or background collection, "
"disabling GPU engine.",
category=UserWarning,
)
is_gpu = False
if _eager:
# Don't run on GPU in _eager mode (but don't warn)
is_gpu = False

ldf = self._ldf.optimization_toggle(
type_coercion,
predicate_pushdown,
Expand All @@ -1936,9 +2007,22 @@ def collect(
issue_unstable_warning("Background mode is considered unstable.")
return InProcessQuery(ldf.collect_concurrently())

# Only for testing purposes atm.
callback = _kwargs.get("post_opt_callback")

callback = None
if is_gpu:
cudf_polars = import_optional(
"cudf_polars",
err_prefix="GPU engine requested, but required package",
install_message=(
"Please install using the command `pip install cudf-polars-cu12` "
"(or `pip install cudf-polars-cu11` if your system has a "
"CUDA 11 driver)."
),
)
if not is_config_obj:
engine = GPUEngine()
callback = partial(cudf_polars.execute_with_cudf, config=engine)
# Only for testing purposes
callback = _kwargs.get("post_opt_callback", callback)
return wrap_df(ldf.collect(callback))

@overload
Expand Down
3 changes: 3 additions & 0 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ plot = ["hvplot >= 0.9.1", "polars[pandas]"]
style = ["great-tables >= 0.8.0"]
timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"]

# GPU Engine
gpu = ["cudf-polars-cu12"]

# All
all = [
"polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]",
Expand Down
59 changes: 59 additions & 0 deletions py-polars/tests/unit/lazyframe/test_engine_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import polars as pl
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from polars._typing import EngineType


@pytest.fixture()
def df() -> pl.LazyFrame:
return pl.LazyFrame({"a": [1, 2, 3]})


@pytest.fixture(params=["gpu", pl.GPUEngine()])
def engine(request: pytest.FixtureRequest) -> EngineType:
value: EngineType = request.param
return value


def test_engine_selection_invalid_raises(df: pl.LazyFrame) -> None:
with pytest.raises(ValueError):
df.collect(engine="unknown") # type: ignore[call-overload]


def test_engine_selection_streaming_warns(df: pl.LazyFrame, engine: EngineType) -> None:
expect = df.collect()
with pytest.warns(
UserWarning, match="GPU engine does not support streaming or background"
):
got = df.collect(engine=engine, streaming=True)
assert_frame_equal(expect, got)


def test_engine_selection_background_warns(
df: pl.LazyFrame, engine: EngineType
) -> None:
expect = df.collect()
with pytest.warns(
UserWarning, match="GPU engine does not support streaming or background"
):
got = df.collect(engine=engine, background=True)
assert_frame_equal(expect, got.fetch_blocking())


def test_engine_selection_eager_quiet(df: pl.LazyFrame, engine: EngineType) -> None:
expect = df.collect()
# _eager collection turns off GPU engine quietly
got = df.collect(engine=engine, _eager=True)
assert_frame_equal(expect, got)


def test_engine_import_error_raises(df: pl.LazyFrame, engine: EngineType) -> None:
with pytest.raises(ImportError, match="GPU engine requested"):
df.collect(engine=engine)

0 comments on commit 4739460

Please sign in to comment.