Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gen_ai only build #2603

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,7 @@ def test_quantize_fp8_per_tensor_with_ub(

zq_ref = (x @ w.T).to(torch.bfloat16)
torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3)


if __name__ == "__main__":
unittest.main()
21 changes: 17 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,18 @@

try:
torch.ops.load_library(os.path.join(os.path.dirname(__file__), "fbgemm_gpu_py.so"))
except Exception as e:
print(e)
except Exception as error_ranking:
try:
torch.ops.load_library(
os.path.join(
os.path.dirname(__file__),
"experimental/gen_ai/fbgemm_gpu_experimental_gen_ai_py.so",
)
)
except Exception as error_gen_ai:
# When both ranking/gen_ai so files are not available, print the error logs
print(error_ranking)
print(error_gen_ai)

# Since __init__.py is only used in OSS context, we define `open_source` here
# and use its existence to determine whether or not we are in OSS context
Expand All @@ -24,5 +34,8 @@
# Export the version string from the version file auto-generated by setup.py
from fbgemm_gpu.docs.version import __version__ # noqa: F401, E402

# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402
try:
# Trigger meta operator registrations
from . import sparse_ops # noqa: F401, E402
except Exception:
pass
5 changes: 4 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
# LICENSE file in the root directory of this source tree.

# Trigger the manual addition of docstrings to pybind11-generated operators
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
try:
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
except Exception:
pass
Loading