diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index 948291612..520a0715e 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -56,7 +56,7 @@ struct DecoderOnlyPipelineState : State { RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; - OrtValue* GetOutput(const char* name); + OrtValue* GetOutput(const char* name) override; private: void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 2c80aad2f..85694e0e3 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -113,16 +113,16 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { // Update input_ids with next tokens, converting from 32-bit to 64-bit if (type_ == Ort::TypeToTensorType) { switch (model_.device_type_) { -#if USE_CUDA case DeviceType::CUDA: { +#if USE_CUDA auto* data = value_->GetTensorMutableData(); auto next_tokens = next_tokens_unk.GetGPU(); cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast(next_tokens.size()), model_.cuda_stream_); - } break; #endif + } break; -#if USE_DML case DeviceType::DML: { +#if USE_DML ComPtr source_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource)); @@ -144,8 +144,8 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { model_.GetDmlDevice(), model_.GetOrtDmlApi(), input_ids_cast_command_list_state_); - } break; #endif + } break; case DeviceType::CPU: { auto* data = value_->GetTensorMutableData(); auto next_tokens = next_tokens_unk.GetCPU(); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 5d371a74c..5eb2496de 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -76,8 +76,8 @@ RoamingArray Logits::Get() { for (int beam_index = 0; beam_index < num_beams; beam_index++) { switch (model_.device_type_) { -#if USE_DML case DeviceType::DML: { +#if USE_DML ComPtr source_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource)); @@ -96,8 +96,8 @@ RoamingArray Logits::Get() { source_offset, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, size_in_bytes); - } break; #endif + } break; case DeviceType::CPU: case DeviceType::CUDA: { diff --git a/src/models/whisper.h b/src/models/whisper.h index 6560d3a2c..d0bf8dc3a 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -20,7 +20,7 @@ struct Whisper_Model : Model { struct Whisper_State : State { Whisper_State(const Whisper_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params); RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; - OrtValue* GetOutput(const char* name); + OrtValue* GetOutput(const char* name) override; private: void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length, bool search_buffers);