Skip to content

Commit

Permalink
Fix several warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
skyline75489 committed Sep 25, 2024
1 parent 2348dc9 commit 7df02a2
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct DecoderOnlyPipelineState : State {
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens,
RoamingArray<int32_t> next_indices) override;

OrtValue* GetOutput(const char* name);
OrtValue* GetOutput(const char* name) override;

private:
void UpdateInputsOutputs(const RoamingArray<int32_t>& next_tokens, RoamingArray<int32_t> next_indices,
Expand Down
8 changes: 4 additions & 4 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
// Update input_ids with next tokens, converting from 32-bit to 64-bit
if (type_ == Ort::TypeToTensorType<int64_t>) {
switch (model_.device_type_) {
#if USE_CUDA
case DeviceType::CUDA: {
#if USE_CUDA
auto* data = value_->GetTensorMutableData<int64_t>();
auto next_tokens = next_tokens_unk.GetGPU();
cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast<int>(next_tokens.size()), model_.cuda_stream_);
} break;
#endif
} break;

#if USE_DML
case DeviceType::DML: {
#if USE_DML
ComPtr<ID3D12Resource> source_resource;
Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource));

Expand All @@ -144,8 +144,8 @@ void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
model_.GetDmlDevice(),
model_.GetOrtDmlApi(),
input_ids_cast_command_list_state_);
} break;
#endif
} break;
case DeviceType::CPU: {
auto* data = value_->GetTensorMutableData<int64_t>();
auto next_tokens = next_tokens_unk.GetCPU();
Expand Down
4 changes: 2 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ RoamingArray<float> Logits::Get() {

for (int beam_index = 0; beam_index < num_beams; beam_index++) {
switch (model_.device_type_) {
#if USE_DML
case DeviceType::DML: {
#if USE_DML
ComPtr<ID3D12Resource> source_resource;
Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource));

Expand All @@ -96,8 +96,8 @@ RoamingArray<float> Logits::Get() {
source_offset,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
size_in_bytes);
} break;
#endif
} break;

case DeviceType::CPU:
case DeviceType::CUDA: {
Expand Down
2 changes: 1 addition & 1 deletion src/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct Whisper_Model : Model {
struct Whisper_State : State {
Whisper_State(const Whisper_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;
OrtValue* GetOutput(const char* name);
OrtValue* GetOutput(const char* name) override;

private:
void UpdateInputsOutputs(const RoamingArray<int32_t>& next_tokens, RoamingArray<int32_t> next_indices, int current_length, bool search_buffers);
Expand Down

0 comments on commit 7df02a2

Please sign in to comment.