Skip to content

Commit

Permalink
repo-sync-2024-03-08T18:25:24+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Mar 8, 2024
1 parent 9553904 commit cc55fc2
Show file tree
Hide file tree
Showing 115 changed files with 3,690 additions and 1,929 deletions.
40 changes: 0 additions & 40 deletions .circleci/release-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,6 @@ parameters:
# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
jobs:
macOS_x64_publish:
macos:
xcode: 15.1
resource_class: macos.x86.medium.gen2
parameters:
python_ver:
type: string
steps:
- checkout
- run:
name: "Install homebrew dependencies"
command: |
brew install bazelisk cmake ninja nasm libomp wget go
- run:
name: "Install Miniconda"
command: |
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O ~/miniconda.sh
bash ~/miniconda.sh -b -p $HOME/miniconda
source $HOME/miniconda/bin/activate
conda init zsh bash
- run:
name: "build package and publish"
command: |
set +e
conda create -n build python=<< parameters.python_ver >> -y
conda activate build
sh ./build_wheel_entrypoint.sh
python3 -m pip install twine
ls dist/*.whl
python3 -m twine upload -r pypi -u __token__ -p ${PYPI_TWINE_TOKEN} dist/*.whl
macOS_arm64_publish:
macos:
xcode: 15.1
Expand Down Expand Up @@ -158,14 +126,6 @@ workflows:
filters:
tags:
only: /.*/
- macOS_x64_publish:
matrix:
parameters:
python_ver: ["3.9", "3.10", "3.11"]
# This is mandatory to trigger a pipeline when pushing a tag
filters:
tags:
only: /.*/
- macOS_arm64_publish:
matrix:
parameters:
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
>
> please add your unreleased change here.
- [Feature] Add minimax approximation for log
- [Feature] Support jax.lax.top_k
- [Improvement] Default log approximation to minmax
- [Improvement] Improve median performance

## 20240306

- [Feature] Support more generic Torch model inference
- [Improvement] Optimize one-time setup for yacl ot
- [Improvement] Optimize sort performance
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ This documentation also contains instructions for [build and testing](CONTRIBUTI
| | Linux x86_64 | Linux aarch64 | macOS x64 | macOS Apple Silicon | Windows x64 | Windows WSL2 x64 |
|------------|--------------|---------------|----------------|---------------------|----------------|---------------------|
| CPU | yes | yes | yes<sup>1</sup>| yes | no | yes |
| NVIDIA GPU | experimental | no | no | n/a | no | experimental |
| NVIDIA GPU | experimental | no | no | n/a | no | experimental |

1. Due to CI resource limitation, macOS x64 prebuild binary will no longer available since next release (0.9.x).
1. Due to CI resource limitation, macOS x64 prebuild binary is no longer available.

### Instructions

Expand Down
16 changes: 15 additions & 1 deletion bazel/patches/seal.patch
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,18 @@ index dabd3bab..afaa71dc 100644
+ else
{
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
}
}

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1a7a2bfd..bc4ad9d9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -223,7 +223,7 @@ if(SEAL_USE_INTEL_HEXL)
message(STATUS "Intel HEXL: download ...")
seal_fetch_thirdparty_content(ExternalIntelHEXL)
else()
- find_package(HEXL 1.2.4)
+ find_package(HEXL 1.2.5)
if (NOT TARGET HEXL::hexl)
message(FATAL_ERROR "Intel HEXL: not found")
endif()
4 changes: 2 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "d1cf2382e57b1efba3bb17d6dd9d8657453405ca"
OPENXLA_SHA256 = "a7f439d54a4e35c7977c2ea17b3a2493b306c9629ccc8071b4962c905ac9f692"
OPENXLA_COMMIT = "495516d2d0b4453d5831905e152594614c8b4797"
OPENXLA_SHA256 = "13f6490065db594c6a7f9914e59213b6785ceb81af1f2cb28d5409f3f18aac8e"

maybe(
http_archive,
Expand Down
10 changes: 8 additions & 2 deletions bazel/seal.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ x64_hexl_config = {

spu_cmake_external(
name = "seal",
cache_entries = default_config,
cache_entries = select({
":can_use_hexl": x64_hexl_config,
"//conditions:default": default_config,
}),
lib_source = "@com_github_microsoft_seal//:all",
out_include_dir = "include/SEAL-4.1",
out_static_libs = ["libseal-4.1.a"],
deps = [
"@com_github_facebook_zstd//:zstd",
],
] + select({
"@platforms//cpu:x86_64": ["@com_intel_hexl//:hexl"],
"//conditions:default": [],
}),
)
26 changes: 12 additions & 14 deletions examples/python/ml/flax_llama7b_split/flax_llama7b_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json

# Start nodes.
# > bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json" up
# Run this example script.
# > bazel run -c opt //examples/python/ml/flax_llama7b -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json
import time
import argparse
import json
from contextlib import contextmanager
from typing import Any, Optional, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.nn as jnn
import flax.linen as nn
from flax.linen.linear import Array
from typing import Any, Optional, Tuple, Union
from transformers import LlamaTokenizer
import jax.numpy as jnp
from EasyLM.checkpoint import StreamingCheckpointer
from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM
from EasyLM.models.llama.llama_model_splited_transformer import (
FlaxLLaMAForCausalLMClient,
FlaxLLaMAForCausalLMMid,
FlaxLLaMAForCausalLMServer,
FlaxLLaMAModule,
FlaxLLaMAForCausalLMMid,
LLaMAConfig,
)
from flax.linen.linear import Array
from transformers import LlamaTokenizer


import spu.utils.distributed as ppd
from contextlib import contextmanager
import spu.spu_pb2 as spu_pb2

from flax.linen.linear import Array
from typing import Any, Optional, Tuple, Union
import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,55 +16,49 @@
# Original Source Code Form
# [EasyLM](https://github.com/young-geng/EasyLM/tree/main)

import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import os
import tempfile
from functools import partial
from jax import jit
import numpy as np
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
from jax.sharding import PartitionSpec as PS
import flax.linen as nn
import numpy as np
import sentencepiece as spm
from EasyLM.bpt import blockwise_attn, blockwise_ffn
from EasyLM.jax_utils import (
get_gradient_checkpoint_policy,
get_jax_mesh,
with_sharding_constraint,
)
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.linen import partitioning as nn_partitioning
import einops

import sentencepiece as spm
from jax import jit, lax
from jax.sharding import PartitionSpec as PS
from ml_collections import ConfigDict
from ml_collections.config_dict import config_dict
from mlxu import function_args_to_config, load_pickle, open_file
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
)
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)


from ml_collections import ConfigDict
from ml_collections.config_dict import config_dict
from mlxu import function_args_to_config, load_pickle, open_file

from EasyLM.bpt import blockwise_ffn, blockwise_attn
from EasyLM.jax_utils import (
with_sharding_constraint,
get_jax_mesh,
get_gradient_checkpoint_policy,
)


LLAMA_STANDARD_CONFIGS = {
'7b': {
'vocab_size': 32000,
Expand Down
7 changes: 3 additions & 4 deletions examples/python/ml/flax_resnet/flax_resnet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import json

import jax
import spu.utils.distributed as ppd

import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument("-c", "--config", default="3pc.json")
Expand All @@ -29,14 +29,13 @@
ppd.init(conf["nodes"], conf["devices"])


from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoConfig,
AutoImageProcessor,
FlaxResNetForImageClassification,
)

from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

Expand Down
2 changes: 1 addition & 1 deletion examples/python/ml/flax_whisper/flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import os

import jax.numpy as jnp
from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
from datasets import load_dataset
from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor

import spu.utils.distributed as ppd
from spu import spu_pb2
Expand Down
8 changes: 3 additions & 5 deletions examples/python/ml/ml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import unittest
from time import perf_counter

import multiprocess
import numpy.testing as npt
import pandas as pd

import spu.utils.distributed as ppd
from spu.utils.polyfill import Process

with open("examples/python/conf/3pc.json", 'r') as file:
conf = json.load(file)
Expand Down Expand Up @@ -70,15 +70,13 @@ class UnitTests(unittest.TestCase):
def setUpClass(cls):
cls.workers = []
for node_id in conf["nodes"].keys():
worker = multiprocess.Process(
target=ppd.RPC.serve, args=(node_id, conf["nodes"])
)
worker = Process(target=ppd.RPC.serve, args=(node_id, conf["nodes"]))
worker.start()
cls.workers.append(worker)
import time

# wait for all process serving.
time.sleep(0.05)
time.sleep(2)

rt_config = conf["devices"]["SPU"]["config"]["runtime_config"]
rt_config["enable_pphlo_profile"] = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import spu.utils.distributed as ppd


# Start nodes.
# > bazel run -c opt //examples/python/utils:nodectl -- up
#
Expand Down Expand Up @@ -114,6 +113,7 @@ def run_inference_on_cpu(model):
ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH)

from collections import OrderedDict

from jax.tree_util import tree_map


Expand Down
7 changes: 2 additions & 5 deletions examples/python/utils/nodectl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import argparse
import json

import multiprocess

import spu.utils.distributed as ppd
from spu.utils.polyfill import Process

parser = argparse.ArgumentParser(description='SPU node service.')
parser.add_argument(
Expand All @@ -44,9 +43,7 @@
elif args.command == 'up':
workers = []
for node_id in nodes_def.keys():
worker = multiprocess.Process(
target=ppd.RPC.serve, args=(node_id, nodes_def)
)
worker = Process(target=ppd.RPC.serve, args=(node_id, nodes_def))
worker.start()
workers.append(worker)

Expand Down
4 changes: 4 additions & 0 deletions libspu/compiler/core/core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ void Core::buildPipeline(mlir::PassManager *pm) {
optPM.addPass(mlir::spu::pphlo::createDecomposeMinMaxPass());
optPM.addPass(mlir::spu::pphlo::createSortLowering());

if (!options.disable_partial_sort_optimization()) {
optPM.addPass(mlir::spu::pphlo::createPartialSortToTopK());
}

if (!options.disable_sqrt_plus_epsilon_rewrite()) {
optPM.addPass(mlir::spu::pphlo::createOptimizeSqrtPlusEps());
}
Expand Down
Loading

0 comments on commit cc55fc2

Please sign in to comment.