-
Notifications
You must be signed in to change notification settings - Fork 48
/
config.py
28 lines (21 loc) · 899 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from ...import_utils import torch_ort_version
from ..config import BackendConfig
@dataclass
class TorchORTConfig(BackendConfig):
name: str = "torch-ort"
version: Optional[str] = torch_ort_version()
_target_: str = "optimum_benchmark.backends.torch_ort.backend.TorchORTBackend"
# load options
no_weights: bool = False
torch_dtype: Optional[str] = None
# sdpa, which has became default of many architectures, fails with torch ort
attn_implementation: Optional[str] = "eager"
# peft options
peft_type: Optional[str] = None
peft_config: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
if self.device != "cuda":
raise ValueError(f"TorchORTBackend only supports CUDA devices, got {self.device}")