From b77e76855c735b95a5a9f6f45dbdc54b16003db0 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Wed, 21 Aug 2024 14:56:17 -0700 Subject: [PATCH] Cherry Pick v4 (#801) Cherry pick PRs and update version to 0.4.0 --- .github/workflows/linux-cpu-arm64-build.yml | 4 +- .github/workflows/linux-gpu-x64-build.yml | 4 +- .github/workflows/win-cpu-x64-build.yml | 2 - .github/workflows/win-cuda-x64-build.yml | 3 +- .../stages/jobs/nuget-validation-job.yml | 21 ++-- .pipelines/stages/jobs/py-validation-job.yml | 27 ++--- VERSION_INFO | 2 +- cmake/global_variables.cmake | 8 +- examples/csharp/HelloPhi/HelloPhi.csproj | 6 +- examples/csharp/HelloPhi/Program.cs | 18 +-- nuget/MANAGED_PACKAGE.md | 3 + ...crosoft.ML.OnnxRuntimeGenAI.Managed.nuspec | 2 +- src/csharp/README.md => nuget/PACKAGE.md | 0 .../Microsoft.ML.OnnxRuntimeGenAI.csproj | 7 +- src/csharp/Utils.cs | 24 +++- src/ort_genai.h | 6 + src/ort_genai_c.cpp | 44 +++++++ src/ort_genai_c.h | 8 ++ src/python/py/models/builder.py | 8 +- src/python/py/models/quantized_model.py | 41 +++---- test/c_api_tests.cpp | 59 ++++++++++ test/python/_test_utils.py | 109 +++++++++++++----- test/python/conftest.py | 15 ++- test/python/test_onnxruntime_genai.py | 43 +++---- test/python/test_onnxruntime_genai_e2e.py | 67 ++++++----- .../nuget/generate_nuspec_for_native_nuget.py | 4 +- 26 files changed, 373 insertions(+), 162 deletions(-) create mode 100644 nuget/MANAGED_PACKAGE.md rename src/csharp/README.md => nuget/PACKAGE.md (100%) diff --git a/.github/workflows/linux-cpu-arm64-build.yml b/.github/workflows/linux-cpu-arm64-build.yml index 5d8fe12b4..b58649de0 100644 --- a/.github/workflows/linux-cpu-arm64-build.yml +++ b/.github/workflows/linux-cpu-arm64-build.yml @@ -72,7 +72,7 @@ jobs: --container-registry onnxruntimebuildcache \ --repository ort_genai_linux_arm64_gha - - name: Doker -- Configure with CMake and GCC + - name: Docker -- Configure with CMake and GCC run: | docker run --rm \ --volume $GITHUB_WORKSPACE:/onnxruntime_src \ @@ -84,7 +84,7 @@ jobs: --volume $GITHUB_WORKSPACE:/onnxruntime_src \ -w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "/usr/bin/cmake --build --preset linux_gcc_cpu_release" - - name: Dokcer -- check test directory + - name: Docker -- Check test directory run: | docker run --rm \ --volume $GITHUB_WORKSPACE:/onnxruntime_src \ diff --git a/.github/workflows/linux-gpu-x64-build.yml b/.github/workflows/linux-gpu-x64-build.yml index e24fac8dd..0af22cbb2 100644 --- a/.github/workflows/linux-gpu-x64-build.yml +++ b/.github/workflows/linux-gpu-x64-build.yml @@ -129,13 +129,14 @@ jobs: docker run \ --gpus all \ --rm \ + --volume /data/ortgenai_pytorch_models:/data/ortgenai_pytorch_models \ --volume $GITHUB_WORKSPACE:/ort_genai_src \ -e HF_TOKEN=$HF_TOKEN \ -w /ort_genai_src onnxruntimecudabuildx64 bash -c " \ ${{ env.PYTHON_EXECUTABLE }} -m pip install -r test/python/requirements.txt --user && \ ${{ env.PYTHON_EXECUTABLE }} -m pip install -r test/python/requirements-cuda.txt --user && \ ${{ env.PYTHON_EXECUTABLE }} -m pip install /ort_genai_src/build/cuda/wheel/onnxruntime_genai*manylinux*.whl --user && \ - ${{ env.PYTHON_EXECUTABLE }} test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models" + ${{ env.PYTHON_EXECUTABLE }} test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models --e2e" - name: Docker -- Run unit tests run: | @@ -143,5 +144,6 @@ jobs: docker run \ --gpus all \ --rm \ + --volume /data/ortgenai_pytorch_models:/data/ortgenai_pytorch_models \ --volume $GITHUB_WORKSPACE:/ort_genai_src \ -w /ort_genai_src onnxruntimecudabuildx64 bash -c "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ort_genai_src/build/cuda/ /ort_genai_src/build/cuda/test/unit_tests" diff --git a/.github/workflows/win-cpu-x64-build.yml b/.github/workflows/win-cpu-x64-build.yml index 9b6415550..c4c8855a5 100644 --- a/.github/workflows/win-cpu-x64-build.yml +++ b/.github/workflows/win-cpu-x64-build.yml @@ -94,8 +94,6 @@ jobs: run: | python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" - - - name: Verify Build Artifacts if: always() continue-on-error: true diff --git a/.github/workflows/win-cuda-x64-build.yml b/.github/workflows/win-cuda-x64-build.yml index 83dca83b9..36ec38d04 100644 --- a/.github/workflows/win-cuda-x64-build.yml +++ b/.github/workflows/win-cuda-x64-build.yml @@ -93,8 +93,7 @@ jobs: - name: Run the Python Tests run: | - python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" - + python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" --e2e - name: Verify Build Artifacts if: always() diff --git a/.pipelines/stages/jobs/nuget-validation-job.yml b/.pipelines/stages/jobs/nuget-validation-job.yml index 6d69930c0..656a1dc32 100644 --- a/.pipelines/stages/jobs/nuget-validation-job.yml +++ b/.pipelines/stages/jobs/nuget-validation-job.yml @@ -116,15 +116,16 @@ jobs: inputs: version: '8.x' - - template: steps/utils/download-huggingface-model.yml - parameters: - StepName: 'Download Model from HuggingFace' - HuggingFaceRepo: 'microsoft/Phi-3-mini-4k-instruct-onnx' - RepoFolder: $(prebuild_phi3_mini_model_folder) - LocalFolder: 'models' - WorkingDirectory: '$(Build.Repository.LocalPath)/examples/csharp/HelloPhi' - HuggingFaceToken: $(HF_TOKEN) - os: ${{ parameters.os }} + - ${{ if ne(parameters.arch, 'arm64') }}: + - template: steps/utils/download-huggingface-model.yml + parameters: + StepName: 'Download Model from HuggingFace' + HuggingFaceRepo: 'microsoft/Phi-3-mini-4k-instruct-onnx' + RepoFolder: $(prebuild_phi3_mini_model_folder) + LocalFolder: 'models' + WorkingDirectory: '$(Build.Repository.LocalPath)/examples/csharp/HelloPhi' + HuggingFaceToken: $(HF_TOKEN) + os: ${{ parameters.os }} - template: steps/utils//flex-download-pipeline-artifact.yml parameters: @@ -134,7 +135,7 @@ jobs: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} - - ${{ if eq(parameters.os, 'win') }}: + - ${{ if and(eq(parameters.os, 'win'), ne(parameters.arch, 'arm64')) }}: - ${{ if eq(parameters.ep, 'cuda') }}: - powershell: | $env:AZCOPY_MSI_CLIENT_ID = "63b63039-6328-442f-954b-5a64d124e5b4"; diff --git a/.pipelines/stages/jobs/py-validation-job.yml b/.pipelines/stages/jobs/py-validation-job.yml index 6e3bd6625..53b8dffc7 100644 --- a/.pipelines/stages/jobs/py-validation-job.yml +++ b/.pipelines/stages/jobs/py-validation-job.yml @@ -164,15 +164,16 @@ jobs: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} - - template: steps/utils/download-huggingface-model.yml - parameters: - StepName: 'Download Model from HuggingFace' - HuggingFaceRepo: 'microsoft/Phi-3-mini-4k-instruct-onnx' - RepoFolder: $(prebuild_phi3_mini_model_folder) - LocalFolder: 'models' - WorkingDirectory: '$(Build.Repository.LocalPath)/examples/python' - HuggingFaceToken: $(HF_TOKEN) - os: ${{ parameters.os }} + - ${{ if ne(parameters.arch, 'arm64') }}: + - template: steps/utils/download-huggingface-model.yml + parameters: + StepName: 'Download Model from HuggingFace' + HuggingFaceRepo: 'microsoft/Phi-3-mini-4k-instruct-onnx' + RepoFolder: $(prebuild_phi3_mini_model_folder) + LocalFolder: 'models' + WorkingDirectory: '$(Build.Repository.LocalPath)/examples/python' + HuggingFaceToken: $(HF_TOKEN) + os: ${{ parameters.os }} - ${{ if eq(parameters.os, 'linux') }}: - ${{ if eq(parameters.ep, 'cuda') }}: @@ -195,7 +196,7 @@ jobs: $python_exe -m pip install -r /ort_genai_src/test/python/requirements.txt && \ $python_exe -m pip install -r /ort_genai_src/test/python/requirements-cuda.txt && \ cd /ort_genai_src/examples/python && \ - $python_exe -m pip install --no-index --find-links=/ort_genai_binary/wheel $(pip_package_name) && \ + $python_exe -m pip install --find-links=/ort_genai_binary/wheel $(pip_package_name) && \ $python_exe model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose" displayName: 'Run Example With Artifact' @@ -206,12 +207,12 @@ jobs: python -m pip install -r test/python/requirements.txt python -m pip install -r test/python/requirements-cpu.txt cd examples/python - python -m pip install --no-index --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) + python -m pip install --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) python model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose displayName: 'Run Example With Artifact' workingDirectory: '$(Build.Repository.LocalPath)' - - ${{ if eq(parameters.os, 'win') }}: + - ${{ if and(eq(parameters.os, 'win'), ne(parameters.arch, 'arm64'), ne(parameters.ep, 'directml')) }}: - ${{ if eq(parameters.ep, 'cuda') }}: - powershell: | $env:AZCOPY_MSI_CLIENT_ID = "63b63039-6328-442f-954b-5a64d124e5b4"; @@ -233,7 +234,7 @@ jobs: python -m pip install -r test/python/requirements-cpu.txt } cd examples\python - python -m pip install --no-index --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) + python -m pip install --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) python model-generate.py -m .\models\$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose displayName: 'Run Example With Artifact' diff --git a/VERSION_INFO b/VERSION_INFO index e33438f1c..60a2d3e96 100644 --- a/VERSION_INFO +++ b/VERSION_INFO @@ -1 +1 @@ -0.4.0-rc1 \ No newline at end of file +0.4.0 \ No newline at end of file diff --git a/cmake/global_variables.cmake b/cmake/global_variables.cmake index 3081d7915..1cced02ec 100644 --- a/cmake/global_variables.cmake +++ b/cmake/global_variables.cmake @@ -13,7 +13,13 @@ set(VERSION_INFO ${ver}) # VERSION_PATCH: 0 string(REPLACE "-" ";" VERSION_LIST ${VERSION_INFO}) list(GET VERSION_LIST 0 VERSION_STR) -list(GET VERSION_LIST 1 VERSION_SUFFIX) +# Check if it is a stable or dev version +list(LENGTH VERSION_LIST VERSION_LIST_LENGTH) +if(VERSION_LIST_LENGTH GREATER 1) + list(GET VERSION_LIST 1 VERSION_SUFFIX) +else() + set(VERSION_SUFFIX "") # Set VERSION_SUFFIX to empty if stable version +endif() string(REPLACE "." ";" VERSION_LIST ${VERSION_STR}) list(GET VERSION_LIST 0 VERSION_MAJOR) list(GET VERSION_LIST 1 VERSION_MINOR) diff --git a/examples/csharp/HelloPhi/HelloPhi.csproj b/examples/csharp/HelloPhi/HelloPhi.csproj index 43dfa6838..a49d8c081 100644 --- a/examples/csharp/HelloPhi/HelloPhi.csproj +++ b/examples/csharp/HelloPhi/HelloPhi.csproj @@ -9,9 +9,9 @@ - - - + + + diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index 5d554ead6..49e9b945b 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -5,10 +5,10 @@ void PrintUsage() { Console.WriteLine("Usage:"); Console.WriteLine(" -m model_path"); - Console.WriteLine(" -i (optional): Intereactive mode"); + Console.WriteLine(" -i (optional): Interactive mode"); } -OgaHandle ogaHandle = new OgaHandle(); +using OgaHandle ogaHandle = new OgaHandle(); if (args.Length < 1) { @@ -16,7 +16,7 @@ void PrintUsage() Environment.Exit(-1); } -bool intereactive = false; +bool interactive = false; string modelPath = string.Empty; uint i = 0; @@ -25,7 +25,7 @@ void PrintUsage() var arg = args[i]; if (arg == "-i") { - intereactive = true; + interactive = true; } else if (arg == "-m") { @@ -47,13 +47,13 @@ void PrintUsage() Console.WriteLine("-------------"); Console.WriteLine("Model path: " + modelPath); -Console.WriteLine("Intereactive: " + intereactive); +Console.WriteLine("Interactive: " + interactive); using Model model = new Model(modelPath); using Tokenizer tokenizer = new Tokenizer(model); var option = 2; -if (intereactive) +if (interactive) { Console.WriteLine("Please enter option number:"); Console.WriteLine("1. Complete Output"); @@ -64,7 +64,7 @@ void PrintUsage() do { string prompt = "def is_prime(num):"; // Example prompt - if (intereactive) + if (interactive) { Console.WriteLine("Prompt:"); prompt = Console.ReadLine(); @@ -72,7 +72,7 @@ void PrintUsage() if (string.IsNullOrEmpty(prompt)) { continue; - } + } var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); using GeneratorParams generatorParams = new GeneratorParams(model); @@ -99,4 +99,4 @@ void PrintUsage() } Console.WriteLine(); } -} while (intereactive); \ No newline at end of file +} while (interactive); diff --git a/nuget/MANAGED_PACKAGE.md b/nuget/MANAGED_PACKAGE.md new file mode 100644 index 000000000..8d3dc0fb1 --- /dev/null +++ b/nuget/MANAGED_PACKAGE.md @@ -0,0 +1,3 @@ +## About + +This package is a dependency of [Microsoft.ML.OnnxRuntimeGenAI](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntimeGenAI) and does not need to be installed directly. diff --git a/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec b/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec index 74b43a437..f4d06bc0a 100644 --- a/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec +++ b/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec @@ -19,7 +19,7 @@ - + diff --git a/src/csharp/README.md b/nuget/PACKAGE.md similarity index 100% rename from src/csharp/README.md rename to nuget/PACKAGE.md diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj index 1cd52e2e2..e1f8a66b8 100644 --- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj +++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj @@ -36,10 +36,15 @@ - + $(VersionInfoStr.Split(-)[0]) $(VersionInfoStr.Split(-)[1]) + + + $(VersionInfoStr) + + diff --git a/src/csharp/Utils.cs b/src/csharp/Utils.cs index b84f1d407..90d007bc7 100644 --- a/src/csharp/Utils.cs +++ b/src/csharp/Utils.cs @@ -7,11 +7,33 @@ namespace Microsoft.ML.OnnxRuntimeGenAI { - public class OgaHandle + public class OgaHandle: IDisposable { + private bool _disposed = false; + + public OgaHandle() + { + } + ~OgaHandle() { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } NativeMethods.OgaShutdown(); + _disposed = true; } } diff --git a/src/ort_genai.h b/src/ort_genai.h index d0c1d0c75..4a83b69e2 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -232,6 +232,12 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_GetSequenceData(this, index); } + std::unique_ptr GetOutput(const char* name) { + OgaTensor* out; + OgaCheckResult(OgaGenerator_GetOutput(this, name, &out)); + return std::unique_ptr(out); + } + #if __cplusplus >= 202002L std::span GetSequence(size_t index) const { return {GetSequenceData(index), GetSequenceCount(index)}; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 6f26d2857..a40807845 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -208,6 +208,50 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + auto* ortvalue_output = generator.state_->GetOutput(name); + auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo(); + std::unique_ptr ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, + type_info->GetShape(), + type_info->GetElementType()); + // Copy data to ortvalue_clone + auto element_size = Generators::SizeOf(type_info->GetElementType()); + auto data_size = type_info->GetElementCount() * element_size; + if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) { +#if USE_CUDA + cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) { +#if USE_DML + ComPtr gpu_resource; + Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation( + generator.model_->allocator_device_, + ortvalue_output->GetTensorMutableRawData(), + &gpu_resource)); + auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData(); + generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu( + std::span(reinterpret_cast(cpu_tensor), data_size), + gpu_resource.Get(), + 0, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) { + std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), + static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, + static_cast(ortvalue_clone->GetTensorMutableRawData())); + } else { + throw std::runtime_error("Unsupported Device type: " + ortvalue_output->GetTensorMemoryInfo().GetDeviceType()); + } + + auto tensor = std::make_shared(std::move(ortvalue_clone)); + tensor->external_owner_ = tensor; + *out = reinterpret_cast(tensor.get()); + return nullptr; + OGA_CATCH +} + size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* oga_generator, size_t index) { auto& generator = *reinterpret_cast(oga_generator); return generator.GetSequence(static_cast(index)).GetCPU().size(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index ec97ce4e5..7b1f084c2 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -224,6 +224,14 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/* + * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor + * and will be released when the OgaTensor is destroyed + * \param[in] generator The generator to run the GetOutput on the name provided and the out pointer to store the output + * \return OgaResult containing the error message if the computation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out); + /* * \brief Returns the number of tokens in the sequence at the given index. * \param[in] generator The generator to get the count of the tokens for the sequence at the given index. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 90a1cd2fb..fb6523b33 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -32,7 +32,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers self.vocab_size = config.vocab_size - self.activation = config.hidden_activation if hasattr(config, "hidden_activation") else config.hidden_act + self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act self.model_name_or_path = config._name_or_path self.model_type = config.architectures[0] @@ -1608,11 +1608,11 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size) + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) else: # Load PyTorch model - extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir} - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) + extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, cache_dir=self.cache_dir, use_auth_token=True, trust_remote_code=True, **extra_kwargs) # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index 48c4ec7bd..5452fe563 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -83,17 +83,18 @@ def is_empty(self): class QuantizedModel: - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): self.quant_type = quant_type self.embedding = TensorModule() self.final_norm = TensorModule() self.lm_head = TensorModule() - self.layers = [] + self.layers = {} + self.num_layers = num_layers layer_id = 0 for weight_file in os.listdir(input_path): if weight_file.endswith(".safetensors"): - module = QuantizedDecoderLayer(layer_id, bits, group_size) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) weights = load_file(os.path.join(input_path, weight_file)) # Map weights to modules @@ -115,10 +116,9 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in else: curr_layer_id = int(name.split(".")[2]) if curr_layer_id != layer_id: - # Add layer to list of modules - self.layers.append(module) + # Switch layer module used layer_id = curr_layer_id - module = QuantizedDecoderLayer(layer_id, bits, group_size) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) # Map weights and biases of norm, attention, and feed-forward network # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj @@ -288,11 +288,7 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in module.mlp.up_proj.g_idx = tensor else: raise NotImplementedError(f"{name} in your quantized model is not recognized.") - - if not module.is_empty(): - # Append final layer to list of layers - self.layers.append(module) - + # Set LM head weights + biases if not already set if self.lm_head.weight is None: # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) @@ -301,6 +297,7 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in self.lm_head.bias = self.embedding.bias # Sort list of layers by layer id + self.layers = list(self.layers.values()) self.layers.sort(key=lambda m: m.layer_id) # Set properties of each layer based on quantization type @@ -487,7 +484,7 @@ def pack_ort_format(self, module, intweight): Pack `scales`, `qzeros`, and `qweight` to ORT format """ if module.bits != 4: - raise NotImplementedError(f"{modue.bits}-bit quantization in ORT is not currently supported by this tool.") + raise NotImplementedError(f"{module.bits}-bit quantization in ORT is not currently supported by this tool.") intzeros_pt = module.qzeros.T if module.qzeros.dtype == module.scales.dtype else module.qzeros.T.byte() intweight_pt = intweight.byte() @@ -521,11 +518,13 @@ def pack_ort_format(self, module, intweight): class AWQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): + if i >= self.num_layers: + break print(f"Unpacking and repacking layer {i}") # Unpack and repack all `QuantizedTensorModule` classes in attention @@ -586,14 +585,16 @@ def reverse_reorder_tensor(self, tensor, bits): class GPTQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): + if i >= self.num_layers: + break print(f"Unpacking and repacking layer {i}") - # Unpack and repack all `QuantizedTensorModule` classes in attention + # Unpack and repack all `QuantizedTensorModule` classes in attention for name, q_tensors in layer.self_attn.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.handle_qzeros(q_tensors) @@ -642,16 +643,16 @@ def __init__(self, module): class QuantModel: @staticmethod - def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): + def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): """ Unpack quantized weights in PyTorch models, store them in a standard format, and repack them into ONNX Runtime's format. Also performs any pre-processing and post-processing when unpacking the quantized weights. """ if quant_type == "awq": - model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) elif quant_type == "gptq": - model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size) + model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers) else: raise NotImplementedError(f"The {quant_type} quantized model is not currently supported.") diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index eba5aff15..8e8cc13cb 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -173,6 +173,65 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { } #endif +TEST(CAPITests, GetOutputCAPI) { + std::vector input_ids_shape{2, 4}; + std::vector input_ids{0, 0, 0, 52, 0, 0, 195, 731}; + + auto input_sequence_length = input_ids_shape[1]; + auto batch_size = input_ids_shape[0]; + int max_length = 10; + + // To generate this file: + // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 + // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) + + auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + + auto params = OgaGeneratorParams::Create(*model); + params->SetSearchOption("max_length", max_length); + params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + + auto generator = OgaGenerator::Create(*model, *params); + + // check prompt + // full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5] + std::vector expected_sampled_logits_prompt{0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + 0.30323744f, 0.0545997f, 0.03894716f, 0.11702324f, 0.0410665f, + -0.12675379f, -0.04443946f, 0.14492269f, 0.03021223f, -0.03212897f, + 0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, + 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; + + generator->ComputeLogits(); + auto prompt_logits_ptr = generator->GetOutput("logits"); + auto prompt_logits = static_cast(prompt_logits_ptr->Data()); + int num_prompt_outputs_to_check = 40; + int sample_size = 200; + float tolerance = 0.001f; + // Verify outputs match expected outputs + for (int i = 0; i < num_prompt_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance); + } + + generator->GenerateNextToken(); + // check for the 1st token generation + // full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5] + std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, + 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; + + generator->ComputeLogits(); + auto token_gen_logits_ptr = generator->GetOutput("logits"); + auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); + int num_token_gen_outputs_to_check = 10; + + for (int i = 0; i < num_token_gen_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); + } + generator->GenerateNextToken(); +} + #if TEST_PHI2 struct Phi2Test { diff --git a/test/python/_test_utils.py b/test/python/_test_utils.py index a314454ba..808f8930e 100644 --- a/test/python/_test_utils.py +++ b/test/python/_test_utils.py @@ -52,32 +52,85 @@ def run_subprocess( return completed_process -def download_models(download_path, device): - # python -m onnxruntime_genai.models.builder -m -p int4 -e cpu -o --extra_options num_hidden_layers=1 - model_names = { - "cpu": { - "phi-2": "microsoft/phi-2", - }, - "cuda": { - "phi-2": "microsoft/phi-2", - }, +def get_model_paths(): + hf_paths = { + "phi-2": "microsoft/phi-2", + # "phi-3-mini": "microsoft/Phi-3-mini-128k-instruct", } - for model_name, model_identifier in model_names[device].items(): - model_path = os.path.join(download_path, device, model_name) - if not os.path.exists(model_path): - command = [ - sys.executable, - "-m", - "onnxruntime_genai.models.builder", - "-m", - model_identifier, - "-p", - "int4", - "-e", - device, - "-o", - model_path, - "--extra_options", - "num_hidden_layers=1", - ] - run_subprocess(command).check_returncode() + + ci_data_path = os.path.join("/", "data", "ortgenai_pytorch_models") + if not os.path.exists(ci_data_path): + return {}, hf_paths + + # Note: If a model has over 4B parameters, please add a quantized version + # to `ci_paths` instead of `hf_paths` to reduce file size and testing time. + ci_paths = { + "llama-2": os.path.join(ci_data_path, "Llama-2-7B-Chat-GPTQ"), + "llama-3": os.path.join(ci_data_path, "Meta-Llama-3-8B-AWQ"), + "mistral-v0.2": os.path.join(ci_data_path, "Mistral-7B-Instruct-v0.2-GPTQ"), + # "phi-2": os.path.join(ci_data_path, "phi2"), + # "gemma-2b": os.path.join(ci_data_path, "gemma-1.1-2b-it"), + "gemma-7b": os.path.join(ci_data_path, "gemma-7b-it-awq"), + # "phi-3-mini": os.path.join(ci_data_path, "phi3-mini-128k-instruct"), + } + + return ci_paths, hf_paths + + +def download_model(model_name, input_path, output_path, precision, device, one_layer=True): + command = [ + sys.executable, + "-m", + "onnxruntime_genai.models.builder", + ] + + if model_name is not None: + # If model_name is provided: + # python -m onnxruntime_genai.models.builder -m -o -p -e + command += ["-m", model_name] + elif input_path != "": + # If input_path is provided: + # python -m onnxruntime_genai.models.builder -i -o -p -e + command += ["-i", input_path] + else: + raise Exception("Either `model_name` or `input_path` can be provided for PyTorch models, not both.") + + command += [ + "-o", + output_path, + "-p", + precision, + "-e", + device, + ] + + extra_options = ["--extra_options"] + if device == "cpu" and precision == "int4": + extra_options += ["int4_accuracy_level=4"] + if one_layer: + extra_options += ["num_hidden_layers=1"] + if len(extra_options) > 1: + command += extra_options + + run_subprocess(command).check_returncode() + + +def download_models(download_path, precision, device): + ci_paths, hf_paths = get_model_paths() + output_paths = [] + + # python -m onnxruntime_genai.models.builder -i -o -p -e + for model_name, input_path in ci_paths.items(): + output_path = os.path.join(download_path, model_name, precision, device) + if not os.path.exists(output_path): + download_model(None, input_path, output_path, precision, device) + output_paths.append(output_path) + + # python -m onnxruntime_genai.models.builder -m -o -p -e + for model_name, hf_name in hf_paths.items(): + output_path = os.path.join(download_path, model_name, precision, device) + if not os.path.exists(output_path): + download_model(hf_name, "", output_path, precision, device) + output_paths.append(output_path) + + return output_paths diff --git a/test/python/conftest.py b/test/python/conftest.py index 08498d184..d3a08df69 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -18,41 +18,44 @@ def pytest_addoption(parser): ) -def get_path_for_model_and_device(data_path, model_name, device): - return os.path.join(data_path, device, model_name) +def get_path_for_model(data_path, model_name, precision, device): + return os.path.join(data_path, model_name, precision, device) @pytest.fixture def phi2_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "phi-2", + "int4", ) @pytest.fixture def gemma_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "gemma", + "int4", ) @pytest.fixture def llama_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "llama", + "int4", ) @pytest.fixture def path_for_model(request): return functools.partial( - get_path_for_model_and_device, request.config.getoption("--test_models") + get_path_for_model, request.config.getoption("--test_models") ) diff --git a/test/python/test_onnxruntime_genai.py b/test/python/test_onnxruntime_genai.py index 41d615e51..212de1cfd 100644 --- a/test/python/test_onnxruntime_genai.py +++ b/test/python/test_onnxruntime_genai.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - import argparse +import json import logging import os import pathlib import sys import sysconfig -from typing import Union +from typing import Union, List import onnxruntime_genai as og from _test_utils import download_models, run_subprocess @@ -34,17 +34,22 @@ def run_onnxruntime_genai_api_tests( "--test_models", test_models, ] - run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_onnxruntime_genai_e2e_tests( cwd: Union[str, bytes, os.PathLike], log: logging.Logger, + output_paths: List[Union[str, bytes, os.PathLike]], ): log.debug("Running: ONNX Runtime GenAI E2E Tests") - command = [sys.executable, "test_onnxruntime_genai_e2e.py"] + command = [ + sys.executable, + "test_onnxruntime_genai_e2e.py", + "--models", + json.dumps(output_paths), + ] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -74,23 +79,19 @@ def main(): log.info("Running onnxruntime-genai tests pipeline") - if not args.e2e: - if not ( - sysconfig.get_platform().endswith("arm64") or sys.version_info.minor < 8 - ): - download_models(os.path.abspath(args.test_models), "cpu") - if og.is_cuda_available(): - download_models( - os.path.abspath(args.test_models), - "cuda", - ) - - run_onnxruntime_genai_api_tests( - os.path.abspath(args.cwd), log, os.path.abspath(args.test_models) - ) - - else: - run_onnxruntime_genai_e2e_tests(os.path.abspath(args.cwd), log) + # Get INT4 ONNX models + output_paths = [] + if not ( + sysconfig.get_platform().endswith("arm64") or sys.version_info.minor < 8 + ): + output_paths += download_models(os.path.abspath(args.test_models), "int4", "cpu") + if og.is_cuda_available(): + output_paths += download_models(os.path.abspath(args.test_models), "int4", "cuda") + + # Run ONNX Runtime GenAI tests + run_onnxruntime_genai_api_tests(os.path.abspath(args.cwd), log, os.path.abspath(args.test_models)) + if args.e2e: + run_onnxruntime_genai_e2e_tests(os.path.abspath(args.cwd), log, output_paths) return 0 diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index eaac1e087..9939242d2 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -1,37 +1,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations +import argparse +import json import os -import sys -import tempfile +import logging import onnxruntime_genai as og -from _test_utils import run_subprocess - -def download_model( - download_path: str | bytes | os.PathLike, device: str, model_identifier: str, precision: str -): - # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cpu -o download_path - # Or with cuda graph enabled: - # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options enable_cuda_graph=1 -o download_path - command = [ - sys.executable, - "-m", - "onnxruntime_genai.models.builder", - "-m", - model_identifier, - "-p", - precision, - "-e", - device, - "-o", - download_path, - ] - if device == "cuda": - command.append("--extra_options") - command.append("enable_cuda_graph=1") - run_subprocess(command).check_returncode() +logging.basicConfig( + format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG +) +log = logging.getLogger("onnxruntime-genai-tests") def run_model(model_path: str | bytes | os.PathLike): @@ -47,7 +28,7 @@ def run_model(model_path: str | bytes | os.PathLike): sequences = tokenizer.encode_batch(prompts) params = og.GeneratorParams(model) params.set_search_options(max_length=200) - params.try_graph_capture_with_max_batch_size(16) + params.try_graph_capture_with_max_batch_size(4) params.input_ids = sequences output_sequences = model.generate(params) @@ -55,10 +36,28 @@ def run_model(model_path: str | bytes | os.PathLike): assert output +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--models", + type=str, + required=True, + help="List of model paths to run. Pass as `json.dumps(model_paths)` to this argument.", + ) + + args = parser.parse_args() + args.models = json.loads(args.models) + return args + + if __name__ == "__main__": - for model_name in ["microsoft/phi-2"]: - for precision in ["int4", "fp32"]: - with tempfile.TemporaryDirectory() as temp_dir: - device = "cuda" if og.is_cuda_available() else "cpu" - download_model(temp_dir, device, model_name, precision) - run_model(temp_dir) + args = get_args() + for model_path in args.models: + try: + log.info(f"Running {model_path}") + run_model(model_path) + except Exception as e: + log.error(e) + log.error(f"Failed to run {model_path}", exc_info=True) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index d09eaee89..5ea725e17 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -74,7 +74,7 @@ def generate_license(line_list): line_list.append('LICENSE') def generate_readme(line_list): - line_list.append('README.md') + line_list.append('PACKAGE.md') def generate_project_url(line_list, project_url): line_list.append("" + project_url + "") @@ -112,7 +112,7 @@ def generate_files(lines, args): lines.append('') lines.append(f'') - lines.append(f'') + lines.append(f'') lines.append(f'') def add_native_artifact_if_exists(xml_lines, runtime, artifact):