diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml deleted file mode 100644 index 89bc747c..00000000 --- a/.github/workflows/publish.yml +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2023 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. ---- -name: "Publish Python Package" -on: - release: - types: [published] -permissions: read-all -jobs: - trigger-circleci: - runs-on: ubuntu-latest - steps: - - name: secretflow-spu-deploy - id: secretflow-spu-deploy - uses: CircleCI-Public/trigger-circleci-pipeline-action@v1.2.0 - env: - CCI_TOKEN: ${{ secrets.CCI_TOKEN }} diff --git a/WORKSPACE b/WORKSPACE index b3d86d68..207eb359 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -35,6 +35,13 @@ load("@rules_python//python:repositories.bzl", "py_repositories") py_repositories() +load("@pybind11_bazel//:python_configure.bzl", "python_configure") + +python_configure( + name = "local_config_python", + python_version = "3", +) + load( "@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies", @@ -72,13 +79,6 @@ load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() -load("@pybind11_bazel//:python_configure.bzl", "python_configure") - -python_configure( - name = "local_config_python", - python_version = "3", -) - load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_repos", "rules_proto_grpc_toolchains") rules_proto_grpc_toolchains() diff --git a/bazel/patches/xla-non-hermetic-python.patch b/bazel/patches/xla-non-hermetic-python.patch new file mode 100644 index 00000000..ac1b0cc0 --- /dev/null +++ b/bazel/patches/xla-non-hermetic-python.patch @@ -0,0 +1,786 @@ +diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl +index 7cc1e08568..45480bd4a3 100644 +--- a/third_party/py/BUILD.tpl ++++ b/third_party/py/BUILD.tpl +@@ -5,17 +5,16 @@ package(default_visibility = ["//visibility:public"]) + # Point both runtimes to the same python binary to ensure we always + # use the python binary specified by ./configure.py script. + load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") +-load("@python//:defs.bzl", "interpreter") + + py_runtime( + name = "py2_runtime", +- interpreter_path = interpreter, ++ interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY2", + ) + + py_runtime( + name = "py3_runtime", +- interpreter_path = interpreter, ++ interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY3", + ) + +@@ -33,8 +32,27 @@ toolchain( + exec_compatible_with = [%{PLATFORM_CONSTRAINT}], + ) + +-alias(name = "python_headers", +- actual = "@python//:python_headers") ++# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib ++# See https://docs.python.org/3/extending/windows.html ++cc_import( ++ name = "python_lib", ++ interface_library = select({ ++ ":windows": ":python_import_lib", ++ # A placeholder for Unix platforms which makes --no_build happy. ++ "//conditions:default": "not-existing.lib", ++ }), ++ system_provided = 1, ++) ++ ++cc_library( ++ name = "python_headers", ++ hdrs = [":python_include"], ++ deps = select({ ++ ":windows": [":python_lib"], ++ "//conditions:default": [], ++ }), ++ includes = ["python_include"], ++) + + # This alias is exists for the use of targets in the @llvm-project dependency, + # which expect a python_headers target called @python_runtime//:headers. We use +@@ -45,9 +63,18 @@ alias( + actual = ":python_headers", + ) + ++cc_library( ++ name = "numpy_headers", ++ hdrs = [":numpy_include"], ++ includes = ["numpy_include"], ++) + + config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +-) +\ No newline at end of file ++) ++ ++%{PYTHON_INCLUDE_GENRULE} ++%{NUMPY_INCLUDE_GENRULE} ++%{PYTHON_IMPORT_LIB_GENRULE} +\ No newline at end of file +diff --git a/third_party/py/numpy/BUILD b/third_party/py/numpy/BUILD +index 97c7907fc3..c80cc5287b 100644 +--- a/third_party/py/numpy/BUILD ++++ b/third_party/py/numpy/BUILD +@@ -2,14 +2,15 @@ licenses(["restricted"]) + + package(default_visibility = ["//visibility:public"]) + +-alias( ++py_library( + name = "numpy", +- actual = "@pypi_numpy//:pkg", ++ srcs = ["tf_numpy_dummy.py"], ++ srcs_version = "PY3", + ) + + alias( + name = "headers", +- actual = "@pypi_numpy//:numpy_headers", ++ actual = "@local_config_python//:numpy_headers", + ) + + genrule( +diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl +index 3728a91b93..89732c3e33 100644 +--- a/third_party/py/python_configure.bzl ++++ b/third_party/py/python_configure.bzl +@@ -1,4 +1,9 @@ + """Repository rule for Python autoconfiguration. ++ ++`python_configure` depends on the following environment variables: ++ ++ * `PYTHON_BIN_PATH`: location of python binary. ++ * `PYTHON_LIB_PATH`: Location of python libraries. + """ + + load( +@@ -6,8 +11,192 @@ load( + "BAZEL_SH", + "PYTHON_BIN_PATH", + "PYTHON_LIB_PATH", ++ "TF_PYTHON_CONFIG_REPO", ++ "auto_config_fail", ++ "config_repo_label", ++ "execute", ++ "get_bash_bin", ++ "get_host_environ", ++ "get_python_bin", ++ "is_windows", ++ "raw_exec", ++ "read_dir", + ) + ++def _genrule(src_dir, genrule_name, command, outs): ++ """Returns a string with a genrule. ++ ++ Genrule executes the given command and produces the given outputs. ++ """ ++ return ( ++ "genrule(\n" + ++ ' name = "' + ++ genrule_name + '",\n' + ++ " outs = [\n" + ++ outs + ++ "\n ],\n" + ++ ' cmd = """\n' + ++ command + ++ '\n """,\n' + ++ ")\n" ++ ) ++ ++def _norm_path(path): ++ """Returns a path with '/' and remove the trailing slash.""" ++ path = path.replace("\\", "/") ++ if path[-1] == "/": ++ path = path[:-1] ++ return path ++ ++def _symlink_genrule_for_dir( ++ repository_ctx, ++ src_dir, ++ dest_dir, ++ genrule_name, ++ src_files = [], ++ dest_files = []): ++ """Returns a genrule to symlink(or copy if on Windows) a set of files. ++ ++ If src_dir is passed, files will be read from the given directory; otherwise ++ we assume files are in src_files and dest_files ++ """ ++ if src_dir != None: ++ src_dir = _norm_path(src_dir) ++ dest_dir = _norm_path(dest_dir) ++ files = "\n".join(read_dir(repository_ctx, src_dir)) ++ ++ # Create a list with the src_dir stripped to use for outputs. ++ dest_files = files.replace(src_dir, "").splitlines() ++ src_files = files.splitlines() ++ command = [] ++ outs = [] ++ for i in range(len(dest_files)): ++ if dest_files[i] != "": ++ # If we have only one file to link we do not want to use the dest_dir, as ++ # $(@D) will include the full path to the file. ++ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] ++ ++ # Copy the headers to create a sandboxable setup. ++ cmd = "cp -f" ++ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) ++ outs.append(' "' + dest_dir + dest_files[i] + '",') ++ genrule = _genrule( ++ src_dir, ++ genrule_name, ++ " && ".join(command), ++ "\n".join(outs), ++ ) ++ return genrule ++ ++def _get_python_lib(repository_ctx, python_bin): ++ """Gets the python lib path.""" ++ python_lib = get_host_environ(repository_ctx, PYTHON_LIB_PATH) ++ if python_lib != None: ++ return python_lib ++ ++ # The interesting program to execute. ++ print_lib = [ ++ "from __future__ import print_function", ++ "import site", ++ "import os", ++ "python_paths = []", ++ "if os.getenv('PYTHONPATH') is not None:", ++ " python_paths = os.getenv('PYTHONPATH').split(':')", ++ "try:", ++ " library_paths = site.getsitepackages()", ++ "except AttributeError:", ++ " from distutils.sysconfig import get_python_lib", ++ " library_paths = [get_python_lib()]", ++ "all_paths = set(python_paths + library_paths)", ++ "paths = []", ++ "for path in all_paths:", ++ " if os.path.isdir(path):", ++ " paths.append(path)", ++ "if len(paths) >=1:", ++ " print(paths[0])", ++ ] ++ ++ # The below script writes the above program to a file ++ # and executes it. This is to work around the limitation ++ # of not being able to upload files as part of execute. ++ cmd = "from os import linesep;" ++ cmd += "f = open('script.py', 'w');" ++ for line in print_lib: ++ cmd += "f.write(\"%s\" + linesep);" % line ++ cmd += "f.close();" ++ cmd += "from subprocess import call;" ++ cmd += "call([\"%s\", \"script.py\"]);" % python_bin ++ ++ result = execute(repository_ctx, [python_bin, "-c", cmd]) ++ return result.stdout.strip() ++ ++def _check_python_lib(repository_ctx, python_lib): ++ """Checks the python lib path.""" ++ cmd = 'test -d "%s" -a -x "%s"' % (python_lib, python_lib) ++ result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd]) ++ if result.return_code == 1: ++ auto_config_fail("Invalid python library path: %s" % python_lib) ++ ++def _check_python_bin(repository_ctx, python_bin): ++ """Checks the python bin path.""" ++ cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin) ++ result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd]) ++ if result.return_code == 1: ++ auto_config_fail("--define %s='%s' is not executable. Is it the python binary?" % ( ++ PYTHON_BIN_PATH, ++ python_bin, ++ )) ++ ++def _get_python_include(repository_ctx, python_bin): ++ """Gets the python include path.""" ++ result = execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-Wignore", ++ "-c", ++ "import sysconfig; " + ++ "print(sysconfig.get_path('include'))", ++ ], ++ error_msg = "Problem getting python include path.", ++ error_details = ("Is the Python binary path set up right? " + ++ "(See ./configure or " + PYTHON_BIN_PATH + ".) " + ++ "Is distutils installed?"), ++ ) ++ return result.stdout.splitlines()[0] ++ ++def _get_python_import_lib_name(repository_ctx, python_bin): ++ """Get Python import library name (pythonXY.lib) on Windows.""" ++ result = execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-c", ++ "import sys;" + ++ 'print("python" + str(sys.version_info[0]) + ' + ++ ' str(sys.version_info[1]) + ".lib")', ++ ], ++ error_msg = "Problem getting python import library.", ++ error_details = ("Is the Python binary path set up right? " + ++ "(See ./configure or " + PYTHON_BIN_PATH + ".) "), ++ ) ++ return result.stdout.splitlines()[0] ++ ++def _get_numpy_include(repository_ctx, python_bin): ++ """Gets the numpy include path.""" ++ return execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-c", ++ "from __future__ import print_function;" + ++ "import numpy;" + ++ " print(numpy.get_include());", ++ ], ++ error_msg = "Problem getting numpy include path.", ++ error_details = "Is numpy installed?", ++ ).stdout.splitlines()[0] ++ + def _create_local_python_repository(repository_ctx): + """Creates the repository containing files set up to build with Python.""" + +@@ -15,14 +204,68 @@ def _create_local_python_repository(repository_ctx): + # function to be restarted with all previous state being lost. This + # can easily lead to a O(n^2) runtime in the number of labels. + build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl")) ++ ++ python_bin = get_python_bin(repository_ctx) ++ _check_python_bin(repository_ctx, python_bin) ++ python_lib = _get_python_lib(repository_ctx, python_bin) ++ _check_python_lib(repository_ctx, python_lib) ++ python_include = _get_python_include(repository_ctx, python_bin) ++ numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy" ++ python_include_rule = _symlink_genrule_for_dir( ++ repository_ctx, ++ python_include, ++ "python_include", ++ "python_include", ++ ) ++ python_import_lib_genrule = "" ++ ++ # To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib ++ # See https://docs.python.org/3/extending/windows.html ++ if is_windows(repository_ctx): ++ python_bin = python_bin.replace("\\", "/") ++ python_include = _norm_path(python_include) ++ python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin) ++ python_import_lib_src = python_include.rsplit("/", 1)[0] + "/libs/" + python_import_lib_name ++ python_import_lib_genrule = _symlink_genrule_for_dir( ++ repository_ctx, ++ None, ++ "", ++ "python_import_lib", ++ [python_import_lib_src], ++ [python_import_lib_name], ++ ) ++ numpy_include_rule = _symlink_genrule_for_dir( ++ repository_ctx, ++ numpy_include, ++ "numpy_include/numpy", ++ "numpy_include", ++ ) ++ + platform_constraint = "" + if repository_ctx.attr.platform_constraint: + platform_constraint = "\"%s\"" % repository_ctx.attr.platform_constraint +- repository_ctx.template("BUILD", build_tpl, {"%{PLATFORM_CONSTRAINT}": platform_constraint}) ++ repository_ctx.template("BUILD", build_tpl, { ++ "%{PYTHON_BIN_PATH}": python_bin, ++ "%{PYTHON_INCLUDE_GENRULE}": python_include_rule, ++ "%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule, ++ "%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule, ++ "%{PLATFORM_CONSTRAINT}": platform_constraint, ++ }) ++ ++def _create_remote_python_repository(repository_ctx, remote_config_repo): ++ """Creates pointers to a remotely configured repo set up to build with Python. ++ """ ++ repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {}) + + def _python_autoconf_impl(repository_ctx): + """Implementation of the python_autoconf repository rule.""" +- _create_local_python_repository(repository_ctx) ++ if get_host_environ(repository_ctx, TF_PYTHON_CONFIG_REPO) != None: ++ _create_remote_python_repository( ++ repository_ctx, ++ get_host_environ(repository_ctx, TF_PYTHON_CONFIG_REPO), ++ ) ++ else: ++ _create_local_python_repository(repository_ctx) + + _ENVIRONS = [ + BAZEL_SH, +@@ -32,6 +275,7 @@ _ENVIRONS = [ + + local_python_configure = repository_rule( + implementation = _create_local_python_repository, ++ environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "platform_constraint": attr.string(), +@@ -50,6 +294,7 @@ remote_python_configure = repository_rule( + + python_configure = repository_rule( + implementation = _python_autoconf_impl, ++ environ = _ENVIRONS + [TF_PYTHON_CONFIG_REPO], + attrs = { + "platform_constraint": attr.string(), + }, +diff --git a/third_party/tsl/third_party/py/BUILD.tpl b/third_party/tsl/third_party/py/BUILD.tpl +index 7cc1e08568..45480bd4a3 100644 +--- a/third_party/tsl/third_party/py/BUILD.tpl ++++ b/third_party/tsl/third_party/py/BUILD.tpl +@@ -5,17 +5,16 @@ package(default_visibility = ["//visibility:public"]) + # Point both runtimes to the same python binary to ensure we always + # use the python binary specified by ./configure.py script. + load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") +-load("@python//:defs.bzl", "interpreter") + + py_runtime( + name = "py2_runtime", +- interpreter_path = interpreter, ++ interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY2", + ) + + py_runtime( + name = "py3_runtime", +- interpreter_path = interpreter, ++ interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY3", + ) + +@@ -33,8 +32,27 @@ toolchain( + exec_compatible_with = [%{PLATFORM_CONSTRAINT}], + ) + +-alias(name = "python_headers", +- actual = "@python//:python_headers") ++# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib ++# See https://docs.python.org/3/extending/windows.html ++cc_import( ++ name = "python_lib", ++ interface_library = select({ ++ ":windows": ":python_import_lib", ++ # A placeholder for Unix platforms which makes --no_build happy. ++ "//conditions:default": "not-existing.lib", ++ }), ++ system_provided = 1, ++) ++ ++cc_library( ++ name = "python_headers", ++ hdrs = [":python_include"], ++ deps = select({ ++ ":windows": [":python_lib"], ++ "//conditions:default": [], ++ }), ++ includes = ["python_include"], ++) + + # This alias is exists for the use of targets in the @llvm-project dependency, + # which expect a python_headers target called @python_runtime//:headers. We use +@@ -45,9 +63,18 @@ alias( + actual = ":python_headers", + ) + ++cc_library( ++ name = "numpy_headers", ++ hdrs = [":numpy_include"], ++ includes = ["numpy_include"], ++) + + config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +-) +\ No newline at end of file ++) ++ ++%{PYTHON_INCLUDE_GENRULE} ++%{NUMPY_INCLUDE_GENRULE} ++%{PYTHON_IMPORT_LIB_GENRULE} +\ No newline at end of file +diff --git a/third_party/tsl/third_party/py/numpy/BUILD b/third_party/tsl/third_party/py/numpy/BUILD +index 97c7907fc3..c80cc5287b 100644 +--- a/third_party/tsl/third_party/py/numpy/BUILD ++++ b/third_party/tsl/third_party/py/numpy/BUILD +@@ -2,14 +2,15 @@ licenses(["restricted"]) + + package(default_visibility = ["//visibility:public"]) + +-alias( ++py_library( + name = "numpy", +- actual = "@pypi_numpy//:pkg", ++ srcs = ["tf_numpy_dummy.py"], ++ srcs_version = "PY3", + ) + + alias( + name = "headers", +- actual = "@pypi_numpy//:numpy_headers", ++ actual = "@local_config_python//:numpy_headers", + ) + + genrule( +diff --git a/third_party/tsl/third_party/py/python_configure.bzl b/third_party/tsl/third_party/py/python_configure.bzl +index 3728a91b93..89732c3e33 100644 +--- a/third_party/tsl/third_party/py/python_configure.bzl ++++ b/third_party/tsl/third_party/py/python_configure.bzl +@@ -1,4 +1,9 @@ + """Repository rule for Python autoconfiguration. ++ ++`python_configure` depends on the following environment variables: ++ ++ * `PYTHON_BIN_PATH`: location of python binary. ++ * `PYTHON_LIB_PATH`: Location of python libraries. + """ + + load( +@@ -6,8 +11,192 @@ load( + "BAZEL_SH", + "PYTHON_BIN_PATH", + "PYTHON_LIB_PATH", ++ "TF_PYTHON_CONFIG_REPO", ++ "auto_config_fail", ++ "config_repo_label", ++ "execute", ++ "get_bash_bin", ++ "get_host_environ", ++ "get_python_bin", ++ "is_windows", ++ "raw_exec", ++ "read_dir", + ) + ++def _genrule(src_dir, genrule_name, command, outs): ++ """Returns a string with a genrule. ++ ++ Genrule executes the given command and produces the given outputs. ++ """ ++ return ( ++ "genrule(\n" + ++ ' name = "' + ++ genrule_name + '",\n' + ++ " outs = [\n" + ++ outs + ++ "\n ],\n" + ++ ' cmd = """\n' + ++ command + ++ '\n """,\n' + ++ ")\n" ++ ) ++ ++def _norm_path(path): ++ """Returns a path with '/' and remove the trailing slash.""" ++ path = path.replace("\\", "/") ++ if path[-1] == "/": ++ path = path[:-1] ++ return path ++ ++def _symlink_genrule_for_dir( ++ repository_ctx, ++ src_dir, ++ dest_dir, ++ genrule_name, ++ src_files = [], ++ dest_files = []): ++ """Returns a genrule to symlink(or copy if on Windows) a set of files. ++ ++ If src_dir is passed, files will be read from the given directory; otherwise ++ we assume files are in src_files and dest_files ++ """ ++ if src_dir != None: ++ src_dir = _norm_path(src_dir) ++ dest_dir = _norm_path(dest_dir) ++ files = "\n".join(read_dir(repository_ctx, src_dir)) ++ ++ # Create a list with the src_dir stripped to use for outputs. ++ dest_files = files.replace(src_dir, "").splitlines() ++ src_files = files.splitlines() ++ command = [] ++ outs = [] ++ for i in range(len(dest_files)): ++ if dest_files[i] != "": ++ # If we have only one file to link we do not want to use the dest_dir, as ++ # $(@D) will include the full path to the file. ++ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] ++ ++ # Copy the headers to create a sandboxable setup. ++ cmd = "cp -f" ++ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) ++ outs.append(' "' + dest_dir + dest_files[i] + '",') ++ genrule = _genrule( ++ src_dir, ++ genrule_name, ++ " && ".join(command), ++ "\n".join(outs), ++ ) ++ return genrule ++ ++def _get_python_lib(repository_ctx, python_bin): ++ """Gets the python lib path.""" ++ python_lib = get_host_environ(repository_ctx, PYTHON_LIB_PATH) ++ if python_lib != None: ++ return python_lib ++ ++ # The interesting program to execute. ++ print_lib = [ ++ "from __future__ import print_function", ++ "import site", ++ "import os", ++ "python_paths = []", ++ "if os.getenv('PYTHONPATH') is not None:", ++ " python_paths = os.getenv('PYTHONPATH').split(':')", ++ "try:", ++ " library_paths = site.getsitepackages()", ++ "except AttributeError:", ++ " from distutils.sysconfig import get_python_lib", ++ " library_paths = [get_python_lib()]", ++ "all_paths = set(python_paths + library_paths)", ++ "paths = []", ++ "for path in all_paths:", ++ " if os.path.isdir(path):", ++ " paths.append(path)", ++ "if len(paths) >=1:", ++ " print(paths[0])", ++ ] ++ ++ # The below script writes the above program to a file ++ # and executes it. This is to work around the limitation ++ # of not being able to upload files as part of execute. ++ cmd = "from os import linesep;" ++ cmd += "f = open('script.py', 'w');" ++ for line in print_lib: ++ cmd += "f.write(\"%s\" + linesep);" % line ++ cmd += "f.close();" ++ cmd += "from subprocess import call;" ++ cmd += "call([\"%s\", \"script.py\"]);" % python_bin ++ ++ result = execute(repository_ctx, [python_bin, "-c", cmd]) ++ return result.stdout.strip() ++ ++def _check_python_lib(repository_ctx, python_lib): ++ """Checks the python lib path.""" ++ cmd = 'test -d "%s" -a -x "%s"' % (python_lib, python_lib) ++ result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd]) ++ if result.return_code == 1: ++ auto_config_fail("Invalid python library path: %s" % python_lib) ++ ++def _check_python_bin(repository_ctx, python_bin): ++ """Checks the python bin path.""" ++ cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin) ++ result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd]) ++ if result.return_code == 1: ++ auto_config_fail("--define %s='%s' is not executable. Is it the python binary?" % ( ++ PYTHON_BIN_PATH, ++ python_bin, ++ )) ++ ++def _get_python_include(repository_ctx, python_bin): ++ """Gets the python include path.""" ++ result = execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-Wignore", ++ "-c", ++ "import sysconfig; " + ++ "print(sysconfig.get_path('include'))", ++ ], ++ error_msg = "Problem getting python include path.", ++ error_details = ("Is the Python binary path set up right? " + ++ "(See ./configure or " + PYTHON_BIN_PATH + ".) " + ++ "Is distutils installed?"), ++ ) ++ return result.stdout.splitlines()[0] ++ ++def _get_python_import_lib_name(repository_ctx, python_bin): ++ """Get Python import library name (pythonXY.lib) on Windows.""" ++ result = execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-c", ++ "import sys;" + ++ 'print("python" + str(sys.version_info[0]) + ' + ++ ' str(sys.version_info[1]) + ".lib")', ++ ], ++ error_msg = "Problem getting python import library.", ++ error_details = ("Is the Python binary path set up right? " + ++ "(See ./configure or " + PYTHON_BIN_PATH + ".) "), ++ ) ++ return result.stdout.splitlines()[0] ++ ++def _get_numpy_include(repository_ctx, python_bin): ++ """Gets the numpy include path.""" ++ return execute( ++ repository_ctx, ++ [ ++ python_bin, ++ "-c", ++ "from __future__ import print_function;" + ++ "import numpy;" + ++ " print(numpy.get_include());", ++ ], ++ error_msg = "Problem getting numpy include path.", ++ error_details = "Is numpy installed?", ++ ).stdout.splitlines()[0] ++ + def _create_local_python_repository(repository_ctx): + """Creates the repository containing files set up to build with Python.""" + +@@ -15,14 +204,68 @@ def _create_local_python_repository(repository_ctx): + # function to be restarted with all previous state being lost. This + # can easily lead to a O(n^2) runtime in the number of labels. + build_tpl = repository_ctx.path(Label("//third_party/py:BUILD.tpl")) ++ ++ python_bin = get_python_bin(repository_ctx) ++ _check_python_bin(repository_ctx, python_bin) ++ python_lib = _get_python_lib(repository_ctx, python_bin) ++ _check_python_lib(repository_ctx, python_lib) ++ python_include = _get_python_include(repository_ctx, python_bin) ++ numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy" ++ python_include_rule = _symlink_genrule_for_dir( ++ repository_ctx, ++ python_include, ++ "python_include", ++ "python_include", ++ ) ++ python_import_lib_genrule = "" ++ ++ # To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib ++ # See https://docs.python.org/3/extending/windows.html ++ if is_windows(repository_ctx): ++ python_bin = python_bin.replace("\\", "/") ++ python_include = _norm_path(python_include) ++ python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin) ++ python_import_lib_src = python_include.rsplit("/", 1)[0] + "/libs/" + python_import_lib_name ++ python_import_lib_genrule = _symlink_genrule_for_dir( ++ repository_ctx, ++ None, ++ "", ++ "python_import_lib", ++ [python_import_lib_src], ++ [python_import_lib_name], ++ ) ++ numpy_include_rule = _symlink_genrule_for_dir( ++ repository_ctx, ++ numpy_include, ++ "numpy_include/numpy", ++ "numpy_include", ++ ) ++ + platform_constraint = "" + if repository_ctx.attr.platform_constraint: + platform_constraint = "\"%s\"" % repository_ctx.attr.platform_constraint +- repository_ctx.template("BUILD", build_tpl, {"%{PLATFORM_CONSTRAINT}": platform_constraint}) ++ repository_ctx.template("BUILD", build_tpl, { ++ "%{PYTHON_BIN_PATH}": python_bin, ++ "%{PYTHON_INCLUDE_GENRULE}": python_include_rule, ++ "%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule, ++ "%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule, ++ "%{PLATFORM_CONSTRAINT}": platform_constraint, ++ }) ++ ++def _create_remote_python_repository(repository_ctx, remote_config_repo): ++ """Creates pointers to a remotely configured repo set up to build with Python. ++ """ ++ repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {}) + + def _python_autoconf_impl(repository_ctx): + """Implementation of the python_autoconf repository rule.""" +- _create_local_python_repository(repository_ctx) ++ if get_host_environ(repository_ctx, TF_PYTHON_CONFIG_REPO) != None: ++ _create_remote_python_repository( ++ repository_ctx, ++ get_host_environ(repository_ctx, TF_PYTHON_CONFIG_REPO), ++ ) ++ else: ++ _create_local_python_repository(repository_ctx) + + _ENVIRONS = [ + BAZEL_SH, +@@ -32,6 +275,7 @@ _ENVIRONS = [ + + local_python_configure = repository_rule( + implementation = _create_local_python_repository, ++ environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "platform_constraint": attr.string(), +@@ -50,6 +294,7 @@ remote_python_configure = repository_rule( + + python_configure = repository_rule( + implementation = _python_autoconf_impl, ++ environ = _ENVIRONS + [TF_PYTHON_CONFIG_REPO], + attrs = { + "platform_constraint": attr.string(), + }, diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 927e0eea..7ad40622 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -16,6 +16,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") def spu_deps(): + _bazel_skylib() _rules_cuda() _rules_proto_grpc() _bazel_platform() @@ -123,9 +124,20 @@ def _com_github_xtensor_xtl(): ], ) +def _bazel_skylib(): + maybe( + http_archive, + name = "bazel_skylib", + sha256 = "9f38886a40548c6e96c106b752f242130ee11aaa068a56ba7e56f4511f33e4f2", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.6.1/bazel-skylib-1.6.1.tar.gz", + "https://github.com/bazelbuild/bazel-skylib/releases/download/1.6.1/bazel-skylib-1.6.1.tar.gz", + ], + ) + def _com_github_openxla_xla(): - OPENXLA_COMMIT = "5f70248ff0e9702544c8eeea0ab9b03e1ef144b0" - OPENXLA_SHA256 = "e2db58c41b7160259e0ec109ecbfc9c4b07c0889312719a19796ab30a970ba9e" + OPENXLA_COMMIT = "d9d0e780ff6a37c4d501c8e0e4f4a9fdca30cbd4" + OPENXLA_SHA256 = "77ef83491f409afbe549a2bd695d710a70fdf7f04db35eeb1fba3e97ef767113" # We need openxla to handle xla/mhlo/stablehlo maybe( @@ -137,6 +149,8 @@ def _com_github_openxla_xla(): urls = [ "https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT), ], + patch_args = ["-p1", "-l"], + patches = ["@spulib//bazel:patches/xla-non-hermetic-python.patch"], ) def _com_github_pybind11_bazel(): diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 1d824aaa..0865e964 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -1235,14 +1235,16 @@ static void dispatchOp(OpExecutor *executor, SPUContext *sctx, const auto fn_name = op.getName().getStringRef().str(); if constexpr (std::is_same_v) { + // trace action holds RAII, we can not put it in a single scope SPU_TRACE_ACTION( GET_TRACER(sctx), sctx->lctx(), (TR_HLO | TR_LAR), ~TR_HLO, fmt::format("{}: {}", fn_name, casted.getCallTargetName().str())); + execute(executor, sctx, sscope, casted, opts); } else { SPU_TRACE_ACTION(GET_TRACER(sctx), sctx->lctx(), (TR_HLO | TR_LAR), ~TR_HLO, fn_name); + execute(executor, sctx, sscope, casted, opts); } - execute(executor, sctx, sscope, casted, opts); } // currently we only support config verifier statically. diff --git a/libspu/dialect/pphlo/print_parse.cc b/libspu/dialect/pphlo/print_parse.cc index 27b02c35..2865db59 100644 --- a/libspu/dialect/pphlo/print_parse.cc +++ b/libspu/dialect/pphlo/print_parse.cc @@ -333,7 +333,7 @@ ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { StringRef innerOpName = innerOpNameInfo->getStringRef(); Dialect* innerOpDialect = innerOpNameInfo->getDialect(); if ((innerOpDialect == nullptr) || - !innerOpDialect->getNamespace().equals("pphlo") || + !(innerOpDialect->getNamespace() == "pphlo") || !innerOpNameInfo->hasTrait::Impl>() || !innerOpNameInfo->hasTrait() || !innerOpNameInfo->hasTrait() || diff --git a/requirements-dev.txt b/requirements-dev.txt index 906b235f..3f5993ec 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,5 +3,5 @@ flax scikit-learn # for tests absl-py>=1.1.0 -tensorflow>=2.12.0 +tensorflow-cpu>=2.12.0 h5py!=3.11.0; platform_machine == 'aarch64' diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index 1369c40b..a900040b 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -276,8 +276,10 @@ def torch_compile( assert isinstance( fn, torch.export.ExportedProgram ), "input should be an exported torch model" - os.environ['PJRT_DEVICE'] = 'CPU' + # remove xla flags imported by torch-xla + os.unsetenv("XLA_FLAGS") + options = stablehlo.StableHLOExportOptions() options.override_tracing_arguments = m_args_flat shlo = stablehlo.exported_program_to_stablehlo(fn, options)