diff --git a/docs/quantization_support.md b/docs/quantization_support.md index 5809e447..ea6a364c 100644 --- a/docs/quantization_support.md +++ b/docs/quantization_support.md @@ -22,6 +22,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite). | `log_softmax` | / | | `matmul` | / | | `mm` | / | +| `norm` | / | | `pad` | 1.7.0 | | `pow` | / | | `prelu` | / | @@ -46,6 +47,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite). | `torch.nn.LayerNorm` | / | | `torch.nn.LogSoftmax` | / | | `torch.nn.PReLU` | / | +| `torch.nn.RMSNorm` | / | | `torch.nn.RNN` | / | | `torch.nn.SiLU` | / | | `torch.nn.Softmax` | / | diff --git a/tinynn/graph/configs/gen_funcs_yml.py b/tinynn/graph/configs/gen_funcs_yml.py index 27680391..3d3b9360 100644 --- a/tinynn/graph/configs/gen_funcs_yml.py +++ b/tinynn/graph/configs/gen_funcs_yml.py @@ -160,7 +160,7 @@ def get_scope(ns): ver = '_'.join(ver.split('.')[:2]) # Stage 7: Functions in new versions may exist in current version -latest = '2_0' +latest = '2_4' if ver != latest: with open(f'torch_func_override_{latest}.yml', 'r') as f: d = yaml.load(f, yaml.SafeLoader) diff --git a/tinynn/graph/configs/torch_func_override_2_4.yml b/tinynn/graph/configs/torch_func_override_2_4.yml new file mode 100644 index 00000000..b4436e7d --- /dev/null +++ b/tinynn/graph/configs/torch_func_override_2_4.yml @@ -0,0 +1,1577 @@ +torch: +- _assert_async +- _conj_copy +- _functional_assert_async +- _fw_primal_copy +- _indices_copy +- _make_dual_copy +- _native_batch_norm_legit +- _neg_view_copy +- _reshape_alias_copy +- _rowwise_prune +- _segment_reduce +- _sparse_broadcast_to_copy +- _sym_acos +- _sym_asin +- _sym_atan +- _sym_cos +- _sym_cosh +- _sym_sin +- _sym_sinh +- _sym_sqrt +- _sym_tan +- _sym_tanh +- _values_copy +- abs +- absolute +- acos +- acosh +- adaptive_avg_pool1d +- adaptive_max_pool1d +- add +- addbmm +- addcdiv +- addcmul +- addmm +- addmv +- addr +- adjoint +- affine_grid_generator +- alias_copy +- all +- allclose +- alpha_dropout +- amax +- amin +- aminmax +- angle +- any +- arccos +- arccosh +- arcsin +- arcsinh +- arctan +- arctan2 +- arctanh +- argmax +- argmin +- argsort +- argwhere +- as_strided_copy +- as_strided_scatter +- asin +- asinh +- atan +- atan2 +- atanh +- atleast_1d +- atleast_2d +- atleast_3d +- autocast +- avg_pool1d +- baddbmm +- batch_norm +- batch_norm_backward_elemt +- batch_norm_backward_reduce +- batch_norm_elemt +- batch_norm_gather_stats +- batch_norm_gather_stats_with_counts +- batch_norm_stats +- batch_norm_update_stats +- bernoulli +- bilinear +- binary_cross_entropy_with_logits +- bincount +- binomial +- bitwise_and +- bitwise_left_shift +- bitwise_not +- bitwise_or +- bitwise_right_shift +- bitwise_xor +- block_diag +- bmm +- broadcast_tensors +- broadcast_to +- bucketize +- cartesian_prod +- cat +- ccol_indices_copy +- cdist +- ceil +- celu +- chain_matmul +- channel_shuffle +- cholesky +- cholesky_inverse +- cholesky_solve +- choose_qparams_optimized +- chunk +- clamp +- clamp_max +- clamp_min +- clip +- clone +- col_indices_copy +- column_stack +- combinations +- complex +- concat +- concatenate +- conj +- conj_physical +- constant_pad_nd +- conv1d +- conv2d +- conv3d +- conv_tbc +- conv_transpose1d +- conv_transpose2d +- conv_transpose3d +- convolution +- copysign +- corrcoef +- cos +- cosh +- cosine_embedding_loss +- cosine_similarity +- count_nonzero +- cov +- cross +- crow_indices_copy +- ctc_loss +- cummax +- cummin +- cumprod +- cumsum +- cumulative_trapezoid +- deg2rad +- dequantize +- det +- detach +- detach_copy +- diag +- diag_embed +- diagflat +- diagonal +- diagonal_copy +- diagonal_scatter +- diff +- digamma +- dist +- div +- divide +- dot +- dropout +- dsmm +- dsplit +- dstack +- einsum +- embedding +- embedding_bag +- empty_like +- eq +- equal +- erf +- erfc +- erfinv +- exp +- exp2 +- expand_copy +- expm1 +- fake_quantize_per_channel_affine +- fake_quantize_per_tensor_affine +- fbgemm_linear_fp16_weight +- fbgemm_linear_fp16_weight_fp32_activation +- fbgemm_linear_int8_weight +- fbgemm_linear_int8_weight_fp32_activation +- fbgemm_linear_quantize_weight +- fbgemm_pack_gemm_matrix_fp16 +- fbgemm_pack_quantized_matrix +- feature_alpha_dropout +- feature_dropout +- fix +- flatten +- flip +- fliplr +- flipud +- float_power +- floor +- floor_divide +- fmax +- fmin +- fmod +- frac +- frexp +- frobenius_norm +- full_like +- fused_moving_avg_obs_fake_quant +- gather +- gcd +- ge +- geqrf +- ger +- gradient +- greater +- greater_equal +- grid_sampler +- grid_sampler_2d +- grid_sampler_3d +- group_norm +- gru +- gru_cell +- gt +- hardshrink +- heaviside +- hinge_embedding_loss +- histc +- histogram +- histogramdd +- hsmm +- hsplit +- hspmm +- hstack +- hypot +- i0 +- igamma +- igammac +- imag +- index_add +- index_copy +- index_fill +- index_put +- index_reduce +- index_select +- indices_copy +- inner +- instance_norm +- int_repr +- inverse +- is_complex +- is_conj +- is_distributed +- is_floating_point +- is_inference +- is_neg +- is_nonzero +- is_same_size +- is_signed +- isclose +- isfinite +- isin +- isinf +- isnan +- isneginf +- isposinf +- isreal +- istft +- kl_div +- kron +- kthvalue +- layer_norm +- lcm +- ldexp +- le +- lerp +- less +- less_equal +- lgamma +- lobpcg +- log +- log10 +- log1p +- log2 +- log_softmax +- logaddexp +- logaddexp2 +- logcumsumexp +- logdet +- logical_and +- logical_not +- logical_or +- logical_xor +- logit +- logsumexp +- lstm +- lstm_cell +- lt +- lu_solve +- lu_unpack +- margin_ranking_loss +- masked_fill +- masked_scatter +- masked_select +- matmul +- matrix_exp +- matrix_power +- max +- max_pool1d +- max_pool1d_with_indices +- max_pool2d +- max_pool3d +- maximum +- mean +- median +- meshgrid +- min +- minimum +- miopen_batch_norm +- miopen_convolution +- miopen_convolution_add_relu +- miopen_convolution_relu +- miopen_convolution_transpose +- miopen_depthwise_convolution +- miopen_rnn +- mm +- mode +- moveaxis +- movedim +- msort +- mul +- multinomial +- multiply +- mv +- mvlgamma +- nan_to_num +- nanmean +- nanmedian +- nanquantile +- nansum +- narrow +- narrow_copy +- native_batch_norm +- native_channel_shuffle +- native_dropout +- native_group_norm +- native_layer_norm +- native_norm +- ne +- neg +- negative +- nextafter +- nonzero +- nonzero_static +- norm +- norm_except_dim +- not_equal +- nuclear_norm +- numel +- ones_like +- orgqr +- ormqr +- outer +- pairwise_distance +- pdist +- permute +- permute_copy +- pinverse +- pixel_shuffle +- pixel_unshuffle +- poisson +- poisson_nll_loss +- polar +- polygamma +- positive +- pow +- prelu +- prod +- put +- q_per_channel_axis +- q_per_channel_scales +- q_per_channel_zero_points +- q_scale +- q_zero_point +- qr +- quantile +- quantize_per_channel +- quantize_per_tensor +- quantize_per_tensor_dynamic +- quantized_batch_norm +- quantized_gru_cell +- quantized_lstm_cell +- quantized_max_pool1d +- quantized_max_pool2d +- quantized_max_pool3d +- quantized_rnn_relu_cell +- quantized_rnn_tanh_cell +- rad2deg +- rand_like +- randint_like +- randn_like +- ravel +- real +- reciprocal +- relu +- remainder +- renorm +- repeat_interleave +- reshape +- resolve_conj +- resolve_neg +- rms_norm +- rnn_relu +- rnn_relu_cell +- rnn_tanh +- rnn_tanh_cell +- roll +- rot90 +- round +- row_indices_copy +- row_stack +- rrelu +- rsqrt +- rsub +- saddmm +- scatter +- scatter_add +- scatter_reduce +- searchsorted +- segment_reduce +- select +- select_copy +- select_scatter +- selu +- sgn +- sigmoid +- sign +- signbit +- sin +- sinc +- sinh +- slice_copy +- slice_inverse +- slice_scatter +- slogdet +- smm +- softmax +- sort +- split +- split_copy +- split_with_sizes +- split_with_sizes_copy +- spmm +- sqrt +- square +- squeeze +- squeeze_copy +- sspaddmm +- stack +- std +- std_mean +- stft +- sub +- subtract +- sum +- svd +- swapaxes +- swapdims +- sym_float +- sym_int +- sym_ite +- sym_max +- sym_min +- sym_not +- sym_sqrt +- t +- t_copy +- take +- take_along_dim +- tan +- tanh +- tensor_split +- tensordot +- threshold +- tile +- topk +- trace +- transpose +- transpose_copy +- trapezoid +- trapz +- triangular_solve +- tril +- triplet_margin_loss +- triu +- true_divide +- trunc +- unbind +- unbind_copy +- unflatten +- unfold_copy +- unique_consecutive +- unravel_index +- unsafe_chunk +- unsafe_split +- unsafe_split_with_sizes +- unsqueeze +- unsqueeze_copy +- values_copy +- var +- var_mean +- vdot +- view_as_complex +- view_as_complex_copy +- view_as_real +- view_as_real_copy +- view_copy +- vsplit +- vstack +- where +- xlogy +- zeros_like +torch.Tensor: +- __abs__ +- __add__ +- __and__ +- __array__ +- __array_wrap__ +- __bool__ +- __complex__ +- __contains__ +- __deepcopy__ +- __delitem__ +- __div__ +- __dlpack__ +- __dlpack_device__ +- __eq__ +- __float__ +- __floordiv__ +- __format__ +- __ge__ +- __getitem__ +- __gt__ +- __iadd__ +- __iand__ +- __idiv__ +- __ifloordiv__ +- __ilshift__ +- __imod__ +- __imul__ +- __index__ +- __int__ +- __invert__ +- __ior__ +- __ipow__ +- __irshift__ +- __isub__ +- __itruediv__ +- __ixor__ +- __le__ +- __len__ +- __long__ +- __lshift__ +- __lt__ +- __matmul__ +- __mod__ +- __mul__ +- __ne__ +- __neg__ +- __nonzero__ +- __or__ +- __pos__ +- __pow__ +- __radd__ +- __rand__ +- __rdiv__ +- __reduce_ex__ +- __repr__ +- __reversed__ +- __rfloordiv__ +- __rlshift__ +- __rmatmul__ +- __rmod__ +- __rmul__ +- __ror__ +- __rpow__ +- __rrshift__ +- __rshift__ +- __rsub__ +- __rtruediv__ +- __rxor__ +- __setitem__ +- __setstate__ +- __sub__ +- __truediv__ +- __xor__ +- _autocast_to_full_precision +- _autocast_to_reduced_precision +- _coalesced_ +- _dimI +- _dimV +- _indices +- _is_view +- _nested_tensor_size +- _nested_tensor_storage_offsets +- _nested_tensor_strides +- _nnz +- _sparse_mask_projection +- _to_dense +- _update_names +- _values +- abs +- abs_ +- absolute +- absolute_ +- acos +- acos_ +- acosh +- acosh_ +- add +- add_ +- addbmm +- addbmm_ +- addcdiv +- addcdiv_ +- addcmul +- addcmul_ +- addmm +- addmm_ +- addmv +- addmv_ +- addr +- addr_ +- adjoint +- align_as +- align_to +- all +- allclose +- amax +- amin +- aminmax +- angle +- any +- apply_ +- arccos +- arccos_ +- arccosh +- arccosh_ +- arcsin +- arcsin_ +- arcsinh +- arcsinh_ +- arctan +- arctan2 +- arctan2_ +- arctan_ +- arctanh +- arctanh_ +- argmax +- argmin +- argsort +- argwhere +- as_strided +- as_strided_ +- as_strided_scatter +- asin +- asin_ +- asinh +- asinh_ +- atan +- atan2 +- atan2_ +- atan_ +- atanh +- atanh_ +- backward +- baddbmm +- baddbmm_ +- bernoulli +- bernoulli_ +- bfloat16 +- bincount +- bitwise_and +- bitwise_and_ +- bitwise_left_shift +- bitwise_left_shift_ +- bitwise_not +- bitwise_not_ +- bitwise_or +- bitwise_or_ +- bitwise_right_shift +- bitwise_right_shift_ +- bitwise_xor +- bitwise_xor_ +- bmm +- bool +- broadcast_to +- byte +- cauchy_ +- ccol_indices +- cdouble +- ceil +- ceil_ +- cfloat +- chalf +- char +- cholesky +- cholesky_inverse +- cholesky_solve +- chunk +- clamp +- clamp_ +- clamp_max +- clamp_max_ +- clamp_min +- clamp_min_ +- clip +- clip_ +- clone +- coalesce +- col_indices +- conj +- conj_physical +- conj_physical_ +- contiguous +- copy_ +- copysign +- copysign_ +- corrcoef +- cos +- cos_ +- cosh +- cosh_ +- count_nonzero +- cov +- cpu +- cross +- crow_indices +- cuda +- cummax +- cummin +- cumprod +- cumprod_ +- cumsum +- cumsum_ +- data_ptr +- deg2rad +- deg2rad_ +- dense_dim +- dequantize +- det +- detach +- detach_ +- diag +- diag_embed +- diagflat +- diagonal +- diagonal_scatter +- diff +- digamma +- digamma_ +- dim +- dim_order +- dist +- div +- div_ +- divide +- divide_ +- dot +- double +- dsplit +- element_size +- eq +- eq_ +- equal +- erf +- erf_ +- erfc +- erfc_ +- erfinv +- erfinv_ +- exp +- exp2 +- exp2_ +- exp_ +- expand +- expand_as +- expm1 +- expm1_ +- exponential_ +- fill_ +- fill_diagonal_ +- fix +- fix_ +- flatten +- flip +- fliplr +- flipud +- float +- float_power +- float_power_ +- floor +- floor_ +- floor_divide +- floor_divide_ +- fmax +- fmin +- fmod +- fmod_ +- frac +- frac_ +- frexp +- gather +- gcd +- gcd_ +- ge +- ge_ +- geometric_ +- geqrf +- ger +- get_device +- greater +- greater_ +- greater_equal +- greater_equal_ +- gt +- gt_ +- half +- hardshrink +- has_names +- heaviside +- heaviside_ +- histc +- histogram +- hsplit +- hypot +- hypot_ +- i0 +- i0_ +- igamma +- igamma_ +- igammac +- igammac_ +- index_add +- index_add_ +- index_copy +- index_copy_ +- index_fill +- index_fill_ +- index_put +- index_put_ +- index_reduce +- index_reduce_ +- index_select +- indices +- inner +- int +- int_repr +- inverse +- ipu +- is_coalesced +- is_complex +- is_conj +- is_contiguous +- is_distributed +- is_floating_point +- is_inference +- is_neg +- is_nonzero +- is_pinned +- is_same_size +- is_set_to +- is_shared +- is_signed +- isclose +- isfinite +- isinf +- isnan +- isneginf +- isposinf +- isreal +- istft +- item +- kron +- kthvalue +- lcm +- lcm_ +- ldexp +- ldexp_ +- le +- le_ +- lerp +- lerp_ +- less +- less_ +- less_equal +- less_equal_ +- lgamma +- lgamma_ +- log +- log10 +- log10_ +- log1p +- log1p_ +- log2 +- log2_ +- log_ +- log_normal_ +- log_softmax +- logaddexp +- logaddexp2 +- logcumsumexp +- logdet +- logical_and +- logical_and_ +- logical_not +- logical_not_ +- logical_or +- logical_or_ +- logical_xor +- logical_xor_ +- logit +- logit_ +- logsumexp +- long +- lt +- lt_ +- lu +- lu_solve +- map2_ +- map_ +- masked_fill +- masked_fill_ +- masked_scatter +- masked_scatter_ +- masked_select +- matmul +- matrix_exp +- matrix_power +- max +- maximum +- mean +- median +- min +- minimum +- mm +- mode +- module_load +- moveaxis +- movedim +- msort +- mul +- mul_ +- multinomial +- multiply +- multiply_ +- mv +- mvlgamma +- mvlgamma_ +- nan_to_num +- nan_to_num_ +- nanmean +- nanmedian +- nanquantile +- nansum +- narrow +- narrow_copy +- ndimension +- ne +- ne_ +- neg +- neg_ +- negative +- negative_ +- nelement +- nextafter +- nextafter_ +- nonzero +- nonzero_static +- norm +- normal_ +- not_equal +- not_equal_ +- numel +- numpy +- orgqr +- ormqr +- outer +- permute +- pin_memory +- pinverse +- polygamma +- polygamma_ +- positive +- pow +- pow_ +- prelu +- prod +- put +- put_ +- q_per_channel_axis +- q_per_channel_scales +- q_per_channel_zero_points +- q_scale +- q_zero_point +- qr +- qscheme +- quantile +- rad2deg +- rad2deg_ +- random_ +- ravel +- reciprocal +- reciprocal_ +- record_stream +- refine_names +- register_hook +- register_post_accumulate_grad_hook +- relu +- relu_ +- remainder +- remainder_ +- rename +- rename_ +- renorm +- renorm_ +- repeat +- repeat_interleave +- requires_grad_ +- reshape +- reshape_as +- resize +- resize_ +- resize_as +- resize_as_ +- resize_as_sparse_ +- resolve_conj +- resolve_neg +- retain_grad +- roll +- rot90 +- round +- round_ +- row_indices +- rsqrt +- rsqrt_ +- scatter +- scatter_ +- scatter_add +- scatter_add_ +- scatter_reduce +- scatter_reduce_ +- select +- select_scatter +- set_ +- sgn +- sgn_ +- share_memory_ +- short +- sigmoid +- sigmoid_ +- sign +- sign_ +- signbit +- sin +- sin_ +- sinc +- sinc_ +- sinh +- sinh_ +- size +- slice_inverse +- slice_scatter +- slogdet +- smm +- softmax +- sort +- sparse_dim +- sparse_mask +- sparse_resize_ +- sparse_resize_and_clear_ +- split +- split_with_sizes +- sqrt +- sqrt_ +- square +- square_ +- squeeze +- squeeze_ +- sspaddmm +- std +- stft +- storage +- storage_offset +- storage_type +- sub +- sub_ +- subtract +- subtract_ +- sum +- sum_to_size +- svd +- swapaxes +- swapaxes_ +- swapdims +- swapdims_ +- t +- t_ +- take +- take_along_dim +- tan +- tan_ +- tanh +- tanh_ +- tensor_split +- tile +- to +- to_dense +- to_mkldnn +- to_sparse +- tolist +- topk +- trace +- transpose +- transpose_ +- triangular_solve +- tril +- tril_ +- triu +- triu_ +- true_divide +- true_divide_ +- trunc +- trunc_ +- type +- type_as +- unbind +- unfold +- uniform_ +- unique +- unique_consecutive +- unsafe_chunk +- unsafe_split +- unsafe_split_with_sizes +- unsqueeze +- unsqueeze_ +- untyped_storage +- values +- var +- vdot +- view +- view_as +- vsplit +- where +- xlogy +- xlogy_ +- xpu +- zero_ +torch._lobpcg: +- lobpcg +torch._lowrank: +- pca_lowrank +- svd_lowrank +torch.fft: +- fft +- fft2 +- fft_fft +- fft_fft2 +- fft_fftn +- fft_fftshift +- fft_hfft +- fft_hfft2 +- fft_hfftn +- fft_ifft +- fft_ifft2 +- fft_ifftn +- fft_ifftshift +- fft_ihfft +- fft_ihfft2 +- fft_ihfftn +- fft_irfft +- fft_irfft2 +- fft_irfftn +- fft_rfft +- fft_rfft2 +- fft_rfftn +- fftn +- fftshift +- hfft +- hfft2 +- hfftn +- ifft +- ifft2 +- ifftn +- ifftshift +- ihfft +- ihfft2 +- ihfftn +- irfft +- irfft2 +- irfftn +- rfft +- rfft2 +- rfftn +torch.functional: +- _consecutive_return_inverse_false +- _consecutive_return_inverse_true +- _return_inverse_false +- _return_inverse_true +- atleast_1d +- atleast_2d +- atleast_3d +- block_diag +- broadcast_tensors +- cartesian_prod +- cdist +- chain_matmul +- einsum +- istft +- lu +- meshgrid +- norm +- pca_lowrank +- split +- stft +- svd_lowrank +- tensordot +- unique +- unique_consecutive +- unravel_index +torch.linalg: +- cholesky +- cholesky_ex +- cond +- cross +- det +- diagonal +- eig +- eigh +- eigvals +- eigvalsh +- householder_product +- inv +- inv_ex +- ldl_factor +- ldl_factor_ex +- ldl_solve +- linalg_cholesky +- linalg_cholesky_ex +- linalg_cond +- linalg_cross +- linalg_det +- linalg_diagonal +- linalg_eig +- linalg_eigh +- linalg_eigvals +- linalg_eigvalsh +- linalg_householder_product +- linalg_inv +- linalg_inv_ex +- linalg_ldl_factor +- linalg_ldl_factor_ex +- linalg_ldl_solve +- linalg_lstsq +- linalg_lu +- linalg_lu_factor +- linalg_lu_factor_ex +- linalg_lu_solve +- linalg_matmul +- linalg_matrix_exp +- linalg_matrix_norm +- linalg_matrix_power +- linalg_matrix_rank +- linalg_multi_dot +- linalg_norm +- linalg_pinv +- linalg_qr +- linalg_slogdet +- linalg_solve +- linalg_solve_ex +- linalg_solve_triangular +- linalg_svd +- linalg_svdvals +- linalg_tensorinv +- linalg_tensorsolve +- linalg_vander +- linalg_vecdot +- linalg_vector_norm +- lstsq +- lu +- lu_factor +- lu_factor_ex +- lu_solve +- matmul +- matrix_exp +- matrix_norm +- matrix_power +- matrix_rank +- multi_dot +- norm +- pinv +- qr +- slogdet +- solve +- solve_ex +- solve_triangular +- svd +- svdvals +- tensorinv +- tensorsolve +- vander +- vecdot +- vector_norm +torch.nn.functional: +- _threshold +- adaptive_avg_pool1d +- adaptive_avg_pool2d +- adaptive_avg_pool3d +- adaptive_max_pool1d +- adaptive_max_pool1d_with_indices +- adaptive_max_pool2d +- adaptive_max_pool2d_with_indices +- adaptive_max_pool3d +- adaptive_max_pool3d_with_indices +- affine_grid +- alpha_dropout +- avg_pool1d +- avg_pool2d +- avg_pool3d +- batch_norm +- bilinear +- binary_cross_entropy +- binary_cross_entropy_with_logits +- celu +- channel_shuffle +- conv1d +- conv2d +- conv3d +- conv_tbc +- conv_transpose1d +- conv_transpose2d +- conv_transpose3d +- cosine_embedding_loss +- cosine_similarity +- cross_entropy +- ctc_loss +- dropout +- dropout1d +- dropout2d +- dropout3d +- elu +- embedding +- embedding_bag +- feature_alpha_dropout +- fold +- fractional_max_pool2d +- fractional_max_pool2d_with_indices +- fractional_max_pool3d +- fractional_max_pool3d_with_indices +- gaussian_nll_loss +- gelu +- glu +- grid_sample +- group_norm +- gumbel_softmax +- hardshrink +- hardsigmoid +- hardswish +- hardtanh +- hinge_embedding_loss +- huber_loss +- instance_norm +- interpolate +- kl_div +- l1_loss +- layer_norm +- leaky_relu +- linear +- local_response_norm +- log_sigmoid +- log_softmax +- logsigmoid +- lp_pool1d +- lp_pool2d +- lp_pool3d +- margin_ranking_loss +- max_pool1d +- max_pool1d_with_indices +- max_pool2d +- max_pool2d_with_indices +- max_pool3d +- max_pool3d_with_indices +- max_unpool1d +- max_unpool2d +- max_unpool3d +- mish +- mse_loss +- multi_head_attention_forward +- multi_margin_loss +- multilabel_margin_loss +- multilabel_soft_margin_loss +- native_channel_shuffle +- nll_loss +- normalize +- one_hot +- pad +- pairwise_distance +- pdist +- pixel_shuffle +- pixel_unshuffle +- poisson_nll_loss +- prelu +- relu +- relu6 +- rms_norm +- rrelu +- scaled_dot_product_attention +- selu +- sigmoid +- silu +- smooth_l1_loss +- soft_margin_loss +- softmax +- softmin +- softplus +- softshrink +- softsign +- tanh +- tanhshrink +- threshold +- triplet_margin_loss +- triplet_margin_with_distance_loss +- unfold +- upsample +- upsample_bilinear +- upsample_nearest +torch.nn.init: +- constant_ +- kaiming_uniform_ +- normal_ +- uniform_ +torch.special: +- airy_ai +- bessel_j0 +- bessel_j1 +- bessel_y0 +- bessel_y1 +- chebyshev_polynomial_t +- chebyshev_polynomial_u +- chebyshev_polynomial_v +- chebyshev_polynomial_w +- digamma +- entr +- erf +- erfc +- erfcx +- erfinv +- exp2 +- expit +- expm1 +- gammainc +- gammaincc +- gammaln +- hermite_polynomial_h +- hermite_polynomial_he +- i0 +- i0e +- i1 +- i1e +- laguerre_polynomial_l +- legendre_polynomial_p +- log1p +- log_ndtr +- log_softmax +- logit +- logsumexp +- modified_bessel_i0 +- modified_bessel_i1 +- modified_bessel_k0 +- modified_bessel_k1 +- multigammaln +- ndtr +- ndtri +- polygamma +- psi +- round +- scaled_modified_bessel_k0 +- scaled_modified_bessel_k1 +- shifted_chebyshev_polynomial_t +- shifted_chebyshev_polynomial_u +- shifted_chebyshev_polynomial_v +- shifted_chebyshev_polynomial_w +- sinc +- softmax +- special_airy_ai +- special_bessel_j0 +- special_bessel_j1 +- special_bessel_y0 +- special_bessel_y1 +- special_chebyshev_polynomial_t +- special_chebyshev_polynomial_u +- special_chebyshev_polynomial_v +- special_chebyshev_polynomial_w +- special_digamma +- special_entr +- special_erf +- special_erfc +- special_erfcx +- special_erfinv +- special_exp2 +- special_expit +- special_expm1 +- special_gammainc +- special_gammaincc +- special_gammaln +- special_hermite_polynomial_h +- special_hermite_polynomial_he +- special_i0 +- special_i0e +- special_i1 +- special_i1e +- special_laguerre_polynomial_l +- special_legendre_polynomial_p +- special_log1p +- special_log_ndtr +- special_log_softmax +- special_logit +- special_logsumexp +- special_modified_bessel_i0 +- special_modified_bessel_i1 +- special_modified_bessel_k0 +- special_modified_bessel_k1 +- special_multigammaln +- special_ndtr +- special_ndtri +- special_polygamma +- special_psi +- special_round +- special_scaled_modified_bessel_k0 +- special_scaled_modified_bessel_k1 +- special_shifted_chebyshev_polynomial_t +- special_shifted_chebyshev_polynomial_u +- special_shifted_chebyshev_polynomial_v +- special_shifted_chebyshev_polynomial_w +- special_sinc +- special_softmax +- special_spherical_bessel_j0 +- special_xlog1py +- special_xlogy +- special_zeta +- spherical_bessel_j0 +- xlog1py +- xlogy +- zeta +torchvision.ops: +- batched_nms +- box_area +- box_convert +- box_iou +- clip_boxes_to_image +- complete_box_iou +- complete_box_iou_loss +- deform_conv2d +- distance_box_iou +- distance_box_iou_loss +- drop_block2d +- drop_block3d +- generalized_box_iou +- generalized_box_iou_loss +- masks_to_boxes +- nms +- ps_roi_align +- ps_roi_pool +- remove_small_boxes +- roi_align +- roi_pool +- sigmoid_focal_loss +- stochastic_depth diff --git a/tinynn/graph/configs/torch_module_override.yml b/tinynn/graph/configs/torch_module_override.yml index 1c2d59e2..2f55ac5a 100644 --- a/tinynn/graph/configs/torch_module_override.yml +++ b/tinynn/graph/configs/torch_module_override.yml @@ -65,6 +65,7 @@ torch.nn: - FractionalMaxPool3d - LPPool1d - LPPool2d +- LPPool3d - LocalResponseNorm - BatchNorm1d - BatchNorm2d @@ -74,8 +75,10 @@ torch.nn: - InstanceNorm3d - LayerNorm - GroupNorm +- RMSNorm - SyncBatchNorm - Dropout +- Dropout1d - Dropout2d - Dropout3d - AlphaDropout @@ -110,7 +113,9 @@ torch.nn: - AdaptiveAvgPool2d - AdaptiveAvgPool3d - TripletMarginLoss +- ZeroPad1d - ZeroPad2d +- ZeroPad3d - ConstantPad1d - ConstantPad2d - ConstantPad3d @@ -141,6 +146,9 @@ torch.nn: - Mish - TripletMarginWithDistanceLoss - ChannelShuffle +- CircularPad1d +- CircularPad2d +- CircularPad3d - UninitializedParameter - UninitializedBuffer torch.quantization: @@ -149,7 +157,13 @@ torch.quantization: - QuantWrapper torchvision.ops: - DeformConv2d +- DropBlock2d +- DropBlock3d +- Conv2dNormActivation +- Conv3dNormActivation - FrozenBatchNorm2d +- MLP +- Permute - MultiScaleRoIAlign - PSRoIAlign - PSRoIPool diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index d8cdfdd9..a19975ae 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -200,6 +200,7 @@ 'log': None, 'std': None, 'var': None, + 'norm': None, nn.LSTM: '1.13.0', nn.ConvTranspose2d: '1.7.0', nn.ConstantPad1d: '1.7.0', @@ -261,6 +262,10 @@ Q_MODULES_MAPPING.update({nn.SiLU: QSiLU}) FUNCTIONAL_MODULE_MAPPING.update({'silu': nn.SiLU}) +if hasattr(nn, 'RMSNorm'): + UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.update({nn.RMSNorm: None}) + FUNCTIONAL_MODULE_MAPPING.update({'rms_norm': nn.RMSNorm}) + # Processed QAT fuse rules processed_qat_rules = {} processed_ptq_rules = {}