From c4c03efecf40ab7416cba97f78ab052eadbae34f Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 2 Aug 2024 16:53:51 +0000 Subject: [PATCH 01/13] Introduce speculative decoding initial version Results are validated with model-generate.py by using a int4 quantized model as the original model's assistant. The output sequence is the same and increased tps is observed. NOTE: Only MHA decoder only models, batch size 1, CPU, greedy select top is supported in this initial version. GQA needs https://github.com/microsoft/onnxruntime/pull/21523 to support seqlen > 1 in token phase. * Updated builder.py to produce MHA graph that supports seqlen > 1 in token phase. * Introduce speculative decoding currently through a separate Generator class. This can be merged with existing Generator potentially on either API level or implementation level. * Extended various components for functionalities to support speculative search. Previously most methods are hardcoded assuming seqlen == 1 for token phase. --- examples/python/model-generate.py | 24 +++- examples/python/model-qa.py | 35 ++++-- src/generators.cpp | 199 ++++++++++++++++++++++++++++++ src/generators.h | 42 ++++++- src/logging.cpp | 2 + src/logging.h | 1 + src/models/decoder_only.cpp | 38 ++++++ src/models/decoder_only.h | 10 +- src/models/input_ids.cpp | 32 +++++ src/models/input_ids.h | 6 + src/models/kv_cache.cpp | 48 +++++++ src/models/kv_cache.h | 8 ++ src/models/logits.cpp | 43 +++++++ src/models/logits.h | 7 ++ src/models/model.h | 3 + src/models/position_inputs.cpp | 55 +++++++++ src/models/position_inputs.h | 10 ++ src/python/py/models/builder.py | 94 +++++++++----- src/python/python.cpp | 40 ++++++ src/search.cpp | 94 ++++++++++++++ src/search.h | 31 ++++- src/sequences.cpp | 5 + src/sequences.h | 3 + 23 files changed, 780 insertions(+), 50 deletions(-) diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index 0a97f25b4..d78b8ac6f 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -2,9 +2,15 @@ import argparse import time + def main(args): if args.verbose: print("Loading model...") model = og.Model(f'{args.model}') + assistant_model = ( + og.Model(f"{args.assistant_model}") + if hasattr(args, "assistant_model") + else None + ) if args.verbose: print("Model loaded") tokenizer = og.Tokenizer(model) if args.verbose: print("Tokenizer created") @@ -15,13 +21,13 @@ def main(args): prompts = ["I like walking my cute dog", "What is the best restaurant in town?", "Hello, how are you today?"] - + if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") exit(1) prompts[:] = [f'{args.chat_template.format(input=text)}' for text in prompts] - + input_tokens = tokenizer.encode_batch(prompts) if args.verbose: print(f'Prompt(s) encoded: {prompts}') @@ -42,7 +48,10 @@ def main(args): if args.verbose: print("Generating tokens ...\n") start_time = time.time() - output_tokens = model.generate(params) + if assistant_model is None: + output_tokens = model.generate(params) + else: + output_tokens = model.generate_with_assist(assistant_model, params) run_time = time.time() - start_time for i in range(len(prompts)): @@ -56,9 +65,16 @@ def main(args): print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}") print() + if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end token generation loop example for gen-ai") parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)') + parser.add_argument( + "-a", + "--assistant_model", + type=str, + help="Assistant onnx model folder path (must contain config.json and model.onnx)", + ) parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from. Provide this parameter multiple times to batch multiple prompts') parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt') parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt') @@ -72,4 +88,4 @@ def main(args): parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.') args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 4532f307a..19ebaae07 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -15,10 +15,16 @@ def main(args): if args.verbose: print("Tokenizer created") if args.verbose: print() + assistant_model = None + if hasattr(args, "assistant_model"): + assistant_model = og.Model(args.assistant_model) + if args.verbose: + print("Assistant model loaded") + search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} if args.verbose: print(search_options) - + if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") @@ -43,13 +49,16 @@ def main(args): params = og.GeneratorParams(model) params.set_search_options(**search_options) params.input_ids = input_tokens - generator = og.Generator(model, params) + if assistant_model is not None: + generator = og.SpeculativeDecodingGenerator(model, assistant_model, params) + else: + generator = og.Generator(model, params) if args.verbose: print("Generator created") if args.verbose: print("Running generation loop ...") if args.timings: first = True - new_tokens = [] + generated_tokens = [] print() print("Output: ", end='', flush=True) @@ -63,9 +72,11 @@ def main(args): first_token_timestamp = time.time() first = False - new_token = generator.get_next_tokens()[0] - print(tokenizer_stream.decode(new_token), end='', flush=True) - if args.timings: new_tokens.append(new_token) + new_tokens = generator.get_next_tokens() + for new_token in new_tokens: + print(tokenizer_stream.decode(new_token), end="", flush=True) + if args.timings: + generated_tokens.extend(new_tokens) except KeyboardInterrupt: print(" --control+c pressed, aborting generation--") print() @@ -77,12 +88,20 @@ def main(args): if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp - print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") + print( + f"Prompt length: {len(input_tokens)}, New tokens: {len(generated_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(generated_tokens)/run_time:.2f} tps" + ) if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)') + parser.add_argument( + "-a", + "--assistant_model", + type=str, + help="Assistant onnx model folder path (must contain config.json and model.onnx)", + ) parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt') parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt') parser.add_argument('-ds', '--do_random_sampling', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false') @@ -94,4 +113,4 @@ def main(args): parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}') args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/src/generators.cpp b/src/generators.cpp index fe005111a..84ce0b5eb 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -4,6 +4,7 @@ #include "generators.h" #include "sequences.h" #include "models/model.h" +#include "models/decoder_only.h" #include "search.h" #if USE_CUDA #include "search_cuda.h" @@ -93,6 +94,14 @@ std::unique_ptr CreateGenerator(const Model& model, const GeneratorPa return std::make_unique(model, params); } +std::unique_ptr CreateAssistantGenerator(const Model& model, const GeneratorParams& params) { + return std::make_unique(model, params); +} + +std::unique_ptr CreateSpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params) { + return std::make_unique(model, assistant_model, params); +} + std::unique_ptr CreateSearch(const GeneratorParams& params) { #if USE_CUDA if (params.device_type == DeviceType::CUDA) { @@ -108,6 +117,16 @@ std::unique_ptr CreateSearch(const GeneratorParams& params) { return std::make_unique(params); } +std::unique_ptr CreateSpeculativeSearch(const GeneratorParams& params) { +#if USE_CUDA + throw std::runtime_error("Speculative decoding is not supported on CUDA"); +#endif + if (params.search.num_beams > 1) { + throw std::runtime_error("Speculative decoding is not supported with beam search"); + } + return std::make_unique(params); +} + Generator::Generator(const Model& model, const GeneratorParams& params) : model_{model.shared_from_this()} { if (params.search.max_length == 0) throw std::runtime_error("search max_length is 0"); @@ -196,6 +215,171 @@ RoamingArray Generator::GetSequence(size_t index) const { return search_->GetSequence(index); } +AssistantGenerator::AssistantGenerator(const Model& model, const GeneratorParams& params) + : Generator(model, params) { + if (params.search.num_beams != 1) + throw std::runtime_error("AssistantGenerator only supports num_beams=1, got " + std::to_string(params.search.num_beams)); + if (params.batch_size != 1) + throw std::runtime_error("AssistantGenerator only supports batch_size=1, got " + std::to_string(params.batch_size)); + if (params.vocab_size < 1) + throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); + if (params.sequence_length >= params.search.max_length) + throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); + + state_ = std::make_unique( + *std::dynamic_pointer_cast(model_), search_->GetSequenceLengths(), params); +} + +void AssistantGenerator::ComputeLogits() { + if (computed_logits_) + throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); + + auto sequence_length = search_->GetSequenceLength(); + auto next_token_length = first_run_in_assist_ ? 2 : 1; + auto past_length = sequence_length - next_token_length; + auto logits = state_->Run(search_->GetSequence(0), next_token_length, past_length, 1); + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + DumpSpan(stream, logits.GetCPU()); + stream << std::endl; + } + search_->SetLogits(logits); + computed_logits_ = true; + + auto& search = search_->params_->search; + search_->ApplyMinLength(search.min_length); + search_->ApplyRepetitionPenalty(search.repetition_penalty); + first_run_in_assist_ = false; +} + +void AssistantGenerator::GenerateNextToken() { + Generator::GenerateNextToken(); + candidate_length_++; +} + +void AssistantGenerator::AcceptCandidateTokens(RoamingArray next_tokens) { + search_->DropLastTokens(candidate_length_); + search_->SetNextTokens(next_tokens); + candidate_length_ = 0; + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + stream << SGR::Fg_Green << "assistant sequence: " << SGR::Reset << std::endl; + DumpSpan(stream, search_->GetSequence(0).GetCPU()); + stream << std::endl + << "length: " << search_->GetSequenceLength() << std::endl; + } + first_run_in_assist_ = true; +} + +SpeculativeDecodingGenerator::SpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params) + : assistant_generator_{CreateAssistantGenerator(assistant_model, params)}, + model_{model.shared_from_this()} { + if (params.search.max_length == 0) + throw std::runtime_error("search max_length is 0"); + if (params.search.max_length > model.config_->model.context_length) + throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); + if (params.batch_size != 1) + throw std::runtime_error("batch_size must be 1, is " + std::to_string(params.batch_size)); + if (params.vocab_size < 1) + throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); + if (params.sequence_length >= params.search.max_length) + throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); + if (params.input_ids.empty() || params.input_ids.data() == nullptr) + throw std::runtime_error("input_ids not set in GeneratorParams"); + + if (model.config_->model.type != "llama" && + model.config_->model.type != "gemma" && + model.config_->model.type != "gemma2" && + model.config_->model.type != "mistral" && + model.config_->model.type != "phi" && + model.config_->model.type != "phi3" && + model.config_->model.type != "phi3small" && + model.config_->model.type != "qwen2") + throw std::runtime_error("Speculative decoding is not supported for this model type " + model.config_->model.type); + + search_ = CreateSpeculativeSearch(params); + state_ = std::make_unique( + *std::dynamic_pointer_cast(model_), search_->GetSequenceLengths(), params); +} + +void SpeculativeDecodingGenerator::ComputeLogits() { + if (computed_logits_) + throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); + + candidate_length_ = 0; + while (!assistant_generator_->IsDone() && candidate_length_ < max_candidate_length_) { + assistant_generator_->ComputeLogits(); + assistant_generator_->GenerateNextToken(); + candidate_length_++; + } + + auto candidate_sequence = assistant_generator_->search_->GetSequence(0); + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + stream << SGR::Fg_Green << "candidates from assistant model: " << SGR::Reset << std::endl; + stream << SGR::Fg_Green << "candidate count: " << SGR::Reset << candidate_length_ << std::endl; + DumpSpan(stream, candidate_sequence.GetCPU()); + } + + auto logits = state_->Run(candidate_sequence, candidate_length_ + 1, search_->GetSequenceLength() - 1, candidate_length_ + 1); + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + stream << SGR::Fg_Green << "produced logits from main model: " << SGR::Reset << std::endl; + } + + search_->SetLogits(logits); + computed_logits_ = true; +} + +void SpeculativeDecodingGenerator::GenerateNextToken() { + if (!computed_logits_) + throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); + computed_logits_ = false; + auto& search = search_->params_->search; + + if (g_log.enabled && g_log.generate_next_token) { + auto& stream = Log("generate_next_token"); + stream << SGR::Fg_Green << "do_sample: " << SGR::Reset << search.do_sample << ' ' + << SGR::Fg_Green << "top_k: " << SGR::Reset << search.top_k << ' ' + << SGR::Fg_Green << "top_p: " << SGR::Reset << search.top_p << ' ' + << SGR::Fg_Green << "temperature: " << SGR::Reset << search.temperature << ' ' + << SGR::Fg_Cyan << "sequence length: " << SGR::Reset << search_->GetSequenceLength() + << std::endl; + } + + if (search.do_sample) + throw std::runtime_error("Not implemented"); + if (search.top_k != 1) + throw std::runtime_error("Not implemented"); + if (search.top_p != 1.0f) + throw std::runtime_error("Not implemented"); + if (search.temperature != 1.0f) + throw std::runtime_error("Not implemented"); + + auto candidate_sequence = assistant_generator_->search_->GetSequence(0); + + // Compare with logits one by one to determine the accepted tokens. + // total new token count is accepted token count + 1. + auto next_tokens = search_->CheckCandidates(candidate_sequence, candidate_length_); + // Update sequence to drop tokens of size candidate_length_, + // and append next tokens. + assistant_generator_->AcceptCandidateTokens(next_tokens); + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + stream << SGR::Fg_Green << "candidate count: " << SGR::Reset << candidate_length_ << std::endl; + stream << SGR::Fg_Green << "next tokens: " << SGR::Reset; + DumpSpan(stream, next_tokens.GetCPU()); + stream << std::endl; + } +} + +bool SpeculativeDecodingGenerator::IsDone() const { + if (computed_logits_) + throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); + + return search_->IsDone(); +} + TokenSequences Generate(const Model& model, const GeneratorParams& params) { auto generator = CreateGenerator(model, params); @@ -215,4 +399,19 @@ TokenSequences Generate(const Model& model, const GeneratorParams& params) { return result; } +TokenSequences Generate(const Model& model, const Model& assistant_model, const GeneratorParams& params) { + auto generator = CreateSpeculativeDecodingGenerator(model, assistant_model, params); + + while (!generator->IsDone()) { + generator->ComputeLogits(); + generator->GenerateNextToken(); + } + + // Supports only single batch size, single sequence. + TokenSequences result = {{}}; + auto sequence_cpu = generator->search_->GetSequence(0).GetCPU(); + result[0].assign(sequence_cpu.begin(), sequence_cpu.end()); + return result; +} + } // namespace Generators diff --git a/src/generators.h b/src/generators.h index 7a8f08951..ffe10744a 100644 --- a/src/generators.h +++ b/src/generators.h @@ -129,8 +129,8 @@ struct Generator { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - void ComputeLogits(); - void GenerateNextToken(); + virtual void ComputeLogits(); + virtual void GenerateNextToken(); RoamingArray GetSequence(size_t index) const; @@ -140,6 +140,42 @@ struct Generator { bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio }; +struct AssistantGenerator : Generator { + AssistantGenerator(const Model& model, const GeneratorParams& params); + + void ComputeLogits() override; + void GenerateNextToken() override; + + void AcceptCandidateTokens(RoamingArray next_tokens); + RoamingArray GetCandidateTokens() const; + + int candidate_length_{}; // Set to the number of generated candiates in ComputeLogits() and number of selected candidates after GenerateNextTokens(). + int max_candidate_length_{5}; // TODO: Move to param config. + + protected: + void ComputeLogits(RoamingArray next_tokens); + + private: + bool first_run_in_assist_{true}; // Set to false in ComputeLogits() and true after AcceptCandidateTokens(). +}; + +// TODO: Inherit from Generator? +struct SpeculativeDecodingGenerator { + SpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params); + + bool IsDone() const; + void ComputeLogits(); + void GenerateNextToken(); + + std::unique_ptr assistant_generator_; + std::shared_ptr model_; + std::unique_ptr state_; + std::unique_ptr search_; + bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio + int candidate_length_{}; // Set to the number of generated candiates in ComputeLogits() and number of selected candidates after GenerateNextTokens(). + int max_candidate_length_{5}; // TODO: Move to param config. +}; + struct OrtGlobals { OrtGlobals(); @@ -161,7 +197,9 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); +std::unique_ptr CreateSpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params); std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence +std::vector> Generate(const Model& model, const Model& assistant_model, const GeneratorParams& params); float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction void top_k_indices(std::span top_k, std::span inputs); diff --git a/src/logging.cpp b/src/logging.cpp index e92afee82..2107e7bec 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -36,6 +36,8 @@ void SetLogBool(std::string_view name, bool value) { g_log.model_output_values = value; else if (name == "model_logits") g_log.model_logits = value; + else if (name == "speculative_decoding") + g_log.speculative_decoding = value; else throw JSON::unknown_value_error{}; } diff --git a/src/logging.h b/src/logging.h index 99dc8c3d4..428be1b26 100644 --- a/src/logging.h +++ b/src/logging.h @@ -42,6 +42,7 @@ struct LogItems { bool model_output_shapes{}; // Before the model runs there are only the output shapes, no values in them. Useful for pre Session::Run debugging bool model_output_values{}; // After the model runs the output tensor values can be displayed bool model_logits{}; // Same as model_output_values but only for the logits + bool speculative_decoding{}; // Log speculative decoding steps. }; extern LogItems g_log; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 206549be3..899593058 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -43,4 +43,42 @@ void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_to logits_.Update(); } +RoamingArray SpeculativeDecodingDecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { + int batch_size = static_cast(input_ids_.GetShape()[0]); + if (batch_size != 1) + throw std::runtime_error("Speculative decoding only supports batch size 1, got " + std::to_string(batch_size)); + + auto total_length = past_length + next_token_length; + auto total_logits = first_run_ ? total_length : next_token_length; + // NB(bowenbao): workaround gqa limitation on token phase. + // if (next_token_length > 1) { + // total_logits = total_length; + // } + UpdateInputsOutputsFromSequence(sequence, next_token_length, past_length); + State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); + + return logits_.Get(total_logits - return_last_logit_count, return_last_logit_count); +} + +void SpeculativeDecodingDecoderOnly_State::UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length) { + auto total_length = past_length + next_token_length; + if (g_log.enabled && g_log.speculative_decoding) { + auto& stream = Log("speculative_decoding"); + stream << "UpdateInputsOutputsFromSequence: past_length=" << past_length << ", next_token_length=" << next_token_length << ", total_length=" << total_length << std::endl; + } + if (first_run_) { + // First run input ids includes prompt tokens. + input_ids_.Update(sequence, 0, total_length); + position_inputs_.Update(total_length, 0); + kv_cache_.UpdatePresent(total_length); + logits_.Update(total_length); + } else { + // Subsequent runs input ids only include candidate tokens. + input_ids_.Update(sequence, past_length, next_token_length); + position_inputs_.Update(total_length, past_length); + kv_cache_.UpdateAndResize(total_length, past_length); + logits_.Update(next_token_length); + } +} + } // namespace Generators diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index cc519de0b..b66ecf91b 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -21,7 +21,7 @@ struct DecoderOnly_State : State { RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; - private: + protected: void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length); const DecoderOnly_Model& model_; @@ -34,4 +34,12 @@ struct DecoderOnly_State : State { ExtraInputs extra_inputs_{model_, *this}; }; +struct SpeculativeDecodingDecoderOnly_State : DecoderOnly_State { + SpeculativeDecodingDecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params) : DecoderOnly_State{model, sequence_lengths, params} {}; + RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; + + protected: + void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); +}; + } // namespace Generators diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 6d281d247..32d540937 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -126,4 +126,36 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { } } +void InputIDs::Update(RoamingArray next_tokens, size_t start, size_t token_count) { + switch (model_.device_type_) { + case DeviceType::CPU: { + break; + } + default: + throw std::runtime_error("Update with token count not supported for device type " + to_string(model_.device_type_)); + } + if (shape_[0] != 1) { + throw std::runtime_error("Update with token count only supported for batch size 1, got " + std::to_string(shape_[0])); + } + shape_[1] = token_count; + + if (!sb_input_ids_) { + value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } else { + value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); + } + state_.inputs_[input_index_] = value_.get(); + if (type_ == Ort::TypeToTensorType::type) { + auto* data = value_->GetTensorMutableData(); + auto next_tokens_cpu = next_tokens.GetCPU(); + assert(next_tokens_cpu.size() >= start + token_count); + for (int i = 0; i < token_count; i++) { + data[i] = next_tokens_cpu[start + i]; + } + } else { + auto* data = value_->GetTensorMutableData() + start; + memcpy(data, next_tokens.GetCPU().data(), shape_[0] * token_count * sizeof(int32_t)); + } +} + } // namespace Generators diff --git a/src/models/input_ids.h b/src/models/input_ids.h index 93874821a..a3c0cc32a 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -9,8 +9,14 @@ struct InputIDs { InputIDs(const InputIDs&) = delete; InputIDs& operator=(const InputIDs&) = delete; + // Register input_ids as ORT session input. + // Called only once during initialization of state. void Add(); + // Resize input_ids to [1], update value with next_tokens. + // next_tokens is assumed to have length 1. void Update(RoamingArray next_tokens); + // Resize input_ids to [token_count], update value with next_tokens[start:start + token_count]. + void Update(RoamingArray next_tokens, size_t start, size_t token_count); auto& GetShape() const { return shape_; } const char* name_; diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 3c2e0dbfa..cbac4bade 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -213,6 +213,54 @@ void KV_Cache::Update(std::span beam_indices, int current_length) } } +void KV_Cache::UpdatePresent(int current_length) { + // Used for speculative decoding main generator. + // This can be later refactored to merge with tensor allocation during initialization. + if (shape_[2] == current_length) + return; + shape_[2] = current_length; + // If we're sharing past & present buffers there is nothing to do here, so early exit + if (past_present_share_buffer_) + return; + for (int i = 0; i < layer_count_ * 2; i++) { + presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + state_.outputs_[output_index_ + i] = presents_[i].get(); + } +} + +void KV_Cache::UpdateAndResize(int current_length, int past_length) { + // If we're sharing past & present buffers there is nothing to do here, so early exit + if (past_present_share_buffer_) + return; + if (shape_[0] != 1) + throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports batch size 1, got " + std::to_string(shape_[0])); + if (model_.device_type_ != DeviceType::CPU) + throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports CPU"); + + auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); + auto element_size = SizeOf(element_type); + auto new_shape = std::array({1, shape_[1], past_length, shape_[3]}); + if (shape_[2] != past_length) { + for (int i = 0; i < layer_count_ * 2; i++) { + auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); + const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); + auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); + + // Copy past_length kv-cache + for (int j = 0; j < shape_[1]; j++) { + memcpy( + new_present_data + j * past_length * shape_[3] * element_size, + present_data + j * shape_[2] * shape_[3] * element_size, + past_length * shape_[3] * element_size); + } + + presents_[i] = std::move(new_present); + } + } + + Update({}, current_length); +} + // Copy present state to past state reordered by the beam_indices template void KV_Cache::PickPastState(std::span beam_indices, int index) { diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index bfefba973..ae7b57547 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -32,8 +32,16 @@ struct KV_Cache { KV_Cache(const Model& model, State& state); void AddEncoder(); // If model has an initial encoder step, this is used + // Register input_ids as ORT session input. + // Called only once during initialization of state. void Add(); + // Move present to past. Prepare present output for next generation iteration. void Update(std::span beam_indices, int current_length); + // Used by speculative decoding + // Resize present to new sequence length. + void UpdatePresent(int current_length); + // Resize past to new sequence length, and drop past that is > past_length. + void UpdateAndResize(int current_length, int past_length); template void PickPastState(std::span beam_indices, int index); void PickPastState(std::span beam_indices, int index); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index f1d8beae7..d189dd534 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -190,6 +190,37 @@ RoamingArray Logits::Get() { #pragma warning(pop) +RoamingArray Logits::Get(size_t start, size_t size) { + const size_t num_beams = state_.params_->search.num_beams; + if (num_beams != 1) + throw std::runtime_error("Get with start and size not supported for num_beams != 1, got " + std::to_string(num_beams)); + if (shape_[0] != 1) + throw std::runtime_error("Get with start and size not supported for batch size != 1, got " + std::to_string(shape_[0])); + + size_t element_count = shape_[1] * shape_[2]; + size_t element_size = type_ == Ort::TypeToTensorType::type ? 4 : 2; + size_t selected_element_count = size * shape_[2]; + + output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array({1, static_cast(size), shape_[2]}), type_); + OrtValue* logits_of_selected_tokens = output_last_tokens_.get(); + + auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; + auto logits_of_selected_tokens_raw = std::span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count * element_size}; + auto source = logits_raw.subspan(start * shape_[2] * element_size, selected_element_count * element_size); + copy(source, logits_of_selected_tokens_raw); + + if (type_ == Ort::TypeToTensorType::type) { + std::unique_ptr logits_of_selected_tokens_fp32; + ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_selected_tokens, logits_of_selected_tokens_fp32, model_.device_type_, model_.cuda_stream_); + output_last_tokens_ = std::move(logits_of_selected_tokens_fp32); + logits_of_selected_tokens = output_last_tokens_.get(); + } + + auto batched_logits_cpu = cpu_span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count}; + HandleEOSArray(batched_logits_cpu); + return batched_logits_cpu; +} + void Logits::Update() { if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == 1) { return; @@ -201,6 +232,18 @@ void Logits::Update() { state_.outputs_[output_index_] = output_raw_.get(); } +void Logits::Update(size_t token_count) { + if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == token_count) { + return; + } + + shape_[1] = token_count; + StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType::type ? sb_logits16_ : sb_logits32_; + output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) + : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); + state_.outputs_[output_index_] = output_raw_.get(); +} + void Logits::HandleEOSArray(cpu_span batched_logits) { if (model_.config_->model.eos_token_ids.empty()) return; diff --git a/src/models/logits.h b/src/models/logits.h index f55f4e464..94e57c355 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -9,10 +9,17 @@ namespace Generators { struct Logits { Logits(const Model& model, State& state); + // Register input_ids as ORT session input. void Add(); + // For first iteration, find last token of each beam and store it in output_last_tokens_. + // Also resizes logits to [bz, 1, vocab_size] for subsequent calls. RoamingArray Get(); + // Retrieves logits[:, start:start + size, :]. + RoamingArray Get(size_t start, size_t size); // batch_size x size x vocab_size void Update(); + // Resize logits to [bz, token_count, vocab_size]. + void Update(size_t token_count); private: void HandleEOSArray(cpu_span logits); diff --git a/src/models/model.h b/src/models/model.h index f52f56499..5a6e39947 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -33,6 +33,9 @@ struct State { OrtValue* GetOutput(const char* name); + // Used by speculative search + virtual RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { throw std::runtime_error("Not implemented"); }; + std::shared_ptr params_; std::vector input_names_, output_names_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 032efee05..aaded5817 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -84,6 +84,15 @@ void PositionInputs::Update(int current_length) { } } +void PositionInputs::Update(int current_length, int past_length) { + if (has_posid_input_) { + UpdatePositionIDs(current_length, past_length); + } + if (has_mask_input_) { + UpdateAttentionMask(current_length, past_length); + } +} + void PositionInputs::AddAttentionMask() { mask_input_index_ = state_.inputs_.size(); @@ -193,6 +202,21 @@ void PositionInputs::UpdatePositionIDs(int current_length) { } } +void PositionInputs::UpdatePositionIDs(int current_length, int past_length) { + if (model_.device_type_ != DeviceType::CPU) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported on CPU."); + if (position_ids_shape_[0] != 1) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported for batch_size=1."); + assert(current_length > past_length); + position_ids_shape_[1] = current_length - past_length; + position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_); + if (type_ == Ort::TypeToTensorType::type) + UpdatePositionIDsImpl(current_length, past_length); + else + UpdatePositionIDsImpl(current_length, past_length); + state_.inputs_[posid_input_index_] = position_ids_.get(); +} + void PositionInputs::UpdateAttentionMask(int current_length) { // Update attention mask if (sb_attention_mask_) { @@ -321,6 +345,22 @@ void PositionInputs::UpdateAttentionMask(int current_length) { is_first_mask_update_ = false; } +void PositionInputs::UpdateAttentionMask(int current_length, int past_length) { + if (model_.device_type_ != DeviceType::CPU) + throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported on CPU."); + if (attention_mask_shape_[0] != 1) + throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported for batch_size=1."); + attention_mask_shape_[1] = current_length; + attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); + if (type_ == Ort::TypeToTensorType::type) + UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); + else + UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); + attention_mask_ = std::move(attention_mask_next_); + state_.inputs_[mask_input_index_] = attention_mask_.get(); + is_first_mask_update_ = false; +} + template void PositionInputs::InitializeTensors(std::array shape, cpu_span sequence_lengths) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. @@ -360,6 +400,14 @@ void PositionInputs::UpdatePositionIDsImpl() { } }; +template +void PositionInputs::UpdatePositionIDsImpl(int current_length, int past_length) { + auto* data = position_ids_->GetTensorMutableData(); + for (int i = 0; i < current_length - past_length; i++) { + data[i] = i + past_length; + } +}; + template void PositionInputs::UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length) { for (int i = 0; i < attention_mask_shape_[0]; i++) { @@ -370,4 +418,11 @@ void PositionInputs::UpdateAttentionMaskImpl(T* data, const T* old_data, int cur } }; +template +void PositionInputs::UpdateAttentionMaskImpl(T* data, int current_length, int past_length) { + for (int i = 0; i < current_length; i++) { + data[i] = 1; + } +}; + } // namespace Generators diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 259f5c0c2..dcbc14d7e 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -14,6 +14,7 @@ struct PositionInputs { void Add(); void Update(int current_length); + void Update(int current_length, int past_length); private: void AddAttentionMask(); @@ -21,6 +22,9 @@ struct PositionInputs { void UpdatePositionIDs(int current_length); void UpdateAttentionMask(int current_length); + // Used by speculative decoding. + void UpdatePositionIDs(int current_length, int past_length); + void UpdateAttentionMask(int current_length, int past_length); template void InitializeTensors(std::array shape, cpu_span sequence_lengths); @@ -30,6 +34,12 @@ struct PositionInputs { template void UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length); + // Used by speculative decoding + template + void UpdatePositionIDsImpl(int current_length, int past_length); + template + void UpdateAttentionMaskImpl(T* data, int current_length, int past_length); + const Model& model_; State& state_; diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 90a1cd2fb..0f2ab24c1 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -605,10 +605,10 @@ def make_less(self, name, inputs): self.make_node("Less", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, TensorProto.BOOL, shape=None) - def make_range(self, name, inputs): + def make_range(self, name, inputs, shape): output = f"{name}/output_0" self.make_node("Range", inputs=inputs, outputs=[output], name=name) - self.make_value_info(output, TensorProto.INT64, shape=["unk"]) + self.make_value_info(output, TensorProto.INT64, shape=shape) def make_slice(self, name, inputs, dtype, shape): output = f"{name}/output_0" @@ -635,6 +635,18 @@ def make_tanh(self, name, root_input, dtype, shape): self.make_node("Tanh", inputs=[root_input], outputs=[output], name=name) self.make_value_info(output, dtype, shape=shape) + def make_trilu(self, name, inputs, upper: int, dtype, shape): + output = f"{name}/output_0" + self.make_node( + "Trilu", + inputs=inputs, + outputs=[output], + name=name, + upper=upper, + domain="com.microsoft", + ) + self.make_value_info(output, dtype, shape=shape) + def make_matmul(self, matmul, basename, root_input, **kwargs): if self.onnx_dtype in {"fp16", "fp32"}: return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) @@ -1809,61 +1821,79 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): unsqueeze_6_name = f"{basename}/Unsqueeze_6" # shared unsqueeze for input_ids and attention_mask self.make_unsqueeze(unsqueeze_6_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1]) concat_2_name = f"{basename}/Concat_2" - concat_inputs = [f"{unsqueeze_4_name}/output_0", f"{unsqueeze_5_name}/output_0"] + concat_inputs = [f"{unsqueeze_4_name}/output_0", f"{unsqueeze_3_name}/output_0"] self.make_concat(concat_2_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) constant_shape_name = f"{basename}/ConstantOfShape_2" constant_shape_numpy_dtype = self.to_numpy_dtype[self.io_dtype] constant_shape_value = numpy_helper.from_array(np.array([np.finfo(constant_shape_numpy_dtype).min], dtype=constant_shape_numpy_dtype)) - self.make_constant_of_shape(constant_shape_name, f"{concat_2_name}/output_0", value=constant_shape_value, dtype=self.io_dtype, shape=['unk', 'unk']) + self.make_constant_of_shape( + constant_shape_name, + f"{concat_2_name}/output_0", + value=constant_shape_value, + dtype=self.io_dtype, + shape=["sequence_length", "total_sequence_length"], + ) # Top path - shape_4_name = f"{basename}/Shape_4" - self.make_shape(shape_4_name, f"{constant_shape_name}/output_0", shape=[2]) - slice_1_name = f"{basename}/Slice_1" - slice_1_inputs = [f"{shape_4_name}/output_0", "/model/constants/TensorProto.INT64/1D/-1", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/0"] - self.make_slice(slice_1_name, slice_1_inputs, dtype=TensorProto.INT64, shape=[1]) - squeeze_1_name = f"{basename}/Squeeze_1" - squeeze_1_inputs = [f"{slice_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_squeeze(squeeze_1_name, squeeze_1_inputs) - unsqueeze_7_name = f"{basename}/output_0" - unsqueeze_7_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_unsqueeze(unsqueeze_7_name, unsqueeze_7_inputs, dtype=TensorProto.INT64, shape=[1]) concat_3_name = f"{basename}/Concat_3" - concat_3_inputs = [f"{unsqueeze_7_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] + concat_3_inputs = [ + f"{unsqueeze_4_name}/output_0", + "/model/constants/TensorProto.INT64/1D/1", + ] self.make_concat(concat_3_name, concat_3_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) # Bottom path - shape_5_name = f"{basename}/Shape_5" - self.make_shape(shape_5_name, f"{constant_shape_name}/output_0", shape=[2]) - slice_2_name = f"{basename}/Slice_2" - slice_2_inputs = [f"{shape_5_name}/output_0", "/model/constants/TensorProto.INT64/1D/-1", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/0"] - self.make_slice(slice_2_name, slice_2_inputs, dtype=TensorProto.INT64, shape=[1]) - squeeze_2_name = f"{basename}/Squeeze_2" - squeeze_2_inputs = [f"{slice_2_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_squeeze(squeeze_2_name, squeeze_2_inputs) range_name = f"{basename}/Range" - range_inputs = ["/model/constants/TensorProto.INT64/0D/0", f"{squeeze_2_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] - self.make_range(range_name, range_inputs) + range_inputs = [ + "/model/constants/TensorProto.INT64/0D/0", + f"{basename}/Gather_2/output_0", + "/model/constants/TensorProto.INT64/0D/1", + ] + self.make_range(range_name, range_inputs, shape=["sequence_length"]) add_2_name = f"{basename}/Add_2" - add_inputs = [f"{range_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] - self.make_add(add_2_name, add_inputs, dtype=TensorProto.INT64, shape=["unk"]) + add_inputs = [f"{range_name}/output_0", f"{past_key_gather_name}/output_0"] + self.make_add( + add_2_name, add_inputs, dtype=TensorProto.INT64, shape=["sequence_length"] + ) + range_2_name = f"{basename}/Range_2" + range_2_inputs = [ + "/model/constants/TensorProto.INT64/0D/0", + f"{shared_add_name}/output_0", + "/model/constants/TensorProto.INT64/0D/1", + ] + self.make_range(range_2_name, range_2_inputs, shape=["total_sequence_length"]) # Merged path reshape_name = f"{basename}/Reshape" reshape_inputs = [f"{add_2_name}/output_0", f"{concat_3_name}/output_0"] self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None) less_name = f"{basename}/Less" - less_inputs = [f"{range_name}/output_0", f"{reshape_name}/output_0"] + less_inputs = [f"{reshape_name}/output_0", f"{range_2_name}/output_0"] self.make_less(less_name, less_inputs) where_2_name = f"{basename}/Where_2" - where_2_inputs = [f"{less_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/0", f"{constant_shape_name}/output_0"] + where_2_inputs = [ + f"{less_name}/output_0", + f"{constant_shape_name}/output_0", + f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/0", + ] self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=None) + unsqueeze_8_name = f"{basename}/Unsqueeze_8" unsqueeze_8_inputs = [f"{where_2_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=self.io_dtype, shape=None) + self.make_unsqueeze( + unsqueeze_8_name, + unsqueeze_8_inputs, + dtype=self.io_dtype, + shape=[1, "sequence_length", "total_sequence_length"], + ) unsqueeze_9_name = f"{basename}/Unsqueeze_9" unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] - self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) + self.make_unsqueeze( + unsqueeze_9_name, + unsqueeze_9_inputs, + dtype=self.io_dtype, + shape=[1, 1, "sequence_length", "total_sequence_length"], + ) expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) return unsqueeze_6_name, expand_name diff --git a/src/python/python.cpp b/src/python/python.cpp index 0a9bcd553..81ea20bec 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -345,6 +345,38 @@ struct PyGenerator { PyRoamingArray py_sequencelengths_; }; +// TODO(bowenbao): merge with PyGenerator? +struct PySpeculativeDecodingGenerator { + PySpeculativeDecodingGenerator(Model& model, Model& assistant_model, PyGeneratorParams& params) { + params.Prepare(); + generator_ = CreateSpeculativeDecodingGenerator(model, assistant_model, params); + } + + pybind11::array_t GetNextTokens() { + py_tokens_.Assign(generator_->search_->GetNextTokens()); + return ToPython(py_tokens_.GetCPU()); + } + + void ComputeLogits() { + generator_->ComputeLogits(); + } + + void GenerateNextToken() { + generator_->GenerateNextToken(); + } + + bool IsDone() const { + return generator_->IsDone(); + } + + private: + std::unique_ptr generator_; + PyRoamingArray py_tokens_; + PyRoamingArray py_indices_; + PyRoamingArray py_sequence_; + PyRoamingArray py_sequencelengths_; +}; + void SetLogOptions(const pybind11::kwargs& dict) { for (auto& entry : dict) { auto name = entry.first.cast(); @@ -433,6 +465,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { return CreateModel(GetOrtEnv(), config_path.c_str()); })) .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) + .def("generate_with_assist", [](Model& model, const Model& assistant_model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, assistant_model, params); }) .def_property_readonly( "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); @@ -446,6 +479,13 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); + pybind11::class_(m, "SpeculativeDecodingGenerator") + .def(pybind11::init()) + .def("is_done", &PySpeculativeDecodingGenerator::IsDone) + .def("compute_logits", &PySpeculativeDecodingGenerator::ComputeLogits) + .def("generate_next_token", &PySpeculativeDecodingGenerator::GenerateNextToken) + .def("get_next_tokens", &PySpeculativeDecodingGenerator::GetNextTokens); + pybind11::class_(m, "Images") .def_static("open", [](pybind11::args image_paths) { if (image_paths.empty()) diff --git a/src/search.cpp b/src/search.cpp index d7a9d3c69..382159e1d 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -49,6 +49,10 @@ RoamingArray GreedySearch_Cpu::GetNextTokens() { return next_tokens_; } +RoamingArray SpeculativeGreedySearch_Cpu::GetNextTokens() { + return next_accepted_tokens_; +} + RoamingArray BeamSearch_Cpu::GetNextTokens() { return beam_scorer_->GetNextTokens(); } @@ -252,6 +256,96 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { } } +void GreedySearch_Cpu::SetNextTokens(RoamingArray next_tokens) { + auto next_tokens_cpu = next_tokens.GetCPU(); + auto batch_size = params_->batch_size; + auto tokens_count_per_batch = next_tokens_cpu.size() / batch_size; + for (size_t j = 0; j < tokens_count_per_batch; j++) { + for (size_t i = 0; i < batch_size; i++) { + SetNextToken(i, next_tokens_cpu[i * tokens_count_per_batch + j]); + } + AppendNextTokensToSequences(); + } +} + +void GreedySearch_Cpu::DropLastTokens(size_t num_tokens) { + auto sequences_cpu = sequences_.GetSequences(); + auto new_sequence_length = sequences_.GetSequenceLength() - num_tokens; + for (size_t i = 0; i < params_->batch_size; ++i) { + if (!eos_seen_[i]) + continue; + auto sequence_cpu = sequences_cpu.subspan(i * params_->search.max_length + new_sequence_length, num_tokens); + for (size_t j = 0; j < num_tokens; ++j) { + if (sequence_cpu[j] == params_->eos_token_id) { + not_done_count_++; + done_ = false; + eos_seen_[i] = false; + if (g_log.enabled && g_log.hit_eos) + Log("hit_eos", "Reverted EOS seen on batch " + std::to_string(i)); + } + } + } + sequences_.DropLastTokens({num_tokens}); +} + +RoamingArray SpeculativeGreedySearch_Cpu::CheckCandidates(RoamingArray sequence, int candidate_length) { + if (params_->batch_size != 1) + throw std::runtime_error("Speculative search only supports batch size 1"); + auto sequence_cpu = sequence.GetCPU(); + auto prev_sequence_length = sequence_cpu.size() - candidate_length; + auto candidate_tokens_cpu = sequence.GetCPU().subspan(prev_sequence_length, candidate_length); + int logit_index = 0; + for (; logit_index < candidate_length + 1; logit_index++) { + ApplyMinLength(params_->search.min_length, logit_index); + ApplyRepetitionPenalty(params_->search.repetition_penalty, logit_index); + std::span const scores = next_token_scores_.subspan(logit_index * params_->vocab_size, params_->vocab_size); + + if (g_log.enabled && g_log.model_logits) { + auto& stream = Log("speculative_decoding"); + stream << "model_logits of logit_index=" << logit_index << std::endl; + DumpSpan(stream, scores); + stream << std::endl; + } + + auto const token = static_cast(std::distance(scores.begin(), std::max_element(scores.begin(), scores.end()))); + SetNextToken(0, token); + AppendNextTokensToSequences(); + if (done_ || logit_index == candidate_length || candidate_tokens_cpu[logit_index] != token) { + break; + } + } + auto next_tokens = sequences_.GetSequence(0).subspan(prev_sequence_length, logit_index + 1); + next_accepted_tokens_ = cpu_span{next_tokens.data(), next_tokens.size()}; + return next_accepted_tokens_; +} + +void SpeculativeGreedySearch_Cpu::ApplyMinLength(int min_length, size_t token_idx) { + if (sequences_.GetSequenceLength() >= min_length) { + return; + } + + std::span const scores = next_token_scores_.subspan(token_idx * params_->vocab_size, params_->vocab_size); + scores[params_->eos_token_id] = std::numeric_limits::lowest(); +} + +void SpeculativeGreedySearch_Cpu::ApplyRepetitionPenalty(float penalty, size_t token_idx) { + if (penalty == 1.0f) + return; + + std::span const scores = next_token_scores_.subspan(token_idx * params_->vocab_size, params_->vocab_size); + std::span const sequence = sequences_.GetSequence(token_idx); + + std::unordered_set unique_word_ids; + for (const auto& word_id : sequence) { + unique_word_ids.insert(word_id); + } + + for (const int32_t word_id : unique_word_ids) { + float const score = scores[word_id]; + scores[word_id] = (score < 0 ? score * penalty : score / penalty); + } +} + bool BeamSearch_Cpu::IsDone() const { if (beam_scorer_->IsDone()) { return true; diff --git a/src/search.h b/src/search.h index 901cb437f..4d5df105a 100644 --- a/src/search.h +++ b/src/search.h @@ -27,6 +27,11 @@ struct Search { virtual void ApplyMinLength(int min_length) = 0; virtual void ApplyRepetitionPenalty(float penalty) = 0; + // Used by Speculative search + virtual void DropLastTokens(size_t num_tokens) { assert(false); }; + virtual void SetNextTokens(RoamingArray next_tokens) { assert(false); }; + virtual RoamingArray CheckCandidates(RoamingArray sequence, int candidate_length) { assert(false); }; + std::shared_ptr params_; }; @@ -51,7 +56,7 @@ struct Search_Cpu : Search { cpu_span next_tokens_; // shape (beam_size*batch_size) - std::span next_token_scores_; // shape (beam_size*batch_size, vocab_size) + std::span next_token_scores_; // shape (beam_size*batch_size, vocab_size) or shape(candidate_tokens_count, vocab_size) for speculative search Sequences sequences_; bool done_{}; @@ -68,11 +73,17 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopP(float p, float temperature) override; void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; - private: - bool PadIfAlreadyEOS(size_t batch_id); + // Used by Speculative search. + void SetNextTokens(RoamingArray next_tokens) override; + void DropLastTokens(size_t num_tokens) override; + + protected: void SetNextToken(size_t batch_id, int32_t token); void AppendNextTokensToSequences(); + private: + bool PadIfAlreadyEOS(size_t batch_id); + std::unique_ptr next_tokens_buffer_; std::unique_ptr temp_topk_buffer_; @@ -106,4 +117,18 @@ struct BeamSearch_Cpu : Search_Cpu { std::unique_ptr beam_scorer_; }; +struct SpeculativeGreedySearch_Cpu : GreedySearch_Cpu { + SpeculativeGreedySearch_Cpu(const GeneratorParams& params) : GreedySearch_Cpu(params) {}; + RoamingArray CheckCandidates(RoamingArray sequence, int candidate_length); + + RoamingArray GetNextTokens() override; + + protected: + void ApplyMinLength(int min_length, size_t token_idx); + void ApplyRepetitionPenalty(float penalty, size_t token_idx); + + private: + cpu_span next_accepted_tokens_; // shape(accepted_token_counts) for speculative search +}; + } // namespace Generators \ No newline at end of file diff --git a/src/sequences.cpp b/src/sequences.cpp index 39354a56c..e5d1fff85 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -73,4 +73,9 @@ void Sequences::AppendNextTokenToSequences(std::span next_tokens) ++current_length_; } +void Sequences::DropLastTokens(size_t num_tokens) { + current_length_ -= static_cast(num_tokens); + assert(current_length_ >= 0); +} + } // namespace Generators diff --git a/src/sequences.h b/src/sequences.h index 5407a3bc1..4b45ecbf3 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -19,6 +19,9 @@ struct Sequences { // Used by Greedy search: void AppendNextTokenToSequences(std::span next_tokens); + // Used by Speculative search: + void DropLastTokens(size_t num_tokens); + private: std::unique_ptr sequences_buffer_; From dec83aa4e8d0aa7853f5799370d7f83aa4acc158 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 29 Aug 2024 10:19:24 -0700 Subject: [PATCH 02/13] merge main --- .github/policies/issueLabeler.yml | 86 ++++++++++++ .github/policies/test_issueLabeler.yml | 24 ++++ .github/workflows/linux-cpu-arm64-build.yml | 19 ++- .github/workflows/linux-cpu-x64-build.yml | 20 +-- .github/workflows/linux-gpu-x64-build.yml | 36 +++-- .github/workflows/win-cpu-x64-build.yml | 10 +- .github/workflows/win-cuda-x64-build.yml | 4 +- .github/workflows/win-directml-x64-build.yml | 6 +- .pipelines/nuget-publishing.yml | 7 +- .pipelines/pypl-publishing.yml | 8 +- .pipelines/stages/jobs/capi-packaging-job.yml | 11 +- .../stages/jobs/nuget-packaging-job.yml | 1 + .../stages/jobs/nuget-validation-job.yml | 68 +++++++-- .pipelines/stages/jobs/py-packaging-job.yml | 13 +- .pipelines/stages/jobs/py-validation-job.yml | 33 +++-- .../stages/jobs/steps/capi-linux-step.yml | 14 +- .../stages/jobs/steps/capi-win-step.yml | 2 +- ...nt-governance-component-detection-step.yml | 3 +- .../steps/compliant/win-esrp-dll-step.yml | 32 +++++ .../stages/jobs/steps/utils/download-ort.yml | 2 +- .pipelines/stages/nuget-packaging-stage.yml | 4 +- .pipelines/stages/py-validation-stage.yml | 12 ++ CMakeLists.txt | 2 +- VERSION_INFO | 2 +- benchmark/python/benchmark_e2e.py | 20 +-- build.py | 6 +- cmake/check_cuda.cmake | 1 + cmake/check_dml.cmake | 2 + cmake/check_rocm.cmake | 2 + cmake/global_variables.cmake | 8 +- examples/c/CMakeLists.txt | 4 +- examples/chat_app/README.md | 17 ++- examples/chat_app/app.py | 8 +- examples/chat_app/requirements.txt | 9 -- examples/csharp/HelloPhi/Program.cs | 30 ++-- examples/python/README.md | 4 +- nuget/MANAGED_PACKAGE.md | 3 + ...crosoft.ML.OnnxRuntimeGenAI.Managed.nuspec | 2 +- nuget/PACKAGE.md | 131 ++++++++++++++++++ src/config.cpp | 12 ++ src/config.h | 6 + .../Microsoft.ML.OnnxRuntimeGenAI.csproj | 9 +- src/csharp/Model.cs | 7 +- src/csharp/Utils.cs | 24 +++- src/dml/dml_helpers.cpp | 84 +++++++---- src/generators.cpp | 28 +++- src/generators.h | 9 +- src/java/UpdatingJavaBindings.md | 4 + .../ai/onnxruntime/genai/GeneratorParams.java | 21 +++ .../java/ai/onnxruntime/genai/Images.java | 36 +++++ .../genai/MultiModalProcessor.java | 85 ++++++++++++ .../ai/onnxruntime/genai/NamedTensors.java | 34 +++++ .../native/ai_onnxruntime_genai_Generator.cpp | 21 ++- .../ai_onnxruntime_genai_GeneratorParams.cpp | 26 ++-- .../native/ai_onnxruntime_genai_Images.cpp | 37 +++++ .../native/ai_onnxruntime_genai_Model.cpp | 11 +- ..._onnxruntime_genai_MultiModalProcessor.cpp | 78 +++++++++++ .../ai_onnxruntime_genai_NamedTensors.cpp | 20 +++ .../native/ai_onnxruntime_genai_Sequences.cpp | 9 +- .../native/ai_onnxruntime_genai_Tensor.cpp | 21 +-- .../native/ai_onnxruntime_genai_Tokenizer.cpp | 15 +- .../ai_onnxruntime_genai_TokenizerStream.cpp | 11 +- src/java/src/main/native/utils.h | 2 +- .../genai/MultiModalProcessorTest.java | 37 +++++ .../java/ai/onnxruntime/genai/landscape.jpg | Bin 0 -> 337412 bytes src/java/windows-unittests.cmake | 1 + src/leakcheck.h | 41 ++++++ src/logging.cpp | 2 + src/logging.h | 1 + src/models/captured_graph_pool.cpp | 4 +- src/models/debugging.cpp | 98 +++---------- src/models/env_utils.cpp | 57 ++++++++ src/models/env_utils.h | 15 ++ src/models/input_ids.cpp | 8 +- src/models/kv_cache.cpp | 4 +- src/models/logits.cpp | 30 ++-- src/models/logits.h | 1 + src/models/model.cpp | 53 ++++++- src/models/model.h | 6 +- src/models/onnxruntime_api.h | 104 +++++++------- src/models/onnxruntime_inline.h | 105 ++++++++------ src/models/position_inputs.cpp | 20 +-- src/models/utils.cpp | 26 ++-- src/ort_genai.h | 6 + src/ort_genai_c.cpp | 44 ++++++ src/ort_genai_c.h | 8 ++ src/python/CMakeLists.txt | 5 +- src/python/__init__.py.in | 2 +- src/python/py/_dll_directory.py | 41 +++++- src/python/py/models/builder.py | 8 +- src/python/py/models/quantized_model.py | 39 +++--- src/python/python.cpp | 112 +++++---------- src/python/setup.py.in | 24 +++- src/search.h | 4 +- src/tensor.h | 2 +- test/c_api_tests.cpp | 59 ++++++++ ...Microsoft.ML.OnnxRuntimeGenAI.Tests.csproj | 30 +++- test/python/_test_utils.py | 109 +++++++++++---- test/python/conftest.py | 15 +- test/python/requirements-cpu.txt | 2 +- test/python/requirements-cuda.txt | 4 +- test/python/requirements-directml.txt | 2 +- test/python/test_onnxruntime_genai.py | 43 +++--- test/python/test_onnxruntime_genai_e2e.py | 67 +++++---- .../nuget/generate_nuspec_for_native_nuget.py | 7 +- 105 files changed, 1848 insertions(+), 674 deletions(-) create mode 100644 .github/policies/issueLabeler.yml create mode 100644 .github/policies/test_issueLabeler.yml delete mode 100644 examples/chat_app/requirements.txt create mode 100644 nuget/MANAGED_PACKAGE.md create mode 100644 nuget/PACKAGE.md create mode 100644 src/java/src/main/java/ai/onnxruntime/genai/Images.java create mode 100644 src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java create mode 100644 src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java create mode 100644 src/java/src/main/native/ai_onnxruntime_genai_Images.cpp create mode 100644 src/java/src/main/native/ai_onnxruntime_genai_MultiModalProcessor.cpp create mode 100644 src/java/src/main/native/ai_onnxruntime_genai_NamedTensors.cpp create mode 100644 src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java create mode 100644 src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg create mode 100644 src/leakcheck.h create mode 100644 src/models/env_utils.cpp create mode 100644 src/models/env_utils.h diff --git a/.github/policies/issueLabeler.yml b/.github/policies/issueLabeler.yml new file mode 100644 index 000000000..462c0c710 --- /dev/null +++ b/.github/policies/issueLabeler.yml @@ -0,0 +1,86 @@ +id: +name: Issue Triage +description: Assign label to issues +owner: +resource: repository +where: +configuration: + resourceManagementConfiguration: + eventResponderTasks: + - if: + - payloadType: Issues + - and: + - isOpen + - not: + and: + - isAssignedToSomeone + - isLabeled + then: + - if: + - or: + - titleContains: + pattern: '/\bcuda\b/i' + isRegex: True + - bodyContains: + pattern: '/\bcuda\b/i' + isRegex: True + then: + - addLabel: + label: ep:CUDA + - if: + - or: + - titleContains: + pattern: '/\bjava\b/i' + isRegex: True + - bodyContains: + pattern: '/\bjava\b/i' + isRegex: True + then: + - addLabel: + label: api:Java + - if: + - or: + - titleContains: + pattern: '/(\bdirect\s*ml\b|\bdml\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bdirect\s*ml\b|\bdml\b)/i' + isRegex: True + then: + - addLabel: + label: ep:DML + - if: + - or: + - titleContains: + pattern: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bobj(?:ective)?-?c\b|\bnnapi\b|\bmobile\b|\bandroid\b|\bios\b|\bxamarin\b|\bmaui\b)/i' + isRegex: True + then: + - addLabel: + label: platform:mobile + - if: + - or: + - titleContains: + pattern: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' + isRegex: True + - bodyContains: + pattern: '/(\bwindows\b|\bwinrt\b|\bwinml\b)/i' + isRegex: True + then: + - addLabel: + label: platform:windows + - if: + - or: + - titleContains: + pattern: '/\btransformers(?!\.js)\b/i' + isRegex: True + - bodyContains: + pattern: '/\btransformers(?!\.js)\b/i' + isRegex: True + then: + - addLabel: + label: model:transformer +onFailure: +onSuccess: diff --git a/.github/policies/test_issueLabeler.yml b/.github/policies/test_issueLabeler.yml new file mode 100644 index 000000000..eb4e4489e --- /dev/null +++ b/.github/policies/test_issueLabeler.yml @@ -0,0 +1,24 @@ +id: +name: Issue Triage +description: Assign label to issues +owner: +resource: repository +where: +configuration: + resourceManagementConfiguration: + eventResponderTasks: + - if: + - payloadType: Issues + - isOpen + then: + - if: + - or: + - titleContains: + pattern: shark + - bodyContains: + pattern: strawberry + then: + - addLabel: + label: wontfix +onFailure: +onSuccess: diff --git a/.github/workflows/linux-cpu-arm64-build.yml b/.github/workflows/linux-cpu-arm64-build.yml index 26b749c5e..8a72cd25c 100644 --- a/.github/workflows/linux-cpu-arm64-build.yml +++ b/.github/workflows/linux-cpu-arm64-build.yml @@ -25,19 +25,17 @@ jobs: with: submodules: 'true' - - name: Install jq - run: | - sudo apt-get install jq - - uses: actions/setup-dotnet@v4 with: dotnet-version: '8.0.x' - name: Get the Latest OnnxRuntime Nightly Version + shell: pwsh run: | - ORT_NIGHTLY_VERSION=$(curl -s "${{ env.ORT_NIGHTLY_REST_API }}" | jq -r '.value[0].versions[0].normalizedVersion') - echo "$ORT_NIGHTLY_VERSION" - echo "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" >> $GITHUB_ENV + $resp = Invoke-RestMethod "${{ env.ORT_NIGHTLY_REST_API }}" + $ORT_NIGHTLY_VERSION = $resp.value[0].versions[0].normalizedVersion + Write-Host "$ORT_NIGHTLY_VERSION" + "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" | Out-File -FilePath $env:GITHUB_ENV -Append - name: Download OnnxRuntime Nightly run: | @@ -59,8 +57,7 @@ jobs: mkdir -p ort/lib mv microsoft.ml.onnxruntime/**/build/native/include ort/ mv microsoft.ml.onnxruntime/**/runtimes/linux-arm64/native/* ort/lib/ - ort_version=$(echo ${{ env.ORT_NIGHTLY_VERSION }} | cut -d- -f1-1) - cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.$ort_version + cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.1 - name: Download Docker Image run: | @@ -73,7 +70,7 @@ jobs: --container-registry onnxruntimebuildcache \ --repository ort_genai_linux_arm64_gha - - name: Doker -- Configure with CMake and GCC + - name: Docker -- Configure with CMake and GCC run: | docker run --rm \ --volume $GITHUB_WORKSPACE:/onnxruntime_src \ @@ -85,7 +82,7 @@ jobs: --volume $GITHUB_WORKSPACE:/onnxruntime_src \ -w /onnxruntime_src ort_genai_linux_arm64_gha bash -c "/usr/bin/cmake --build --preset linux_gcc_cpu_release" - - name: Dokcer -- check test directory + - name: Docker -- Check test directory run: | docker run --rm \ --volume $GITHUB_WORKSPACE:/onnxruntime_src \ diff --git a/.github/workflows/linux-cpu-x64-build.yml b/.github/workflows/linux-cpu-x64-build.yml index 7b7e08c7b..5dd1adcbe 100644 --- a/.github/workflows/linux-cpu-x64-build.yml +++ b/.github/workflows/linux-cpu-x64-build.yml @@ -23,19 +23,17 @@ jobs: with: submodules: true - - name: Install jq - run: | - sudo apt-get install jq - - uses: actions/setup-dotnet@v4 with: dotnet-version: '8.0.x' - name: Get the Latest OnnxRuntime Nightly Version + shell: pwsh run: | - ORT_NIGHTLY_VERSION=$(curl -s "${{ env.ORT_NIGHTLY_REST_API }}" | jq -r '.value[0].versions[0].normalizedVersion') - echo "$ORT_NIGHTLY_VERSION" - echo "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" >> $GITHUB_ENV + $resp = Invoke-RestMethod "${{ env.ORT_NIGHTLY_REST_API }}" + $ORT_NIGHTLY_VERSION = $resp.value[0].versions[0].normalizedVersion + Write-Host "$ORT_NIGHTLY_VERSION" + "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" | Out-File -FilePath $env:GITHUB_ENV -Append - name: Download OnnxRuntime Nightly run: | @@ -58,8 +56,7 @@ jobs: mkdir -p ort/lib mv microsoft.ml.onnxruntime/${{ env.ORT_NIGHTLY_VERSION }}/build/native/include ort/ mv microsoft.ml.onnxruntime/${{ env.ORT_NIGHTLY_VERSION }}/runtimes/linux-x64/native/* ort/lib/ - ort_version=$(echo ${{ env.ORT_NIGHTLY_VERSION }} | cut -d- -f1-1) - cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.$ort_version + cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.1 - name: Build with CMake and GCC run: | @@ -91,6 +88,11 @@ jobs: run: | python3 test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models + - name: Build the C# API and Run the C# Tests + run: | + cd test/csharp + dotnet test /p:Configuration=Release /p:NativeBuildOutputDir="../../build/cpu/" + - name: Verify Build Artifacts if: always() continue-on-error: true diff --git a/.github/workflows/linux-gpu-x64-build.yml b/.github/workflows/linux-gpu-x64-build.yml index 3006b750f..fd08add72 100644 --- a/.github/workflows/linux-gpu-x64-build.yml +++ b/.github/workflows/linux-gpu-x64-build.yml @@ -37,19 +37,17 @@ jobs: path: manylinux submodules: true - - name: Install jq - run: | - sudo apt-get install jq - - uses: actions/setup-dotnet@v4 with: dotnet-version: '8.0.x' - - name: Download OnnxRuntime + - name: Get the Latest OnnxRuntime Nightly Version + shell: pwsh run: | - ORT_NIGHTLY_VERSION=$(curl -s "${{ env.ORT_NIGHTLY_REST_API }}" | jq -r '.value[0].versions[0].normalizedVersion') - echo "$ORT_NIGHTLY_VERSION" - echo "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" >> $GITHUB_ENV + $resp = Invoke-RestMethod "${{ env.ORT_NIGHTLY_REST_API }}" + $ORT_NIGHTLY_VERSION = $resp.value[0].versions[0].normalizedVersion + Write-Host "$ORT_NIGHTLY_VERSION" + "ORT_NIGHTLY_VERSION=$ORT_NIGHTLY_VERSION" | Out-File -FilePath $env:GITHUB_ENV -Append - name: Download OnnxRuntime Nightly run: | @@ -65,15 +63,13 @@ jobs: ls -R ${{ env.ORT_PACKAGE_NAME }} continue-on-error: true -# TODO: Find out why do we need to to have libonnxruntime.so.$ort_version - name: Extract OnnxRuntime library and header files run: | set -e -x mkdir -p ort/lib mv microsoft.ml.onnxruntime.gpu.linux/${{ env.ORT_NIGHTLY_VERSION }}/buildTransitive/native/include ort/ mv microsoft.ml.onnxruntime.gpu.linux/${{ env.ORT_NIGHTLY_VERSION }}/runtimes/linux-x64/native/* ort/lib/ - ort_version=$(echo ${{ env.ORT_NIGHTLY_VERSION }} | cut -d- -f1-1) - cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.$ort_version + cp ort/lib/libonnxruntime.so ort/lib/libonnxruntime.so.1 - name: Get Docker Image @@ -130,13 +126,25 @@ jobs: docker run \ --gpus all \ --rm \ + --volume /data/ortgenai_pytorch_models:/data/ortgenai_pytorch_models \ --volume $GITHUB_WORKSPACE:/ort_genai_src \ -e HF_TOKEN=$HF_TOKEN \ -w /ort_genai_src onnxruntimecudabuildx64 bash -c " \ ${{ env.PYTHON_EXECUTABLE }} -m pip install -r test/python/requirements.txt --user && \ ${{ env.PYTHON_EXECUTABLE }} -m pip install -r test/python/requirements-cuda.txt --user && \ ${{ env.PYTHON_EXECUTABLE }} -m pip install /ort_genai_src/build/cuda/wheel/onnxruntime_genai*manylinux*.whl --user && \ - ${{ env.PYTHON_EXECUTABLE }} test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models" + ${{ env.PYTHON_EXECUTABLE }} test/python/test_onnxruntime_genai.py --cwd test/python --test_models test/test_models --e2e" + + # TODO: Enable this by adding dotnet to the docker image + # - name: Build the C# API and Run the C# Tests + # run: | + # echo "Building the C# API and running the C# tests" + # docker run \ + # --gpus all \ + # --rm \ + # --volume $GITHUB_WORKSPACE:/ort_genai_src \ + # -w /ort_genai_src/test/csharp onnxruntimecudabuildx64 bash -c " \ + # dotnet test /p:NativeBuildOutputDir='/ort_genai_src/build/cuda/'" - name: Docker -- Run unit tests run: | @@ -144,6 +152,6 @@ jobs: docker run \ --gpus all \ --rm \ + --volume /data/ortgenai_pytorch_models:/data/ortgenai_pytorch_models \ --volume $GITHUB_WORKSPACE:/ort_genai_src \ - -w /ort_genai_src onnxruntimecudabuildx64 bash -c "ls /ort_genai_src/build/cuda/ && ls /ort_genai_src/build/cuda/lib && \ - LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ort_genai_src/build/cuda/ /ort_genai_src/build/cuda/test/unit_tests" + -w /ort_genai_src onnxruntimecudabuildx64 bash -c "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/ort_genai_src/build/cuda/ /ort_genai_src/build/cuda/test/unit_tests" diff --git a/.github/workflows/win-cpu-x64-build.yml b/.github/workflows/win-cpu-x64-build.yml index 9b6415550..b63cb7c8f 100644 --- a/.github/workflows/win-cpu-x64-build.yml +++ b/.github/workflows/win-cpu-x64-build.yml @@ -72,11 +72,6 @@ jobs: run: | cmake --build --preset windows_x64_cpu_release --parallel - - name: Build the C# API and Run the C# Tests - run: | - cd test\csharp - dotnet test /p:NativeBuildOutputDir="$env:GITHUB_WORKSPACE\$env:binaryDir\Release" - - name: Install the python wheel and test dependencies run: | python3 -m pip install -r test\python\requirements.txt --user @@ -94,7 +89,10 @@ jobs: run: | python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" - + - name: Build the C# API and Run the C# Tests + run: | + cd test\csharp + dotnet test /p:NativeBuildOutputDir="$env:GITHUB_WORKSPACE\$env:binaryDir\Release" - name: Verify Build Artifacts if: always() diff --git a/.github/workflows/win-cuda-x64-build.yml b/.github/workflows/win-cuda-x64-build.yml index 858fdd7a6..36ec38d04 100644 --- a/.github/workflows/win-cuda-x64-build.yml +++ b/.github/workflows/win-cuda-x64-build.yml @@ -93,8 +93,7 @@ jobs: - name: Run the Python Tests run: | - python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" - + python test/python/test_onnxruntime_genai.py --cwd "test\python" --test_models "test\test_models" --e2e - name: Verify Build Artifacts if: always() @@ -107,5 +106,4 @@ jobs: run: | $env:PATH = "${{ env.cuda_dir }}\\v${{ env.cuda_version }}\\bin;" + $env:PATH echo "Current PATH variable is: $env:PATH" - Get-ChildItem "${{ env.cuda_dir }}\\v${{ env.cuda_version }}\\bin" & .\$env:binaryDir\test\Release\unit_tests.exe \ No newline at end of file diff --git a/.github/workflows/win-directml-x64-build.yml b/.github/workflows/win-directml-x64-build.yml index c46967ad5..a2c29f3c9 100644 --- a/.github/workflows/win-directml-x64-build.yml +++ b/.github/workflows/win-directml-x64-build.yml @@ -18,9 +18,9 @@ env: ort_zip: "Microsoft.ML.OnnxRuntime.DirectML.1.17.3.zip" # TODO: Update with nightly ORT-DML build ort_url: "https://github.com/microsoft/onnxruntime/releases/download/v1.17.3/Microsoft.ML.OnnxRuntime.DirectML.1.17.3.zip" - dml_dir: "Microsoft.AI.DirectML.1.14.2" - dml_zip: "Microsoft.AI.DirectML.1.14.2.zip" - dml_url: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.14.2" + dml_dir: "Microsoft.AI.DirectML.1.15.1" + dml_zip: "Microsoft.AI.DirectML.1.15.1.zip" + dml_url: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.1" d3d12_dir: "Microsoft.Direct3D.D3D12.1.614.0" d3d12_zip: "Microsoft.Direct3D.D3D12.1.614.0.zip" d3d12_url: "https://www.nuget.org/api/v2/package/Microsoft.Direct3D.D3D12/1.614.0" diff --git a/.pipelines/nuget-publishing.yml b/.pipelines/nuget-publishing.yml index 3f67a1f1a..137e2c36f 100644 --- a/.pipelines/nuget-publishing.yml +++ b/.pipelines/nuget-publishing.yml @@ -36,17 +36,17 @@ parameters: - name: ort_version displayName: 'OnnxRuntime version' type: string - default: '1.18.0-dev-20240426-1256-b842effa29' + default: '1.19.0-dev-20240805-1630-ee2fe87e2d' - name: ort_cuda_version displayName: 'OnnxRuntime GPU version' type: string - default: '1.18.0-dev-20240426-0614-b842effa29' + default: '1.19.0-dev-20240805-0337-88c811b638' - name: ort_dml_version displayName: 'OnnxRuntime DML version' type: string - default: '1.18.0-dev-20240426-0116-b842effa29' + default: '1.19.0-dev-20240805-1630-ee2fe87e2d' - name: cuda_version displayName: 'CUDA version' @@ -98,6 +98,7 @@ stages: enable_win_dml: ${{ parameters.enable_win_dml }} enable_win_arm64: ${{ parameters.enable_win_arm64 }} ort_version: ${{ parameters.ort_version }} + ort_cuda_version: ${{ parameters.ort_cuda_version }} ort_dml_version: ${{ parameters.ort_dml_version }} build_config: ${{ parameters.build_config }} diff --git a/.pipelines/pypl-publishing.yml b/.pipelines/pypl-publishing.yml index 2c4b93da5..6ecbc938e 100644 --- a/.pipelines/pypl-publishing.yml +++ b/.pipelines/pypl-publishing.yml @@ -42,7 +42,7 @@ parameters: - name: ort_version displayName: 'OnnxRuntime version' type: string - default: '1.18.0-dev-20240426-1256-b842effa29' + default: '1.19.0-dev-20240805-1630-ee2fe87e2d' - name: ort_cuda_118_version displayName: 'OnnxRuntime GPU version for CUDA 11.8' @@ -52,17 +52,17 @@ parameters: - name: ort_cuda_122_version displayName: 'OnnxRuntime GPU version for CUDA 12.2' type: string - default: '1.19.0-dev-20240530-0257-25ac65375c' + default: '1.19.0-dev-20240805-0337-88c811b638' - name: ort_dml_version displayName: 'OnnxRuntime DML version' type: string - default: '1.18.0-dev-20240426-0116-b842effa29' + default: '1.19.0-dev-20240805-1630-ee2fe87e2d' - name: ort_rocm_version displayName: 'OnnxRuntime ROCm version' type: string - default: '1.19.0-dev-20240602-1103-217b66f' + default: '1.19.0-dev-20240805-0337-88c811b638' - name: cuda_versions displayName: 'CUDA versions' diff --git a/.pipelines/stages/jobs/capi-packaging-job.yml b/.pipelines/stages/jobs/capi-packaging-job.yml index 4530bbb86..1d936eccb 100644 --- a/.pipelines/stages/jobs/capi-packaging-job.yml +++ b/.pipelines/stages/jobs/capi-packaging-job.yml @@ -55,10 +55,7 @@ jobs: - name: os value: ${{ parameters.os }} - name: feed_name - ${{ if and(eq(parameters.cuda_version, '12.2'), eq(parameters.ep, 'cuda')) }}: - value: 'ort-cuda-12-nightly' - ${{ else }}: - value: '7982ae20-ed19-4a35-a362-a96ac99897b7' + value: '7982ae20-ed19-4a35-a362-a96ac99897b7' - name: ort_filename ${{ if eq(parameters.ep, 'cpu') }}: value: 'Microsoft.ML.OnnxRuntime' @@ -77,11 +74,11 @@ jobs: - name: ortHome value: 'ort' - name: dml_dir - value: 'Microsoft.AI.DirectML.1.15.0' + value: 'Microsoft.AI.DirectML.1.15.1' - name: dml_zip - value: 'Microsoft.AI.DirectML.1.15.0.zip' + value: 'Microsoft.AI.DirectML.1.15.1.zip' - name: dml_url - value: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0" + value: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.1" - name: d3d12_dir value: 'Microsoft.Direct3D.D3D12.1.614.0' diff --git a/.pipelines/stages/jobs/nuget-packaging-job.yml b/.pipelines/stages/jobs/nuget-packaging-job.yml index 02699db52..e01374b9e 100644 --- a/.pipelines/stages/jobs/nuget-packaging-job.yml +++ b/.pipelines/stages/jobs/nuget-packaging-job.yml @@ -138,6 +138,7 @@ jobs: - template: steps/utils/set-genai-version.yml - powershell: | + dotnet --info dotnet build Microsoft.ML.OnnxRuntimeGenAI.csproj -p:Configuration="$(buildConfig)" --verbosity normal displayName: 'Build CSharp' workingDirectory: '$(Build.Repository.LocalPath)\src\csharp' diff --git a/.pipelines/stages/jobs/nuget-validation-job.yml b/.pipelines/stages/jobs/nuget-validation-job.yml index c2fd44f92..e382a3ca7 100644 --- a/.pipelines/stages/jobs/nuget-validation-job.yml +++ b/.pipelines/stages/jobs/nuget-validation-job.yml @@ -84,6 +84,12 @@ jobs: ${{ else }}: value: 'cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4' + - name: cuda_docker_image + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 + ${{ else }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240610.1 + workspace: clean: all steps: @@ -98,12 +104,18 @@ jobs: displayName: Copy python 3.12.3 version to agent tools directory condition: and(eq(variables['arch'], 'arm64'), eq(variables['os'], 'win')) + - task: NuGetAuthenticate@1 + - task: UsePythonVersion@0 inputs: versionSpec: 3.12 addToPath: true architecture: $(arch) + - task: UseDotNet@2 + inputs: + version: '8.x' + - template: steps/utils/download-huggingface-model.yml parameters: StepName: 'Download Model from HuggingFace' @@ -139,19 +151,57 @@ jobs: Copy-Item -Force -Recurse -Verbose $(Build.BinariesDirectory)/nuget/* -Destination examples/csharp/HelloPhi/ cd examples/csharp/HelloPhi Move-Item models\$(prebuild_phi3_mini_model_folder) models\phi-3 - dotnet restore --arch $(arch) /property:Configuration=$(csproj_configuration) --source https://api.nuget.org/v3/index.json --source $PWD --verbosity normal - dotnet run --arch $(arch) --configuration $(csproj_configuration) --no-restore --verbosity normal -- -m ./models/phi-3 + dotnet restore -r $(os)-$(arch) /property:Configuration=$(csproj_configuration) --source https://api.nuget.org/v3/index.json --source https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json --source $PWD --disable-parallel --verbosity detailed + dotnet run -r $(os)-$(arch) --configuration $(csproj_configuration) --no-restore --verbosity normal -- -m ./models/phi-3 displayName: 'Run Example With Artifact' workingDirectory: '$(Build.Repository.LocalPath)' + env: + NUGET_PLUGIN_HANDSHAKE_TIMEOUT_IN_SECONDS: 180 + NUGET_PLUGIN_REQUEST_TIMEOUT_IN_SECONDS: 180 - ${{ elseif eq(parameters.os, 'linux') }}: - bash: | - dotnet --info - cp $(Build.BinariesDirectory)/nuget/* examples/csharp/HelloPhi/ - cd examples/csharp/HelloPhi - mv models/$(prebuild_phi3_mini_model_folder) models/phi-3 - dotnet restore --arch $(arch) /property:Configuration=$(csproj_configuration) --source https://api.nuget.org/v3/index.json --source $PWD --verbosity normal - dotnet run --arch $(arch) --configuration $(csproj_configuration) --no-restore --verbosity normal -- -m ./models/phi-3 - displayName: 'Run Example With Artifact' + dotnet --info + cp $(Build.BinariesDirectory)/nuget/* examples/csharp/HelloPhi/ + cd examples/csharp/HelloPhi + mv models/$(prebuild_phi3_mini_model_folder) models/phi-3 + dotnet restore -r $(os)-$(arch) /property:Configuration=$(csproj_configuration) --source https://api.nuget.org/v3/index.json --source https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json --source $PWD --disable-parallel --verbosity detailed + dotnet build ./HelloPhi.csproj -r $(os)-$(arch) /property:Configuration=$(csproj_configuration) --no-restore --self-contained + displayName: 'Perform dotnet restore & build' workingDirectory: '$(Build.Repository.LocalPath)' + env: + NUGET_PLUGIN_HANDSHAKE_TIMEOUT_IN_SECONDS: 180 + NUGET_PLUGIN_REQUEST_TIMEOUT_IN_SECONDS: 180 + + - ${{ if eq(parameters.ep, 'cuda') }}: + - bash: | + set -e -x + az login --identity --username 63b63039-6328-442f-954b-5a64d124e5b4 + az acr login --name onnxruntimebuildcache --subscription 00c06639-6ee4-454e-8058-8d8b1703bd87 + docker pull $(cuda_docker_image) + + docker run \ + --gpus all \ + --rm \ + --volume $(Build.Repository.LocalPath):/ort_genai_src \ + --volume $(Build.BinariesDirectory):/ort_genai_binary \ + -e HF_TOKEN=$HF_TOKEN \ + -w /ort_genai_src/ $(cuda_docker_image) \ + bash -c " \ + export ORTGENAI_LOG_ORT_LIB=1 && \ + cd /ort_genai_src/examples/csharp/HelloPhi && \ + chmod +x ./bin/Release_Cuda/net6.0/linux-x64/HelloPhi && \ + ./bin/Release_Cuda/net6.0/linux-x64/HelloPhi -m ./models/phi-3" + + displayName: 'Run Example With Artifact' + workingDirectory: '$(Build.Repository.LocalPath)' + + - ${{ elseif eq(parameters.ep, 'cpu') }}: + - bash: | + export ORTGENAI_LOG_ORT_LIB=1 + cd examples/csharp/HelloPhi + dotnet run -r $(os)-$(arch) --configuration $(csproj_configuration) --no-build --verbosity normal -- -m ./models/phi-3 + displayName: 'Run Example With Artifact' + workingDirectory: '$(Build.Repository.LocalPath)' + - template: steps/compliant-and-cleanup-step.yml diff --git a/.pipelines/stages/jobs/py-packaging-job.yml b/.pipelines/stages/jobs/py-packaging-job.yml index 4e65e0454..99582c9f1 100644 --- a/.pipelines/stages/jobs/py-packaging-job.yml +++ b/.pipelines/stages/jobs/py-packaging-job.yml @@ -93,10 +93,7 @@ jobs: value: ${{ parameters.os }} - name: feed_name - ${{ if and(eq(parameters.cuda_version, '12.2'), eq(parameters.ep, 'cuda')) }}: - value: 'ort-cuda-12-nightly' - ${{ else }}: - value: '7982ae20-ed19-4a35-a362-a96ac99897b7' + value: '7982ae20-ed19-4a35-a362-a96ac99897b7' - name: ort_filename ${{ if eq(parameters.ep, 'cpu') }}: @@ -114,11 +111,11 @@ jobs: value: 'Microsoft.ML.OnnxRuntime' - name: dml_dir - value: 'Microsoft.AI.DirectML.1.14.2' + value: 'Microsoft.AI.DirectML.1.15.1' - name: dml_zip - value: 'Microsoft.AI.DirectML.1.14.2.zip' + value: 'Microsoft.AI.DirectML.1.15.1.zip' - name: dml_url - value: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.14.2" + value: "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.1" - name: d3d12_dir value: 'Microsoft.Direct3D.D3D12.1.614.0' @@ -163,7 +160,7 @@ jobs: import subprocess subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', 'build', 'packaging', 'twine']) workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' + displayName: 'Install python modules' - ${{ if eq(parameters.os, 'linux') }}: - template: steps/capi-linux-step.yml diff --git a/.pipelines/stages/jobs/py-validation-job.yml b/.pipelines/stages/jobs/py-validation-job.yml index 0e1310288..6e3bd6625 100644 --- a/.pipelines/stages/jobs/py-validation-job.yml +++ b/.pipelines/stages/jobs/py-validation-job.yml @@ -107,6 +107,12 @@ jobs: ${{ else }}: value: 'cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4' + - name: cuda_docker_image + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 + ${{ else }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20240610.1 + steps: - checkout: self clean: true @@ -174,7 +180,7 @@ jobs: set -e -x az login --identity --username 63b63039-6328-442f-954b-5a64d124e5b4 az acr login --name onnxruntimebuildcache --subscription 00c06639-6ee4-454e-8058-8d8b1703bd87 - docker pull onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 + docker pull $(cuda_docker_image) python_exe=/opt/python/cp310-cp310/bin/python3.10 docker run \ @@ -183,21 +189,25 @@ jobs: --volume $(Build.Repository.LocalPath):/ort_genai_src \ --volume $(Build.BinariesDirectory):/ort_genai_binary \ -e HF_TOKEN=$HF_TOKEN \ - -w /ort_genai_src/ onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20240531.1 \ + -w /ort_genai_src/ $(cuda_docker_image) \ bash -c " \ - $python_exe -m pip install numpy transformers torch onnx onnxruntime && \ + export ORTGENAI_LOG_ORT_LIB=1 && \ + $python_exe -m pip install -r /ort_genai_src/test/python/requirements.txt && \ + $python_exe -m pip install -r /ort_genai_src/test/python/requirements-cuda.txt && \ cd /ort_genai_src/examples/python && \ $python_exe -m pip install --no-index --find-links=/ort_genai_binary/wheel $(pip_package_name) && \ - $python_exe model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder)" + $python_exe model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose" displayName: 'Run Example With Artifact' workingDirectory: '$(Build.Repository.LocalPath)' - ${{ elseif eq(parameters.ep, 'cpu') }}: - bash: | - python -m pip install numpy transformers torch onnx onnxruntime + export ORTGENAI_LOG_ORT_LIB=1 + python -m pip install -r test/python/requirements.txt + python -m pip install -r test/python/requirements-cpu.txt cd examples/python python -m pip install --no-index --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) - python model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder) + python model-generate.py -m ./models/$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose displayName: 'Run Example With Artifact' workingDirectory: '$(Build.Repository.LocalPath)' @@ -209,16 +219,23 @@ jobs: displayName: 'Download CUDA $(cuda_version)' workingDirectory: '$(Build.Repository.LocalPath)' - powershell: | + python -m pip install -r test/python/requirements.txt if ("$(ep)" -eq "cuda") { $env:CUDA_PATH = '$(Build.Repository.LocalPath)\cuda_sdk\v$(cuda_version)' $env:PATH = "$env:CUDA_PATH\bin;$env:CUDA_PATH\extras\CUPTI\lib64;$env:PATH" Write-Host $env:PATH + python -m pip install -r test/python/requirements-cuda.txt + } + elseif ("$(ep)" -eq "directml") { + python -m pip install -r test/python/requirements-directml.txt + } + else { + python -m pip install -r test/python/requirements-cpu.txt } - python -m pip install numpy transformers torch onnx onnxruntime cd examples\python python -m pip install --no-index --find-links=$(Build.BinariesDirectory)/wheel $(pip_package_name) - python model-generate.py -m .\models\$(prebuild_phi3_mini_model_folder) + python model-generate.py -m .\models\$(prebuild_phi3_mini_model_folder) --min_length 25 --max_length 50 --verbose displayName: 'Run Example With Artifact' workingDirectory: '$(Build.Repository.LocalPath)' diff --git a/.pipelines/stages/jobs/steps/capi-linux-step.yml b/.pipelines/stages/jobs/steps/capi-linux-step.yml index 89412932a..c3074f607 100644 --- a/.pipelines/stages/jobs/steps/capi-linux-step.yml +++ b/.pipelines/stages/jobs/steps/capi-linux-step.yml @@ -34,8 +34,12 @@ steps: - bash: | echo "arch=$(arch)" + echo "ort_filename=$(ort_filename)" + echo "ort_version=$(ort_version)" echo "ep=$(ep)" - echo "build_config=$(build_config)" + echo "cuda_version=$(cuda_version)" + echo "target=${{ parameters.target }}" + echo "build_config=${{ parameters.build_config }}" displayName: 'Print Parameters' - template: utils/download-ort.yml @@ -118,7 +122,9 @@ steps: docker run \ --rm \ --volume $(Build.Repository.LocalPath):/ort_genai_src \ - -w /ort_genai_src/ ortgenai$(ep)build$(arch) \ + -w /ort_genai_src/ \ + -e ONNXRUNTIME_VERSION=$(ONNXRUNTIME_VERSION) \ + ortgenai$(ep)build$(arch) \ bash -c " \ /usr/bin/cmake --preset linux_gcc_$(ep)_$(build_config) \ -DENABLE_TESTS=OFF \ @@ -140,7 +146,9 @@ steps: docker run \ --rm \ --volume $(Build.Repository.LocalPath):/ort_genai_src \ - -w /ort_genai_src/ ortgenai$(ep)build$(arch) \ + -w /ort_genai_src/ \ + -e ONNXRUNTIME_VERSION=$(ONNXRUNTIME_VERSION) \ + ortgenai$(ep)build$(arch) \ bash -c " \ /usr/bin/cmake --build --preset linux_gcc_$(ep)_$(build_config) \ -DENABLE_TESTS=OFF \ diff --git a/.pipelines/stages/jobs/steps/capi-win-step.yml b/.pipelines/stages/jobs/steps/capi-win-step.yml index 3000f244d..ed52f8172 100644 --- a/.pipelines/stages/jobs/steps/capi-win-step.yml +++ b/.pipelines/stages/jobs/steps/capi-win-step.yml @@ -33,6 +33,7 @@ steps: - script: | echo "arch=$(arch)" + echo "ort_filename=$(ort_filename)" echo "ort_version=$(ort_version)" echo "ep=$(ep)" echo "cuda_version=$(cuda_version)" @@ -82,7 +83,6 @@ steps: condition: ne(variables['ep'], 'cuda') workingDirectory: '$(Build.Repository.LocalPath)' - - powershell: | cmake --build --preset windows_$(arch)_$(ep)_$(build_config) --parallel --target ${{ parameters.target }} displayName: 'Build C API' diff --git a/.pipelines/stages/jobs/steps/compliant/component-governance-component-detection-step.yml b/.pipelines/stages/jobs/steps/compliant/component-governance-component-detection-step.yml index 0d63911f5..9ecdb7bd8 100644 --- a/.pipelines/stages/jobs/steps/compliant/component-governance-component-detection-step.yml +++ b/.pipelines/stages/jobs/steps/compliant/component-governance-component-detection-step.yml @@ -21,5 +21,6 @@ steps: and(eq('${{parameters.condition}}', 'succeeded'), succeeded())) ignoreDirectories: '$(Build.Repository.LocalPath)/build/cpu/_deps/, - $(Build.Repository.LocalPath)/build/cpu/win-arm64/_deps' + $(Build.Repository.LocalPath)/build/cpu/win-arm64/_deps, + $(Build.Repository.LocalPath)/build/cpu/win-x64/_deps' diff --git a/.pipelines/stages/jobs/steps/compliant/win-esrp-dll-step.yml b/.pipelines/stages/jobs/steps/compliant/win-esrp-dll-step.yml index 69a99ec69..03603d4e2 100644 --- a/.pipelines/stages/jobs/steps/compliant/win-esrp-dll-step.yml +++ b/.pipelines/stages/jobs/steps/compliant/win-esrp-dll-step.yml @@ -32,3 +32,35 @@ steps: SessionTimeout: 90 ServiceEndpointUrl: 'https://api.esrp.microsoft.com/api/v2' MaxConcurrency: 25 + signConfigType: inlineSignParams + inlineOperation: | + [ + { + "keyCode": "CP-230012", + "operationSetCode": "SigntoolSign", + "parameters": [ + { + "parameterName": "OpusName", + "parameterValue": "Microsoft" + }, + { + "parameterName": "OpusInfo", + "parameterValue": "http://www.microsoft.com" + }, + { + "parameterName": "PageHash", + "parameterValue": "/NPH" + }, + { + "parameterName": "FileDigest", + "parameterValue": "/fd sha256" + }, + { + "parameterName": "TimeStamp", + "parameterValue": "/tr \"http://rfc3161.gtm.corp.microsoft.com/TSS/HttpTspServer\" /td sha256" + } + ], + "toolName": "signtool.exe", + "toolVersion": "6.2.9304.0" + } + ] diff --git a/.pipelines/stages/jobs/steps/utils/download-ort.yml b/.pipelines/stages/jobs/steps/utils/download-ort.yml index bb35c196f..47717112c 100644 --- a/.pipelines/stages/jobs/steps/utils/download-ort.yml +++ b/.pipelines/stages/jobs/steps/utils/download-ort.yml @@ -9,7 +9,7 @@ steps: - task: DownloadPackage@1 inputs: packageType: 'nuget' - feed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/$(feed_name)' + feed: '2692857e-05ef-43b4-ba9c-ccf1c22c437c/$(feed_name)' # projectID/feedID definition: '$(ort_filename)' # Can also be package name version: '$(ort_version)' extract: false diff --git a/.pipelines/stages/nuget-packaging-stage.yml b/.pipelines/stages/nuget-packaging-stage.yml index ab89c0cb1..4137b29bd 100644 --- a/.pipelines/stages/nuget-packaging-stage.yml +++ b/.pipelines/stages/nuget-packaging-stage.yml @@ -31,6 +31,8 @@ parameters: - name: ort_version type: string +- name: ort_cuda_version + type: string - name: ort_dml_version type: string - name: build_config @@ -53,7 +55,7 @@ stages: - template: jobs/nuget-packaging-job.yml parameters: ep: 'cuda' - ort_version: ${{ parameters.ort_version }} + ort_version: ${{ parameters.ort_cuda_version }} build_config: ${{ parameters.build_config }} enable_linux_cuda: ${{ parameters.enable_linux_cuda }} enable_win_cuda: ${{ parameters.enable_win_cuda }} diff --git a/.pipelines/stages/py-validation-stage.yml b/.pipelines/stages/py-validation-stage.yml index 6bbee518a..d619c3acb 100644 --- a/.pipelines/stages/py-validation-stage.yml +++ b/.pipelines/stages/py-validation-stage.yml @@ -79,6 +79,18 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - ${{ if and(eq(parameters.enable_win_cuda, true), contains(parameters.cuda_versions, '12.2')) }}: + - template: jobs/py-validation-job.yml + parameters: + arch: 'x64' + cuda_version: '12.2' + cuda_display_version: '122' + ep: 'cuda' + ort_version: ${{ parameters.ort_cuda_122_version }} + os: 'win' + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + - ${{ if eq(parameters.enable_linux_cpu, true) }}: - template: jobs/py-validation-job.yml parameters: diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c816a9c3..06d196d7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,7 +115,7 @@ endif() if(USE_DML) list(APPEND onnxruntime_libs "${ORT_LIB_DIR}/DirectML.dll") list(APPEND onnxruntime_libs "${ORT_LIB_DIR}/D3D12Core.dll") - list(APPEND ortgenai_embed_libs "${ORT_LIB_DIR}/DirectML.dll" "${ORT_LIB_DIR}/D3D12Core.dll") + list(APPEND ortgenai_embed_libs "${ORT_LIB_DIR}/D3D12Core.dll") target_include_directories(onnxruntime-genai PRIVATE $) target_include_directories(onnxruntime-genai PRIVATE $/directx) target_include_directories(onnxruntime-genai PRIVATE $) diff --git a/VERSION_INFO b/VERSION_INFO index 5c73f2ac0..9974c4be8 100644 --- a/VERSION_INFO +++ b/VERSION_INFO @@ -1 +1 @@ -0.4.0-dev \ No newline at end of file +0.5.0-dev \ No newline at end of file diff --git a/benchmark/python/benchmark_e2e.py b/benchmark/python/benchmark_e2e.py index a4066e03f..08fa63aec 100644 --- a/benchmark/python/benchmark_e2e.py +++ b/benchmark/python/benchmark_e2e.py @@ -173,7 +173,7 @@ def save_results(args, results, filename, print_memory_usage=False): BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records) print(f"Results saved in {filename}!") -def run_benchmark_memory(args, model, tokenizer, batch_size, prompt_length, generation_length, max_length): +def run_benchmark_memory(args, batch_size, prompt_length, generation_length, max_length): """ This function is to run benchmark and print the momory usage """ @@ -186,7 +186,7 @@ def run_benchmark_memory(args, model, tokenizer, batch_size, prompt_length, gene monitor_thread.start() - metrics = run_benchmark(args, model, tokenizer, batch_size, prompt_length, generation_length, max_length) + metrics = run_benchmark(args, batch_size, prompt_length, generation_length, max_length) stop_monitoring = True monitor_thread.join() @@ -198,7 +198,7 @@ def run_benchmark_memory(args, model, tokenizer, batch_size, prompt_length, gene return metrics -def run_benchmark(args, model, tokenizer, batch_size, prompt_length, generation_length, max_length): +def run_benchmark(args, batch_size, prompt_length, generation_length, max_length): # Get user arguments num_repetitions = args.repetitions @@ -351,13 +351,7 @@ def run_benchmark(args, model, tokenizer, batch_size, prompt_length, generation_ def main(args): all_csv_metrics = [] - # Get tokenizer, and model - model_path = args.input_folder - if args.verbose: print(f"Loading model... ") - model=og.Model(f'{model_path}') - if args.verbose: print("Model loaded, loading tokenizer...") - tokenizer = og.Tokenizer(model) - if args.verbose: print("Tokenizer loaded, starting benchmark...") + for batch_size in args.batch_sizes: for l, prompt_length in enumerate(args.prompt_lengths): for g, gen_length in enumerate(args.generation_lengths): @@ -368,9 +362,9 @@ def main(args): max_length = prompt_length + gen_length print(f"Args: batch_size = {batch_size}, prompt_length = {prompt_length}, tokens = {gen_length}, max_length = {max_length}") if args.print_memory_usage: - metrics = run_benchmark_memory(args, model, tokenizer, batch_size, prompt_length, gen_length, max_length) + metrics = run_benchmark_memory(args, batch_size, prompt_length, gen_length, max_length) else: - metrics = run_benchmark(args, model, tokenizer, batch_size, prompt_length, gen_length, max_length) + metrics = run_benchmark(args, batch_size, prompt_length, gen_length, max_length) all_csv_metrics.append(metrics) # Add metrics to CSV if args.verbose: print("Adding results to CSV") @@ -410,4 +404,4 @@ def str2strlist(value): parser.add_argument('-mn', '--model_name', type=str, default='model_name', help='Model name defined by users') parser.add_argument('-pr', '--precision', type=str, default='fp16', help='Model precision for metrics info') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/build.py b/build.py index c03a06c45..07f9db1e1 100644 --- a/build.py +++ b/build.py @@ -341,7 +341,7 @@ def _get_csharp_properties(args: argparse.Namespace): configuration = f"/p:Configuration={args.config}" platform = "/p:Platform=Any CPU" # need an extra config on windows as the actual build output is in the original build dir / config / config - native_lib_path = f"/p:NativeBuildOutputDir={str(args.build_dir / args.config)}" + native_lib_path = f"/p:NativeBuildOutputDir={str(args.build_dir / args.config) if util.is_windows() else str(args.build_dir)}" props = [configuration, platform, native_lib_path] @@ -519,7 +519,7 @@ def build(args: argparse.Namespace, env: dict[str, str]): util.run(make_command, env=env) - if util.is_windows() and not args.skip_csharp: + if not args.skip_csharp: dotnet = str(_resolve_executable_path("dotnet")) # Build the library @@ -536,7 +536,7 @@ def test(args: argparse.Namespace, env: dict[str, str]): ctest_cmd = [str(args.ctest_path), "--build-config", args.config, "--verbose", "--timeout", "10800"] util.run(ctest_cmd, cwd=str(args.build_dir)) - if util.is_windows() and not args.skip_csharp: + if not args.skip_csharp: dotnet = str(_resolve_executable_path("dotnet")) csharp_test_command = [dotnet, "test"] csharp_test_command += _get_csharp_properties(args) diff --git a/cmake/check_cuda.cmake b/cmake/check_cuda.cmake index 4867eddc4..620a5b27a 100644 --- a/cmake/check_cuda.cmake +++ b/cmake/check_cuda.cmake @@ -60,6 +60,7 @@ elseif(USE_CUDA) else() file(GLOB generator_cuda_srcs "${GENERATORS_ROOT}/*_cuda*.*") list(REMOVE_ITEM generator_srcs ${generator_cuda_srcs}) + add_compile_definitions(USE_CUDA=0) endif() if(USE_CUDA AND NOT EXISTS "${ORT_LIB_DIR}/${ONNXRUNTIME_PROVIDERS_CUDA_LIB}") diff --git a/cmake/check_dml.cmake b/cmake/check_dml.cmake index 505b443a0..fbc364758 100644 --- a/cmake/check_dml.cmake +++ b/cmake/check_dml.cmake @@ -15,4 +15,6 @@ if(USE_DML) else() message(FATAL_ERROR "USE_DML is ON but this isn't windows.") endif() +else() + add_compile_definitions(USE_DML=0) endif() \ No newline at end of file diff --git a/cmake/check_rocm.cmake b/cmake/check_rocm.cmake index 9526449b1..567118f32 100644 --- a/cmake/check_rocm.cmake +++ b/cmake/check_rocm.cmake @@ -5,4 +5,6 @@ endif() if(USE_ROCM) list(APPEND onnxruntime_libs "${ORT_LIB_DIR}/${ONNXRUNTIME_PROVIDERS_ROCM_LIB}") add_compile_definitions(USE_ROCM=1) +else() + add_compile_definitions(USE_ROCM=0) endif() \ No newline at end of file diff --git a/cmake/global_variables.cmake b/cmake/global_variables.cmake index e4586d4e9..2953ee287 100644 --- a/cmake/global_variables.cmake +++ b/cmake/global_variables.cmake @@ -13,7 +13,13 @@ set(VERSION_INFO ${ver}) # VERSION_PATCH: 0 string(REPLACE "-" ";" VERSION_LIST ${VERSION_INFO}) list(GET VERSION_LIST 0 VERSION_STR) -list(GET VERSION_LIST 1 VERSION_SUFFIX) +# Check if it is a stable or dev version +list(LENGTH VERSION_LIST VERSION_LIST_LENGTH) +if(VERSION_LIST_LENGTH GREATER 1) + list(GET VERSION_LIST 1 VERSION_SUFFIX) +else() + set(VERSION_SUFFIX "") # Set VERSION_SUFFIX to empty if stable version +endif() string(REPLACE "." ";" VERSION_LIST ${VERSION_STR}) list(GET VERSION_LIST 0 VERSION_MAJOR) list(GET VERSION_LIST 1 VERSION_MINOR) diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 3a55307fa..56420786e 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -57,11 +57,11 @@ if(USE_CUDA) target_link_libraries( phi3 PUBLIC - cublasLt cublas cudnn curand cufft cudart) + cublas curand cudart) target_link_libraries( phi3v PUBLIC - cublasLt cublas cudnn curand cufft cudart) + cublas curand cudart) endif() file(GLOB ort_genai_libs "${CMAKE_SOURCE_DIR}/lib/${ONNXRUNTIME_GENAI_DEPENDENCY}") diff --git a/examples/chat_app/README.md b/examples/chat_app/README.md index 7720d14fe..3755325c5 100755 --- a/examples/chat_app/README.md +++ b/examples/chat_app/README.md @@ -32,7 +32,8 @@ This is a chat demo using the various versions of the LLMs 3. Install the requirements ```bash - pip install -r requirements.txt + pip install huggingface-hub mdtex2html + pip install gradio==4.36.0 # Gradio 3.47 breaks the UI and versions between 3.42 and 3.47 haven't been tested ``` @@ -48,7 +49,8 @@ mkdir -p models/cuda mv cuda-int4-rtn-block-32 models/cuda-int4/Phi-3-vision ``` -Folder structure should look as the below: +If you would like the app to discover your models, please create the following folder structure, with the `models` folder at the same level as `chat_app`, one folder containing a set of models, and the actual models below this. + ``` --chat_app --models @@ -61,16 +63,17 @@ Folder structure should look as the below: --Phi-3-vision ``` +If there is the word `vision` in the folder name containing the model files, the app will create a UI that processes images. If not, it will create a UI that processes language only. + ## Launch the app ``` -python chat_app/app.py +python app.py ``` -or launch the app by `python app.py`. - You can also attach your model that is outside of `models` folder to the app by passing arguments of `--model_path` and `--model_name`. -``` + +```bash python chat_app/app.py --model_name "Phi-3-vision" --model_path "/mnt/onnx/Phi-3-vision" ``` @@ -81,7 +84,7 @@ Running on local URL: http://127.0.0.1:7860 To create a public link, set `share=True` in `launch()`. ``` -Then open the local URL in broswer +Then open the local URL in browser ![alt text](image.png) For vision model, you will have the below UI interface. diff --git a/examples/chat_app/app.py b/examples/chat_app/app.py index 2c2fd9d82..67feeebb0 100755 --- a/examples/chat_app/app.py +++ b/examples/chat_app/app.py @@ -29,10 +29,12 @@ def change_model_listener(new_model_name): d = available_models[new_model_name] if "vision" in new_model_name: + print("Configuring for multi-modal model") interface = MultiModal_ONNXModel( model_path=d["model_dir"] ) else: + print("Configuring for language-only model") interface = ONNXModel( model_path=d["model_dir"] ) @@ -219,7 +221,7 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa demo.load(change_model_listener, inputs=[model_name], outputs=[model_name, image], concurrency_limit=1) - demo.title = "LLM Chat UI" + demo.title = "Local Model UI" if expose_locally: demo.launch(server_name="0.0.0.0", server_port=5000) @@ -234,7 +236,6 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa parser.add_argument("--model_name", "-n", type=str, required=False, help="The name of your model") args = parser.parse_args() model_path = args.model_path - model_name = args.model_name if not os.path.exists(optimized_directory) and not model_path: raise ValueError("Please download the model into models folder or load the model by passing --model_path") @@ -244,5 +245,8 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa # check if genai_config.json in the model foler if "genai_config.json" not in os.listdir(model_path): raise ValueError(f"Your model_path folder do not include 'genai.json' file, please double check your model_path '{model_path}'") + + if args.model_name: + model_name = args.model_name launch_chat_app(args.expose_locally, model_name, model_path) diff --git a/examples/chat_app/requirements.txt b/examples/chat_app/requirements.txt deleted file mode 100644 index a316454b6..000000000 --- a/examples/chat_app/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -gradio==4.36.0 # Gradio 3.47 breaks the UI and versions between 3.42 and 3.47 haven't been tested -huggingface-hub -markdown -mdtex2html -protobuf==3.20.3 # protobuf 4.x aborts with OOM when optimizing large models -Pygments -sentencepiece -tabulate -torch diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index 5d554ead6..82576c1ad 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -5,10 +5,10 @@ void PrintUsage() { Console.WriteLine("Usage:"); Console.WriteLine(" -m model_path"); - Console.WriteLine(" -i (optional): Intereactive mode"); + Console.WriteLine(" -i (optional): Interactive mode"); } -OgaHandle ogaHandle = new OgaHandle(); +using OgaHandle ogaHandle = new OgaHandle(); if (args.Length < 1) { @@ -16,7 +16,7 @@ void PrintUsage() Environment.Exit(-1); } -bool intereactive = false; +bool interactive = false; string modelPath = string.Empty; uint i = 0; @@ -25,7 +25,7 @@ void PrintUsage() var arg = args[i]; if (arg == "-i") { - intereactive = true; + interactive = true; } else if (arg == "-m") { @@ -47,13 +47,13 @@ void PrintUsage() Console.WriteLine("-------------"); Console.WriteLine("Model path: " + modelPath); -Console.WriteLine("Intereactive: " + intereactive); +Console.WriteLine("Interactive: " + interactive); using Model model = new Model(modelPath); using Tokenizer tokenizer = new Tokenizer(model); var option = 2; -if (intereactive) +if (interactive) { Console.WriteLine("Please enter option number:"); Console.WriteLine("1. Complete Output"); @@ -64,7 +64,7 @@ void PrintUsage() do { string prompt = "def is_prime(num):"; // Example prompt - if (intereactive) + if (interactive) { Console.WriteLine("Prompt:"); prompt = Console.ReadLine(); @@ -72,7 +72,7 @@ void PrintUsage() if (string.IsNullOrEmpty(prompt)) { continue; - } + } var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); using GeneratorParams generatorParams = new GeneratorParams(model); @@ -80,17 +80,22 @@ void PrintUsage() generatorParams.SetInputSequences(sequences); if (option == 1) // Complete Output { + var watch = System.Diagnostics.Stopwatch.StartNew(); var outputSequences = model.Generate(generatorParams); var outputString = tokenizer.Decode(outputSequences[0]); - + watch.Stop(); + var runTimeInSeconds = watch.Elapsed.TotalSeconds; Console.WriteLine("Output:"); Console.WriteLine(outputString); + var totalTokens = outputSequences[0].Length; + Console.WriteLine($"Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}"); } else if (option == 2) //Streaming Output { using var tokenizerStream = tokenizer.CreateStream(); using var generator = new Generator(model, generatorParams); + var watch = System.Diagnostics.Stopwatch.StartNew(); while (!generator.IsDone()) { generator.ComputeLogits(); @@ -98,5 +103,10 @@ void PrintUsage() Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); } Console.WriteLine(); + watch.Stop(); + var runTimeInSeconds = watch.Elapsed.TotalSeconds; + var outputSequence = generator.GetSequence(0); + var totalTokens = outputSequence.Length; + Console.WriteLine($"Streaming Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}"); } -} while (intereactive); \ No newline at end of file +} while (interactive); diff --git a/examples/python/README.md b/examples/python/README.md index ed7f027f0..9410bd8bc 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -41,6 +41,6 @@ The `model-qa` script streams the output text token by token. To run the python examples... ```bash -python model-generate.py -m {path to model folder} -ep {cpu or cuda} -i {string prompt} -python model-qa.py -m {path to model folder} -ep {cpu or cuda} +python model-generate.py -m {path to model folder} -pr {input prompt} +python model-qa.py -m {path to model folder} ``` diff --git a/nuget/MANAGED_PACKAGE.md b/nuget/MANAGED_PACKAGE.md new file mode 100644 index 000000000..8d3dc0fb1 --- /dev/null +++ b/nuget/MANAGED_PACKAGE.md @@ -0,0 +1,3 @@ +## About + +This package is a dependency of [Microsoft.ML.OnnxRuntimeGenAI](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntimeGenAI) and does not need to be installed directly. diff --git a/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec b/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec index b204ab49a..f4d06bc0a 100644 --- a/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec +++ b/nuget/Microsoft.ML.OnnxRuntimeGenAI.Managed.nuspec @@ -19,7 +19,7 @@ - + diff --git a/nuget/PACKAGE.md b/nuget/PACKAGE.md new file mode 100644 index 000000000..7f7c324db --- /dev/null +++ b/nuget/PACKAGE.md @@ -0,0 +1,131 @@ +## About + +Run Llama, Phi (Language + Vision!), Gemma, Mistral with ONNX Runtime. + +This API gives you an easy, flexible and performant way of running LLMs on device using .NET/C#. + +It implements the generative AI loop for ONNX models, including pre and post processing, inference with ONNX Runtime, logits processing, search and sampling, and KV cache management. + +You can call a high level `generate()` method to generate all of the output at once, or stream the output one token at a time. + +## Key Features + +* Language and vision pre and post processing +* Inference using ONNX Runtime +* Generation tuning with greedy, beam search and random sampling +* KV cache management to optimize performance +* Multi target execution (CPU, GPU, with NPU coming!) + +## Sample + +```csharp +// See https://aka.ms/new-console-template for more information +using Microsoft.ML.OnnxRuntimeGenAI; + +OgaHandle ogaHandle = new OgaHandle(); + +// Specify the location of your downloaded model. +// Many models are published on HuggingFace e.g. +// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx +string modelPath = "..." +Console.WriteLine("Model path: " + modelPath); + +using Model model = new Model(modelPath); +using Tokenizer tokenizer = new Tokenizer(model); + +// Set your prompt here +string prompt = "public static bool IsPrime(int number)"; +var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); + +using GeneratorParams generatorParams = new GeneratorParams(model); +generatorParams.SetSearchOption("max_length", 512); +generatorParams.SetInputSequences(sequences); + +using var tokenizerStream = tokenizer.CreateStream(); +using var generator = new Generator(model, generatorParams); +while (!generator.IsDone()) +{ + generator.ComputeLogits(); + generator.GenerateNextToken(); + Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); +} +``` + +Generates the following output: + + +``` +Here's a complete implementation of the `IsPrime` function in C# that checks if a given number is prime. The function includes basic input validation and comments for clarity. +``` + +```csharp +using System; + +namespace PrimeChecker +{ + public class PrimeChecker + { + /// + /// Checks if the given number is prime. + /// + /// The number to check. + /// true if the number is prime; otherwise, false. + public static bool IsPrime(int number) + { + // Input validation + if (number < 2) + { + return false; + } + + // 2 is the only even prime number + if (number == 2) + { + return true; + } + + // Exclude even numbers greater than 2 + if (number % 2 == 0) + { + return false; + } + + // Check for factors up to the square root of the number + int limit = (int)Math.Floor(Math.Sqrt(number)); + for (int i = 3; i <= limit; i += 2) + { + if (number % i == 0) + { + return false; + } + } + + return true; + } + + static void Main(string[] args) + { + int number = 29; + bool isPrime = PrimeChecker.IsPrime(number); + + Console.WriteLine($"Is {number} prime? {isPrime}"); + } + } +} +``` + +``` +This implementation checks if a number is prime by iterating only up to the square root of the number, which is an optimization over checking all numbers up to the number itself. It also excludes even numbers greater than 2, as they cannot be prime. +``` + +## Source code repository + +ONNX Runtime is an open source project. See: +* (https://github.com/microsoft/onnxruntime)[https://github.com/microsoft/onnxruntime] +* (https://github.com/microsoft/onnxruntime-genai)[https://github.com/microsoft/onnxruntime-genai] + +## Documentation + +See (https://onxxruntime.ai/docs/genai)[https://onxxruntime.ai/docs/genai] + + diff --git a/src/config.cpp b/src/config.cpp index d6bb344c9..00708faae 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -64,6 +64,10 @@ struct SessionOptions_Element : JSON::Element { v_.log_id = value; else if (name == "enable_profiling") v_.enable_profiling = value; + else if (name == "ep_context_embed_mode") + v_.ep_context_embed_mode = value; + else if (name == "ep_context_file_path") + v_.ep_context_file_path = value; else throw JSON::unknown_value_error{}; } @@ -84,6 +88,14 @@ struct SessionOptions_Element : JSON::Element { v_.enable_cpu_mem_arena = value; else if (name == "enable_mem_pattern") v_.enable_mem_pattern = value; + else if (name == "disable_cpu_ep_fallback") + v_.disable_cpu_ep_fallback = value; + else if (name == "disable_quant_qdq") + v_.disable_quant_qdq = value; + else if (name == "enable_quant_qdq_cleanup") + v_.enable_quant_qdq_cleanup = value; + else if (name == "ep_context_enable") + v_.ep_context_enable = value; else throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 4979c555e..7263dbda3 100644 --- a/src/config.h +++ b/src/config.h @@ -27,6 +27,12 @@ struct Config { std::optional inter_op_num_threads; std::optional enable_cpu_mem_arena; std::optional enable_mem_pattern; + std::optional disable_cpu_ep_fallback; + std::optional disable_quant_qdq; + std::optional enable_quant_qdq_cleanup; + std::optional ep_context_enable; + std::optional ep_context_embed_mode; + std::optional ep_context_file_path; std::optional log_id; std::optional log_severity_level; std::optional enable_profiling; diff --git a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj index 1cd52e2e2..653f23332 100644 --- a/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj +++ b/src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj @@ -25,7 +25,7 @@ - + $(CsharpSrcRoot)\..\..\build $(OnnxRuntimeGenAIBuildDirectory)\$(Configuration)\ $(CsharpSrcRoot)\..\..\VERSION_INFO @@ -36,10 +36,15 @@ - + $(VersionInfoStr.Split(-)[0]) $(VersionInfoStr.Split(-)[1]) + + + $(VersionInfoStr) + + diff --git a/src/csharp/Model.cs b/src/csharp/Model.cs index cf4b7d2a8..675bc2540 100644 --- a/src/csharp/Model.cs +++ b/src/csharp/Model.cs @@ -42,8 +42,11 @@ protected virtual void Dispose(bool disposing) { return; } - NativeMethods.OgaDestroyModel(_modelHandle); - _modelHandle = IntPtr.Zero; + if (_modelHandle != IntPtr.Zero) + { + NativeMethods.OgaDestroyModel(_modelHandle); + _modelHandle = IntPtr.Zero; + } _disposed = true; } } diff --git a/src/csharp/Utils.cs b/src/csharp/Utils.cs index b84f1d407..90d007bc7 100644 --- a/src/csharp/Utils.cs +++ b/src/csharp/Utils.cs @@ -7,11 +7,33 @@ namespace Microsoft.ML.OnnxRuntimeGenAI { - public class OgaHandle + public class OgaHandle: IDisposable { + private bool _disposed = false; + + public OgaHandle() + { + } + ~OgaHandle() { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } NativeMethods.OgaShutdown(); + _disposed = true; } } diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp index 6016cfceb..876ed7036 100644 --- a/src/dml/dml_helpers.cpp +++ b/src/dml/dml_helpers.cpp @@ -77,6 +77,9 @@ static std::vector> EnumerateAdapters() { static ComPtr CreatePerformantAdapter() { auto filtered_adapters = EnumerateAdapters(); + if (filtered_adapters.empty()) { + throw std::runtime_error("No adapter is available for DML."); + } return filtered_adapters.front(); } @@ -228,33 +231,61 @@ void ExecuteReusableCommandList( execution_context->ExecuteCommandList(command_list_state.graphics_command_list.Get(), fence.GetAddressOf(), &completion_value); } -static uint64_t DataTypeSizeInBytes(DML_TENSOR_DATA_TYPE dml_data_type) { - switch (dml_data_type) { - case DML_TENSOR_DATA_TYPE_FLOAT16: - return sizeof(Ort::Float16_t); +// Copied from https://learn.microsoft.com/en-us/windows/ai/directml/dml-helper-functions#dmlcalcbuffertensorsize +static UINT64 DMLCalcBufferTensorSize( + DML_TENSOR_DATA_TYPE dataType, + UINT dimensionCount, + _In_reads_(dimensionCount) const UINT* sizes, + _In_reads_opt_(dimensionCount) const UINT* strides) { + UINT elementSizeInBytes = 0; + switch (dataType) { case DML_TENSOR_DATA_TYPE_FLOAT32: - return sizeof(float); - case DML_TENSOR_DATA_TYPE_FLOAT64: - return sizeof(double); - case DML_TENSOR_DATA_TYPE_UINT8: - return sizeof(uint8_t); - case DML_TENSOR_DATA_TYPE_UINT16: - return sizeof(uint16_t); case DML_TENSOR_DATA_TYPE_UINT32: - return sizeof(uint32_t); - case DML_TENSOR_DATA_TYPE_UINT64: - return sizeof(uint64_t); - case DML_TENSOR_DATA_TYPE_INT8: - return sizeof(int8_t); - case DML_TENSOR_DATA_TYPE_INT16: - return sizeof(int16_t); case DML_TENSOR_DATA_TYPE_INT32: - return sizeof(int32_t); + elementSizeInBytes = 4; + break; + + case DML_TENSOR_DATA_TYPE_FLOAT16: + case DML_TENSOR_DATA_TYPE_UINT16: + case DML_TENSOR_DATA_TYPE_INT16: + elementSizeInBytes = 2; + break; + + case DML_TENSOR_DATA_TYPE_UINT8: + case DML_TENSOR_DATA_TYPE_INT8: + elementSizeInBytes = 1; + break; + + case DML_TENSOR_DATA_TYPE_FLOAT64: + case DML_TENSOR_DATA_TYPE_UINT64: case DML_TENSOR_DATA_TYPE_INT64: - return sizeof(int64_t); + elementSizeInBytes = 8; + break; + default: - THROW_HR(E_NOTIMPL); + return 0; // Invalid data type + } + + UINT64 minimumImpliedSizeInBytes = 0; + if (!strides) { + minimumImpliedSizeInBytes = sizes[0]; + for (UINT i = 1; i < dimensionCount; ++i) { + minimumImpliedSizeInBytes *= sizes[i]; + } + minimumImpliedSizeInBytes *= elementSizeInBytes; + } else { + UINT indexOfLastElement = 0; + for (UINT i = 0; i < dimensionCount; ++i) { + indexOfLastElement += (sizes[i] - 1) * strides[i]; + } + + minimumImpliedSizeInBytes = (static_cast(indexOfLastElement) + 1) * elementSizeInBytes; } + + // Round up to the nearest 4 bytes. + minimumImpliedSizeInBytes = (minimumImpliedSizeInBytes + 3) & ~3ull; + + return minimumImpliedSizeInBytes; } ComPtr CreateCastOperator( @@ -267,7 +298,7 @@ ComPtr CreateCastOperator( input_buffer_desc.Sizes = &num_elements; input_buffer_desc.DimensionCount = 1; input_buffer_desc.DataType = source_data_type; - input_buffer_desc.TotalTensorSizeInBytes = num_elements * DataTypeSizeInBytes(source_data_type); + input_buffer_desc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(source_data_type, 1, &num_elements, NULL); DML_TENSOR_DESC input_tensor_desc = {DML_TENSOR_TYPE_BUFFER, &input_buffer_desc}; // Create the output tensor desc @@ -275,7 +306,7 @@ ComPtr CreateCastOperator( output_buffer_desc.Sizes = &num_elements; output_buffer_desc.DimensionCount = 1; output_buffer_desc.DataType = target_data_type; - output_buffer_desc.TotalTensorSizeInBytes = num_elements * DataTypeSizeInBytes(target_data_type); + output_buffer_desc.TotalTensorSizeInBytes = DMLCalcBufferTensorSize(target_data_type, 1, &num_elements, NULL); DML_TENSOR_DESC output_tensor_desc = {DML_TENSOR_TYPE_BUFFER, &output_buffer_desc}; DML_CAST_OPERATOR_DESC cast_op_desc{}; @@ -377,7 +408,7 @@ void DmlCastInputToOutput( auto dml_from_type = DmlHelpers::OrtToDmlDataType(in.GetTensorTypeAndShapeInfo()->GetElementType()); auto dml_to_type = DmlHelpers::OrtToDmlDataType(p_out->GetTensorTypeAndShapeInfo()->GetElementType()); - bool rebind = command_list_state.previousOutput != p_out.get(); + bool rebind = command_list_state.previousInput != &in || command_list_state.previousOutput != p_out.get(); // If the sizes change, we need to recompile the operator and rebuild the command lists. It should only happen // once after the very first iteration. @@ -403,6 +434,7 @@ void DmlCastInputToOutput( DML_BINDING_DESC input_array_binding_desc = DML_BINDING_DESC{DML_BINDING_TYPE_NONE, nullptr}; execution_context->InitializeOperator(compiled_cast_operator.Get(), persistent_resource_bindingDesc, input_array_binding_desc); command_list_state = DmlHelpers::BuildReusableCommandList(dml_device, compiled_cast_operator.Get(), persistent_resource.Get(), persistent_resource_binding); + command_list_state.previousInput = ∈ command_list_state.previousOutput = p_out.get(); } @@ -413,10 +445,10 @@ void DmlCastInputToOutput( Ort::ThrowOnError(ort_dml_api->GetD3D12ResourceFromAllocation(&allocator, p_out->GetTensorMutableData(), &target_resource)); std::array input_resources = {source_resource.Get()}; - std::array input_sizes = {element_count * DataTypeSizeInBytes(dml_from_type)}; + std::array input_sizes = {DMLCalcBufferTensorSize(dml_from_type, 1, (uint32_t*)&element_count, NULL)}; std::array output_resources = {target_resource.Get()}; - std::array output_sizes = {element_count * DataTypeSizeInBytes(dml_to_type)}; + std::array output_sizes = {DMLCalcBufferTensorSize(dml_to_type, 1, (uint32_t*)&element_count, NULL)}; // Make sure the source and target allocations are kept alive until the operation is done command_list_state.source_resource = std::move(source_resource); diff --git a/src/generators.cpp b/src/generators.cpp index 84ce0b5eb..8ada93793 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -16,13 +16,37 @@ static bool _ = (Ort::InitApi(), false); OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {} -std::unique_ptr& GetOrtGlobals() { +// Ensure Shutdown() has been called before process exit +struct ValidateShutdown { + ~ValidateShutdown() { + if (GetOrtGlobals()) { + std::cerr << "OGA Error: Shutdown must be called before process exit, please check the documentation for the proper API to call to ensure clean shutdown." << std::endl; + std::abort(); + } + } +}; + +std::unique_ptr& +GetOrtGlobals() { static auto globals = std::make_unique(); + static auto validate = std::make_unique(); // Must be after the above line so the destructor runs before the above destructor return globals; } +// Used by Shutdown() to display the counts and types of any leaked objects +template +bool LeakTypeList::Dump() { + ((LeakChecked::Count() != 0 ? std::cerr << "OGA Error: " << LeakChecked::Count() << " instances of " << typeid(Types).name() << " were leaked." << std::endl : std::cerr), ...); + return ((LeakChecked::Count() != 0) || ...); +} + void Shutdown() { - GetOrtGlobals().reset(); + if (LeakTypes::Dump()) { + std::cerr << " Please see the documentation for the API being used to ensure proper cleanup." << std::endl; + std::abort(); + } + + GetOrtGlobals().reset(); // Delete now because on process exit is too late } OrtEnv& GetOrtEnv() { diff --git a/src/generators.h b/src/generators.h index ffe10744a..26d30e3fc 100644 --- a/src/generators.h +++ b/src/generators.h @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. - +// Licensed under the MIT License. #pragma once -// Licensed under the MIT License. #include #include #include +#include #include #include #include "filesystem.h" @@ -31,6 +31,7 @@ using cudaStream_t = void*; #endif +#include "leakcheck.h" #include "smartptrs.h" #include "models/onnxruntime_api.h" #include "models/debugging.h" @@ -55,7 +56,7 @@ enum struct DeviceType { std::string to_string(DeviceType device_type); -struct GeneratorParams : std::enable_shared_from_this { +struct GeneratorParams : std::enable_shared_from_this, LeakChecked { GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in GeneratorParams(const Model& model); @@ -125,7 +126,7 @@ struct GeneratorParams : std::enable_shared_from_this { // The model outlives the GeneratorParams }; -struct Generator { +struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; diff --git a/src/java/UpdatingJavaBindings.md b/src/java/UpdatingJavaBindings.md index a31620308..9ee369510 100644 --- a/src/java/UpdatingJavaBindings.md +++ b/src/java/UpdatingJavaBindings.md @@ -75,6 +75,10 @@ A header file for each class with the JNI function signatures will be generated To create/update the .cpp file that implements the JNI function (which will be located in src/main/native), cut-and-paste the relevant parts of the header into the .cpp file. +If creating the .cpp file, add a `#include` for the corresponding generated header file. +While not strictly necessary, doing so allows us to not have to explicitly specify the correct language linkage +(`extern "C"`) for the JNI functions as the linkage is inherited from the earlier declarations in the header. + Update the first 2 parameters of each new function to be meaningful by adding the parameter names Generated: `JNIEnv*, jobject` Meaningful: `JNIEnv* env, jobject thiz` diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java index 258d8974a..cb1f77b41 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -106,6 +106,24 @@ public void setInput(String name, Tensor tensor) throws GenAIException { setModelInput(nativeHandle, name, tensor.nativeHandle()); } + /** + * Add a NamedTensors as a model input. + * + * @param namedTensors NamedTensors to add. + * @throws GenAIException + */ + public void setInputs(NamedTensors namedTensors) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + if (namedTensors.nativeHandle() == 0) { + throw new IllegalArgumentException("tensor has been freed and is invalid"); + } + + setInputs(nativeHandle, namedTensors.nativeHandle()); + } + @Override public void close() { if (nativeHandle != 0) { @@ -141,6 +159,9 @@ private native void setInputSequences(long nativeHandle, long sequencesHandle) private native void setModelInput(long nativeHandle, String inputName, long tensorHandle) throws GenAIException; + + private native void setInputs(long nativeHandle, long namedTensorsHandle) + throws GenAIException; private native void setInputIDs( long nativeHandle, ByteBuffer tokenIds, int sequenceLength, int batchSize) diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Images.java b/src/java/src/main/java/ai/onnxruntime/genai/Images.java new file mode 100644 index 000000000..51eaa61ac --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/Images.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +public class Images implements AutoCloseable{ + private long nativeHandle; + + public Images(String imagePath) throws GenAIException { + nativeHandle = loadImages(imagePath); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyImages(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long loadImages(String imagePath) throws GenAIException; + + private native void destroyImages(long imageshandle); +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java b/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java new file mode 100644 index 000000000..85670b71d --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/MultiModalProcessor.java @@ -0,0 +1,85 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +/** The MultiModalProcessor class is responsible for converting text/images into a NamedTensors list that can be fed into a Generator class instance. */ +public class MultiModalProcessor implements AutoCloseable { + private long nativeHandle; + + public MultiModalProcessor(Model model) throws GenAIException { + assert (model.nativeHandle() != 0); // internal code should never pass an invalid model + + nativeHandle = createMultiModalProcessor(model.nativeHandle()); + } + + /** + * Processes a string and image into a NamedTensor. + * + * @param prompt Text to encode as token ids. + * @param images image input. + * @return NamedTensors object. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public NamedTensors processImages(String prompt, Images images) throws GenAIException { + long imagesHandle = (images == null) ? 0 : images.nativeHandle(); + long namedTensorsHandle = processorProcessImages(nativeHandle, prompt, imagesHandle); + + return new NamedTensors(namedTensorsHandle); + } + + /** + * Decodes a sequence of token ids into text. + * + * @param sequence Collection of token ids to decode to text. + * @return The text representation of the sequence. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public String decode(int[] sequence) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return processorDecode(nativeHandle, sequence); + } + + /** + * Creates a TokenizerStream object for streaming tokenization. This is used with Generator class + * to provide each token as it is generated. + * + * @return The new TokenizerStream instance. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public TokenizerStream createStream() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + return new TokenizerStream(createTokenizerStreamFromProcessor(nativeHandle)); + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyMultiModalProcessor(nativeHandle); + nativeHandle = 0; + } + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native long createMultiModalProcessor(long modelHandle) throws GenAIException; + + private native void destroyMultiModalProcessor(long tokenizerHandle); + + private native long processorProcessImages(long processorHandle, String prompt, long imagesHandle) throws GenAIException; + + private native String processorDecode(long processorHandle, int[] sequence) throws GenAIException; + + private native long createTokenizerStreamFromProcessor(long processorHandle) throws GenAIException; +} diff --git a/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java b/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java new file mode 100644 index 000000000..5811e554f --- /dev/null +++ b/src/java/src/main/java/ai/onnxruntime/genai/NamedTensors.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +public class NamedTensors implements AutoCloseable{ + private long nativeHandle; + + public NamedTensors(long handle) { + nativeHandle = handle; + } + + @Override + public void close() { + if (nativeHandle != 0) { + destroyNamedTensors(nativeHandle); + nativeHandle = 0; + } + } + + long nativeHandle() { + return nativeHandle; + } + + static { + try { + GenAI.init(); + } catch (Exception e) { + throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e); + } + } + + private native void destroyNamedTensors(long handle); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index c6c9c122f..cfe684e19 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -2,15 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_Generator.h" + #include "ort_genai_c.h" #include "utils.h" -#include - using namespace Helpers; -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Generator_createGenerator(JNIEnv* env, jobject thiz, jlong model_handle, jlong generator_params_handle) { const OgaModel* model = reinterpret_cast(model_handle); @@ -23,27 +22,27 @@ Java_ai_onnxruntime_genai_Generator_createGenerator(JNIEnv* env, jobject thiz, j return reinterpret_cast(generator); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Generator_destroyGenerator(JNIEnv* env, jobject thiz, jlong native_handle) { OgaDestroyGenerator(reinterpret_cast(native_handle)); } -extern "C" JNIEXPORT jboolean JNICALL +JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) { return OgaGenerator_IsDone(reinterpret_cast(native_handle)); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Generator_computeLogitsNative(JNIEnv* env, jobject thiz, jlong native_handle) { ThrowIfError(env, OgaGenerator_ComputeLogits(reinterpret_cast(native_handle))); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Generator_generateNextTokenNative(JNIEnv* env, jobject thiz, jlong native_handle) { ThrowIfError(env, OgaGenerator_GenerateNextToken(reinterpret_cast(native_handle))); } -extern "C" JNIEXPORT jintArray JNICALL +JNIEXPORT jintArray JNICALL Java_ai_onnxruntime_genai_Generator_getSequenceNative(JNIEnv* env, jobject thiz, jlong generator, jlong index) { const OgaGenerator* oga_generator = reinterpret_cast(generator); @@ -65,7 +64,7 @@ Java_ai_onnxruntime_genai_Generator_getSequenceNative(JNIEnv* env, jobject thiz, return java_int_array; } -extern "C" JNIEXPORT jint JNICALL +JNIEXPORT jint JNICALL Java_ai_onnxruntime_genai_Generator_getSequenceLastToken(JNIEnv* env, jobject thiz, jlong generator, jlong index) { const OgaGenerator* oga_generator = reinterpret_cast(generator); @@ -78,4 +77,4 @@ Java_ai_onnxruntime_genai_Generator_getSequenceLastToken(JNIEnv* env, jobject th } return jint(tokens[num_tokens - 1]); -} \ No newline at end of file +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index 2d2bd3049..76caafe86 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -2,13 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_GeneratorParams.h" + #include "ort_genai_c.h" #include "utils.h" using namespace Helpers; -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_GeneratorParams_createGeneratorParams(JNIEnv* env, jobject thiz, jlong model_handle) { const OgaModel* model = reinterpret_cast(model_handle); OgaGeneratorParams* generator_params = nullptr; @@ -19,13 +20,13 @@ Java_ai_onnxruntime_genai_GeneratorParams_createGeneratorParams(JNIEnv* env, job return reinterpret_cast(generator_params); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_destroyGeneratorParams(JNIEnv* env, jobject thiz, jlong native_handle) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); OgaDestroyGeneratorParams(generator_params); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionNumber(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name, jdouble value) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); @@ -34,7 +35,7 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionNumber(JNIEnv* env, job ThrowIfError(env, OgaGeneratorParamsSetSearchNumber(generator_params, name, value)); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name, jboolean value) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); @@ -43,7 +44,7 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobje ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, value)); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_setInputSequences(JNIEnv* env, jobject thiz, jlong native_handle, jlong sequences_handle) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); @@ -52,7 +53,7 @@ Java_ai_onnxruntime_genai_GeneratorParams_setInputSequences(JNIEnv* env, jobject ThrowIfError(env, OgaGeneratorParamsSetInputSequences(generator_params, sequences)); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_setInputIDs(JNIEnv* env, jobject thiz, jlong native_handle, jobject token_ids, jint sequence_length, jint batch_size) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); @@ -63,7 +64,7 @@ Java_ai_onnxruntime_genai_GeneratorParams_setInputIDs(JNIEnv* env, jobject thiz, ThrowIfError(env, OgaGeneratorParamsSetInputIDs(generator_params, tokens, num_tokens, sequence_length, batch_size)); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_GeneratorParams_setModelInput(JNIEnv* env, jobject thiz, jlong native_handle, jstring input_name, jlong tensor) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); @@ -72,3 +73,12 @@ Java_ai_onnxruntime_genai_GeneratorParams_setModelInput(JNIEnv* env, jobject thi ThrowIfError(env, OgaGeneratorParamsSetModelInput(generator_params, name, input_tensor)); } + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setInputs(JNIEnv* env, jobject thiz, jlong native_handle, + jlong namedTensors) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + OgaNamedTensors* input_tensor = reinterpret_cast(namedTensors); + + ThrowIfError(env, OgaGeneratorParamsSetInputs(generator_params, input_tensor)); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Images.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Images.cpp new file mode 100644 index 000000000..49a1848a9 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_Images.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include "ai_onnxruntime_genai_Images.h" + +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +/* + * Class: ai_onnxruntime_genai_Images + * Method: loadImages + * Signature: (J)V + */ +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Images_loadImages(JNIEnv* env, jobject thiz, jstring image_path) { + CString path(env, image_path); + + OgaImages* images = nullptr; + if (ThrowIfError(env, OgaLoadImage(path, &images))) { + return 0; + } + + return reinterpret_cast(images); +} + +/* + * Class: ai_onnxruntime_genai_Images + * Method: destroyImages + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Images_destroyImages(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaDestroyImages(reinterpret_cast(native_handle)); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp index 61e4a0828..5025b4f3a 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Model.cpp @@ -2,13 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_Model.h" + #include "ort_genai_c.h" #include "utils.h" using namespace Helpers; -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Model_createModel(JNIEnv* env, jobject thiz, jstring model_path) { CString path{env, model_path}; @@ -20,13 +21,13 @@ Java_ai_onnxruntime_genai_Model_createModel(JNIEnv* env, jobject thiz, jstring m return reinterpret_cast(model); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Model_destroyModel(JNIEnv* env, jobject thiz, jlong model_handle) { OgaModel* model = reinterpret_cast(model_handle); OgaDestroyModel(model); } -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Model_generate(JNIEnv* env, jobject thiz, jlong model_handle, jlong generator_params_handle) { const OgaModel* model = reinterpret_cast(model_handle); @@ -37,4 +38,4 @@ Java_ai_onnxruntime_genai_Model_generate(JNIEnv* env, jobject thiz, jlong model_ } return reinterpret_cast(sequences); -} \ No newline at end of file +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_MultiModalProcessor.cpp b/src/java/src/main/native/ai_onnxruntime_genai_MultiModalProcessor.cpp new file mode 100644 index 000000000..9ce4ef225 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_MultiModalProcessor.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include "ai_onnxruntime_genai_MultiModalProcessor.h" + +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_MultiModalProcessor_createMultiModalProcessor(JNIEnv* env, jobject thiz, jlong model_handle) { + const OgaModel* model = reinterpret_cast(model_handle); + OgaMultiModalProcessor* processor = nullptr; + + if (ThrowIfError(env, OgaCreateMultiModalProcessor(model, &processor))) { + return 0; + } + + return reinterpret_cast(processor); +} + +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_MultiModalProcessor_destroyMultiModalProcessor(JNIEnv* env, jobject thiz, jlong processor_handle) { + OgaMultiModalProcessor* processor = reinterpret_cast(processor_handle); + OgaDestroyMultiModalProcessor(processor); +} + +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_MultiModalProcessor_processorProcessImages(JNIEnv* env, jobject thiz, jlong processor_handle, + jstring prompt, jlong images_handle) { + const OgaMultiModalProcessor* processor = reinterpret_cast(processor_handle); + + const char* prompt_str = env->GetStringUTFChars(prompt, nullptr); + OgaImages* images = reinterpret_cast(images_handle); + + OgaNamedTensors* named_tensors = nullptr; + if (ThrowIfError(env, OgaProcessorProcessImages(processor, prompt_str, images, &named_tensors))) { + return 0; + } + + return reinterpret_cast(named_tensors); +} + +JNIEXPORT jstring JNICALL +Java_ai_onnxruntime_genai_MultiModalProcessor_processorDecode(JNIEnv* env, jobject thiz, jlong processor_handle, + jintArray sequence) { + const OgaMultiModalProcessor* processor = reinterpret_cast(processor_handle); + auto num_tokens = env->GetArrayLength(sequence); + jint* jtokens = env->GetIntArrayElements(sequence, nullptr); + const int32_t* tokens = reinterpret_cast(jtokens); // convert between 32-bit types + const char* decoded_text = nullptr; + + bool error = ThrowIfError(env, OgaProcessorDecode(processor, tokens, num_tokens, &decoded_text)); + env->ReleaseIntArrayElements(sequence, jtokens, JNI_ABORT); + + if (error) { + return nullptr; + } + + jstring result = env->NewStringUTF(decoded_text); + OgaDestroyString(decoded_text); + + return result; +} + +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_MultiModalProcessor_createTokenizerStreamFromProcessor(JNIEnv* env, jobject thiz, jlong processor_handle) { + const OgaMultiModalProcessor* processor = reinterpret_cast(processor_handle); + OgaTokenizerStream* tokenizer_stream = nullptr; + + if (ThrowIfError(env, OgaCreateTokenizerStreamFromProcessor(processor, &tokenizer_stream))) { + return 0; + } + + return reinterpret_cast(tokenizer_stream); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_NamedTensors.cpp b/src/java/src/main/native/ai_onnxruntime_genai_NamedTensors.cpp new file mode 100644 index 000000000..a1978d7c7 --- /dev/null +++ b/src/java/src/main/native/ai_onnxruntime_genai_NamedTensors.cpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +#include "ai_onnxruntime_genai_NamedTensors.h" + +#include "ort_genai_c.h" +#include "utils.h" + +using namespace Helpers; + +/* + * Class: ai_onnxruntime_genai_NamedTensors + * Method: destroyNamedTensors + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_NamedTensors_destroyNamedTensors(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaDestroyNamedTensors(reinterpret_cast(native_handle)); +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp index ac57f2337..0eff851a2 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Sequences.cpp @@ -2,26 +2,27 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_Sequences.h" + #include "ort_genai_c.h" #include "utils.h" using namespace Helpers; -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Sequences_destroySequences(JNIEnv* env, jobject thiz, jlong sequences_handle) { OgaSequences* sequences = reinterpret_cast(sequences_handle); OgaDestroySequences(sequences); } -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Sequences_getSequencesCount(JNIEnv* env, jobject thiz, jlong sequences_handle) { const OgaSequences* sequences = reinterpret_cast(sequences_handle); size_t num_sequences = OgaSequencesCount(sequences); return static_cast(num_sequences); } -extern "C" JNIEXPORT jintArray JNICALL +JNIEXPORT jintArray JNICALL Java_ai_onnxruntime_genai_Sequences_getSequenceNative(JNIEnv* env, jobject thiz, jlong sequences_handle, jlong sequence_index) { const OgaSequences* sequences = reinterpret_cast(sequences_handle); diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp index 1b3e5deb8..58d9aa9fe 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Tensor.cpp @@ -2,24 +2,21 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_Tensor.h" + #include "ort_genai_c.h" #include "utils.h" using namespace Helpers; -#ifdef __cplusplus -extern "C" { -#endif - /* * Class: ai_onnxruntime_genai_Tensor * Method: createTensor * Signature: (Ljava/nio/ByteBuffer;[JI)J */ -JNIEXPORT -jlong JNICALL Java_ai_onnxruntime_genai_Tensor_createTensor(JNIEnv* env, jobject thiz, jobject tensor_data, - jlongArray shape_dims_in, jint element_type_in) { +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Tensor_createTensor(JNIEnv* env, jobject thiz, jobject tensor_data, + jlongArray shape_dims_in, jint element_type_in) { void* data = env->GetDirectBufferAddress(tensor_data); const int64_t* shape_dims = env->GetLongArrayElements(shape_dims_in, /*isCopy*/ 0); size_t shape_dims_count = env->GetArrayLength(shape_dims_in); @@ -38,11 +35,7 @@ jlong JNICALL Java_ai_onnxruntime_genai_Tensor_createTensor(JNIEnv* env, jobject * Method: destroyTensor * Signature: (J)V */ -JNIEXPORT -void JNICALL Java_ai_onnxruntime_genai_Tensor_destroyTensor(JNIEnv* env, jobject thiz, jlong native_handle) { +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_Tensor_destroyTensor(JNIEnv* env, jobject thiz, jlong native_handle) { OgaDestroyTensor(reinterpret_cast(native_handle)); } - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp index 92b56e7f5..b3a0728c2 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Tokenizer.cpp @@ -2,13 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_Tokenizer.h" + #include "ort_genai_c.h" #include "utils.h" using namespace Helpers; -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Tokenizer_createTokenizer(JNIEnv* env, jobject thiz, jlong model_handle) { const OgaModel* model = reinterpret_cast(model_handle); OgaTokenizer* tokenizer = nullptr; @@ -20,13 +21,13 @@ Java_ai_onnxruntime_genai_Tokenizer_createTokenizer(JNIEnv* env, jobject thiz, j return reinterpret_cast(tokenizer); } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Tokenizer_destroyTokenizer(JNIEnv* env, jobject thiz, jlong tokenizer_handle) { OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); OgaDestroyTokenizer(tokenizer); } -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Tokenizer_tokenizerEncode(JNIEnv* env, jobject thiz, jlong tokenizer_handle, jobjectArray strings) { const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); @@ -49,7 +50,7 @@ Java_ai_onnxruntime_genai_Tokenizer_tokenizerEncode(JNIEnv* env, jobject thiz, j return reinterpret_cast(sequences); } -extern "C" JNIEXPORT jstring JNICALL +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_genai_Tokenizer_tokenizerDecode(JNIEnv* env, jobject thiz, jlong tokenizer_handle, jintArray sequence) { const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); @@ -71,7 +72,7 @@ Java_ai_onnxruntime_genai_Tokenizer_tokenizerDecode(JNIEnv* env, jobject thiz, j return result; } -extern "C" JNIEXPORT jlong JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Tokenizer_createTokenizerStream(JNIEnv* env, jobject thiz, jlong tokenizer_handle) { const OgaTokenizer* tokenizer = reinterpret_cast(tokenizer_handle); OgaTokenizerStream* tokenizer_stream = nullptr; @@ -81,4 +82,4 @@ Java_ai_onnxruntime_genai_Tokenizer_createTokenizerStream(JNIEnv* env, jobject t } return reinterpret_cast(tokenizer_stream); -} \ No newline at end of file +} diff --git a/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp b/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp index 8e725c807..9d0c14158 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_TokenizerStream.cpp @@ -2,15 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ -#include +#include "ai_onnxruntime_genai_TokenizerStream.h" + #include "ort_genai_c.h" #include "utils.h" -#include - using namespace Helpers; -extern "C" JNIEXPORT jstring JNICALL +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_genai_TokenizerStream_tokenizerStreamDecode(JNIEnv* env, jobject thiz, jlong tokenizer_stream_handle, jint token) { OgaTokenizerStream* tokenizer_stream = reinterpret_cast(tokenizer_stream_handle); @@ -27,9 +26,9 @@ Java_ai_onnxruntime_genai_TokenizerStream_tokenizerStreamDecode(JNIEnv* env, job return result; } -extern "C" JNIEXPORT void JNICALL +JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_TokenizerStream_destroyTokenizerStream(JNIEnv* env, jobject thiz, jlong tokenizer_stream_handle) { OgaTokenizerStream* tokenizer_stream = reinterpret_cast(tokenizer_stream_handle); OgaDestroyTokenizerStream(tokenizer_stream); -} \ No newline at end of file +} diff --git a/src/java/src/main/native/utils.h b/src/java/src/main/native/utils.h index b56fdbafc..f8952b7f0 100644 --- a/src/java/src/main/native/utils.h +++ b/src/java/src/main/native/utils.h @@ -30,7 +30,7 @@ bool ThrowIfError(JNIEnv* env, OgaResult* result); // handle conversion/release of jstring to const char* struct CString { CString(JNIEnv* env, jstring str) - : env_{env}, str_{str}, cstr{env->GetStringUTFChars(str, /* isCopy */ nullptr)} { + : cstr{env->GetStringUTFChars(str, /* isCopy */ nullptr)}, env_{env}, str_{str} { } const char* cstr; diff --git a/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java b/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java new file mode 100644 index 000000000..5ddce67f0 --- /dev/null +++ b/src/java/src/test/java/ai/onnxruntime/genai/MultiModalProcessorTest.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime.genai; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +// NOTE: Typical usage is covered in GenerationTest.java so we are just filling test gaps here. +public class MultiModalProcessorTest { + @Test + public void testBatchEncodeDecode() throws GenAIException { + try (Model model = new Model(TestUtils.testModelPath()); + MultiModalProcessor multiModalProcessor = new MultiModalProcessor(model)) { + TokenizerStream stream = multiModalProcessor.createStream(); + GeneratorParams generatorParams = model.createGeneratorParams(); + String inputs = new String("This is a test"); + Images image = new Images("/src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg"); + NamedTensors processed = multiModalProcessor.processImages(inputs, image); + generatorParams.setInputs(processed); + + Generator generator = new Generator(model, generatorParams); + + String fullAnswer = new String(); + while (!generator.isDone()) { + generator.computeLogits(); + generator.generateNextToken(); + + int token = generator.getLastTokenInSequence(0); + + fullAnswer += stream.decode(token); + } + } + } +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg b/src/java/src/test/java/ai/onnxruntime/genai/landscape.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce6948899961274994ef9c399b336faa2e0c1ce5 GIT binary patch literal 337412 zcmeFY2UJwcwkX<7&N)bKa?TmaNX{9WB%u)+n%Iqih$uk?NeZZdk~58@Zh|68lq9)9 zk{|*Cq9CAvAb+9z?BjXwynD~SW4!;zA7igESZj9GtXWlaR?Sq!>GbIm=(TRRuQv!} zXea_A0f9h7AXYFT2n;}Wz#rQmuqyxyf$@Kq;{mWV7!M>4e1V1_831Mgz8S!u?@fZA zIOPC*_80(vXSDwD)$$7TfI%#L0+E5CK7m0H6W>cd$dJH51VlkXT1rA1B4q0B=^JoX zD-2MUl95u8l2MV8hnxY5D$;Twp!YOl5CK3_N=8;1od0V~XMLPKr&}Pge^Qr~kOC=b z8=B&8{j8+lvEpg1`V73Ih}0fIvJN5C{=~46)MJhB!Dv zA;hnNj!6OLNP|G`o}ode)_OK)G!1oiAfW)#Px^npzT$pP8Z<6$U;%;rS^S>?w4OmB zNPue=0J^LfEYuT#;{e$IGBW6_d>eonJp9gJyfZd1ga88s;B#m2rC+evPn=({<{9jT z2=D@M&c+<%h44Cqy8!rB7zzf!_$C0H5atWJ48Rir%;%3n_yX_&05c$9?x7$M0mWH4 z66Waxz)}EA5n^Md1Hh^v5DCSl-(Zj5U?eOY7$*p%6BrZ`awcntkf$(2T3J~MVgS4B z4?`lwE!;i*+(Wz|I)R8F_kaix=+~UjT7l@!cnblTtRStdpdcKW)Chzfy(2Dy8}Amab)BK{vY{B2mjjYG@|<_!yh1pr;y0jtb6;1VG2054ypZ(sn# zH{kyy;s0T?-^OqT{#w@nKwta=#NZ|Yq8VlY;cpFs2C>xXqDh5@6YC!d% zCeSm`OHdDJ05k%c1kHjLK`Wq5&^G8R=sOq;CI(Z2>A@^uE-*h>6f6T)0;_{{!6slU zup{^)_!2k(i~>i2W5G$_yWnha0k{HO2W|ns1oweQ!EeAz;C1j9@DUz99t9o)9tWNP zo+O?Uo)(@lo(-N0-X*+Xyhyw_ycE0#c!hY?cv!rbcmsHoc#C+Oc>DMud~$pyd~SSk zd?kDxd~Sv5u~Z4#iXsI z!=x*u-^pmn1j*FMY{=kb*U27`)sS_Q&5`Yrlauq1E0bH2`;y0yKOnCqe?`7Ten>$_ zAwr=;;Y@*|NTDdBctP=oVwaMVQh-vE5=x1rOrflx?4(?vJfNbdlAtoA@}jy%l|zN4 z8l(C|O-jvAtwrrh9Z8)<-9SA|{gH;0Mu0|##+@dbCWoelW{PH?mY!Cc)|?hWdxy4? zwvYA$9Wk8%oi3dhT^wBzT_@c#Jw81zy$-!6eJp)3eK-9I0}+EDgCPT)A&CLa@S0(p zk&aQ0(T*{kF^92@aghn1Nr1_a$)72esh$bPbi~ZTti=ptzRg_2JjVR(9NRg~bFg!X z=boIKJa@zbVbNvrV@YReVwq+!SaM_#k*r&`>Z^uu5=Fh*C&HC|IaaXhN7sSV`DV_>u642%d<% zh>u8)$gn7$sDh}k=p)flF+wqAu>i3`u_i5n745}zdbB%zY& zlD(2Yq~xXirHZ9yrRk)Nq;E*KN`H|Nk@1wtk(rRCl+}~HCfh8#BPS{clY1;TEl)3R zDt}YHL;gfTQ6WU3Mqyo%Pw}E+uHuvugOY{P9i>;wgv#2=*OZ?rAF3#*AXVyBKA#sq z?{~iP{0CJ*RhVkA>arS-n!8$o+M+r{{i1rl`l1H6hPy_g#yd@3O)t$-%{47yEx1;- z)+cQ#?GSB@_BS0BovS)8F5q9#zmRz0wJx2mt!}pNf*z0FCA})WZGCzDEBY@C2n~!4 z(ha5zISf4wD-5@d6pgMLbs19_TN`H^zcUdr2{CCk#WOWBy=OXS#%~s2)?^McH#EOz zK5rpt5p40)lE~7+GRJb&O4=&Qs@Iy{+ReJs`m2q$O{&eTt&lCsw!@Cv&e^WQ?wkDu z`@8mw4w4R69R?g(9eo{}p~O%-Xeo5xN!KaUX~kLIIl&p{BIpwC(&x(R>hId-M(yVA z)^L&VqW#6ni^uNf?uG9A9)=#d9^0Nep7%XBy)?Ztyw+iAu)DA|Z#D0`-XAWhUAlK^ z-ABVG%jcu7j&HW_4%`5q2mj`0?pNk_>TmD=1VMthh-e9*3-Al*3giq74;&2=4N3@F z3|0xw4Bieg4k--%6`LpP*u+`X|EYaiPh z#}OA3w;Zn@UzI?W5Rx!;^Zd=cTLia!Z@s=Pd;8(-ABkRx{Yg?uSxLutJn!@;OD8`_ zK27mXd7Y||`Z$d!Eg%h-u904Tm-cSt-Q|0x_plj|jKqw+Ot;LwEV-<_`(*b~_ZJ_S zJZOH%_b~0@akfwPWR7;ulU(-P+qvH!c|96?togVmk0UQB?lUqCn4h}RT6 zVS18O3#vucZq!|@!_^ztcQhz9)HL!p=3?kEiC8c;0=wPh)3n&^*gV=|)Y8#*?&( z>TQ3e_Uc)mN?%LAQh(Ec!T@$qeh@PxKZJR$@ESX;INUsj~|=} zpS1ok{xNroI6XaGVnzJn=LLjHxSvHm5}ttw$#C}|NoffwNsy{WIN*8t!H^IS*d^Zp zHJ;B;+j$_qUTQq{@}^RzK{_xWUxUaHn02I?jc25vr?MB1hB}#QxJo!82mwR7L&6dM z0ii14YCJ!Qs{rs>u_O=VCkoO}jptXa^ojHbIsAPT9*!voYq{xK6o zkg4gP|_H z2zYt^**XXn;{TH$UY?RLe;5MjD-@8m^uLT5klvpS{_PxpA6<|?D)e`3R4DAM=w~PX zUfz-bv7d|S_cn$>NKZMSLO1-MtN|GB`Lj_d(o6C;1W6ewDJ5|!8F3lqUxNhNnVPCx z2=qjqjo$Er8V^t+;pOY4BCQ~!EGH!cgkVxqa^iAwGRor0N^+j!ii(Ob4`rCFjJ%ZM z&td!`VG#lgg@uH{y!3qiVP@_Km>Q3al!BbBjIy+ww5*b%w1R@1j1oXfUQW?XPDW8m zPF6umPC-^iK~7OwL0ZNwPy!xw={JJrK=9}r;O_tX*z^>nl$CW9b!0DS>*#3#9~mWW zSt(sPxeHQy(z3Erzcn`u2t~RB(Z@fg#$aK8kTga^U?9Gi+x z2+SQB7-AC`=>Lx`_s;~xf0f?&$M9541HF8`BQCfje{xAi>P#Y1@;3j-ctHP&2yAdK zcclA&gL<|`RP10Op+In@#v?5uBOCqh(z3CF-U6c}*C+2;IXds@SSfBpKkTm5~1 zRzV;?cbJO1=h?nh5~5u6aK&FFU(uUT}oD2R@_ro*+cwnUy3Vxz~sd}J)}I{ zr99-MZwe^dJ*T1nsf0ahy|HNwi(pRd#*-4e3j>B*I zQp+FtuS)IAhf(ndb~|7(0d4d%%l3~w|BILZXrlVN0sbf9Ur4{XNdWy{Wq|ty_K{u&4Vd2|E|0FsNBfrFvbSx_8^ zj*5zkhKiAfhLMw=mY$QBm64H^mxuZ6;i0AY`TX(u_2YkigBj@QnOKra08$KR zFDiy#O=$lU0;lu9%@?oJPavwFM^7NH7r-?b@aYQ>EeU}bktRNf2Lz@COLG$YYgMb=B8W0{BAB;~-L_$hRf(HhZg7NS{v~=7A z^gO)MGO}9Q+d7t3)*haO4DP`pp*Lc$r{BF-P*m6Ovar6Rb98KcX@^hlg3Y_-FT0Fh z$c(t+Mxw3>e%(D*AZ{PTdUFOL|OEbk*ZzY<_t zQmLn(C^6BMTjA4RY0~Q~vq82zc!Oz85jaj>OC819G;Yg8=7_l=Y4=VZ?)N`QfRf_Abj(Hc>Yt{!nfW$?JAYUj`z+5{w#TI zbfSBz(XFn!e)7oL}s3h=8kk(q@SAjo(Qr&a`7@S?TM3d*ZegbWopl z=Q=M9zFcmB2XTR<^N;8GhF=RHpfEQ!X{O_OI+psuZK>yC{ZwjGAKLBr-PV)XGV1HH zxH$~$>T~WsaqsSr;Icmj!AqnI)_powj}k1@G#p*yk~k}FWGQ+OUzn4SSxoM0)n6Xs zEyp_%L)IzOvO7F??MB`oXEFTHEO3{Zl8xb_7>jwy*muTHAM?m{8G@=S4eP9QZ{%jH zhGi{oF5mY0{N?S?_AOg0%`V#Ls1(I7k82(R8+0jog2rZ!-642+Ph>tN@@nKcg-mfa zVXHxgignM0;=Ts=a=ujgsMS6!cJRYi`JFlu@gf7Bi6>(+kB*mN=}SofRZ z99sI`)V0kTWf$oznIl7y;v3zXylsf+v4bC7>cce&Z@E|HKRyd?QsNRfnUx>bP=tl` zR{Xe6adiF^B-jx1cu0D zzlld1W2E2mUgK1a8mMXLOy2DzsllMCo(h}Ppld2*8hTT?Rxr5o`c~{@R)tDyYwM&S zv~_ZFlBoj2_|!+J1_8%UrbMACq3%1&yKXti2_)eo_G1Og1wjU*Z2MF!K;? zhX?I)GF1&f1qr`1PNGzY7SUi2&-*=g`xN{(Pp@j(Z~e=0qFL6Ne-K4j`R)K z7RSBIy^VO3kWz^VOYq`fDi8zc#iiT4Z0r0n^O5wy`hY!dY`u z|1zrXqmR7m-cEhNn$vj4xQ%_P*r(&?2}*{l3^@z}a~d~peLSzsvmv)SRu1JIa`@CP z4v(Vc9A+gKVhB$fCq$Iu1)#cq$YP87cI5jJs$JK|yb7znbptiW(X<8jA8XazKb ziMK{X_+Uf{;{+9>6cxIWW3i0ZF{@TULq;+&!kM@lhKsJngr=xc5TOqgu5H$qiSS_> z&%W=$`mQ{P(JYhOLnJZV;qj0!U5p$rG_qFowcN*I1KlAho58pILN|O%P-vSLbgc!H z30mC)sx4WH#3FN8Kf;seNu6x#3Z-lCvNiLu*b+G*#&}dI*g@CeZX>^Mmqof2Ut))F zvBk21MEyiZqe2DV1ivq~0Ao-n?l`R1QA?)j(+xvkAQbKaL|2U=;j;&qFjiG+0z#=N zCP-wn%pS19cXEPnTHym?06W3c#2w+PerR}FuzWSP&=NXPXi$KWN*u6=(M6T-0OI8K zt1GN&t~k#Mtci(-E>7@%9nmU3v{}O2(~+GLa$YcNI!{u4X0Lh4#`{g53nIV^IB`%hD*c&>24XLILComZy zGLw*$MZ6*O=d#EmiTK3K2p$izCe4&xYnoTLvs26qWd!reS{E)Ks77oQjow3oF1(q9 zQF@e(s(p35oSQc}cF9W`bWx&sudz=Wxbzb8G_TCCHB;M6@95qPiEBlBB|5=O7sF#4 zO$lOdJ!u{GbMl25xA67yx9CxX-_Ji0J7U;v9AGiSf0`aS(eRMPT)*&H>gxM%7}$7lNyyaeOBBK|;pC0L)sT!B1B_u#7bRm^y2@&-6nTF^6xrxts(hJh>-xFT0 zK)m7LRL}b4#x^W7lrER3sWfNE6pLLSD(gUZ<<*xBD4BF3biK&fA-rZMoLH+lA#h@h>hixhCBW zA($0k)4hxoXSsuZQgdG==i6nuZyv6nJ)E?W?WA92Q#XCB@Dp!@*BDI2;g1;MayMP# zYF6xN@_mPIf7W2WyAU1tm2+5iR0@S34n8-x14qvv$xDS^sVRnQt+f|zjXqudMue-a zKYZDDwY90#&8?x;Y`l-9m07#2fhm#r!=%FMll#+Q%k{6=cw1hnR)Sl!pO(K(B0GwQ zdEjjXQwMk^A)~@M6y#W>S!dpd+UGPXj7~oj92)*wHSywJX};w-Jz{ac@HOc#YkB6a z5nEA9Yjq=uZ1yEs&1Xv}-rO+z$2K(yZ4WNVq!NWY2gCw9W590i4SX|vXZZk8g|>vD z5LMzlWc^*{-XwXfYNAQh}jeBU_(yVG+YaE~|v8@xOX9^dM zAyKF)>K#G@DiyF^ho~w)gcW5K+6lgq@CLpiG5|DP0(7uV?3p44w5+||l&>YMN-r}D zU2%>NHWC)T^KQ1ya|~$ws&Ye)VlZhl$j?qE(*Dv;4WnnB_7B!r1%b&s z=Qb$Oj^7o{F$a;nK5lB#0z}U)8my^IlDwyWqvs+Nyy>*ES%q)HNFg~J8U7$*Gr4E$ zL$7iP17D)ufyF~f$1U__oll~Y)-0lS7m-_K*pvfdQuc}Xw9U|=3!krLm);jm@?AqO zCp~t&l62QN@q8|-d{N9ney>VjepA`hZu@EKdB0j_OoY`&1lsk}SboJZ>fP%sk1Lds zk%~7{pU7v8Pn}uq`R+;FndPDYEEC!V5Cy=)GsX{ds$NTn))#NuLz|2?0aLbVHwDcf zu!mwNae%eLq49p%J`DCZ+Xaa9jKXqn>Jql5{0yrGqnou$=22d&@Elt(a_#d+?X>`}XXYGH>duK1rGQ&Ths0D)~HdDp%*sZE5z&x)+IZ4Esjg z`zntn_gY;f22l~!UmM`KoypTk0iu5|WlKtJu7y&@95CEqVJ^TNi^T?gtz$&Hzw zA`^)TW}J3?pr4&=hkhg|eVienvS-Jr^403`c@CjB<3Lal?=WgLXtVW5um?LamQ~&J z2+auJv?d!#Hz2H{Zi&0iq z>nbzM3gIvd~JaU|mctlk3fqqjr2Ll_W-x z165kPF+=Y+)M#vwhv{T;{Fl|%qieRytC3wcg)PrKdgeBhBGrG~V;D?ixObgNXM9!7 zYPpx|6qIZenDr)0A++xzxyR;PGp>aakF7DE(fZDtG@_xXx+cQRsgdC_BL~Iw=UaIY zaI$q?9#J6IAw!N+9~DLA@eAX@o?{NT(5v5vN6vS$P=sj5aj0lg5p`a^zj+E8xyO&% zCRiwrSh}mBOl(gUJRG|BS{?2>v9N#Y=|M(VR z7UTi>cue<1_JdB=^2)nj>Wv>ys>|7~Q5kJ^bloURP;HxgtbwVJA%=N)dVpx>pkKex zUFPZ?#>%;;8Mt(aIJ`;Rq~!ym)KueMuX`dzq!OBKw|1`x&8zIWSjW;9r0R}Bf84)v zZ8;=Dm1gNhL%MYO>N_uM0{uHV!V|78)T!S*YV27-RMj?HbX{<~ahj7mrfNtv=> zU~@CIJ66`;RR0p+Nn@L;edyxf`=y~0Dm4EhMeAAga&BN@1CL<>+c0XeU@3g=GuZwP z9*TN4@mp&Ys^wCG%;T|sX%ZWX7v#cCEQ0xSY8qUzG{PigZ;X7P4{+AL0JtjUcz&b5pFg1<}9yk)n#Fn>aEmq zR@M!fpwf|UFzcI4x`5QsjD>D{^(r^!hAm3k+sjH_cUedS-rqqJj0M`a2r%wjY^4tf zY$;v8W5B+CUt67;YqkH}&J9$BE1_pV@uMjQ_dkZj!Tjmru7kx6<+Zf8p$i1AIk;cgtKBwDWU*DSF zA^roEq1RAf*uN!e07N`yH8nt7g99Q{gkgU6h-9F~hiWvjiwKwpJ6Sq8Dex=8v_CWI zfGyrm^A`oo5FGfRb=?6wGK;Gqn#7%jPELS%2MQK1*h351?EeUwGHfRW>=HNaZSGi? z$SACx$dS6g2@+27%OM9mj2wA$@qyv80hc!>FQGTt=!oS#1F~yu1{&6igkG!3 zFS(ZqU*FF(NY6T6LW`vs`w$PA8zy7QX_H!->8W^zhlftn&}Dt??d%U6gE#7x`Jp*;;#MgR4wZ7H#U&Vt$YRIk)_b z?J4Nkb;fg$|A*GFbzg&7+)eFb^68X*f$+DCRSDLw73su8ftcwybQ~KVtKhII2i8}X zo5LtN_OfM8K*aZ0ahY{>Yr9dHJ&R*%j;{FN_RB+hdH7u|Oz*@3(Gj0_91!byFZbn- z>vk0w%S)0I%7)ymnO~llQQu{&W1r;9Bdnc?11#Npu>G-}qW39?ICn}WcJUMxD^^ep z%9$$t&O_@VJfdgeFm5pREMi|~Kz!?SXGw}uAD^azjTVdvP2=Ean~CnJn|C?yK{P#? zQW)p&LjpLsnrKpnZ*J+c>a?FSO{I0e3ztYj>!xi3jvs-@Jj9vgjKsB=c;YiG@g010 zYfGeSOELWGw~CI045p0S8S}01%6+RivYX}0J1ugO$;Fv$<3@Cn5x^Slx4t5JUeqtU z*m!pg_@5K8aYC^XR)Lc23D{Ikb3q$yQxjHBK0>c)s=}tE__#?VPh@1$*B^(A@pZ`z zsvu$Ib|7S|)TRJcy!X2+G(C2!lIK;M4^*l#-1Sj1ja@BTc_h)XLK?}O?|rs0#?0jj zD=4bv^lHh+&Jn87zc09;Sn0{6k|-|K^=51)=@f(@?|89k*P)B-g8Sv6kQ3&tcc}Zc zY2+nVk3uaIc%ewvQFMja4a};bZoaN38iF;V5W&SQ8%bxwSd9jZc&cCu#WJ`R?+%h` zAh_)i_F}5WFjk=fCm^CKnNql^z%v>?3uHeWh?25fCo_pz&tl_bA&5Qn0bzhVj}gw( zyrY3Eoq){(`q=6wz9opM5JLl9F0rs+9l*`(Pvaic5JP=1lC1L9z&?IcDqPuA*nySv zL@Tl?>93c)0?ik!ms+ejVT#@#78AnSnRZjD^vqwMf^@BQ1wI;c#?3wF z)lK@Gw4gK1?Wf8@Jy%DRiMMP79hDRPB9!@SQGv*Q&m!+bAJb z6ff{)+6N8hvQO=}N}a+K?f3u&A%co?t#gt@!vZea){p8^Mlsq`GZ==oW%Low?OLio~SRP)^YG{vZ^xqpOaZmLf1PQaK%ZShcKxC|G*vqGW*SZtUVz+S!j*zgBhA zHxX%nXaD*w*U_8iBJ0U>-jDo}g}n@;;qYM2t3KOm&-UBaA0d<}hqX;*H1!oP!Ng=Z zNZR{FgI{-K%PV3G#T;_3+cA;nb0w5L90>ivX74N>an!twE+e3VbC8zG%I4*+d#I!jhh z3Iay^_^1^yZ8)6W1cyrr;4_>hUQvJe4_QAI(ythGBMosDX8v*wU{&I}K-3z0mNl)P zWIA&uaIKTvl=?$ZHm9%lL}Lq)ui@Zz`Q(`PKGQ`(2~X>H;-Ryv`4ka17a5b1LeC4T zR@3qSfL;t@S_4MHM}`<4nl*{2#=BnTtF9V7ILFpoSv>gA^Xe5J)^iK*Im{mywli+o z1a`y}xhQw5WrcHm`5{D{oHjE>{9v?49)sbG37@_6^IxK`M&+JM{9#2!NqESx=<4$nHPztkyFolLN%nc$i3HaZH=`r>DkGL0 zb+hjTRHHAYhAF6Mh3sbVdBwWj%+MG3bVI7fkomQZBA}FH$P@q z(dZmh9>~QmiuzSm#g1(Zkj-q|!<)eZ3x-ZqsEX1sAuds*(~1y*f^BHm6tsEU@`{q? zzitq@2hEt54xlP51V>6`0i_*v4O4uAvmNV5f4Ul#4^}^lLG?l&&*mmCn`TPM~td66B z9IKQsWHgcX#VEaIV_0D$n_WCn6YHFh_6v(?eBHPq8h^}M^ELSbv4kp6eHQB=N^eN zmIro6;u3nQfd3O_!W9qN-8PhWD2W8&YmeUW;SSMjoW-%3?2u2~ElG~#ql!#GIN<84 zCyFg-&3IxG*RgUge?a<;NJ}T)nZu5)Zug?l->e4e%T`!h{(EbvYdE4 zD&7eyGNEa??BZ@Uf^dYNx4;-ER730*;B+xaQ}b;^S0;kJ!w zCpB>s{OJb2dHX?P zon>Kqg8ddG;>nBs5{Dk`9BJH(uWyf#;M-)q@i%p!swcCi_3K@~so_)ixm06vB2E*; z1~S9Cr%m-8ab?esiW*c}+$+T%s%exhugRwf%kbZ^3aw0fF!#fph*PbYL0CXSTTvw~ zKz;qV(jvN2G5yuKl}l>`+s(YAd+X`-BDu{53gQS7RKVvZtIU)aeZB5MoqA2x-#OJy zs7N7oK8d>=;Lc-7N4w`#-&T+-OEDLVa+7u$CLNpT_MeQc6y7_bSM|9W!NvISLed-V z^}zN|qtzuh=$1adSy@|`WNw0fOi9qSW>w~|kxXuu{&vl2F-OH@ogr}$>jLJSJi7Xp zd-;937IX58T7`CB;-tf9b)ynrW4F2*2TCcAaJqC6CeW^;bd!(TC$7NWoZv*H+54R7 z^2xx0x}^`2nXoQc zPK^{%%K)F385A!DG{ZFtMk?xSoxH@Ticdib4>bMVY>6#i>pt-hejlMh`-3Kg^0m>O z!r6wfsiM|EuJ_A!dZaEOftP005)v=3%9)*lQXdA1g&V~sGGWV-k5=2GsoPw6l5poG|24EL`bE$4U{K2YJVkC;7Yian_*zupT69CWVN2dWh1E9 zz`(6Jq6;6ec?iV28)+}k0m1AJW%ni!oN5Q3FlJ)T zlGASsfg|to)-zLHTfe&-QI2r20*-8mvukrsK`ul~^z;H<&QxB5gL z3uM7{6nfrDrgAa_0`rVcnZQBwaf$`#8oc6uz@h>s69~5Ji#L89U!Z=T)@II4lmR|R zzBTXd_X@UroD2Pm6j=v34AKcfY(=Y2ti2Bu!*!f!s4uCBE>1JF^e-lhzKz$8Aq+n{ z@slWed#t$qWqCe>+g-$8xA$>w`t{*qbBy>jrD)$%*AoK>Awt-zM#i!bwMkmC;*@D% zb8+kml+s{kgJ!acQDDl>+1|ByW@7;7T0HLHC+T2nI8`9foQ(%J707xmdb>nk2E67> zxz3Es&X_J?!yh(!wNSpskY(y~i@!n~t*U2;o`*g{3y{obmVU)(=&lv??0xhhMav1a zj$54<_ci{C3qu-;sTO;-zL8mpWIi!v zx60#OiY>8?&{#9w=a~hYhI~&4-)yUyVtmKQaV-OB=!EX=5_?y)c&Z59gktqt z1vQWFNKN5hF*K5#eA33QHPmv&z<#Vrlx(fcG*M6EM%hJ5HZ~k>iImMpVeiKSM(*%A z?+hPPmTHXH2(*3zxBwvY(#a7%6AMy26xCM?!{pnTR2Q_N^btU63)T%~^Hs4jwmk1! zA_SPkQD-}gQE-W$$=1a5o{E)E7LdmvbBq*f;<#UP5scvxyy8Hf9pto&6Sd5z*N5{n zF*YL5coPbxv<|w)=BE>f`Ge>gF~k#Tb=8RZZ04ur%buZQ#*Y&sE(s>Dc!h}8Y`tn} z+x_fT)EvEne}NEFi%F8gv{9}dtiL)1ao(ZE!;yOwIXzBo$c_;HQ5hR4$H6^=p1yYU zUEjQSx!d2+j;Y;8B`N4RN1tjm(Or%J&_y{hCA%L;8I`Qxo&x>Am`C!KL@jfePqqE`ZDJa0<6f`C6$Br5slpJ#_vpQ!%SrZ*P_BE&X zh_ylZ+?!tcTs+nA!@EWqqjkL6)GB!x(;i)7{_dD&Y-{vp?>o8cx4d8!T9Yv!Vd_ub z3Z~v?I9InDqZ6_14%Ji?el5>)ygc3CQta)b+rRsGKc2+Lp5cYZ3+MRN@mSTUYn~1M zN*KzZ>{*?;Z^-y$2d_K$v|ZHHim4no!{(>q`HoqgInUyKogdEAABu^XTc5zT#_WI8WE$Segi)f0 zW{x)0*elDtYWv^4UqQdlbvCq?Aks_@c{M}pNBcqMsNa9I^^t)EiuFctyZ%j$r@9^Rk8EhH>6G!;XIWne{M;gofA0teY?O-THeZ@DmI#4h`*T1>nCB`2gd@8 zb-dj7*n~uoO%YtZf=`dh$ryT4;jAbrhdc$3&XSitY6yye&l?pb7oS|%W|M-cdxoyt6%ADvt zPF1H5%UL^FSE8P6Z}ksse*1p%QT#}>Wxc%cXxF+ZM)nPAi%W2+z(G6EBUi?SwoggexeO&TUUR~$h_y8NlqgH!`9-Gij*BrG9 zI`&pU^Ks}S6YS6jcR_0DLfG2z)eI))Dt-|nh32g`6Wm9pLs6_UVr7^md65Y)uvOwt zp+C>@H||yGqN*y*3i0-1$RG&|aFLNSTMlgm&z7)339B?2bfabfJDw-<2KUh2S(k=! zIuVN`w}U>sRV`ZL#7bC&F7aixhdzn{oC?gOfP+1)lgrs@{}&)%FldXbz&kr3Kvi-B zXAHj%pn+=ufES@#^7BGgd9^4>IU>gHI+XCI%?GmLo{_V5*MaoGvOh6vhj3Q4s9hhb z(P8~KD*OlAt;3BsgeY3qYj>L|ufepk&;|oi7V@e;RIkC3$hKo7gNN-ue&_YP8*g<5 z;atDb$C;A1eF%%b{q8_ntYFrSU4H$G!GnBOa=YHTW}>Zh)y^SP(|b4WTVntc%+`&9xokFDX$l%QszOW5 zUeJ-E_55^_LUF6YUgG#^!Xq;=@0wK7xrQ&eL7&=(wjlr<1K4JuCM9CSOpYieZX!vJ=*~>7kU*Lv)MHd85szfG%1|*6IWaGRQFrhaLQw8A3y7{K_9Osi8Nvrnb(S$c(wFg(|Mh;9}LF}($mJhlk0QQS$B z9Ieb$O9r>}O?waZ)zq!x#jGFT!~^K+sAjyGifZKi?|QQDUq@lo;^PHZD&%^|jxl^U zCiT2nZLmBTLIWR4sbn97(?y7-)L27y;}~$VZf#K0YlLmgO=Pp4hGyu(?R>rCvUzpj zI`MRgFN9-k@^5D6$}!{%4-7e?wS>`aBsFqOSp`J3IDGg$=-r7iUxW!mk>dp)Ka<+T zPWZUC@Ab*aDU7#G!lsK(5OjpuN|VGGiseBd%#P>7PBq5O7gA zGvmQ%`Ny(z=H5-Ey$ze8W<3>nJf><2(XZhZJZ{+!=o&)c9yURQtd{SDGqKe(xa;sa zAGEe9u@+CvXg0mUwbRMNJ+djhx4MmxlS`=|iG6uJB5pB57LK-et}xc+zryRKG5vH! z?fD5qJwa>EtJvGNjlt2JM0;xUNLYXRV2iHLr|`7Bknd{t^C33eC5A)~Ug;*Z-nrVY z^?^~-uWL?WYn}YasWrf8k(8>EvgNvle#`eKpCQcp&jqc!T0u77rnZjq2ndV#IN5U6 zM;X?`LLXA@#m?T=53zTROZ$G%!rG0@C0SWYBln_KYo(4SxOeez&N2A`mERGiUaS?7 z7!j7&t+3fshv(Bww{G%QO_thgt_8~+baO;?2+v@2-lt<mv|n zw1;Fqv8g$ADii-f75N&^<`WX0g;aqPr4JdpxAB5V_9l{2J~MN1kagfK1)-Rjejv_; zFwl1~KbW=r8g(t;R+4MK3t)msFX}hMi-!>PVrPyg@j38CpBaxTE{!cLe7;G|70+W* zgT#E&9PzVZ>K{z_XnLLXxnT3`^VYI-u@~=jC<9c)8Vf0t?xI>W_!S84{L=KaXTrD~ zGx={5Ejn|lsViC-_)X8oadvJkuTPq|kUB}V=(Bg`(y90y>}Og$&GoRgw+ySZ;8l~r zH09zUZEofZMBoVXgm*I5(!DWW+=t}QjrR?9k#{t@}`~K+d z>N){>+KZ7|*01Wa^~n6eVMdj}<#(&AkXfy7QDv!lu~I2hAS&dinG<&Pyi_^iLDJin z+PUGjT1>1f3+p|{WsaKtLy%7frTwIzqRmrRH!_-QSDCp}s^|Bd?RTYd^>r4;wRQni zx1DbaGnkwx!M;V?f6dSBXVj+pi07%+{5F=1|6+B`40MonPSJI@cm2d(ctw;?kWh_S zulXT7W{e4WcQ(7Z0}*d)rDS7D!7@Lxe2mG!yym-9G&_quE2c$j6xiOd+1Kdt%ee67`gvYCXclb4DHku9Sn&WIVFxY}>Vm*a!( zQ{3HrVtziw>YC_H`!boFsfF~9QHAA+R@Gtw)Jm(x+!XmKp7~tT9f`QZ%}JDJcgjvX z89$GyeM9!~9(p2SIeX^qu-5&>L*3x?R1FOnodeF%Goq}wVF1z$B(={Y{qK~%R=4ow z@g>S1zlU%~tjzA&`45jfEIj%m-uGGJQJk$*N0Nt)__GP=tsv~pxWiP0{AyhCIeh0Y zE#c!>S!$Wi*xB55$Dp2NsLkad-NT9<{?tC_kK>I`-C3>P6XEc_RcbPJ>t2e&1kVmU zraR&6=Wck&OI*0TH#r!SJyM_>xFwPMygy)eduO_jrEYEO(6@F=Wygomy0N30wEVJZ zdD9XGb)UUVf|O zBgb81*)q@Gz2(1B+T~QAHL9o8nsm7SDDc7SYn`S++x;Um3ZNqIm7-l;FZ*xpbnEhQ zc6jWsq3spd94D?^a_U*;YA9DX)YDd53sk7eP3eRN8Xan}W1q!1OLBbF)gtv7AP<|K ziY7O>H>})vZfDM_@YF%+BIIf-uVY-Sc9i_9gHnT5-i;ldl@Mx1Rl%y!+4%|*)S}gy z(WSAP;Awf+1_GTeoLYpnEV!UAR+i5etSxFz^03KPhR)tvw=!yc#dh+YEvU>20{5ge zl(QvWqqy?**@~FSS)FUNcPTC?7=ua9voltf#%m}`P~MTT6GB?D^fe8(^_7*jr8?7hfKgh8D7Q5@8*|ddty`60qXSlDOTB3|A+uhT z#<-db(iv(~gKBA|ITof0Ni;XEHGp2X842g2TbkT&Ohf=xv>a;)u&3_yp~j7!5LV)n z!m43NVOo`%n2O-E!nL7oTH7VsjYVy#sUDZ2(;xt*4r=UmttO3~1e)vC8tRiPD{oj$ zA;i!c(?V$gG&DxC_M|d|fm}(wW|W5!YFRW)-fCHDS#9e@1?Z^)y>K;^)}$$o$p8ay zn!PnaVcVxYd3!hkh54)27Pg$#oRkAZr2N%}B=b{iD?4jN2+(#`g#hVHB9$YA8q?)h zTmfB|nx$!pI%_~>kk-3d?p&>UEQf7H8FNot-&!m(n9b&)y=|+lLJjK(qP1nIPyg2N z9vdy2MIz(1e?1?M;=EhE#Gvq<3$N9^JsG0&#tA$_i8c0m9DY-Ky4}}F7;Zb@MVKOy zh__jX7boj|U-0P4aR>@sNk#+)ApZa9 zFXT!0b-UG7*NIy|M>)5lBKOnvpSD60x*S|$!%(N4eJ`=^Bm2q{h{KxZe_bN3|PfK-DY-tGq5(|r) z=m-?iv#AybUeYbi`BfoU^jrgJd-unsToN^YAXI5!LH>`A;T}|&3K5G7n+vXmNa&|0 z1!mgll%@Z3zO{--#%uLgSO@GMs_b=sJ0~>nfjwc1d<=1ao8+ zIt~;OB6nLny$-&m-Wek*4F;79d=Kqg%hxNA)!HB=88(S5Ie|b~Zh21cLZBR)_dTY-^P3On*0Yw? z+g&U`xKV#9{0~}620OA!uSX@pu|D&aP7fa@3mz5MT|39}x~Y7A1X$)JhL&YIl|1)8 zJa+1{42WwU^0#PTOVjcAR#e8j#TP~-V!;0ZB1NU|fssh_i{Hm`t+U}maS%M?P<2+n zDvRU&Iip5b1Y@qF&Ggg!IswCVQ)88gw{?0AO-3qn2OxV6zm@6u{2GXi-B2jKmqT)H z2=7x1lBA6oi|NeP?=QViA1<6#M2e&xLk`AK!p@m{f13TE6@sv02Wpe8yV&O+q503@ zsLg^FU`2)}q&oogxM?hgHy1f6&^PzIf0NXMqZ<={#p_=lG%ZBv`Rq-rD%_{$Daf#~>?kbt&f-JsS5mO}K4a zVNwVZBG?;WBphq@I@;@oc6o5>^yTz%LqlVyK9yns!sE_I|cM>)ZLgJ>p-VTGB;CpsQp-wG&OQhOyhruYVkB_z)ZsKGROwrtOBlib)JBf&5jej=VE-w~GaUqcE+n=jN_LM?a4qK20UBR_fL|@R{QP`KT}_ zww}QAnhRT6+_lFnX8ab&7fy3SM~%(TJq|yHw;b&FjONm`4RzAADeUJSZG3qAE)YBF zqff`J7cCrcx5O`Zbhznr+Mlt)&zb0*0MoX*C1qd_I6Bv*8ihIwUbyF-lTTMmqb^e| zc=8w^!&cPQIEuh&Q*CfT*1kS_0!rV$_^aVwYn8W`w;W`ehT5>w7jD^PY!X1{p6hm9 zeLNSFkv-)mPlO-|yi=OxiSFy{=fU-R$~O+ai!Wh9&N-+AfP%xKpcW)p4Eh_)=vvp( zwKT?CA>aj%W%U(3s_l8rO(H1NXil}?9KBbYsfV~hUmiSv94q6;kKn*)u?Dxd9^I!a zLRP58gsnZy+zGXHskl(uowM3mt+f@ah^J#g9z1yd&9!Z-twzp+jdg_;Wg|f9Telgf!>&mIjIE!H@zB#zPjJARxEXv@4}0z?6nGjO>*RZ zOL^;89(&x@9VaI(t@|sOk~B6Vv0A$IojUK19lvxsi;lxx9@NC2H6~GW&TM-JWoqI( zb*Hvzp8U7vXnd#|LNuwwgSvkWV>Q9C*0V5f^*9uh$OR*NJ6~U%e3ic`BJ}8V7d+;- z#GV}Ya=t~gLnfJ6sn)#&NXK*^Jw~m6lGNb(R~)xqj5R0!04GYr18&Z0@b<7;LgQQL zMSA75sHcl9M>#syx>jR(W3P?N9Y>gq!1G*HO3jX!9&6PRz^FAEnj(mFuDNF|@Z*O| zDLkhDo^xx}yEH&9w5`B%R>FZ)4>fYnE;x4S;djZHt*H(;w5PwxNo{qmG3cP?tp>8N z(gs?yDg#1IT3ax^32IR^1k<6jbFR*jm3CCC+uRu(j@tFNtE8#5za zl`ksbL90t`W;K~taMt5WW$SNPMW{}dBQID@1?x$xOKiP%R3?ZDw6^N%+-nH8;i)pV z^@Lioq&wD%3eajs3L!u!z@oJV1sD_n(cD_HMWr!KCiFq5CtkMntgEu1>}1-rDgtXS zR*|=?p%h?LfN4xA3be#aYecI|wv<|Q1{zXW+Me9iJ5?1e zs0hBaXP&r_MXOAZbgm##GV7&rT3HH>YJfDQDAt*JDU2e5)OCiU6F_KsSQ;#J^U?|ex{ncAYFTI%pjQeo zFH$lvJ8RQ`-3Q6dUVLfI=C4X6vQ~09io3G5)s{7efbMOjO(Hw%%qp_An2$9wixO@z zTlSh2h+pKZ2nMqXdxvQuLi1AvZU9zVt*th?y1LaTTDror#)uC+MgQ0N&>A)&tU`bo zbAEH)M?x9F;Wo=%M?DVp;gFX)vXN;P;fLAd``TvDDwAbw9D;KZTBKuI88-k+i4@?E>eJa?B@V(Z$zhVwrl{{a5)FV=L0+DMkp zRHU57#IoDClBYxOKl0NK*+_}xjGLhh#={^R=vPUz;bU_qmtrie7TpHvZ+c!kH$)0l zfTxHzd9w9O?f8?-Fl`K3q?qsr7U9)deyhjuk!KjBD{G^G>ciR@XG|Fx6+H#DXka6?qH*c4e z-Q$8%iU2PeI&{?Y=}uYLE(|SWWe?Hzym(P{xyf~7+B6@^UmwoXGy#DC<&D1z+rt6WH`Oc36bBJqeYujI?-~O(qhBGYcrxHc^^in;)LcJf0N>CTsgQ}e|Z?n_j zPXrll&WC5Vzf5_tM{Cw`iSZ8E?N$!FHBiZT4B7U>@FL0rEmkbT@W$Yu2T#MM!)lX9f@PPq=t({j~({UB%<4vN~$c!}ZX z{Ju_>>hY6t%z3YWw2jcJGiHyvv0%jWjIG1NZy&m+I|&gJoB zfFNJDp5~efyIczv0}u_kgC~!9T9uH-2yNb~z~37UJCD1kM8J$k$^h2Z^g3K=i#aM{&RI+(v@M5)kz>8 z!>$$bg-k>cZ(<6cn~FO1{tw#~wtO9iXl;*a5;i&hep=Qy)pN={7K|QhNs7l>X=`a) zyaZU+No$?P5xetge#+3G=;}zPg}I^;weGz1+=pbT#7S#hl{~=qeHssoHSN#ly?Sxr zY$Q!aqNsS$)Eq_p^^di&zp<6&rpV4Z10XaY>qt@~ZW2KnFe6c3j^dVe2xAPJZ*K;*xsg_t?I=xCo{lpCnHK5Z&yhEzOCL8f-!Y~|lP zFXznYIb$9KF zLY~^vmli4*d8(kG)|9nf$Fbx~vVp{=!_86=@_s7yc7<#FhP?NLjX|(A-wd}oc{;di z@c@7mYBo^1`RkNn+0WZtTH@Y%-SWn!fI^?1fW~?0nD+UrZKpTm@#FS8(>EQE5LT30 za;d*HakprzW3(`;m)EzL)mb8|b^%g_zcBIR$C;4tc7jqEjH!GcP*)NBzaO`kS1ohx z-=U89o%hMm)?tXR%W@#k9!RkdN(zP#VtOtxLYWV2R2bjyu{f~AGPBGi)k8+r5< z?0h4%>;f|s?l*^sDvadQ*iY*#=0uW6?x@y^=N)d{9u#(udC>9lcgL+J_a1Ee7VEaG z7PraG+5H;PUPi>uyJ?IykOnS?nd}!9_ciM-wdlt#=r@;KtAYF{&sl2fk&AOp0S8M{+iE~IrmK&_Ycltxc9E}BISbVfMsEL7N zb3?la#j8VJv=B8vrmj|}kcA3wa%d#W_|cgMoPG+%VQob-Y~5>utS$-sbkQ$h=-QCR z{%^-yeKh5p4Wx3{j|uWdxmZ9#o1`!jmW zjDYD_8A-iF9>(8@*JokRSJ{uYdMLu4vI_Ed>^Z5cImY#y(yqAm!Fc@E&SAD^sZ`jRB&)=twjW z8p3NzngR_Wm#n=OHQTdm)EGIhU3$>9xEhNE60LAi+eSD=6u|_d0^cgabfh z1koDn+N1*&k2S?VstMU@eRj%{c#YLh{u1k&*!QG-DpXrUBfP+BOh zLMXu%7Fx6@q82r_p)d*|Sz0uyH=co(im{DIt7}FEmCCessWfv|lS6tEYECAU^kmZ7 zm?>tmh|t!91!1D0H3tgOuRRTI=|j?ln#<6WMw1OKt4XDkTalYsYYlNUB+|~^dg{t0 zxuk8Dp~m#!x1tCtN${-Q)g?sj~+i15|pgQo>ujkR)7^RPDB@` zBx0P1pzHW3wMHWLTJ>5jzIs|H?G-@gtYurqI^KHiD`-t+rxQX*?Vypo)jJY2gL5p}(-S${PzN;E&FqRuPR0-I|@sYj(7IsjTAwV+no)EEEI`L{@k zPyt|s0xnns-|2OHr;j?)CRqt4;ZB2@uhLGHdh$aIYO6XDp_zApb6=H6_KvTS&HB&P z`Q?{o-2o~%XxpT$2}Ji!`0{bdogdX}h(cERCTE zh_dnePi2*4B^~U;RBPKK?Y1<^EI34t zp?Sg46pPrNQPmoiDrqBuX9_s2Y>57|3iJ~Tv~0Jr0}=|Hy=SM@^;^12beWWbR{Sk) zZ=~E3d(GzRwM1spOTvuTIRG=)v(wAz)Tl~}0xa4Q&bIWKf2@v}Nyl}lAar#%zefb? zKY3HTW1|j!W>d_=+423E+fph8ak0>Wpgo7byZ3cDLTsdJIf2mq)nht_1r61O>_=~u zj~XDs;~;(j-QR9E^b&-?>=Y0|H|PGZ+`VXIAcczg8-g+H_mQk}z7z_QE0wnwI)XK^ z>0|3C#EulHMSG5W;#a!e%qlw>5$a>qcB2w5eV-Ci$}*P2Pe~iQ{w->~`a;LChi0AH z(jM6%4UaIlXyWPm%0Ufcd8}7UbFFFxYx87(D4THcQw}p+4jBfdHw%3Y-;eBDV#q;T zFy|vnd(1tLSx`P5nVKsuj@onELEgrVb`7%p4}UScElUeblIq1nt%)GD!0Py;t5!!t zc;?h|=-C_l57pKyb4)u2!V8HfEAxO&wDjn}+mP90Wd~DtZ9Be~DOxu9(!k7C`f(^f zoUU(#Hnm{t@ylylROje%M^`53GD|rO3bcuE8zYg(bYjW6bcR(elft|@s25c?^@_bZ zaNG8!CdF~)5w0@etK3BLpII#5=ZmyFvaUfSy@MoVtE*IQIvF-_q)D@@ z0tmcsBFTR4c}e-WsnN#)19vvy9av0($MZX$@|6ToGc)f3o847~wdhHnIlOgG#o`J8 zkhGk`5?E`}=`Q8!rV}p6jf+Jx%G%olyJ8#APNtK(Ms3`25X;gMCpX_myQ+@GoB(Okr7fvs1J0UdAi<*SM{A;3-d`P%OV^GL5DDKNAl|wojoC^;-|_s zrL)LcKo%LZbDu$(c~7gPv}}i6N#egYo$P7b8DsPid&PqeiO}5fBXMI`v8h6ZdF~W* zwqA#i^$fLJcK+Hdt#A&CbR69ExswVqSrDBDBeZ3RJWuHFBhT|+wS2_(71K_OFH0^T zAJNFkF{-tcT-op1Yj;h1HiYxu;allHr-zR={{YwMHZ8I}q?K+7%R4X~5Th9r zrQGE|RaBK1w?3D+?G^Ju$OBRxG`J)Zw)!>iY#U(OROyQ%s3Svu?rHApra4~@d1E}N z0hgaa*-b{j;G~uCV1-=~WqduSa})bR6fhy!pNo*o#>u$obJ()cqMF zM@l8THce|u8~Lt$EDLBMC8O%biF!zu4YYQ0*ZIR|RjXh6P&RLRDJp zmE7X@IAiy1-0f4EkDjo~FXp7}1Qp4DMNb^0=ubD}s$;X3db(b-f1Nl?tSigR>#sL$ zF<6~wxGB_ECN88@%!Cbf6k&}PK#Llu2I{2T-%-t2W3^djp=DavePZQ@`#RA>z0_sUN7N@>5I&~2M50Ztf6tfVj2RYCWwOh5!?$O8aaYE2QYeyE&@@*yA zRhNwkEf_21ap}hXK5iLha=g7b<1w0%L1Syy2NmOg^HK>4Zqh--a#m$3sgL{A5<|U$ zOeFExzy5g{LvDG=7grqP=`9*NeE>J9)8P$?*s4s0fy9g0<>>KOn+L<2BL`WZ;_joF zh#3oX_Ezus`(8HLwCs$9qiG~!0>biub#wmK`>5dA-W|+dSF4ttExy6Fh!kDqRRdCP zE6Uow98n;TXOBGz&5kSVaCh^soFIf{?<9(V*koS(d*}L|S9cv!0Cr8J8^+ARhMKQO zWd8v7nl7Ps*R|~q$YhnB$+e4MxAgaM;k!yq67N-{l1FJ)#SvSLu`M3e%TK(*Jp#W$-OUYh*l~*_Td_qK zAX_}f^FOB_l&3-*E|vs~C66+nEZ#rM0CfRNpqWslx^-cGAXumF`IO@fYxLs# zx?pb}JbxbEUSAhf;45)9pjM^>Sgpqepk$|3+Kshw+ESD_`0?ZTS&b~9txTA>Cd0>% zAHYi0Sb3;QUbWoK!7ZyTPjuFoiLFZKaV(0<3T3M?r)+2#j>^nU9jk3l){DloA09k^ z2HLXKY-qe_g3;qz7Ok0xLeYBGy+{+PDGcs5-k^>k8{sS>wl#^V;iI1@Ys@@PmaK+zMnHJd-d~ z7xSMUKaZrJ1|o|ayLS1jo?T2Q4*fosd2=37rYvu4t#2^xuR%TxEcdnuJ3DL9ND)BL z5~9MnW!96!z8}Noe--U{P;An^tSmb!h1m(`uRsd!8`lfewd%*GBzcdr+nSQ_Sl6!V z4M^*)Gb*kz5ZV0Zj7A2Qr`o$3cUNp9bJu*>pO-XcQ+{Qk`_tooehwDJCRy3d*Rcr96IWb}Ry$Q3j=u!?L>e zs2i1VYED|wUiE10QtL-bj9S#%&1y2T^`ur*n_16US$dGgVKtYbCZr;3FG5XaA=fk} zw9uMED-CfpH>5bwn$uWm46dxMr8b1~R@WMl3}^|bI$!6d?#*g3Kx-{UiMD{ov02K@CQDq?;^LopY{X!nHD$hkCt04b>z=v8&A z5GY&QUdpXf79t>A(*kL6%}rD{zdaSN*`;FE^vKgsmYHEf9eWL9u(g5Kv&rGi%05#| zE-iX1X-jou@(R=xwriSE9kkLdLUrb?GC?F7>yu3p7N#r~HEt;>BJ|Yr(*!1rMnxNj zO2xSpV0q|*(FCk*sR3%+QR7`|wsil|^V@4xRf+MH8%D=QTOOCI@%YBkh)Gu=23;K< zG+rm*@|z$0&d2xrJwV!JDu?8`u6 zt#y7D9lY9)DBgs=@|Nn{(j7q`ALhb!vkZaLHRs{(@qXq?Ti!hlnlqq2YyCPw zyF95IjMr5KSQfD2=sKo9y8i%M6KL6P+sOBT+Y@|QF&)bDEUu&;%`eQu*`};cVOi0G zC?FL)_rC4DKTF4MY%;wym*C@$Xy6OzPW4C8kCTfEbg^tsY$7UT76hWUKUSGCSB#C_ zzs?Sv-af9hyP}2&2q(JZPp8&!QLm}r(PQjza$pg2BVQ%mvZa>>eb0~JZL#dla80dy z6=Dzd-akaKiP(_JSy*deLtT5kbmFsHNf`%r%f+V_JeKmW)qaOJ!Lp=z$7xh*Se27g z?cwA0Dl9-3MN&=jIVH&NtLI40QEVFHQIc|F`Fs;6kDa1P3IQr_5iP_AGamX?%HZ_g zmED5nEsBugk;&q8RHBo3RH-U47GfhGns-V<&Z@=21(cp5M>o;O{6G4);plnBN_`fbsAHt&~3^Ehld)k}& zJf2bD=}4AX7U7hz=F=X3&bN=yZ|HQ1qteF2rL`TW)wKeFXCT~syx`xU)l8eGxf~$d zQ_pZXx*k7*I!Sw}0O&@H!`qE>@%lHe9BkEUk_Z5E5=cKSIoO2AXa_hSC#6$?Z^C(x zo{sk*eEi;erdnkAP(9R;16ZST3!W(R@%~Ga?4_7llGg{H(SOUXcFsII;0qqYI#@dR z{>F5XWNjFLp!P8*7}xUath*L`0yh`6^|gWX)An7aU7kj0lt{eJjnQL^M!g@blXp_Y z*f|!zq^Gd-a$KmP(Tb_jfFyjUq5QudcVkbg*^18!PZhYw5(Abk>hs}xt~OL#5o6n5 zEwSxVY?3}E=->zQsRyKs`Oh5RL%)ti4Idgfx}P|$dG^hcbETp{d9??cj9B(q)F~iQ zt5^Yh_5tI^oPYUUo->MGYpwE2WAs{CFi#(kIZZW<+1qB)umE)=>IXKR-R4y?ZJR`l zt2WW`Wac6@w50nA+s+Sr`nrjx5}osA>@U-{JI?OuYCX4Tc8nk_WsvcoMSPlYb=^Ko z#j)krQa!olo!mnl%utrPla44lF+T6u^}k#5tTx1vsbhPbNk!m%dd-(FDZ&0G^Unkp zw&-z{xs+V*GRFsB`F@|jmW3G-!JC-=52LyAeyiSo%=VqWi{SqNgd#C+v`Lwc>OFcL zE*U(g9zI$`n>N$4M4L+~d9ueFjbqK7qxE~Z;VX9ECVb6l)lg*9)5Ez@D?- z@?J(iS&yjz*)xU}8`)qCxv0fJwlch#aA7Z3u|X828TufWanTm(BR6tL`p+7=eXcE&ZM1o;AeWg?DOZt)=Q%Ga{{U}U zli^Q;cY&8>kgKsFg@`P2S!9tnmARg7e`klnuZ!Jo4uo7DacWbamHU}%_n`zL%VoniYR$s?elhpGVHRngt0|Z3pPcWy0zl+o<4SG zN7=`=cHNa92AwkVUr)5u^ktCpY$Ufn1TogWWt4@pEy!ErYmGC*fqe`-cH(Vv?6OHT zWCkEKfQ%K0M%p}wL#qC+oz3*=|-GxKE+Qb!uukC<2Q(acYs@_crlIm z>%uP}S$ANUlNHOn7X(yKcgfN1DkIs34oaz14Ax@apAyl@)O#QNtSvIck+U>}i1>o5 z1@!ngk<_#8vTZ;_phdejD0xJY$yno(eg6QEjo+*F+^>_bH)SpHt)4qs%A5%~5tB1} z9bb=U&^s7XY|CcIGhdyuUKtmQtgK~qTgGrbm*H7a(kbFrKyYPF3oy%zbsr(S-Y2U6 z06$E8H!z>#?;72duwV^^k3+}$e5|XqK_i6)Tv**tot2hI^5f+^e7>aHN-oh1hU>;v ztO>-rHhakzlIwSVIjp>_RyV?OOzpF_&?C;lQBa97e0cY&@%>a^tobu<+Wrk-jS!LE zA-fcW!qvbw#maCdRUH|O~c9Xc<)H>7>Zde0W8XeP4vF= zg}+x+#(7&@Ale&f5a4IUg~(D1DIAl@BKJ5vr&qiDeYO^f#F{jYHDFF9hbO&7ZzT^y z#}kg#wQSM{+2_PaNR&CU^DgfPSDLGtS4YJjvQl zBHH=SXVIr@^5rZzgOqn`fs-}3u==Q@-cmNAG6MHf6@`XA*Fme$*pNc4JgQSW2>Tk{ zK29h4ADP~M4P|v1!g97vm1LP^a!7k-q}WKvySSSb@^PYao-Waz>^ZWgToxQ4Azbn4 z`28qZ*vbm1z^F3Z9PY=J_^SI4yNlKT07^}>Y?1bf228Q23aO447xucoHS&FS@9}Vw z*KNr4cBpZY1dOINR*}_CM6!g-X2E?2c)q*ps2oD^tcAlCxopLmIv*E#)p{KvYsN8X&gkd znGR~19DKxYa82tz9zP2g;eoV3BbhA-QG*p3jT%4pW^A42e_Kgd;`r9k%t&-8jCTQA zH&kq7;pJi8c=I|5S(_#b++}0X1Bk^W`@U1i`7hEPv7m&kipiT?EtrxUUOtbHh$#d8rr?s zUCq9KTXKjRVznxWR;+WEN>!-utqrO62868}J*n9# zWh%@C3rBj?$%}#_hJdYDQkG)9Ze-t&r9JcIkFv`6i7RNC z8BLfKK5rgBh}-@^y zx#iTxdv?dAo%DnX+POXtl(m}5JoFRnGDb2SYv%(|N7`fZ-`Qs-^pqOg*J5jeoBsfn zJTc3xD;adA1Y$H7t|;DZD%tJh!e&8BD^!)C!nHDUxTPYp)lTZ>u)uo-Yq~LTq*}U~ zVqw|%ymVlc=7$2TUjE#bq(vPEhd20 znjYFfNC!#-1@#@Y;Y2Fh(iT##SF}|yr!h2b8n(Q~iZsXNs?i{@?G#nj)HD-9XpI3i zCs%1%ddt!pO$jt4)?Sb@Xm441%9DwqCbZC+L$6s_YcEI@39h|mNrcv3xYkk|dfFNx zOFMOSXs#lau;W>dv|_XkX%&E1TA7HV!|_xK0*=O-v1?3N{Pb+-jn-5R5n8t!3Kq4s z=c8vxN;M87RW>)SXxg>5NK;}em0)!=fVN^Quh=RI+E6X8;H)y6Tbic}ZACKI8rF3h z-DzM&g$Op&wzXp>w;OBEN--GQTIaoLu9vlLH>~XFu^v^st9;b;TMN*fucozZ=qi@h znyM>RTGYU%vujcZ5$3M26n3F2dNy2UxU2OKTgd{}q0uC8k0az)gV z%VNUXj`!Q5+IE39EOm3)%v31uWz+h(KhY(Mo^atb3=X_rmO1D3oTV;C4!1V2^K*}D z$Nn@l3T$o;pC|oZKZcMwtOyzcG-2N9QnF)X-nZbUDMER$VfK`-w7$oWmQGx`fh?n< zJ8MA(#2XDsxVZ7|s}!lw>N^ejzOB4}X56E3i;l+TAp6B`nm>=0M&B?5gPTk9{f{5& zHv>l&9Krfl@#Eyns*uBkUdKXI9_spOK zAPe)a$xKTCLY*61TY-nQ>RyIG8A%`(T~EpdN;4o@ z2uQK9?zc9+-?vB!3RINS{CM*aJUrHLdNhJ9CC@QGH$4?f^&xGq z8-Pe0;DCSYqIuXyZJ=FmkjZ#-?y94Hv-OoauFwUFl?LNNFQ-q(fl|S=gjWh)IR%%Q z)rX;^2lv!ny2hS|H-shKvN$y@$p3bO$sb+$(f=fwjJroR^Qp zIIEI$LOojn&+8K^G`4!AY_hsY!OD&jc!zay#cNYz+5p8`CLN^Qu(<0wUeyiaz`L_E z_fCt?`K{;mbZTI5374DzRaYAje3P_|9`dAT#n4+ah;-xP@yF`xS_}%IyP+dM*BL*d zyk8%+U_gtI(2>GRGb#6ZG@O#Q7hD?Zvm&ArFp|ma2NjHd` zr@NzOlQuFDyZ~!rxes~4y5RQz0Mm1iEaj`=HbbNh$mb+;`xU{sfI1WV4Qr%)Stb_j zH@DO8an}Sm2FMZFZkMCu#oHLuudV$47L8m8-aT!gP&*H`>{0>Li&%{aK6=o=T>d{E zKd|NlRNofu?fg{OfIts0y}5@mJ@298<(+MRkpBQ>J~;xaf(_3^k4S%hw=c*wMp_a{ z9r4@sz5a!IzXk2l#k0ziTayMwK9%_`%X`h{>b#%;3hANMLywU4Qh5IWB(XU%^mGM> zp4m60PV}_*wwW7fj4)6Uh{hq(#kjYxkA9S4WnEWLZ*QBUACL6E;H{kPTP!k0=&687 z3D!XIs>6oo>8Fm$Pr;M3hLNprRbzD?PbpeHQh(J}ZcWTye$w$`1WwWr0B}~@*JJ+x z3*~T0DLa&oTRQ`^NO8q@w?f>knWsZ&^*gSo>pDK{djiA~WQp2)1&=bUVABlc-9G9` zKHvRRZ2tfUBDi@<*+Ce9IoqcD#jN*yHLsV9Ha1jchA<#IfxI zdcw#=jHzx)J{;tk{>PE!KZnuFu`4@D=IXknVCi+}<%LAM4$&lqiktMZHnETukv`@> zA1N}X$W4=UM$0Ub#&}g!Vyc6VPM;ggwmh5fjP<|7(c{>5X$WFun#+~=xw$QlEOFh- z#nyj^UW$Ak4h~u_`Y6*ay*||)=})c8JlE%t>!Vq4nL;PM=>3XW(_+#3L+BV^S8%7fp)etJHZuio*_rLd)G2tXJ zf=-e~8BO@HZcb03JzVnzbk<8jH$_S46l;Eiv^du3(-9(Ml#Xa0Nt09ipX zqKLqZbPX8D{QK;3uPK38oDxZWpNhm^%70PTc(<>)lL=$W_=W>Z>umUb+t-OY6SN$_ z9GEzc(5TPHhpfZLZn0*XqHC35!M%z%Acbs@Zg5BYza5l(AOor-cLd5&eOrq^?qA2B z%P8VC$mr#mUizD|e@2@urG#l1C>6569G5#{k|&|%Dc2)!Q)YmCb~kH;jug}kRdMck)g4nRl-mUjhRmJ^maf zkM1{{$mTpRk53F=&V3!_Hf^O2!xUPiphou`S%}&AB3~!D>%Hdh@#@QaS(Zyr+T@JOvjzcz z%oOQ~#!i}kv-7ncD1I>!(NwsO-R4&0{{SMhY}!^_3lqLa1&y)Ecp4^3Bf-D0BmhWY zMaq^*T1WB9{?<)Fz5f7nw$y0LobDpQLiB?x^0~+8bi8&fJb|xb6;C+%eHyFb1mrfS z=L%yW{FkH)#~TtXT!vNWu>SzAjBeq+QBP)%c#1JnY)I~&gsonbXLYg*6D0Cr_dMZ= z!WDzm;r(agiB2I-oO}!$QG?=*_qB6~(5B3Yaku~jb~Y)5ZzsI+o-b$VV|=J>o@pfA zz%dM5GuquL9(~u|Z1fv`#AVu0+X%No5DqptZ!yd2tJdBV1dc;1aJWx0k>kXAFJ^Wy z**J;66|^dQt@VCBPCM5wn%n0F(&ix+IIBn{C9Yp1-{bN8J+rgz0D>gDArcF8a#D}b zUrsL`e>>qQ;z_n_NGwn+gSmufa|HurZ?96PI6fvH2#`po3}%2yP@D{<(&nLSAq zrYm9SjFg=&dRfWP@9!edWb@A1g#BF<$56=hZ+ z_qSypRPXmWN+pPhje#muxqFTwCmZCRKd|Fn1+zRrp}N>G8E%&b^T|8=I)=f{3EGNr1KsbKbFPi#zg(Zlr9 zJ;&VsFFO2MbM}jNSEHww5$Eb*FPG9Q0xk$;^Mi5uUO%#u&^f3ns2QD6;*v{xOJOc$AL^3g{a?X23= z$DJ?}Mhy#qIfZiTLtZ#EHoVk#rGP-A!R#8^8qIER;kDMI5V`E9Cd=s>&B{imHu_6} zBKEEko7&ZRi~;gd0Jg0CqhDn`l?)21E-rcLgaR#U(QDCIRxR74n(&5QlsVU5lAxDn z!N>zc&0g(=VO1ntb5`-{WgPS9uahEwQZ|zE3&CDiV`S^peCdeD=SS6=?<;E4eO}_CrTUM zy0WCvrFCU>>P#YrjO+kU5v-_I3Sgk7 zC^Uxwtr-H)c zTB}Bm&m>yZhdyd%oYW;RR3w3|J#`w25W3R{Hu)-|LR}X8)B-RKHI~NpplkBewV-1} zsNCkBkz9k>Ot7Nlds7?Rny@YEtO1~VC!6G>xb}eQrDzAuOIIbW=c^F~g%;^aveik_ zny3Xb!n(7z(|x1_a+V$QJ;(X(iV z=eL80r{nPx1~Hh)AQPZH5&r;3kD665964)==Odf#dvwgK3Y%-K`we0)99avSsTzVU z+-@HIYgnS3d9*e?hsTb^mjPp-w=ns=tgCe#Ctr~N05-9b%a}Yi0`yCyL7Ss{6Vg-u z7mvdcB8%NlhlHzWK^@h4B|2E#5v}wM_NW&2Bx*p^cpjn9G;puUsVM2i- z#_gq#3yz1;k$jhrz&8zWwh!j+x3Qm(?Q5eiR7nk`yBqSne-wqYaR9yR9*z8X^QZt2 z+L8+S`Nce39LTsl_$VJbb;w>=+aum zUrXPcvAvdBaZq))KW!wd2E_U7FXOpN2Qv|`HU3)1D~M}YfzMjn8;gq_Pn+4UkYAqG zwxn}fpr$;r1eIrCBFNn<@nhA!`nC7!7|`&iT+w@nvj7NOT89UKPI= zN37$eydZWlE>1%=!x6`nb0;O<@^mc}5hqOMY(@twYw5<@zt&Z+tvghMVPFGIB9o!X zEWf*dAEcw%c1_}1rNPu_ZH{0L7Fi#wf32zAE(0jy0?prvcn{n3>21Ofn%42-{Il#T7e+CHNCPE4&q-$g02aBGWkTRw zgJ68erTWZ#c8HqnjoK~`%|!!SH3H$avAdLBQHww{;4zY4a+44_){m_PJ7Oe z=WFSplWmzY41i(?AZ2hzPl~=ISgN|6WahCPEICB+;l77q_+8~z%oJ$xyDx8Zl#Vs; z^_(B*>22)B+jfPDwfx!ueU2X=wQ}u~#>;iZFN+og56?&A_{SWMh{$p#=GM6d4@U_o zgS>nS#{j;l20%TmNKWO7?S$76AWzOFla93J`Ew{jr2J-J3ksrHu}*rOUU zK6e+is>>u&edIgO9iDjFA%-Q|oH32b8GvjGo!%1qu7{4+{99|BiMCl|+M!}WML3kd zlUCvX02V4T#x;gk-wB;H3)stYw(T7aqCZ2YNQsHqHfaddN~$#Jg?CK;+DTrd+Z^!3 zuP{iK9Xd>^+Y8#GS~1cq}dWwYKDEi7sN&eHhm>pS7J`!f{U z2z%L-UjE5$KOgH*xOOL^tP6JVpB_6?EKL-!$niK?1|hA$19GUl&sm48qe)b-Ub|(;kV-sbrDzC-WZoGNyRYpPF#oZTRzSCsSCV!Vl zxpQE)CC&;iPZO1o`gC%tIwtAM@j9Lz+AE%04V^f>XO#Z{#`3%0io7Gfh`L)AjBhrT zboYIKjWw~N3OmA^Jb|vcI=b;G=w%zMcyM?8X;#OvylUoK47jnrLlKS-U2_O+z z%A{QLFFx~;`x5^EiSis`bYS=s2rR0>8Sk0mO{3xf?m!X0`_13Q-*@AiHk~}bgf?g3 zm38MwMc=(<^rEp|>p(*sQ9A*xtnkAIqXD#bRb0AwbYv{as76IB^(L zach%oP3)@G=&IwhdM6HWZRtKfKhr)Myu7JcSRe&jE)ou-0NL<<^TwW{$`)BAU<0_| z9mOGe9Ao#jLmw)NwUo0W+=zpdT= z3C>f(#pZ4CD`aodq;VP^7obQMAEOp#Ml!F zSycx_?%gE*WOW?B2AE_6cpla)NNZ%+zV9q|JzILp^`nKKWpwRiBHnV@M<3pCQI(o9 zU5%1S8!LspQX>(}r#^?~tIiNCn?j&sLJ%y=ZWzWh_jyeBl~PJRGVzjBsxua~)aE1a zeLvoQD))zz;JwDKbg{Se>z9X5C9cj4VVIWCDFpXLYV5J2;sPd6&xjXzw%%C)l15;=1&4Tba4z+y}oZp$M*Qo8HRR7;dIjA_fP#5^llL^ zaxZb(N%Vah@jeh+Z6u+yw?vM{ZX7uLe@D^NzMOOUB}urEP0fI@1DO4qqh*_MI>~TU z-CpCOm*r$^dVaI+>PU$STy$&>H`1LMA#zi2+OQ*u2YN?eRqs3woVnR7eN1W_Qf@iH z9iaF8f1{BPX0_-19VS~GC!WezPTkqu^gWC`c6w0iD9X`{g_zq?MXO^&kkd-1xNO=Or@P?8NnH^O@)EKt5~JE^bo`nro~ z_*z$O*3N=5bZ2vNi{EY1RGz?nX@Rd_J+gcuGS7_QvO8-bBP3wGEzhc(4=4UE9^@q? zK=HU`Z9xX($B&t!AA0N1fNLlt28#6{+V*0~qGT-ADl!*feF~r&ki~5z?OQl~+e(5< z+Y-Yoo-A>u-lW><>gtoLtE(zauB@)TNww4^t1_tw)==eHdXa**)Bw?3RG4WpHH9^Z zq&XD9N!~G{7wzq!Bm-NJt*{xQL#?%`S!`;iYtV!ToL2G6PI=?M3Dp_SDlTgpCRVeH z8oC&x4UiZ4>Zgq)976Wys`Yh5=+Vti4IMp*7d6H3n7@U0p~Gb#bnxGO~!!(iv*%+QU)_ zp)^LY(#UJ8tEmpIwRLEuI0|%XZF&@2nxHNsq!6O_ssv*i5SKMdLR^)terhZ} zN}|HD)ps7!cS5kme~?q{Ta6oCq;5rOz@ndO?Jrk#J)qt8S8nyD+Pc-9wW3@ADoHgp zYp0u>(zjLy^|@HFI*7sNtCHC3QLI3<`RS}k7_V(tm5&Mo1M>>&i(Zu&YjIK7-(G6X zjY`IKT3hWgugZHVtOc!YJoPhd_aKU9h1oAFxxKCPQf_apT)Q=fw6n7keoEHkI-yOa zfVDhLai=-#pc^<^H9>kVjN5Tzsjjs$*3+K58`fHX|J3lejek1<&UEJP{R&x}n`XKV zbkph5+>msz9OGVf`(15lm^js}M@SxKw|jwJwl+Dik~GX+mf#gVgQLCw0JU67_$) zA-(mvzi8_HiW3#DkTwMBJk~wR78(#Q^LDLvj5X56`g73&4x?WkA7Z3dCAgQfPiT8I z3VSsGoo(mzs;vO#=yQ|KFX?#y03^tqXC9wDhd!Q-DUotBhtmQzW35nDDK6o8v^cpS z{&S9|1dtSPw>a|~c=6guU!FT(nUvb$y%gwCOg{)Y+%Qwye24ODRZtK{YNMFPe8;!S zbm(KqYCyN5ct*dKS^@HebNW_qv(lqyLnK^;cV+Di1^)oUqfO9Iw=1s^C=G{w_%P)X zf0853`cz`??29p$o%~pq@_G?{Y0A3U4%4V<6@rvMRN^j`%E**0>sEaWcdi zp$D8T#$G0%I5R59rY7`Tq3rZ{DEL4o#|AnOx~TvwBCpwZ zjr;W&){UM>)04PuK)PX=&w0smd%P2O<5Fmuti@G`BwcuSey=C~v&q*&a0FNl3b5Bv zeii9*o4;2LYfmPF5V)Q#4;A_?Ct2h9S*AijbU|$c4j@|Zk@Vx}P2=zZ;l#1Pf&gKm z*N34`l3Q-KoGKSA>{KYApGYX;8r@X-Tr}9oY*b&dO^?Y<+0C0%_<8xf zcE74*piFJgzgqvz`SKT}@1M>ANa83A&yp(ogu zE)Cz^{{WxJK{<*@l$$95TKaIY-f=T`Givu>+GW|~nNt}fkBAG{u3oB6F6WZ{-+Auz zyLMacvB@UqYk{B~O2nL%Z_@7cH=gtSz0HkdFho_v*r;9T1B%=QMu(;Aj!;u^&! zA(Z;JN2W)!Y@u?6LaP!(MoqHlXX*> zI}ECWX!tf#mlkNy>C)vR{{U%Oe2@FQUQSB%HocNfs|;nuSGdCJThduN$ozgEIZT@b z%;gzg$0Y2LJ=a?tl1Kjl4=>!y>>Kp7oJ+NhnnjtUW|F~K^2QmHr<=&#Ig{1$G>E6# z3~r)E-76>vtl4iLE$jOgX45yqW-^;y{U8I#y10HFj%R1 zBQiexTa%xPKmP!S^4(=Vkz;7u7duL&99g3*abet-9Dm+-X+geSWGj9vT!n5G;=$HM zc>I1MF`&so2*lW{aGQM@7e611v3rzUN~+{_1xV`pOZ(cj7VC_gD=0|@aJeO-3K3sS zm6PK8y~O_jQ$%3+BWOSiKFH%tR?dKM@w~XCpPcIGwpca>NYZVqaUr>A_hj1Tzv4Q* z->al;Fm7f$xf5uP;c`MQN$T2o(d6;Gt5f<yU9EIEc`w@ zl^8w_CN+g+l_GuXDjZos7;-zU&L5op<#{peQaytW3A-`oZ+RtpaK!^|+jTUXI)j>^ zAh`z~=TBz)&E;m#uFStjwP;{YPhyboM;(=K_w3Sfy-d4S#T%$-wvCyx&emWZy%tDa z_rF)*RsR6VAOm@`&l-_^V`monaPM&Y&-9xA00Frw?YV1~#zvAB;vHgqVUzV=>yq@N zmN{Y*BXVpm7=pTe7q1^~*|r_BMi5M!=N4cNYI~ft887PV*;a%RvxWtgLtsu#hetQc zzH8x0qJ~njNRvd@&LV?@E|bHx?Ee53f7XAKkDdlv_VefX2jK-Ope(VL3~(ny(2@6Y z{U2FSZ96=$I1qI-9`|?3zVDOjFdNK7@nzCY?#(XiYKb!(L_#9tLr>po(+>}m7iz4QkBfWyrN{hGVa}ZUjG0o4LS{k#LK&JD@cSA zgyHWxXr4>NeWkF;UL)+%Be7v&lK@YgAL-Tq00!g8MBp}VT!K0v3>Ty0_|ZIRg!}Ys zLg3`O4kb*tpZ$H~avm*E**0`7EH@EJ8is3)7yc*Fbe2yaA(m4|99~SaJgFYfe1e=+ zPyh>o(T~<7Z@v2Nwv9GTpfm{*Fxr~#tlSEi@T%p|X8#uT0q3mt4zyxsmkC*sGwxYYE8g4|+0^S>o6=3p{gW&j4jxyQ<5?e6~TdHOkh zHr~FECv4f}eWZxXE#n#3*?ZY<{nwA5egr~6EX9F+0X*BaW$F4DJzn$d=_8w#l!1^W zYv?2SJTUR;Pq7V}IEMLJD0OxgCD`4(qKuboDaeyxFLp9H)sg-#f>7{AocVJ@|Ac7Z!?=ZqXt|>9`*i0(9MRJJ= zBq;!%+jEn{UPr<)vf|d@y_tr*Y(%6v4?lirlqZ73c8DcbZejbD&*OsXCl~JB5^q_yPS

{J8h>0zIb?dvYw?9!iuwzt`zA;lRZc%R7MC%#Ypp3m06B9<`mFm!Ha9y)dK zxg51Nbv?RVA4~jC4eS2^{J*!XbC{mvdt}p_JMvP=pXUAN>@$~JiwifW)QBxiSQoan zTH{(P*A2Vq0x+jaJ84*2_2=zh3q!NXxfU!44$MMF9r*cKy=Lh=Z+N7+tfK5d1GMR0 z4&Ao2vly7NY~a~pORTf^cl$)2@Vf`=C#m+$w`(%U7!_boCf00MgA{BP;PSrv$?m4} z${Q(J<+~jN3x#P2=taEClm4nctCZxaLymbCJY&WcCd5Vpmeai5ck62-3`Z5WX>o67 zrG`yPX;tJ~a0g-;RCFWhW5m19bvE;s_kM{?%)2zwNdUA~Ra+G>o>kK^x|3eNQ9d+WZPj>h?^D2jJ`0*j|81(^_$x9;xb7K zW}xqxn8?>G7W>TlfAYJy`9JmaVA?Fv;5j(M7+`(^qS)-|#L{mbO|Nyho8HmmgE{5u z&G3RESq5_GAH1=*c_d{0-9UD|rbSd~3jo1O3k-`lvEPUP0B6XJrai745aMQF88&3H z_Tv0L-&@K_%5{G-W8R}W$hIWhL_kp)zaH_Gvy=Yk<1WtwvRjkBae}H=*Eh$?uYMdi zm1LW1%>-*ENh)1rOX&CDN0Sb-I*nMPZvhEod~IQF?0p_n$KW+quvNXUk!y5`$?=cJ z`JNX8)bN|0Voy0IyI!utwMfsokmaE&dqupJc|s{%bd8q8+qC&ge#9}_mOgfl&2f7G z*~5qO=$DnJzk2J^?_HI2WOuH;22>@gMqaSBEVT)Gk%Z8er5B}{DbTdy_1V=7U$4a3ky+#P4ELWSFi#E5Tc8NzdNNi7f2W0DAAb$wu;rEbf~}zJDMzOEm>$cHI!DZ#-v@VtxY0=)Sz0l3M*GymRtpO ztE;a;aiOg=H76RwTUbRYm6f)xwMmuL)s+U`y0Wq(yqOef3AvI}a4}xuKdkZd7Rc0_ z6~*zmx#k_Txn)?;vEsyD!=Fm@HqD=AI}wsN)q<_jl-q}=yh(h2XPdNaY#Jv*U1W~f zkTVs(yFXpX{$175UB^78tE;Q4D@?f1lSM#9A=Q?xuBDLLy3|&#wMaFf zR@$@*K$l82G#0JBstY1*MOe4Bc~Qz#-mOig3TvUPWs7*^rJQ1Qs7%Eho|iPo&(G$l z%aUow0hVuBcBS;Gi?ty(JpOv|LmFUmDZMzSQoXRi^Hw~FqV7V2pufvrd>cVQbI}xy zqB&i6uGq09u_ElM2dQ2(@%T99J8qru`_pexXN8kBb;F+b6u?0jy~jQ4!i!}_3yz9_ zrkw`4W2Nyn2l>>rX^5HKBWIT#D;!R{Ut%*4B>O>32_R=t}`p4XKAXy(Q$3o1SjLS8ZL< zc30-5N(s`Z-#(V-qFG5gi__bs+#w5&+M!}V0i^F6nwpJqsiN%31;T3x7oz5^urxKG z0e_l`)LII{!i}vmyGqV_!YW#6wGJy;jSw1Aw8Fr13(%Hs!A!6f*A)l<)%k)j1wpzV z$^nodjy)_G^R4@|OxI95dA@(QT!aR{J&4m<(>vtj9fu0sj&aUECYOv%tP0{!XJ9_X zKYNf3U^)4Pwe%<%fpA5EVPa33pG(>8AYh9d63PHoT&y(e`d&_#aKNOB2s)m1;`eyc zm4+t60akz=+x~91lx{2V{KbX%Uhd^&TVTRYqw?|WoH z)LUD3^IxIr=?7mke(}`#uUTLTxX|i*XaEU&o^4m~_kBMfly~MGpj>uTT5;1W@QxC? zHNBUlx*UA=*X6Z19W8r#g|s!6U=4v_nw7q`zjN5A6_SyzjprQ4>DMxmBII;GhhH8% z!D3qxq2?WwMMz#Da5=}D=~YZR>`1Y=w{Oc>rE==G6bE9i;InKV?D)G>ULol+E4@Thhm(gYNlq?T1U$bM6(xXKPQBBma z?HYX^=I_Vgp`vVf2_9frhZo`UTjcThqL8vOkkJB%-;mLJWhYeHG$(Fwd5b6lK=WSUU*pv1bxm;0Ol8ZBuy}7~f`*oUQRE6P^ z>N%B{RUc#3y>kUT;62T^h!-v&v*Y&e4Q5o0l&I&kaDA@jEu5%Ri9xAhjuE~i)+YB3 zKdbefUpAnmw?wf)#nL?=vmQ07V<s;{$QUR^zcyAJJeGCj#@J(iVWIhHi!n&#DWQ>Zy}n$k zhFv{)5zCf&IAd&8x!H~PnfuO9$DOflW*G#L7ltVq?#B=TZ-*K)Zo19yyZ#8-1V+Fv zLBPbGh?BXJRXahAmzdkNxH&{>@GuD$@d1+8GOf7mrJ>&DzU@8tDN{)~k7+>( z5-R}3zC|l4VV9JT^Odbj;fuKU;adiWvSldqd z*li?+X1O;cSoe6nSF5B0vz@G9%BqMg%rdjHe)GzB)<&6%1ufDfi-U3PukPuaEb&9U zE>&{U>Z2TEP{iV_pzox|f|)Ff@r5e5$9|5voh07A%>1Hteojv*`U@J` zfx3k5gA&T6kGF_kvmd7)SN^McuC1OvhbQ{KwW3(W>1e>+hKO+k z!_fPET?W;>GesIGSk##xv|@pG3&Y%!TF!|w$+a%qxkUvWhK&W%HeX$ew7(h$H4efSu;~{@?3M-ym;;2GDb4Crjj|@6xznNU}YvYXp+Mf91kQ}c>ZpIV{Y3v zc6NeG1$Y`4t{Hs`zOR!6XHjU>A~QpGH8xhO1# zIitzBU7k*}<4Gp-dDQOz0A{bi+i+Y;Dq8HhO&(65Shu@LAC{Agtp5OY z-T6;9&h1gWc_fNsD9r2_8{lx+V0?Xyy6E?ja-5gTW7~Z!xAMjKM{Q6B`r$dQ1AbO@ z=*P6a(HUmA15smcJ7vVx=j^)%*|atGcv-D*(i<#m>k;Ia2)C8%!|-8Eouhdoi-qpR z!DTIqB5u*8@%*PLjPxgW$qI2$ECWhUON*i1r@DzFeNCj5BRT7ILXycGYQCCpDey!V z&SMfX0u_nqf_tmbf6Jh8vqt#J;48=;B@L>~cq4)cm#4htH~uZo*7P_-87C2pn?7!C z&#_Rp%<%0hCsv9~m5`l7Z2o@@T>Bj4<2)d`_&Gr#Of`6VeHA{_!q&JZPUh)FySD`xGS8V*&i&~ zXXEk7wCt?Ai0yM8Q4)Lkbxr#39ln|~w2gyU_UnFbiv8!5@okO;nlC|>&YxR zJ9A1{eG|}n^ka7%Y}I}28*<2s?4>gR*z=2ygr4%ZYWPAuodX4FS-FA>{i$EcLtZDr z+jM0X896o_C|HX&ir8oNxh>~g^_TQ>ClVnW9s~CGRh3w#!zKm znFQ(%I6|w{`2222j7rScX%(yq)CC=84fN`GwrPPSkgH0_R{Gs?{-2cff41D!TB8D7 zVuUH)#^9!$JYl!IuhZo`X~o1-caBJHI&Ec0mHUaxQAH$(;Z2GmW>R%f%y8hr zyq}ZxJg?ZQCfU-`2t%TrZ092@>CgXaKVj(x!Y}z%7d%FWLN>#jVJlBgC-)jdit!>}xiLSS5;X>h*vNo!Q~x z`nj4pKRZr|gI$k64KJWyt$$C)%*f$sOMs&hYZGl7(DDA6my@mygKLLrh)WA(hT{Ck zN@jK75&r zl4+n5P8=hep6w3VQ4mMqp*N~##PbbmdsgLg%4&ku)K;#wX^vS$s#w@oTABy|5zj?3 ze7km3T=USHHhYR2R$JrHYpAV7VQ-d{U2W^FTWSkv*Udr(mA3<%o{JikL{OR_NvIPyBcXw1?VhIO!CZa zLuzI0DJ?W$mAJ_UOoLXi^p=aJ^wKDMZQD&m4x(0(5&_QS-LHH%czBoHd66Wp`$Fuo zsfA>{&py$B;@|t9clpV^M7g;}rZ$2%X|o~_@khkjY8gP!4d49N>f!RAC#0J1=F;S)Nn+opd0dHv_R?|EMveVtuO6>!RK zwq{HY@g;*iaIt`*{Nn(*i_YCf3(+2|`*OSBD#~oTM z_{{q}W-_YaG2w6(P_HMJ82V|w`p?sc;ccgQUNWZ1X+r{5K%8Grxlwxq&dI}vACK*dWoQ{uh9m{SeGuidPq6eamiRetWYqU^Rz^y3 z%T;BL=AOZ>r9&$@aQZMH+u3jW>9yKj$DYkzSZgS&fTS|j*wKZiDqf&aTI_0#-NuwA zHmIkyXcWS>t_yn_b~L~*SXQb?vKJMN51*c`ttJA9x0;?~DmkbKPWFbajV5c8tqEFd z0og-p(>efJFo(_aR|9QNFxT)>CE8kO+Bg_T7}8y#hc;apVN92nm&We1mR+-LVTX0l zrk+^9J2x9II7Q5#V>aMVyQBfO%e*BR*fq_xyDxIR0@9-8!LuUp1-QNK(3i>h@||O+ zwL<04pN+vMF|n- zjB(WZ82Y5~=V;p2>Z>adE&~NP2pVXvz3KA7-LmI=qZ{{VU2 ze_{8&F+X&N*eBd8Jb;yZT-f?qNcDctZJx^RtU1)_lKZ|I^51Jh&l^V}QD$ZZa5o=I zr{m{rTTCnzq$Q6x8T_8Bd!1@r=g0PEDUEwu+l@!hSYrily4Rk&c}VXxS`Bl{8?F4B zI`y8fdH$`B;S|F5Ds{SmSp6K=liqIr>-cqJ32W;k$*pG88{8c=)XzI|;(fY95vQ(*dYH}qtttg*{0y{w@?XoxzV_unU3<5^vS zZfyv+;G; zg+Ls~I3yVgXE)5)QoHjbv?ce$T{fgYs+z$zCG1{!;t8SUBh| z0ky5D_Np-8Q(J3p8jpW4X&IkT|3MqkxKmbP<9TC~h-$^}Z zkDm0HEwUw+-1^?fhk3=k<-YE`oSht%Jl~7yRl79LhWDAf4`5^dEoR20SCgn#b(0sp zh+BpGG?$llNkZ-qF=1$g0W2{m@$8E( z(7~06Wjy+1{{Wn7JO%MOf!J6Ko4bbXx}>AsHdiHaZoo7oc{O~~$5-;pL|Q^sPGQJ6 z0tn~n^4aw->gZj9o*<+G1|s}6@>w@qO)83~88FD)09*2&DY-$n=+ILXD=N-%@RlpxIYtWg&Q28w>ZolXpGcYvi^q$;onZ1Cw*? zWy#h`fVm52E}0&Gru>^Ke}@3pWn03^R_Mtat?$Kekf~ko^V&*6MpXgk zMiy>o!}8U)b?t9l85pi0wd_|=E$MZ6`1z&EUEHqAX1=yt=>6u65$VR$l!#1mkT4;Q z@La5gerL24=-V z5^a7G2PVJG>uK8#&@p?)4oq0FK5^}P9H)(QcGm~Vq8w}_7DeXAi=>YVv}fC zrk9AYawBfQH*259G2`R)UM*CU$l~gOIK`PpP0g1dN7MY@>GGwK#F9E2G0NmGcHDh0 zqeP0=v$2@NM$m!2V2WPB_;FdL!6)8sv)*#`Q>a-@n@qBAl|8;3H=a&E7V^FgLU!pC z97xasu7L8t%<G<>J@mie6hb-Gf8%wH4RkPKViiTF=&)bj3 z&xy3i$l~jgJ{$x1Dt3PJY#FJ2u6)LN6#ypfGDF2`boic*&GIL_V*R1XqyiT){@-^eWFC!}IBT~4JJe)7~(x{G;X;%@3CDrcufo%X>5&oX?lBmtK z?=8sXlb+0wD}UWd~r`6}m^?S|jbH?pEZ21k!0c?5QZ-oAr zrB4VmZ6V0S@U7P)-srRAi65a z&6@M<&!?xNV>t!{1C5!vxx3Fn6G+jkF%GvP<3LBVQ*Yyaoz(P+fXA38qxp5}?J}H1i)N0* zgQ&f_X#W6LZegOGXkgikGcG|CYhlN7zfXdlJS=1kfPzMV9DSTKma}G8R+X5ARV0DV z*(m6L?EbS}dvwtz$m3ASp#bV~Cmq~qy;Ub?mP55ElrjVoGXu@);{J6M$g3H32Pqk2 zC$xq;@8|D4beu$uGYM|dsMD}BT^AkZK0lqU@HB72i6Pkn+nvz4>{L>XUVvL>aV|l) zwx-tL;jDlxkp-Q1ks#u+vbH@0%3a_1f5!CN;wyM#`2!9VHuZ1iqL0=^$^QTxd7qPB z({#PgC=UzJZ1k~ZV<_~?R1u>%xP zqQG$+x?A3bXK}543nlk=Q*+PwE#0M=n8kH-gAA^{)*L&}&Uo_}ln9t+Gm80(jvHl; z^B#pKXNA*dSe&ac84yl3N6JU~d~M^snX|y-mS>Vgadx<7;qIQUFRY8!coYX~3!YC~ z_6}y__WCr)qm68YnHcIw==yys^L`cEUK!REQHfJ>r%`?`tH;&(^)a^a&9Xct+NN!a zNXs!@;B+OJ_HvQ&`7Pe+y^n(S#CVryU{+|>);!LmkCOKI`l|87F8emls%)!}bv71| za;ZDy-DcAt{kjRZh7A;{bY%tZ4?KuET^^i?$jR>amFdP$573r8)Zs^(21?cZ;(!Ofp8KB2rv!0A45K0b;wP+@r zfvym-?7cgKR08zS)tk0naca(LJf93XG|18)>^WP)Da-|P!J7mZK_+Pl&uk(q9q5V%~Sw{R&CQ#Z4g6RR0^=D zB9Q7R?l%W5;9&&LMv$SlXOBt1b^B$I+7`+hTqX6NwK!{z9J zSp$xV^XX&gad+u&C@s2;mwI{0)f9K$M&b0lcVI6pkw+Zv0^aGBnTqsL_bJx&IXB*3zgLf{*ZJ=?bSJULC@2l)#mvO)2TmvoFHVhPVLPId3e zCy4y{wzVWE%=iMpZp$1&=6#eLRP&mX{>3#18Op6oLRG#vB8eDY{vH zvN+SL{$n5ES>%&ubVX2vV{D~<9Y*dwZEfqSBk)w{cd3k9tF^C+J-f0* zoE|?b%!g`t$dx#7Yuw)>k6Yd2&;I~t$Wza@Z3@K>8IeFm;XC$keN$bd8HUVx=xwxq(BM7opcxV^ML;d5_3ORpsPNV6X%=)s*uN z?k`pR89GZI;@HUuwXcw5zBxwz?z(jE70aoQJ5&~!krFRd!vc|WPXSc(R<3MJ609e* zQ-l_$F$9X+QMF)I02PfGW2F;1tt8Gz$yX}rj|QC8vB9N?tTU5Dn?M)8KV>|Q4p_pA zc2^k3%}P5;j@lEoFK)DJBfL#90x!=|yks!|8l-J;+fiymZ8>v`eBNr-GbVEFDD3iH z*z4z^jwN<=QVT{vdv|ZeMrM-+So&me-&6Ff$(mP9v9@%S1rdjxdAd1a{JAL0J2-0| zoV>Dqrb$tvf+x(IMsIEAJw6Wna#Xuy5};yHQdtJ=Q7wwGUR!kj$0t{WysY4sQaUTDysuMmB#7a=fmLm}3K~MG(wAmsa>#S2wX6plGkSFB zplFWiTn52{3*z?&xVf|SJGy9vrdL1x*i4b_2YWLDu%fNmY}m2l z;a_f=@XA*d3<=jFY`SBHP3_b6TSDj@07`(->I>V3XxV(8;!dx{sg@`sb;Zjp%x$6J z#BVuF_#@hV-Dc!;fJGd5?*VxiBJ#1wJ9knDypH8LE>HgeCG9d8!!jsPQI3Mb_Uh{G z+$yRRXn?3$7;h}twr`K2Ntg}HjKypTy||7gaNaFas|Fx=%Yq2r;!?*zKwb}_`>N3_ z(LN8xgIx5@Z`LDuK1=nqxS7?LcnBduBb*a#KUs}6-86G;JuMVv?6p$mD7fQiC!+_9 zNQ5#sGhW}xt`3(VFeG!0X;Vy$psHgkFIS)g5zyIR$@u&|8Y{#d-qJV)$>CLI_pR#d zr`q46A7hWzO7?aG%zG%yTeMO_5#DDDpdAoBUK`&fMvEPa%bsUdVQYO}ANOdTY;T?w z=;b_e-Rv!9#^rEOqMD`WQcnD|aj` z^2EnEd5=M52TEnu*S&d(GPt@MkJI$2f!yT$DkE$5 z_Fl1$Kv;o&HSbWGY_{|Bej3F9zMxp?YQnk>!)xo?S{M~mYg}ptNJcvrkxoi&e#ReW zglIln_Ey}1&!cJuabU*81tbn_2A<7HK;63C?XRCv^l4@@5y~Ty4v-q-^m6IfDc0YD znV}Zu9&USHwJO<1g|bGo3!5DeKN@?=(5m;;EA+MKky=?r(fO}%K@CeOkhg8_FWL6J z3eqw*7O>P4au1FQJar{|bsm4br6Yja>vLj!#8Z?8=-i($=cN>`%X{@?Tbh^}%CC6%)pmjdVi z09_oDbPa&gpgVjD96b4w^iI^)T>@EhQ(N7=S*$D_h`Vnt1IduAax}vHcK+uF+ z1Ctbc-pzii?^DN|;t{&!SA_YSn%q2JZ^xZWyq4=|$hon&#AE*eoiaF?_^<%3l33$r z9?y;CRY9^TQUSXjaLaLTrATd^LIHGSLuInthyMUJrho#Tn%_~{ald(K7`^wD(`FYWP^&{ zvpwhP_msKJt`$~MB8KAv*&{bWrrq!J-hF2yR8^8ADu&b(W*Q7Ee4o>yQl?DQ>Uepw z#)>y+(-vZngMDh-JO;~g)4m0%WcI3nxwf6+yW^uGpe!fTmY z&RvKLt^WYxII{ZQPKrra<8ohuP0&2qmAhuYPXzw}4!(b4m_dbhySUYRK#=yx@o;3?-ou7ko>dY+04f-;XTPF{idfY# zlNWE_B zs{Um>wd<2MeVyakM$sFp!!}DI<7m}akla3fX7ST|)&BsWqZPK8q}b#O8!D?r$wd4Q-d4HT)YIoS$WF~Pk;*Z>nmZfhguYV#Uy|^{vF_~5(jGTK(Y1*<@UwJ0 zeV?b=_BGlgG0Li2on+=2dd2BHRXS{IvSXA1g}or*i27T<-Twfa-@`K(knI^jkE__R z77TE=C)Ro%K2IpZr)FtUTMn+{qsAMHdY?wVo<~P(m}67S6#>Wr-I4mqKQ}@j5F_8vfoM)R z9v4zeaYfvgC-r`_^i0TkrQ-ltHash0yqY&C^;zA6qg?ZbK4LoFkA8_%5&8ivEe-BPqSZ1B1*>0 zEtVvX_MYu{H-0{skG3r3nYO@K--cDO0bKVNhm=e@n(c;ramx)Oja8JZ^Vk(5u-_lZ zt9DJHAjMG}Yl^ZOs+d4azR(y#n#aJwjyooo7#*~=L4Wu-LsAeM` z+mX9m`g|OJtgEBhud=8TDg{8Rn+3~6oh^1&B_(39Zrqh(PH!3DlO@Hv-+A|5 z&NmK+s>?QDs^yKr;sKI1_{@R0T?g7=lz1 zIk6_bfudNSTSzeCDupamK2O4fAYr8mfRYedhBHugA&L9H!jz#WV!& z)pW|nKWm=7HnEIGYyxAA?ybTxY-Eqd?w7J`8)C?t$XUbl%A+(=JMQuL{sp!Z8)Eh* zRCHhs@yE(^QjDy$&B$6n>`Joa>~OP-s^8Tu6p+OnxT3NQEO25AcV0X?zMVP--T^I} zQGU&EJ!dDp`@a?$;SvETyL8!{+cs7-9vwx^=~84{9`9g}Ej8WfdDEXd!_kEMq1>g3+nNauMJ zFAdZ(??vsq=l0cCqvQ1Z`wY7T#O_xa7m1a@&3)o;dzIjk4A~WhVanvH5^SyOa`iq= zX!n2E*q>?GVsBuw@LU z>c9|u;PF!hbFu`XO$xEcdJhVe8incd>-M~#`bn11;e>lL5za5&x8zr&w#~Y6 zyqThffU?G?Jcm!05{yWFC0I8x|8}??VkJ&S~fWy}_l!54No+Sxn(Unk?XxT$%1xJxXGZxA5Z$S)*>pH(crnpI1) zI`NexvEhD^Bid?A1{m>-k5kDTV(bVZf)|N$;^5$im#vFy zjO^%CfCA(@I4=2rGP7~8%P8X5W|M1up$d%jkxrngRpry`Jym)Sv&d$T5DMDTG4m+k z$L}gAcKMt7Z$UEh zUQSN~#oPW4J2yh|Mo1|oxP$r<;p6&mPGjBWocFJi(L;j8z%l3rt(HaCy-JWs@rdjU zWECY%{jZ_HccA@+;9d>bTu|kE8fA zYHgwdA6I~cX6Mg^S~)(o{IyzZ<+e?}XdRs!5Y3#6U{(4%%zOFDom7?ei;JSN^jrp8 z`HfaS?z!wV?+4I7PWks<+0BI3fH(KkH-jQOdRn?8hY$NAAm z!gc1RVr^3GJ4AajhMM}^fGkaiN$Xof`_S>sNI3Cv$93IVtr zN_r{9P zlH>&koq9;symOb0^2S(M>w^VumDn5DYfnv^(m#T=Y>dfof|pZGC_L2ljaQgfvWmsY zxr%FBnutH1wUi1|R8HOl&p=Wp`3)yp86^>bPNa|&dDzwAd1laKMP>{xg5+Mj^D&Tu zyhI%vu}R_N=wv}9noR*Dy}CLcDE;n!K3N-UA&Ns5B$9!zau(lVHg6wm7c8>zT9{0I zJqC((XjvBSvJK&5(T3Pb9ADvApj2-2Q*nC=`O~yK3nw!!Su+ys7s>E+v+Dk;te0(* zW!aKwad7M}8)&5*XT0`0zD{of!z+%b^51JbW!-vU(SeqNm?Ba*e+5w#VHf4?uD2=a z%N%dCnAR1miYe@+v_6mGqqz2=Akgzw>kph#N4H&uO;K~ z$pmU%`()t}iMtRGq2`e4NtiMOzc}aYy>}JItZ;W7K ziqZop6!_-wfPQb{5&IZnc&AhsCp0W3M z{{Xb_JeTIaV%D_hWLVtX^LHQny;rt~;)@iIcmmlOn;@*m zaT4+7{#%rj-To-Hzk}vT)@GMCzAg#iw-ZH&4FW=Tba>}ww6WmC65FRduF9NP%ZFJ< z2T3D=%#u4grq3FwVUH6R@?1D_o9`ZW-M3+o6}7Fdr{6cc^;Bje#bYOhuFVICAm-09JO_YGS2haVi&ic%pFm9F#0lb$(9@6icn7V5+Af z5?_g}j|ab(s^3$-E{h{3B@-r8RT!33Jf|nuNfLX%{BMh4x!%r1aknyGQ|K#r{{Sa- zYI}JFY^1uDZ$UxQ{7V#g`y7=P&|O#uiEk}}q^Bcx!x}41-;$;mScTm7c%-tq$ zs{Xc&Z2PUo&mp|1H-s-2MYyq($Mzq>pA|3#f=Lam{CJWx;>JgfvrX3f)pB%EBtlF; zv!q%vy@(w(-u+|N?c;we$;k~lSuqMTVyeIlnCX%{f1LZuEu!lx46?XmW1!(4Rae{b z=NKK@ZQ_LhmU7p&mfh!ra=p*H-u@uk#A`HYtI37SB{wAWvz5GWYMy80Z;Sbvw5l%3 zxf7GH3Z~0^k9j!xJUn@vCE8$E*ph* zOjsc>#T17<_2k~$84;|qZ;{t7M|Sb@8@XeX$ySIXkpi%~-$xo8$3xTWJbab2UKE#R zZ{ZqTEd_rNbH7_&ZqVxt%L@S508Op~1K;CX zC>RN5HUy%fQag$ZOTqT)nCfmTcg1~lJ48}Nyha5;yBqXwGv#Ta!50=Fi`;4mua(e5 zOUjY)Wh5-tmFy!}cUv-bydF1o?)(`0&6owyBNeyF(@I(WRFZF{J!O1sb)dqysqCa^ z4%CLdP|=zPi~*%;R&FVYp}C`Fpdz;`oq@TnzLnPe6}DBgNx+2IUs6v?2-DN6ETo(G zSmSNe4o4hqt61LtZsjTFjZQ2=xEWx3%Y1zrpt`f@+#qo1e=4x-TPD&_JC%?g7aZ5R zKIy+Ayx(h;y#5!f?D5x+ya@J!>e0rkyTt@B=`x(SURLEOHnoB=y`49afFnJO44juI zD)PPRp}r2<6-ZDVYA%M|yT$3$`(BZVGG8EBL^ySRmhgc!p?lBr{187CXO{G`Y~6?*@-84zt7vtm91q| zYan767P%pYK&yS<^|e|wcIYy1&;}@Y7z?$JPl@^uQf%@rR=;_Q$ z`x{aL;|j+qJsONWau$t6!E9dGs4Sjw0Q*@}8gl z$^k3t0;vwSc-(MVPfyr*sN@A00Jc&(M%Msz6QxpZGRYiL%0jSF!dB59u4e>{D0zp z-ihqz(@tjQJT(0W)KCbPFde{KSdq>^^l5}_%@N}W7QhSfJ9Ux9i{GhU7~6=bWzb-G z{+ADr;+Z6p>b4P__h1en!B$Py>xY6`wZCh6^r_xNlX(s$P4AYNx7DxF@!IjUIl+hv zbD%s&A7lRjU;Q)13j#I0?hWjK{!?#j(~WUh(Unjzx;7*?qnw%X$L^@hU({y*03^ii z3oEN!7Pfb_gIj}#x#RE@YZA$x*2Q81`CjUzcO899o<3JRAzHv(8<2ZQW0w~CnFJWJ z<}Rzobr|W@q@M4`&6$2UDDjjH%P~<2#gR`9x@#JHu0N}b$^3=5ILL5jBv_KxV#mKL z`=8y_Mm^zVB!bLU4kw&(d|xkf+HOAT$#ekxn^yJbGKlD~{O98FZP0 zA;rN{<~n~Wgw;@O$+XPz*vhMMW>AElpC`O-KOcslh`}m`0iq|JK^l8LQo2s3%yYdr zlxTBN+7IEcU49MOS5^i7k^b#K>cUorZIny{aG3ZT7UPG#?);|iwf2+XeVxA1nN1G- ziuHf^&5~rgtdL3Cr0O)wA1NFC+?D5TaBMp?o4Iy}Lu|SM?ezCeoPT-#nw?BPanAJf z`H%n5`Eb`6&Fs@AQaFbeUTYmb&2(jm z!o(XNpF>-!jcycb2b<{7sRx{k+zopTJSe@0xHdZRTV5>itkf9R5;JU=+Uf_&d-TV2 z)qYEXYQ~I$G8m40d9)o00cLQ(-0CgQ%rc`<1&5l%b6>WWsdfNp0qyOgUY0$g+Wgem z@m!d%p5AH%0~34N%R)u%ZEegAVUx^3AEvao5cAcf>yQ^)+-dUBrmQYCQ_tD<=qw2a z{(Dh{VRLJS{k11thbK#se?L7EFtx`%pb7)nM`u6btpgrz-eK9#QW!Uk+-s$WZDI~J z7Uwvg%Ny24%1wdd9M`9F%7oiOE=~EXTR>|SIvev`TZ_k!msF50d5PgX)CS;!K7a31 zZFLs74SsKv(WNtb$EI$3Z)@vVT~zA4`iB}2-_xw&Wzdcu-4Ew0LfjQ$mGu_>YF51l zWs}%gi}McA?IkuRC7lU2PKfT}+%=KM4YMV?4fsGFo-%&x#fnDV;OIdELJ;dWhZa<&4{J83%x#zDaZ{6dm;eB6ax4w) z-DKg;u-Q$HT*Ta+7z1YX`aVBm`8yN}1-eGoIeNbLkI?ZVo;zdDW6aj{dR{+ZSx!I= zKqYh}Wp86XhmVm}vEt2(f-hhk6Wi%}&mV>8in*yQ9@yz?Fdo(Aw7@l)VaZWEwGS zPd=J|D+*BD5N~nLIrXDwMk=0d!)gvVV|_czQ?bO|vJ9FFs3ad^T4CI+fM!#1lH_u+ z;NVA9{bdS6DRQvQlBg!^GVNXtO(p#AgQ(5m7X?UN4?WINlg5*&gX*m^Db-mqiGu#BC-e#K|hg}#Q%#Zw!C509tg{VbhY5P_@-w1b-)H+$Xw?xkhz z%bQGlac@`L@%c!h;Q$L2J+0^Zx-h|%=f#QF`~JrVk2Cg=teL3>*3j8>x01=T@%Ydd z?|WjauC@pAcr)?%MWMIye;EMuXhFr*t;2qAHEO>F?k~*8v+dS^pxrr$Ium~x`p*Wi zJ#C9<21pecgSlYWCAH{qe)^p(Mn*YG5Ct2yRXJjVq?}kkf0Ff;<80FIr3?rIQLen> zqvQR3dA5!ygJX9RT=SMz#qvnogSlRVd@DRr zOz2bu7~2(I9^)?QQ`4^=MiLZ^3xEl_kXGzGtm)9*r@Q|Et69>T3xa%=UENoaO`uXDeY zq;5E^q1UQ^am#q~%34>sAe#_W4GNHViNd@a-U_2?5%2erxOivAM8iYGRyk4f>fI*& z@8aB=@XhDS3HB%1rwbx0TwE|XmyGlzN?uh zw#H^MSEUcaFhvkXc9u)n$W)cM>2zB?*NG^{!ZJuNv&6CuP56LS9p>@=Qbo14&zdAz zJNK9>s{3rdKd__Pd*2vmkZ2=~F1Pi46oEbLT(9@}HzMtyXgfolo@86djBKn~2pkwm zm(s5-S95U^BDlKnY8VdcZ*qR+YVFS!*pp*)Ybm!4@53T;o0z>{#s2`rquL2B<+(O3 zaI8049)|&4gwoA-OUKnTI`8*=^fB7DBT$>94 zqYV$#{XN$6$2jiG5>#6D;%;;KUO#|t^kru`8@aGN%jo4%;Jxt$Q+5}AY{0Q#s&bX_ zMWw8fZz747s$HE4fv@?0K|hOt*P`4|o2HmDRUdab1;^uVX2~RnO4fN%w!V6DzjvGfB2Zknkdg zNSLcJQInpVFaH2uSpNX^{FXM($F~9rJoh}_)kQ2~DWPE4nJy*N;B0zol76m!Q+ZE1 z#Va$|4Hp}+*VFXmdTqPgI^#%Fc^TbEgt8HtQ)Wx~TxciQouTZ#hcoGpt-A%7nHm z%)pGwkPfAl_gxR6Uaau5D8%HpAgDGX-Tmz*86rX$CNY))%PMBqRdr?Bx*Lt&8{bBQ-7!kSilpQb4@|9HWI>(d;z=k{$=9v&FzYB$e##4651L)o;M`Kgz*E?$%5|T zecgQ1nRbE}2$;J_OB-UB=w|iyaw5pMvW6aJ87+95Y;Xw%&XP%`-c*aG`E_tU?}8Z9 z-fsT@vnPaZ5?78nj#EYgm}0R(qji#d%Az#c_H|>sR#?>JfSq5b8p%IdxKCFf`Xgp% z_{oMeNaInPqH#%M)%<2pTkAJl+H998GIIn$#BMInlI7h_4=Gu*>eQ=}R+7PyV>-49 zECxr*KYjds%ve;)a$AA{weC6F2Ca=rnSsr`HCrZa6N29~z_+uvG zFNC&nAy|cixNI0{X;+trKVk6NH32FC$ZAPi4bsTGE)?$)=SdT*+EeQ* z-S|u3*qI)Eoo`~zJbkaO^L-c9zmGqGKD#jl2LWt_e9g*lpP}_t>*-}OleWCwNoD4F zWI&|Mgj(8Z?3PY1dBI-C!M_R3v!swB0vh{};MzGGz5f8?=j8l1ktCeEnL0!rIy}dr zoffAan*L*eE^MUB@rSQaurlv zkH{;weWE?0F(gLnNgiXN!o2OfX%a{yQ^}#!gVJIC7mt%SgC^5$TU@p>Lu3ee6pQ82 z>%TAVstF-Ujh^gW*~_uasp8%KT^`jV#;ONG{J?9k(*4~`u8U@pLKqT55oT-fVdMQ7 zEsEJHqq9I_5!-ii8JOzB44WsRNMdO-!~uKcdp$4vl!e;jk&-faM>h#j6dwNdvy=0G z=h4EFNYW#3CRQP0j_v}1#k6kY^z+?#)RNCNTRkg`D`cbq0)eMLJ5!@oM6&ZQM1M7C z_CJEqo6d}~kV6}e0K6MI``qF3oTWNVf>?T?yi#s0Wh3ctD>-BFcGfw3rDs9T`l%m} zy_F@TfXo|QSTMNI-Fa-${bwgZE+$kea<)uL^8tK#{b$x!0BjXW$yl3XmpJeG6-ljA zZucIq!TWm2gin3>Q)abi3oCW<_{=Fk!YqAiHA4SmIkkIQZhvzh2y+-u%_l zSzN5J^*0r#T1-Uqc2laRcqM_LDSBz3Vg{!{^6644LXo38n<|T+9zVCs+P1W1TjJ6J zt-BXEaVGDS@5#K~d%xoCvDRkDFv2myb>b(c=)O~TnV+2WK)mU(2o)>kJRak@hs zarJw&-d1hfQxe@IEZf;kjcbLzh`rTS;ZLlok|`{SyI@!X(SfM5apb=BI;y$hndK%| z8A`VEWsesl>?VFYRw9%P(qbL0lj~xtd39yWzMPlt{#@;?3?dNMvfoSCcYQjSYT7z# z03}m^mEK*6UhDZkSyqJEw_)UK5-*22^wy@@=Zwt`C_o^#UUBNz?>BqOG5kv}!nfd! zzFAJeXkJYG5dwQBh-nnhUEdwbUVj<{RJ$fhfR8O@-krr!Q3)@c~K7Dr%*g;cV6xDdfR5n8v<~Pva zytORWH^HofOO9gSPOF@dOf6W)Ru}c z2~b3UD-d)Rzop~*xS4s|diV#$97PbvM5z^&b=mMTkec%rfA?uiUKU7k;_&kn~qN&u9jm$ z3bQnU^nFd0d;2^dhmXN@8#>8rf}68)FOb)yoK(r-0gGQ0A0__)Arg*tvo*aDHqbVR z@L1aEX4gP*E0+}^r=8fMDy~dg=s1p;beq_!emAp1jT#c|dmS;v$;TRUnAGg*fHp0I zaIqaW%cr`Nk-|9B%SZKa^0i!)iz11_D@dvWpt`M#JDx5#UmwwIdpv6+Z7#UGuMjrj z^r{G_+8)L*n>0mkN#hHTT94K4{a^mQa$Tu*mc7_)sw|x8y` z>HS^_nqLMRU=fusZc}7ZXp|A7~+k(e;4}ByY^A= zbYX~D8ywo?t+BB?#F4Yz;U01F(fY2d<=vA^&n()YXGs%6w%(=mQ849XS8hI1RCztz zzp%nayDt63z)-Rd^_z-5FV)lKNivNMisxx$L%^%$Wr|kyIBq_(^o;SWQ5fT4ENv$w zR_(%gw|hr^8B^49f85QtO;07AmK7>-ZS8;v-3~F(e;?&8(6LOC9o3!<8rHZ@uZ?}x zdYL9+W@8cyk)}>@81y`!kHIp{8#<~2?AKEyX@+V!Mt!X&@{m9O0H}wO`I%_45wu23 zu)r9A5R+>Rb;17tcjL5nrf6&{O(4#iW*f1*cTE%!qZrokX=acnv(*&|Bl>}LfDb#bXrBTSC?E={p-INeiH*e;>Xn~J8x1QS7wWLmhlQuA> za%c}kcbNMNR!>cY+FJXk7r=YdmDB1n-uFe zb>`+m-CpszRY7JP0e{{(Skn*IZvOzBn?j5Q$w|-C!}PKKoBX9U2-6n zewT57_EoX$QZi&4BAqK~oyVhIS}9PWk2gdcUlET_zcc+$=PI|`B}5^twmN~y@LTED zDKun*6&U5O53}5?ZIv-u0xIH5Sw)Ly!uk4iOSCfrP=^=d?XmRoq79)2_Bx6RbTKli z(%@OXPiCU{Lu3yd7-Gqao8zFw6|YUCyqMy=97@>aZhNw0g%r|V3P^7mwo74a^t_%g zhmW69>yc*OLq{NzIKqL2(O&mqjwrb|)zxXT%CBh{Gm=1H3ApK18$`;xr{NhJj%#An zEXFl?S*`%FAfL`><6RS(?Pgd6=&r?H|YX;hRENvsNrI zA4`t8J~JUv8W{$J6}Fc>Je`xFS0%{NB!erG1?_7f0PeWH-;Ya7%bK>geq~TG;PQUL zRg3U;(4$hWTGzQF5Th8?cszdH9|kZZDHTVUFXi~~Z4IEGXl(+Z$%qSkkiG9?xm+cd z>${PaN29v|R|x9Y6k4)>hW7Ad1=}Sjc&la06W#ItN&f)kwyr57klr9(MrI~K z_B^KUghwmG3p=XF>`{w%k(s3Fdc>@YW`v_WS-gb0A!4IDb-&2*=he85+qI`HqiEjW zdQFZj*q5+Qt-7?${G40A?{3|)?2eK)&Wx&TvLGBg#_z|EDIO2j?>Qgi(kv+)aqiU$ zkb*!L=9t?k?RXn^59Vf$;GW6r_+*1tcsmNkIc1;Ga9#OYU& zLZlO|@A7%*&CY|fpX9X&P7uI`Uf(!Ai+gpc&Y;|{W&U36I!=r!vm4|&jWk-|-$26l zsqAvF=l!*-BPsb$Z#b&Nb(M{eG5jM>Mzow!EOpe2A3X?=0dg!&XauGC!i^f8XA-RDAvYz-n zdTW2c)hWMd(Yt@6M69E-e;4ZV(N* zk2(I2kI>cH9YyX&oXUCs04g{A=lN#FBs@qOo4 zKsIPyUO^LDF~-ogI*uIKIuxlk-M+vxM;q}@n(tUMaQ{%gjI5lUPT!Cp4CiQ-I7Tt2Zp zeA)*c6p#X5-~x4ZJsk7>G0B3c!LWyjyU#_+^upFM=GR7%s>atdL{x0RqlGZGP8}Bk{?OOOaQAcN7#Q_!cZ$gYG0kmb>FlCt?&@{xEqo<2`D4WOd1us7(R zMDXX0SSF0NL$jh!5}Kiv^g9B_B+*T>!XbeP-f?BeY59zT#W;hgkI*YO)c_y-Z@rfEh82sh|?p-&|X3Z z!nIM#1|f;Kur~TV?;pa6FO)FG8;Ho;tAq{3$n?CN@h2rOh2kYGlf<#(bBj1OE99Rx zkCpn_KB|bR5yYe(atH(y(Xk${>Co*ln2N}nA-42(YvuiikE@oc2{K4q72)7z0{r;> z3fEa=Ha!M5IEA*M(~pV%`^DdE?I=cN>5#98ApZbye(s3v8*r6?ase3I6|)55k31L2 z;mAwYNBepb!uvMwc_Ka~iycP?qP~JrZ;gV<3WIW>>(M2f9uMAlH|R!_Oo+gX3lIr9 zW%l#sB}zEwc0?({5xAqaA(N-}BP{F9?&*Ecl%>Tg22mqE zBEgzICQ)xDt%oW1l)sD5JnaV0jL1oBh$sTb%JQ4uQ)ZLNa^6yzkI1)P9E4U zC|0{MI$1|(`Y5%|;@qnmM13Dt2Ow|SxllgPmN*NN4T1T4nEQ1LOcY#dJao8xe%xnc z=RLiedOco#KTX?C;%RRl7Z@lEx63^V;f-_S!lj(qEC39pNGv)fPH5?kI%x~ptU{qG zq>V7^BjWV@tqilPGOKcz(@~DMUD{=i?p~LU&bqE!k7r|*oH=FOx)SD*W>lwu|3gj>ZDqn{-SmR!=)%d%gXpQZ>kO?L(g+_+z zGl4fU^z;7!4!HK+*Ims(?J=?0?nSpOEIGM1#}-e%q(<{f3}(R0;Tee39m<#yS%W*n5VAzK{}5$-RK z@{eQM0$~teC{A4DU9jl-&8MiUwx~p6c|)KLWsWi|19gv|PoYsm8V#H}j#fOK+#6>b zbv^FAAMW>F@|<G4shV+ANp2bi|X=Bn&z=?ED?Htou6KM~cy~^JCS1bH~f9 zd~4y1r@xi5&lC}vC18jxb6yX^GBlgPvy&VTcamo|PHeD6PCi#39iH6LB(aonaoO|t zF|U^RMoHN>Qjm%cJHBp-^3r$U`?#Y1jjl(fUkAFNNL6(?32ZfIBR%HTC-q)Nw29@5 zXe@>~9D#{BSaeb(lkCxK8$6I~@)Q6$$$=)pn2#ry9_x?GZs>v{MI3qEvPXS+`7P;)Hn2VSzG66#)5i92uw3O2IsgY=8iCOY;x4Ll*622nCj8DhS~@D@oRU zic^VDgkTZ70{6S!?;pc0!$_2a9F7sl#GP90-p04!$KOY}GJ`mO_I72y?UVYQqjyph zb^|Q|xqDN#0{c1MEtp>{NCc8OP6;#XB#t`PJTsHflQeRraBCBZ0BhBEe2)hu?x{<> zzyoGKJ%BdhRIg2@*(FQMsU@&&GY#r9kCRE0K2AR?A1U)8V=}-Yh%L#p79)8NluhR8 zW$W;DOanw8t4tk zZ$10f)2GHZM~rjGxF3<5px>9{%;A;CEziUJ%Pb~X>iK+;U}SR26qK^*%vjdow^B(^ z2qBPazgLxJ+Y=_z6pcEwp;3}`f5`FXcFVKLD0helXhwJwZXYYVPb=ffSJ9bEHt^+* zlu9Hj%#$;M;91ueIdG5dZ4Z{9Cu_(`^Bk)Q!o@Vb)Bw@wC`7cRB^ zca*BvY1xzsXOk$2=!Ltx9y~W6td_PKFe2<|z;(%7Mg}==Tg;_{+58kMBP2vm|QT z^cUd`qiL{m(ky|jt-_6mB=uI$@Pzv^yjUy47*MHubhuj?dzd|Wtl0qVc90x(BoAq@ zrUgchNaWy-U{y#{lAiN%b)TI703%X(JGJ$jH0eAmDqJgES+r$kI7u7*ThlVH20gZD zq{btT@D~cg>ck#f-lvbW?Wf@vRQ$$F6{I{=hL~gRIEb7VXh3W*>>ZLa=7Q{{b%trKFT(b_?7s>3;?+qKfA}u z)^6#X1xXP|s6bV^1-12F6zTdElc`utmpi~wkPh5R=NkV2E}0uf(n%PMqgXjNEDrOn z{M|Qu3L$I*vIejuXgl36p;`SK-_@TNEEE#QcXAEws~={j8x9c!V}302ZhN0fu4!L) zA4Dkts0X8`MuntSMz;|bVxV>oogUI?_DKi~Y^pEWZ?eAbleLnYrdyW3 z>f9pz5_`G*A8#)msoCRZ*xgZ10PewS_srkfLmQa81U#rdOw#)_{5C4DFsG^jauW{`p+I};*dI7h0X2_tbMG#KdhuEF*KVx zpfRxG>>#o7Uf-1QrhhiH$mA&7$cc%G;{&4W(%xa?`)`Ng0!|YyC!ZnE3*6~ch9^?F z1Cv>F7q}e`j#yPIu#%w2;RiV#A^`UMe`785ES|+S+TW1NsxWQkgQ;n&HNX$rk z5(&Mn)~DO2F^?=0n;sTAme-H}0A(w7tl_Q1x?K6IH!fP3w|_GthJ66g4jYFId@3Zx=jW9Fr8LU}fFBVvzYR1OjueP68cyQ-kswu2fO9A#mZw98H> z`fp5@i)1Z9vEc-3&*!Ycn1bBc>IX+k^cpp?j%!y4ly+audb=p^iCQyZYk|#bmE!@) zECt3R&*^lR$ME!|F%-KslzuAno>W(qMxx1qYw4>W|r5dSq3DMPL{~F zJb3$b^kE1JNLB4Dbw5EX&x>f8?@%mdH3lqsF zNQ*N7a{EZtxs|btj=Ut=Oma(vjH)fdr$#I2;svZ4zHuovhX(?DYmygb0-&%(79$Xz-6&58nX~=3tQG<(DC>y zc{7Neek~kQK&7|EiFpehSakF5aA&*RcoMeHv&5VBP>K!+?b-(^$j2Y8lgFK(88RW3 zPy;IH28*cA75@MdU-THx_8YogTbvzuNW%vG~nvlC=O*Z@GkM|$eqK1Y8J z&-Y)6P&*x*$i(ChJF%+_L-&4#5Q_$z7QWFFi5O#}Tq|;Z=3mZ#9z}iAG7MRn-qz%| zZn#Mj&(q}duDm(jg&7xTV%MJG$R^gr9-I1pPMNW@WR6@A(ey_Rixu6?{Egj591=mX z%@YxjC2!e_1a^`?r%c%v>>o!V08lx=PJ{lkn zp6@ML$xv*YHqR?&%oUi44fMMYj77<>)2*GhP6-TcWjF%dORpEk8Y-dK8Kjh`1&}iZ z7ve^IS~&cltM{~xqiM%y%M0Xh$Dtof8Z~dVxG-1aag(O!A(uOG`p==&(ta0^$I315 z!>IFRmnY>o@U7X|lWpCw>pkTWGT5*|eFj5`jyyZI$LM|E>;#An zokEgvdoel?Z@hH+zCRt3NjAvhS8LiE41lncnH|%khJRwA+F)pmsiQI|A^o1W4nK8X z(VQ1xdCkrF$hgNh{F;Mm8SO2)8_n!j1JY$y&+6-@8C~l5V=QKoY|auUW>=GVF1Q>w zShrr-b+LQ8_oMi3NTLDejZ`7Tvb2-8rz>L7!PD{Q`Q^mN9_u@9h(2P<**8YGOUauB zVk02ser}d={fxlPX;UIc@bi>hIvyo4 zoK%ZrW0qLs&5g9ME7SnfDHekE-!v`IcvVpD;Hj4XF~MHGp& zdfuNUG`yu*vnxa<#DEJD3s`9G=NuJTnDCs+r_LFR5ep{{Z#T zE!osb97yn9GQn~?wq6!nAJ!w`{WQ&PI1_&~4E;Yyh0AK8@IS&G%JmCoz z>34qag%$cObz2c_3XFYbKC-Q?Gva_Et09Wl!Ao(cL!kcZjxrIKFdEi5Hy1Y!S+z6c zGm8dfe#L1le=~(4hy;_&1@1Ki=z4#g{{S$kG7>C6TYepHmku;dDr5Mr-Q(SZgpf zO%WC)FPmQ6-cgll`$pHZWWga)w+B*Cbr_UcFAQiDj_PuK!U>@t~pKLv4`rar?Jy8!$7vz zPLJ{_J5J7ahgONh7)fCnYcr|u#UFolM6$#qR#wP^X&TyP#NWuG-VB_RwQS%`zfp0H zBzD}MPDJ_@4Dn4Oq_4#r8wb6!b>LnG`c>%ddt53=2F#$iQ_?=~i^udV2U-3gPY;5Z|4D7KJVSvDoGR(s4E`(izD3^ z>Byd5x=|rQkAf0&=gS;;r0CZ?5uKVdJgn@9hcdKdk)~h$Mv-MSxq~xye0jnMXsi z3A8eta|DNsa&I4T=l-XbZ((s~HYlC72s&&%+B;_btQuaD3rhx#^d(8c1HaN_IVo*w7 zCbw^(xXIAob(PY_PAhR5>wDVQ=f0PF(ZeS76@3pW<-C_EMc%g7K2XgItcM^= zCrM7H=+w^>qN|IAVR38!04lwogdi}ukPvbQRV0l#cpbWY6uA;ajy;)}28>;C*R16@ zeT*uXh%%FP<;Ks$djbKrotLYliHovG*Ab{D+{&kmz2zw^pc03Ot+eLG{@itpToyfr zY~74)#pA{6eS01Y3`E5e=j0u<$p{v5Mt}f)pF)!v8`Wj$kafMUkB=ReMKE~qbr?9r zjX#**)zYCMGg}K-c5`1^#({~lYQ32Y{{T0~@OVTD*1k3ii05B;^gs|4qgge}OB;G= zyu-)nctr@eF7T6OBwHb^j~a2WlcpV#*DGsjkW+N5+SNnE5y8zaK?QAzWygu6?p0zb znk?)AO)j{p#0>pDSMKAMm>x-uOIVBfzG~gti*w`qDE5{)t2M5}iruHutcpX37ADs{ zoYh}|6ON0|aj#{~Y8)=A+ysl?n&O>?Ln!QHCf4U|e0c3ts$+8$h+~kQSp2;Tv_~S5 zi|JxcE!f`9KE-m}oHS$Suv2dyQn47Q0Ci;RmbkF#G<_N}x4T}#;O4wS-)G1063Ai? z3CLY?ZhKn5;Zw=7xI6%=>%zk4PW(xaSL5fWX4;xQNX5O0wF*d6bJN{7LQksXzp~kO zaOIj7W)Y4w%J?3V^1NLS9?>q2o?x2&YKtCxey4sO?udeV|iL@!i@E&^*8A{tPN1-HZ*KYp+>8~awmIcc2dZ$e1PMl6!UU0pvRMj;k~}A{f%!|09zQ`?fHe($4ZANzAqvG*U3rV1mM6;af9dzWWQ$xVzNZkU zHj2k?EEsjBnBU6~$tSv5;J^IXou)`E&JNK4D_LB_vx*-%n+} zhmV>-WRY~zL2E8t_;+`Hg?3J)2%}qDrK$T4l;F2+3AMhaOMXK8&mT5fk;Yo72G%>X zuLoHt1hqw?<&m=jH8GKEg*@%)Wz&!BQGnX1bq;yZ9+Ce56W*mNDAhQN7EpEO{+m}z z6mMd-H|8uc%HFB__;~qXw`UT&O9KqEQDWSqaQ!|yk)v|WxfW$OwTT0nvQ9R4dKlhR z3JYf)_)Wnh?>$_eK3n2gneqb=LCAA~-OKVn#HK+~jhK=S(Q$qtUMcS zjmc0!Ud)F{;&-dbb&=y)1dDsGIUt{Ri|+pO{XFTEk^-(QPONXcBi1^1Kg0bmv91-P zAXYH20@);NG&yyvhC8hMeB&O^11Jt`(lx_c$OYHDm~OuB{YPXBVqk2nIoJBCWU?!w zuoh$|h->_+!0JKXjI0Hl5tYH~FkYL`N@6Yw-Coxt`?biqivm#G-%cN*^w|-g5sI<_ zbt&dvAFYm0cS{mT8A}E^3|RM=o{BPhbbv2Q78rnU=K|h`topj!lE(UyG&j@7lcq2& z9ov9qpzjf4}CXvVE9jX1Ur0dv__*z z8Lc8^!IDWGdEXejMPbtf6+P`Rc{BRj2^G!Ct%N`{jcrv+BuVgO~b>&AYHqj6#f;C~pEr?;&$?pAEi?x;}5$~qnNQ7=;3T+wb zVJux-n(_aUzXy*dMLSKk>Z?;s@=VN)Dou$9tO3i$nsgcNYCX2#5#6PBvJ-zmEB=-OL)TO^RITb&dlur|;( zO_Vhj)KN(nxXU1Ks5ZuKKh{$sA!fu2IlaMMEPLEuKaOqD-2g@52ZY;8-;N}#$*hgP zbfB1>;0i-B=4V+Sq{M$6#U-KL7#PtGP%c`_`0fvoW z#|-lH`22TsrllC!$w8H@#czIy9Vd%6s-lk>UmRo$7Xq`D9A_DK%(Rtl;YefvGdixS zdrX|Pd3inLU#$F9ENGismeNLCSaT)DSZoCc|I?dF4ooB_>v69?M<&|0!smk)3 zpDiBZPnun(X||cwq{h%<93}f*VdP_yKW~=af4kQ<_J#JKcE@+qcnFO~iLXvP(~d3Q z>#8YMXPW4yL#l#1RdLDq+uN#s5ZQ{(aU39Gq*~Wr=1az&KL-34v&`F4q6LyN2y%Gi zc;Byvecti*NhHJlpMSgF&kw!Rc;i148D;1j5GyM4+U=nl0mPrkt=J`aq8cb{%2b^_MqWW>x zd<#9^Xj_H2Tp{PzEagw9$@uAeO&YqDVZy+IK7JOgWfZTRG1@kdGN6z#2PVDM-+McE zD<>fC+#V*7Dv{3pbQqRfX{XrX@%wpKXu%{~ z%bN^aEp)}Yzvb6$8$^=1C;?jlJ*#)CC4os7)Ug*Qim4? z2aw>v=s>?e(yd_2at@(Qo3|w>%Npa241fb*QLJt-CDE_a-R&Mfu%c0L22t2{7<_#SmkV8dANIxnt4p=%L56H&nuYnHrHIW#trYflgG&_ASwc}2FF%q zL{s&hZF?|Bvu$ytq~2I;j>FErm-TeW&95_S*xE8yceyGFYgt?=$r0W2@28I4*`-Xb zO_`fjJuCY@?w#ubi}19AxkTFgLZ#1FIWZ+S0u_OW_< zZfB_V+DMg#M%bQV(XgyZ(FwKjW6ym(mTL2&mU&|qumk{044Lo4-TD>VQ?zPwYk&f` zFvONjB$ZnZONVp!m5I{y*V;ktkx8{jBf_J4jAD{A~b&_E9oMrpG0VZR91y6G$?Bv%?$a7Th?uJ395t!rjXA_GN}$(bst8HD4R?H^eC`a63w~pBmQugaYGhj+gp1-7_8nTo{<^nl*E=NcXkk|<#FE@B;mdGK?)-e{-#{>pVF5#y zHY{C2@vFxD4xj4kDip%O6L zJYm5st&nL-qaIE3UDoON>F~^oB`X+Lh>I?^?(UK6HddvAf=2+N-NL&CajP6lV6uA6 z)HK(OG|^n3A+AnEn@nZT9{&Kn@nf1{5Zs8y;$v`%ab$b%T1eTKzn7=~08d7nF7@01 zTwM*#@zXuuq2tQ;Gk#1o6wMz}A`O!baSRc8Pp7-Ac}b0(7V1YmtUZcg#UUyMt~r!+ zi5|v$r;nauc8`gwIysWSxnNUE@U|aUS(9e8O64R3fJM3av@g8=MIs%j%F#YS3d!?u z5YhK&T%>SF%B)T|T!9B(lH=q1C&6NbobD{RfvIcpbnC3fa>ciYuLxb^vom4>-F|bs z{)djif;2G54-&dMFy^+e!}+(U;*p62G0^G`g{sJRW@#0Th4lxwF~{xr{8+f?%$35) z6c9%Cd}B^F9nd*Bw`)`f49r&9J+1 zfYX_KuUPGL+tS*dH=`EFr~!|)?AMNV$4`V=OJx*{6~lB^!wyUG{B}KOTiEcfnvrYU zrPA(KlO&NDUfQT`8|lOJTe(`p9qpX%qtI!_ZtxmQ=X#ampCAziv&o*p|A@{9{)G((Scfozp|N-j5MWy8{$;JV-mWNd-G z_|9BkkxoGzl9eYAaTy-c-OBPdITd0?eGSTj!0WF`BVIW^va5N^l8Q@=heU4feeAOF zq)+altXz3+ULJsl9^g0>g)_G6$AB8rgMOY)@kVzW`11oFRf7+&F%>Mu|w^Ns>lZ?(r?QWfyA-B!MF!fN2H7UzZP%7fu090#H+C?2(u_X$tlA58G14%nWVjrA-UVlZYrW( zP^dOyEIIkl(X3V~pf~}$EzPG!mXe@yx|i=*9^W22H@!}47n{UQjI78Nj}Ere6Wzn3 zL9_`0#EueeC!XJD)2hBT6Coza^ZbX4qg-Tc)rwpZqj`^aEWCdnc2Ab&zX+q<1Cm%7 z0A|xty#&wh>WLCK&`9lUbsF|`cXVhK*y_0`;x_|R^mBYEE3+|ti2^W2APllUNxYqA zW5L<+%u>iw<+-iqUrp&!%H_}m8*9}ZWtVhH)G9WewT8wlrm$Lw6JP-b)~7~ zoPsAvFP|b!ujy2r1(d`{R$|6WW1{vluXo3=)qgWxk4@%wFe;?T!64zj*QFS&K z{M>yxD3I(xv1e>CiC2ZR&nG;4KiKi|MYhNuH^_sKD~`5P-{GrfLd>G~RW`8fq$Fz3 zZ4MmgrDEPZcJ#)3V&6|W)xcccb0x&ipw~^rS{h*R(6U3HGS=5wUYB<|> z=v99R!WmY^)|VHyycghqtWHjo@O&+lq1OSvdt4jpwW(#BTD}Sqa%z{wG)7AcYB9=} zvPB?B(XJR~ji+Ls6i+TDOI(XcB@*#OLrIqk*L`g~i+{8%U*eJABw|962)}#Q?$gK5pEHPjURx?j@s^D9NMsia(0iiTS(LtdW- zdA@SX(fsf7x2(Py-1%6BSay>j11w79b<+8-OVFa)CTFrdfuJ?8N&j=*91qS38 zW_-gcx&w(PBYs9t&Ti|~EwI5J;`x@oSd5!5OM{_adD>HFvY=@ra8Z=0Q|YcU@|^_U zc<6_3GI*4$x#2ctAlw3b%aY>t;Ve8oFH$Ft4~#=EIeh`QevQqK5vy2)ee{Z{Ys zKWX?e2;HibiP;$mzFBoMFB`v4K2ycxi)dAfOu90uZ7vyGrOEMl{e!VdFUBK@pC(|! z*JE&SXV5jxoUD;c`7g}8xRX+%=MAVGnG*mEEz*#rme*kkMAkC1&o{Ckm65o$C{eeT)hBTJ2=)w;UwyBZk;P=k$+&NZ>L9T?izy@&Y|%wZBJ%a$CAdoX9U0uvT?@v-$Q=x4fk?V67m5uvoK`}G_n-G zCoe~*Cy$e%$SnNU4}5g)_x}Ks$@sISwXulzV$L)c0Q7XA#sMa3NjS2Rt%+_q zh`NRjY%G1>PphR$s`06k`Na1GdR-1K9)zAgOu+98bP<5i9BfBRAE!r?vr}c1lIB41 z#()!Zk&tD_{JJ!WGK?`RYiXf8{rtLqKLY+YFboF>#_GH{f)AAGzbQi`%pTD^iu(xj(J(a7~RRaCbr#h~iihS@(Z4%Wn3J?oSE85G|)qD7|)m zJr^YHE8Mm21*``{zn4KXt95By9}#G`J4_qG8J!7TOOnsY{{UAdA>)OWk@~j>oDwbQ zW9j($^ux7n5nzN2h(&j6AOq0D$MJOAG+Q)cH=J+C;?bC&cM@aB{bgXLjKPq!O01w8 zFGX6}FY-BB2=-+&G?1?=DdRV?5PqFBM$AY70XmWnBmhv7sJQt3#7S`n?cMBcjyEQ9zWL7qXmn0CPv2B-G$-Jlj+IId%w_KDH+#OBV=u; zzoGUig-$naB{wWV^LKcBe;nT}PXdPqj~BA(b}Pk`dNAc!&`F7Sj`V67Mvgrnte|ud zvSdY_Mfv&t9xk16k)9^RoQFY20W3WlxLm*g)AFQ{g(}=ZhD=)bH{)kl`gDda>R6{m zNC%m$c(@~aA!l+~S<*gXdvtMC;{9gt!Lgx}XxdriJI4`@Gp7gAjW>Pj-s3v2KcoQ{ zvMtEx{!JFj#bpO9NErjqi(PoDr!5qg13+=fLCWLl{<4ZSwk3F;{PrLN>@)CvXNy%u z%^=z#k~g`YJe^NtqhY_rKHhaK$|p1rj~3<^!R*#Wos0TBgCl7U>SrmyCfD!?a#EY@v{)I@g#Nt34UAExM+VxcJk8Cha zVoDY`v2Gwf*G`zBS~8k$-H1zpe|IG4zD|()I1P8YvFMiQx6t*RYgxFRte0t1B#hul z2^MrVmxOd4OnSP;+F3?$aR5PJ+%1V5q;baFA9)=~!cYqZ<{9%9$A<>-y#AKT5HjH- z5COP*UjDBRp6&hTll>-^)|eogHDt{lu6g*Bj+}enckelAzOs=n&V{#$X3%8T#Gd9- z``X2M1bMg{vOvjsq#$ye*2@k4emwfyK~e}BFgFcrlD+OPAHvgH4VY+D5(VtQHlN(_ z^C)SP8=(V${D;u@b!0FGv54U!^MS;z6ZE%9u~(EY}D^IklMSW$#kP7v#u1;~R~Ak9p(p^0G4nsAT6}W9ag2 zB}nn`jLJ7)9IB!NQu9~en|eucR#gNr!W>yQJ-^>u*mf{h<8!Od&q-Tieo8!tMk z5s=In-EhB6ojCV@cggv0kUJ>i8>5=rosGi%h4K6tU5^ULH+It@!}MZ_xHq0YVtG+w zIc)r-hWq3tF}eQ$m+}4WQiYT`&N4;A@L#3LyGNew-}8UL+R;v6HfaDt$5mZ$?(zC@ z(-wIED;~`V( zB`d$^QWBq&Y;CFl^#}7E|VK9+wOLa{{UQ?BP2}hxD0GLiPOIPd&smG9Z;SY4f5El zed1Oq;v|)2Itz>G=~LbE*^t~BRFTYgocw%P;q{(Ur?DukR}e0sjSgKbWKRD8kHIG_ zg$^yPxq0~3V0GvEUO&>16cE6R5o>h0#~Sx~bb{j( zKbgvEp!+;4{LdY|Hz>>vfnXS7EWeZE`wdh=`Y^!TC)5t^J~w;6dFCku#F=8*+h04`S&arp`E(MHXshZkrcC2Cul^~*tWEL+%b9G!&1J+ zU*D=Wc_C=GMKaB=U`RQ1abH&~U3SgEC-#X{8$5$Du@y0!$Q)d&lF=kl=gItIlkwX8 z0Q_C0EK_s8aN-e?&?nBt* z2S>ffS<1R&;WpE;zLF^x?G2_^@@d0HkT}HT9S;dDqek5)f=FiGdOdfRt7U{TaT^^# zx!0fBPE*I>T#ThMEC^i-u+dxW`A#mY!8@01-vt?jZIYgCOEubCpDUNGu1%v*-1oh% zJASLj_DLj`aF@Y*xeK#n?>?Wq_kT&uR{^QXfjS|}(vnV=?g;*~l=psgJ8GUsx@e)T zbU-n;wu-d4XNoYPlEo}+IC30Zo<2XZf;J$kMpIi1yKU|DRJ9HyDnKOjH*{mte7O96 zglGx1-AXztF(eW}(*FQTwVV-PKq@bLTEut$wz)FtGia4$Boa#l&31QU%c_=*g~ozwaL_25HV`E@dT}6m? zu|9%wootUqf)okEadK5n+Y{2|JDaMf!mps(VNIBPWf4Th*a31-2{X@W$Je@j*7Cg6 ziJ3R^nF7hciZumzx2=!Z{cq%v+D_ZCkT|xDj>Va8%EZHUEU`oOIck<^^7uPvyXcv! zfJw?1Wk|M(w|!iW`^dLGuX!hvqeE$#SVYkf1IEQstz*geKHXdJc-tanlN(!GVz(*f zGxadmJ?J3>=0P>^Wc}P9% znKA*iZ2+v+5J(uIj@>wmE6V=>R{ly(v-Y##q$OWyX@o*0Wp-dJw;UTs7VnbsoTlsh zGmp4(`Bm_S(HvV^G&T(z+G0j5MVFCJvE$9@!z^<~B#7jPnEcf#+2Goyk~yMi=24>_ zD4hokJ4{&eGyPv*AL|@;IEKZvG*XOOP%$^;nU|%w^d!TNmE}IRnl;euY7U&;Eo7LBYv_&omQ&s0y zeFT?2t$N4(=k3Ouk3-`>1j5%4Jgca8a$6rm2-}1FI`d?FgUEe&%Q4wau0xma*;e-e3tuCSQs&Lo$%iB35#BEM zcBH%DO_x)zaU|;gvTr5(9SyoinB;767~RK6Ld%0%nH;mqgPMrnGBz6}!Y;HhaEqE7> zr&3F_Sdv2pwZaC|q596Zz-&cu0J4^1GS zSY3fzu~IeZVz@J>2l!Q()klWYvd$VqXu}bixyt$S5>Iw{fIg>4r*oX-n%lat=XNsvetp?w3cjcbFCD7f)2R{lKwCa##a&8ypX zy{Bx8aFDVH*r)cvdB z4Wnv(<2a30>XApL?OUhHc}ci=Pu2C6XTc9=cOimC0+k00vh<@)vN+S$SBi}u3vCU_ z0F6hZ-Rak{@KQ-EjHoeC3rQpNZn*B3dmd@}&+^7|hVi_;fW(t~6X*W`ri}?a-7e2J z(WDnx&kL)A%j{Icio_mTMK z^QGCPcsf}~>&ynWBL?0-)z_+Bn`tYl-YdQBZVm;DJd@UMymKRN6Tu0Z;gkyp=#?aJ zzU>#<`@h-g>E*3sw+!*$I(b`65l6KQ*y3%vKPl-ZIMZi>2;z|38)_FzYwY<>I^h}G zBP}GDl2ws{$tDM)vOJqIOB#LV=lgA&;61HlE|LwVMHX~m0Y1z=Pruf3eO-6&##t?K z?6-Q4AUP0&lWJ_>5C$v7!})bY54_E#yQl*aM=%^qmOn$sWjklVe-+rFb_HZ{$^b!B zkmJdHPV1BVzs86nDz_a10KT>&p6Ql5ly=!>XVrq}H` zHzyajC2{qqS?9u;gEW~)21JIr7uMNye$_TDk}fU&fS!tIHSqOx%Eo4I_a zi)#-*Msx5|Gj5_|qZ6pT&Cv0*sNG#OZ7;o&Ma7g|Z2ZI##LMdI&oUCoPZh>YZaO3A zc>W~-EUFs!b|jOc+lO}D?*9OfH*ViSd*?PhHo)-`{bpY7JS3rn?9t^j$V=M5 z+T16Bx;{ULnox;K3xj)VMqkRR(PZ7BP%N4Y*j>85l6o1{afi6mw8m8K<2#gXQrzdI z-R^NvPb$l2AP}Ifh3}~0b;h2tSwwM-+BL2$aw?-fisj$&7h#>26&Ykr^u${s-+q?w z1g!5Ee)V~!p)!E{3N^SH6aN6Q{CWK;q}fwvA|U$YMwpD*eUB&O=ksKNjO9Sd?0ntc zQ@xLqs$h{f3ao6I1<#({{F~L)l$TXwY_=~2L@|cPFa^QCn}HQqf`U1u4iVyuv;N!7&;hLI`)BPnDaL^;9<42;5b4M14)kYiIzL@NTGLdL$gg5aqgeWes_Dw@f%t?2;-F)vWDH873RAq3dJ5h z$f}l=NK|f+U`z-y;=0KnE7j(hb+m_;AOOWx5o`+GBTzVaUAlC&nj{d$tzgFL2Y(Eu z>pc%2H$83BZ$dd#nUYJ305M`1Tc*7arM#|$AaKRAk%6%1jcwjU{<5Hr*|)5qa)hW+ z!NDIVCmt`;c;9zbY$8b{cW$Xdh>tMTajLJ6?9H6DD!t<|G8-+fMb1XTsDfoK$1V#C zU$obl9GRs8#bXB6*H_z0(3{jaTh!y z_w>8eGdf6BNNa*dns-_FS!zIzUN-|5OmaUb9abdcQag0hU25*5v=Qu4MP|zv`AzJ9 zzwarn9E@3ahZb84YRcZ7DqW=`D|N}*7AY&EhBxjJ}UpB~DQ zrxUg$zjd)#nLVmsOc5&en(sPbm41)!I(2+rN`Rz{C*}%po6;@THqe`ql!v`b%sHF z(6Kh=;Tlw1JVr9$1ydx1=sORruS!%hw(-j_(*-9LQEyitTK+z9jsyf-QoPjfUKYMm zKF-P%f?H0`aA+Ciid?Y_TqyIM8x7_^d0v*~LA{ASd(s_kv}!dgeL1u}-^!}hAvk%EVM*rpAiWiMvDQ@)Joyfz+9^UH(X)78Sgj=wEWv~3Ws(jo+1)LawK zbAo=cymk`vqf_m_2LQ4oNUpLj0CSt3Qat@$@5jsJ$FuEN%r{I`6B69x01?=@T7o(m7dq zH}(>zk6?x{L6E*Ia3I?4>@U;m>GKj!0Z7S7RkJm*D%^X%Jp6q8BpJ$aShZV@_EzGn zb!i#+J7wRjvX)_yV1V1Rp`TBI>81^vEaFT~oaFP4DBT&Fm$VGp7DhTdE1MlO%X!@N z>aCw++GS}l7grI0#f8%X+2p?3c|Qkx`m?W~LJ_D0=n3oid?&@Lc?-;4vL zOSVe5f~X3q45Xh#@|-_Hy=}W~*#f9VRzd-M}5Z!1fJ5R&|qQM`Uqg zgyhkl_ZxKfc|3fc!H#JZ1uBJfHXM=HIQcB|=ykO&(9#q!yk%2U+lwDz)73iq-c!na znwfvQHgG+eCy!@{x()OkEQ6fi6_0zz$M=8hts_UY4V{+|bV5f7Njgd6{bw6QF9{%# zpf>;vP4lYZXtt&b<&!QGAsH$$#!clGdAe#}LE1J0I&XoFHwRpf92Vl6?b`@C3Xh-J5&+xj0_RlKbe8f@(pMagwb5ZATba%;rD zx}uP5(m6t!x-!_FZzm>o3khS{7$D&V@jc#_Q7_g<v-2xCG>GStH~2JbaCEFkT>YRv_G;yMJTq z=tBvVBx)W((G~$lyl=dBSp0r7T%H94s{LTP+yb}Zw@ZiZtzINsT=O{7ewLf<>JfPE^sQnSi$uJ)bb{)|hQ)02}H-y9#7O zWLxiLEJ@%ohCC#9`9B}Fm?!u>9lB-LLS0(lV}^}W-4irRgId5hYU%lS)M=W?%`%of zC=s<85Bw|BmuNwNz$q6hLFP`Is%7K$nmbI=0f+!9j&J}` zuNvcx9C9Z++=k}>aRv0LfWm_?14zIXUVB{ijGP~v@DTDTC5a@l;vfNb_IsQa9k8T< zY_bV1%-6Au438$gEIB&pgQd%nk_(JRb}iCn==VH%_8E#?*t(HsM$vWr<`kn5Vpaxa z?XwVnTAoktsR34vBrUtfO)j}zc*}758NFzWXJ!j>aeHWNcbqQJuaEjxskj*y1;N)I zEIk+Kb(LHY?7@~vvJ&OID!j)rhbxz-?*9PRMQ#*G1~|H#79_B-8UAm_&Pa#}&P>rs zErRpkxSv-1rt+(jc;t)=3>3S}pd5#g;Z@ zMGQiz#TUOotkSMoX6y#6PqXM#w|TZ$s8VtO@PZGkR;CzGM|nA|Y!b*@UiYeoKhFA} zmD)n4UY2{h^fqms9?L6<62QjboeAOHrhQ#OCmG5|3{POteV;)iu8pJK zp)GKvYFMbc@yDE$9egd?#>*MBLpcS3a&p4nt1g6;848EF$Fwv^%+nkU_gBRNlVQT;)`&|KXDJH>Ywo}V zvwB<#x%m8JzKa$M5fU46hSoOv*3ByX*W~}y^5GHLiIGA)hS@21W>wyIc|Uj6{{WNo zLki59J22|ewTKQKqbQ$)2{}Hlrj}URP(ZcYA?H^0ec#S_60L%`)EteEV86BF`x!Xc zx6#eLX?1W_FT-oxdvDX@rn3jfvmH)a1D6h4weCHSZSwZZ0XRBDT% zVzvwDrg&6=BNe`?M?RX{E-nwS`PA8qVphOdH=FXil=n%UatW3oGbjY#Txl$hEO8(? zHy7!BUyt%9ohr)(2OE%0mfmk`$N8CM5eE&a(+@)5N{4p00^|YBaCP=*_d$|HIH`|N z4@+Ji3f1r~UJxYPlu4o7}U=ZkkL=onbzuXF^VEL4kF ztKQ&V_U~1-pc8V+2pHPe>|&{!e27=5zgLqUlz=O-X1Fh7Z$ON{TjhOU?PN~DZJ|`Kj2oy{ z=`KFw%5Obq>5Rz0DzLM3TH_b@T`wQNk%3STpBAz2bXEJkT$JNkr#q%WLV$E*bBG#X z@$g~e^s0Dm8=zLX)mHb!j;}cv^`=IrL{&)*DaEWn z3Bsy2yRbh5~b z?Hg)T4uclnrqS;Hv&w#znAI6TZUrTl;Hj}2$Jt`%4U?O z5zOF)3v}Tci(+!!#p$9|QSAu?O0zKC@hUSeSb$kRF7F=m?!Pt*4Wht>G`f<8`Derc3C_5~-E@#6a+ zC4;SlBAjy2mLsFV)_$&}0~@OX2pU)usRP0?xR8))(5@mVdL;JIb#c|fa-WiC0qJ&)I^RFQf|iJ1?N`w(Y@7n zY^0SW1I9i^tZ5>1#KbTIq%ZjrL!a*GKwNBE@&S_DnXG?#wyh+=;Q6z2IU6zN`N_y?i;D7Yh|ukP!e*r-AXI+Jm0*!R4qbiMgZSw|4igR0v8#x)I{c^3AMCcTb4+HmlC z6j2j#h_E*pwynd(ODr=dRM%ZizLyUl>i+=C8QRLrt~!>|?6>yusci6UrS1W}h#kGm zAMznd;OYOB}uobcu^8iORFGmA*$k+sNHykw1?( zmpi0J1#C|dvVrs}0g_?81l&64ujvEn`2PStCF6+jfcLl~pZk)*VPdUtR^7HaaCrX!G8@itc!v$X zPLkH$w8$IY=Sz@#^(EUj-6UmXi1Lv(OK(VY{{TA|Ctkd)t0~;ub1I3W3(N$ki67gJ zUH0?E{&7YO87vPrnk)YR+^=v0@m-ArtA>U~MhZqbyjj5Bo)6>GGL2(QZJ=E6+%PMZ zhB&WraK@I>ERrt}DI*|YJp4(+^!mT6r82813~~9-FdT1;XubacA10FuERssLQKH+Q zn>NBfPu}<$(_11)jN3?MutY$7*$Z`^>o=z;kV-MPW<+*$eJ?BH=M9)VsO|;GI)xXq zsLL$)5#;46$fAu=S(tBJgNZy@9u?rK)gf#ZK;sAqu(-)=LcaY!?fF_%Xe5|`tjL1b zB(2H*U%c_F9pO?i01M<5Mph(ahZFTrpI2S8v!a^-k9*vlhzc_C;~^fgZvhQ$+gO_b z6s5f?m6Z^%4xyB2ZFV4K^tnylIid`_?O}Dha%-n^4~^aQk(sQn>c1#gH6OkVqOGdR-1u zxnu~zNYhEj(@DSMR)R%s^htNiCy#$F688|DGJn}#OB;D6!c zw|05sS5oP9QX4{i&5APqvv*jqG;pIj;MO@QAP_?`<y__J2eFrSYD8U@8wdMRvIcar*jXWo828z$Suv@)|UC5)CB z6nr;slQ;W3ka5S<$-Vl%%#|dGNM*-%x`pFg+c_Fp6F4N^{dAkEIN2^a6d`a6f<^S{ zX}ZITsa>h(g%!&Y4dXyh3NL>i{G1goEb3GMxG1<#M?`erdykJN=YG!{HqDUGDRyC| zqS2>&*l!(`Q{Zt=yJS#t*Ob3QljSPDQR)8w1`Ai+6ptrzDNaARrjP;I;AMUnyC$;i-OKN3whb>0~;R)5YrdD9_w+ezdI9GW#&0bXQhGM{sQ2=9Wi$8Vb`yM+~^12j7Ylb1fJtSXQ zyX8E7i(6Ld^nV2HEj(=$gqa~W3_!an>hkVq65LfAVh-9xTQdRYB$S>y4%6qXXC z=^)hK65h3WKX+4YaXd~Ec4EZv$e=E22h;06{vY|#Jn}}$BWzadnC;5>q|{>3R|W0E1RY^2xEw zBts(X&Orl`u)~w_)_#?2k%WL0QcQp|xgl?FLHwTY^}n}sh{u{-tlt@C2IFTwlrhROjx-D7U=%m`g9NjVz87Yp4$ z=Ppf^R`~UJ`2D81lw?t&Es3@&a?i_4%5Nx1wwdsV$sagyx%Y>0znMe7_|#jp0^r=S z96?v1(A~tGbPtrtgdL*p&(Tr_I9EjpiPMw}goSKIL^n!4j(Mlv{{YjuBQ@%L645L& z&nabHnXn>bzQAt{;Q8lm5hAXLdJ$4O;5Ylrhc7?CIm?FM;;F9Hgaf zF~NZ0WFs_%*}Sgz%6P^mkVNuF6ss(bcTvbVQCPPxhWm8?02z;}j&0{9*8KYR%<$)h z!;P6ER*;@B3cT6d{n~A>962mGJ)36azVe>~yTiIt6)$iJ&68ATn9-VaJ?p8*Ze3z5f6|YiQVZW3kI5PO95G*_^5n!FZ)s_8WP< z%G-^pBR;RQ>6D%a;J`Px*bx- zO**E@e0e{en;yZlKE#a7MCYX-B=k>pU)|TAB1Xc|@h}`NSZnD=g5zGs$A`sZz=DBC z7t4R|)^fKTe(O?h@?={+$q|qu==M^lQDMq=z2j3_9i7-S#<>evk>=u5{0nK7YpUtm1woRL5aDWt9uV*CNE}W3G zo|~fbUN7nI1&!`LOMIDEnefZsVl2EZ5faA4b4Z+U@cygDn#S98&mztc?HNY-^o7X` zI~B@1C)lLS>Wr%#g<9Egwl8xcNcel*+J0_=ZK5ry1Fgz~b^}8Qeva|_9nMd!{?_5M zm4U1uusaKLmfk-Nda;&yamwus(TD#4Xbe?@;1^7iJ0nNM9uUEYlHJAC#axgOrc`hd zHGQV{OM58YX3-{FsNwRzvYn@A06eK}+Y5rM9awQQ$0tHf-<4dOW7kkIBjEf1lc2xY%8k>+O$kppCYZf*H4s9+y57pEtV$E9PA2wsP` zUe3X=?6_qtph7Lb=Nw&RaYZ`u?-Bal*ZRq^LA=`IG6G_m7?7jQjH$zqkK60U+dk`e z6ipds<6MB60PEus+D$envgyz~wNCwl$1AE*vacgekc-_&(v?$66i4jV5v9n|k zP9*Q|yuYsfv)gD@Awds`%!>nd;%vHnuhmkQX4)Wl-O=z>7!o>8xwqEzCGoqgN%d~q z@Ulsfp(_J945sG7Je6|jdw%k2mE<{wOC*k=+zoMw1}n3p$<<9il-+yZ!MI3M4M@;2 z&c10y72Uu~a$@dNC{b{{S;{M@isp4?XdBKo1Za$sX;! z)q0TaX3q(W7Cnk=@X6eWVPsD&TEIqcD|Rc_`4+GIoW_>-4y^I zWu)jO6vI)Kwv0nuqTQG2^go$Ys2KyzU~Tw&)PWKWqkIVr7Yk`?6Q3P89Lo?7n*RVI z)=F)l+h~tM5Cg;k=@Psctz@ftj{`nL0>t1E4-33FaqoRyJ<+$ovxQu2tKKc;9V>E0 zVp{H|)-A(eIym_M06N6vVSq(a0XSTKPE_U?BFB*P5Iof}ne!1wLaK(z7pHJ|dUR-u zrN+GHw9_mW3n)e_XL58J+xfM@ak&?_m{Tl@Z_VeSW?W6du-DB^stgMYfns@&vq{<+ zV~O2?DyqzF=24b30!D}`H%}3eI+4=p(qQ}R~C z$#xg!(EOC3t&p~Zw*1<9HOj1V48V@d4SwBQC&*OaUUArYs%Ln_gr0D=^V1nXHzN1* zP$n%S1<|ZB^Xt*7fa*vJr$x|O+{fCl8{rv(?Kw!LH~E(z4-sp0(L3|`P31uD2Gr6cjBbG$L1Gn#<@bFLLzMNOcg5^S%0Qs%*$ehtx!tW@ zMrP_*UQ(fw0*rV|xy}7vUqh1W>62#CSx6{Gwon<0`*fX5-wL}FgaysKyh`ZF=&->Bf33Ncf0k+1_nbw=aL zkttb2M&N*KNVwtmm676zb6Y!{ioMnNx(hUXT{PuCdGZh}?4mf=TjXvNhTczoyU(TL zw?5j}1i>D|9?%zQm4cI~A#jtqyV^NEvp!Rk{{UY`9h+oVc8T}fIk*h>0tUB7XvYnv zcxCFk&ad=M@tIm97*sHp>0x}>och-CzE+2L=XkmCD=hN}}cipyobo4e~ zghtMclPECr$AGzJCyOMJysudwi#`4oibP&b)tg*N3(f6UisfWgBq__BHS)g)OLw0s zNs?Ch%xc2Q8u=dcwknI`{Y^#6+4i{Fb)sP`SI9I>105}Lr==!&a$Y`fIy6LVTwPhQApE{RcTPyeO3d%3R*g||{Je6m zc1Gsa?Gv^$WX!7834aQ2B2Zy?^nt1#~Bb|{(4US96b1JRk z>f`l1T&;RLMB7rvrB8fiVcp`*xF{oak=e2(7?aPhHaPI{{%(%`4{9SjCeRfCP3dA- zn++aFRd{}dI$4z&W0YT*mA%%)^r@p~g;+N{h*^UXrYz3g2H@cF^24;MuF;57ChapJ zSAzx;HhE>q(=NrDETTLx%!-KeGJ$b%rcd#Q%jwmtDF%45T#y2+fO4!-S$ZFJ{XH`2{zKq1FT~Ub`=iW$#U`hTf&ormuY>`CHS#YbRF-b?fCPu zEvT6s9aiLnZ5GQ?%N(n-l)~&5Fh~aX0U4WxJbu~6(pzk3h(%V=o+8;Ca(d3U4-tHl zl!8acP!~oYSLoAZ-bpZP+R@w|eH$O{RPq#ko=DN`j5tDqZUA0}R9hT~@SHMmMkwsV z&Fu6mH|4*g#Fa$6Y92i@2cH(>{+&8ql_qSqe1d^{5upRUv&!B-jEMSJ$Bd0&eq2X) z9-TBxv>;pLt9}#m)=h=SvX>aMVi&X<3_;LO@3qE#zfw3eeZ!vG@4D2Fjg90M!otXo>mbcEv<(E?XVc< zO^bH=0CzXJQ_mB0xI6SSq(C0lJ+?p5@^oL!#P}&tKqy#UMS`3A&#~johIKkVWCMO< zrZu@ek2qmqPj!jqI&R%xACAl*lDU(u?$+Vv_VMxhANN8MjGLzR;w*bI%XYuXqGn>c zp#XCNbv4l9Nm6xiRa>wzJr><($Uo};0Gf+iYMC67yC9XZ(6HSvM=nd%)Pzd{Kmb-{ zBLfgPsq=w*{VHriJaT0$bq3>7z~hBtZ1T%3@!jQ<10~dRVU8=G=PE#9c4HyR;1aj* z5zyiMFK|Z>^6oLPy~+B^Mr2FeK$;Zzni0t z8CJ&3dAJRfc=7Vd|JCuG2+_1<84+ws+xzqpuPM7wOS>({Fl^5c8r5=`-Q{A~H!_85 z-E!~n$@|apOkyJ`jQ5#HlHFNemsb37({Be{oq~yENd>IMSb>fta7QDrl&pdqkHs$3>v3Q} zxEF0n7XJXQwLI;bUpmL<=J_qr=+mH(2<1mNBcpz84@##8Kre7DeL1hv@!+1j(~h{C zs65!7x5tC4%oB=ysH~|kVqQDgFGTlWzy2i-?n8CSZbu7gdZ`e$&Osn13(wmi8Dzhg z%J{dQVYy<&3k@y_9)|>(dTm+8oW4NeS|tOV+!0_sr;WI1;y@5QppbhxvOV67A)TMi)cJ?Z$U0HEpszUSI z<>*OJI+6&|-M(vHwmD@>n2s`Y4r6iAXmEd7RNISiz&8O8AJXZ5)uePY-^Q6xCXuFL?*Ly`z!!17aW-B&&a<-A%>k`lW#Y`j~L2m_)Jmd0jYoAmiu{{T1g z9IF_)AY3i12Des8G!W0Y2<}GuL%N6PI<;nG)Ji^<~h+I0zp5-FPMl##%THL|x5v!gl36TekpYHI24+Dn7U{^)V%(>_@TbwK<7i|I zjmDP;vi$?>{R#l`DHkT{ZVzqi+C|xRnHd}?jsc>nRSvxreeC#s^pk(%9@P%XBks0k zz=LFB!c5C|7Y`@*pX6<54}lp1E$H0e)A}9y6$;*OoLlmjCE%e%XxPOvv}_5*`w!W! zc&-5`P5a|Z{hWM$)d|rZc)t(JOI1;&$@xA*OyRIlPnz8we%FuSA;Ozmi(p6G=U+MxON;`Saaw)6tLuK zpb>tLX1ywGaz^JYSqbxk;g8+pPLYa|ro~ul2*Xa6T^h8Lvtp*&T<8sPC4ua@7r)3W ziUvSeRb3Bh*V3+FFK~4SwDwROq=m~vt^Q4GGePHD9Tbfjd&@z7A3n7GSYT}4Ve$?t zWlYw^h$8&{YxwczD-{@>N{gxH)7hnx4XjANHDEH5Yy$zP=g0;6UrL!#yx{52+BC5g zaJwy)jjmLd8k?!fYI80lTkD_)&+JjQD{EUJ&r(zHmeui@Rl@G|$f3|*qrCgX@#aeg zTVlCJTK@no{?DpTDakpV*2jn==iz{iO9RNM5 zdd2$6JV20Xpa);WhXrMTU@mf91^d_Nt_@ben%L#qx)IpROug7g%x!eBvbVGJ=_9pD zG-sUKixKY;s~9>e5-dTuAlz%-Txml~jLw(szMTA7uhXkGU8Q4eT$2lY;0E>BFHgtW zm7N)jsK!8}6(kUSJ}>J05sd@bsQ?S)51Yr2FrM6~OIblUkP6#U$CQ(g>Co*@ni!=x zhFg%jn=dKS^_-uM(1vD~8?dt>3dbRKHtDz1tTCfFB`((NJUaKi{<97q~vksJ=*$4~q^+s3BtUf*(??2EQL0}2L z$!>d|;;(+K4TMJ0WbFA|mVJ9oJW5z(2Nmp1vN!H=c>cj@0%C&dU3CWDeuGh?M4cRo zweBrr-F^=e{{RsaA|~|Yav%#IXUV;Ag$j(0A)AWE$DbUs`&8G7R#L6R>TLEL3nL8o zPm90mxj*Pgbu7&;>}Kk+Y;P%9X;#xKdGy0NcA?hgSXNAmZ8N!AaU$h>}(iBIuGe#hnSd#4+P!+T2%PV?wNm+C#aa_Xl1EX8L(W%{7 z5W}2m5A^DYpCPS+ts9U zIb_YrT|qssnBnU^YCn%ZrLxPj&n#0h5$#S|IaHn&vhNdZjeMN5BwLyAskXhQO`*#< zM%PsW!yn1%$LoBjDfkhy#7ypNOG2B=0e%%L-RZ}={{VRPyQsevXB>_4^)~3%Skq_< zceXLGEs#a)pL^}4czKa=ds|EVzE8)2>o(IIdot|ltbj9-&toZOZT|q}{&KvS;|}3k zP2Xa(4h+}*-+tL9>NkJ*pCjdYK81%7WVSuBW!Rr%2MD*BDA!OraVA~O=P6Q_IGZqF zZW%4YJ-}FVXo$JOLxw0^Hyh#2swMrS>zUgtIOR z2X2=GrZ+Fsp&5vZDHb}0xA$ve4Hc(5xddkBHatXImEc}Juggc&Ny=5~uW-W~NSU(* zRW>D>-7X|e*M6_rAbqaeEM-zy3EeYltqbDYg>IhxCFLUCZ#g~XFKZ1J)YC;BtVQ9F zZ!8GLX!e|!{vkO|N=@G>7UpWJW!W})_KntOWkYii7HKpYZnv58bHmE+z2E4~t!So3 z#p5YpNCGw-N`oR^EJ5;~ zjxYZJF}vjb?R@-z#O#e^Ww6G=>Pb!vfA&!oGvMu+rrOFNAi}%EBcj;{j|pGM=hj6h zfmgw@PqFP=sgSD3Fl$+pEAUOeUgsYvQ0&5ZbChy$&BPf=<5>9=-OnE*)^B+pv)68; zu5Gap0OKW)l^_r_99j7Lx_*OCshSApjHBGC<~`(L-%9!Yw`FMOD<#_zk$c~Q-8j>r zo<9ez@K)F+gvTRgEN2!znt3Vm@A7?bi`BGjv28J!PVS8x0DMj}AZa95m;iKZ`VqRwj&C1~zur z3^$SYCX4me6iQgLve}KnH@;UJ$vR%HI^Cu+i4l($7-F2j=fx{GAF7YXoU_5%vCk^( zs)0!{sK|lDN0Y@pk9kRsY1v%jHEX(jtT=IY>C@nBrIAgtU1VEZGj~jSbm+k7sxfA< zwX|z=aN%yRDlJ+~vPHDH#jJqzl|}eZ*5|vSsv_dI@-PlS71PMR~OOoRqh>8}nxx%_#LXVB63Z)HqxkWsE(-1ncYt$56V zZn}YSpaR|O-|5#z6u?I@1dv9puX{c^CNb{@S8-*;kDYjVKUr5hUXmEFMz;xTm0Q}< zCmeYv?)+Ylgw|I`c9_w3Ga{Ce5Gy>SZvOz=B>I2u^{gVkw$|eP6Mt_%VwEn@D?+;| z(hZu~Vr<7O_irS9{5PN0eI3fhgsKbyCg7V~59ZTk+8+eH*&@+n({?Y>eai5CbrS58 z$d`>?Ra3=^k^ngR9#h=M%Xv-K<992zC<)>*x#t6ue0>M#XZp|dcvEeIXpu2xB~9}^ z*2w)IcjM3W8z;iiG6mXg8v&#D8~c^vuE`=q+e0SpO8{0}l-!KsNaJOXc$j$fTd>2m z%_DBrNg$SC=zERtC&yU2YHgdMS0RasZY+Skk8_8tjxSXuA+lBgiz{dYaNLeAG+!g` zs`z7Q3EEaczfw_SLwnnYcQmUXDHeX~@BGSHUCXqH8RQ2RVPdcAAy0Su&*x)p=`;ke zaz^Fhj1;jqeII%H&F3dk?D1J+II{wPdAEAWv_2o2-b?j$d#&^mTNA^j?QZkp-TIU~ zkf;Ijvl4@%-%^pc1MhS7l)$ynO&ONW8#Vd0AYD=|<~<1=DsAV*q)GDP)G$7ZbYPY> z;H`<9cOn~;aO9E{Gvgd{aHm_G=}X!X9u;hIO>#HL<~Xr$Zizipzj5e6@Z4>RI3`kR-^aE)Ots!#i=Y)_ISIPJKe0+5e!V{gNVaHXUCX6Z9i5fho;reO! zmF9`|gR|qjI1^^SXOv9xPbuW8-J5wf(F~wyCU=o-(YnKf<{}cN zlg1ZzB|EuL8}7n|XyD8L0Nyvg_wr--vTVj!vtlE=3zK1x5o?jz5PClvBX-|8YmiL?5g31kH(^75(f%Hi|fp`JTLV8 zejU!O_QmPPnY0yOi^GQOfy!K3>^TkO*m*qTm&Y}mJ26}`xF?&XmOe@I@;FyM5~+Db z-Ah`;>94GBPYvgeBZA1l4K6!J$(N2#{JM9`80@l`z^M+S78g7T1W25!fhXTi_S9)a%W{ zcbs35T9za)Z$H^TW}4B9$MQCVv6w4+HHy?t5;U^TU4dgE)0}XH-KW#Su24db!ERxv z5T!x}{!qO6Ue?g!vE=LYDv8t-2ITOu=jPK7=}~D040d@NUe^*51Jg?2=PpFdH5{+xKgIlSEX?Y!k$dgUF3Juq%EpZX7+j-J&Lv8A5`xGLWT?nGwM2 zyQg<+PoLR#a?x%k7|RJL3foyOuKDowICXzpUec&u!VaCd-yfq=>~R(NTWKI&OmUSz zFx|89c}^PiOm{cuJl{W$A1j^YVk91(NitCaJhm7V0*95D@u)L3^Y8Q*@vZb48& zgJHr|YUhh)Mv#juNN?Irx?LZqNy;UWF;!=BMeSp3T3+|7-B#rm5VA?VV7ijamp>A6 zznkOn=i5D;fK&@y4K(RfnZU#iPdj77b8m0Q@k34^nMt$9B9KMyzL!}~=m^?ll##+Q z3R?FVo#yUL8}0b>iYZk|2;f$;Yoh1LZ!k{!RSSf#CZS2cWyP-C5&F8E%Y-L8f}dbylI={azg7WGe$snC!XB-GN-XlX&sWp;n+dfw=DjjkC}7rO6vj0-K5NL z^Qqo(@jU}g6Y;8(s~2V#&fP41IxN2naY8D(+_yP5$6KFS#=XkJ60CsT&NgMXmSNHH z;HDj+64t*Sr_4h6MZ*Ju|Cxy(d2eRYo)@+?Q$_KCNgtS z1}&DqGjL`90B)3JVG%48Mxy2WIFG0HX^5Lc6$CI@6quE43rf6c(&afcBO>u z6Bb<+taNe@cgHDM`$Q0JM)0=%KnOwAZ~Tbs2FwJ?xd5k#SYM)k?kJ+COrhj5yEKwp z8$_3|zFq$SAH$}k;f>NFb)FQ(NRb$2mvvJ$%QW%)9ke5?jUbHVVj!NwZm;eB@|C%d z2P>V63)zyzy)I3#@%WxgjoEN)%&Wxu1KfGHD(~a*78!VKu}VTHVT%CjOJ5+X(9fs5 zsVrCF1ju(*#Dm`H@nhCiiy;!DD2OG4j8O7gVV-jJAJ%^~rJraZ?aNt zd>&VM`o4|Iu2zUf;5c%_6}n~@YnsIR7SJ9Kv*Z0L zfd~fZTs87@Z32#l8D~YugD)l|kw*ElmA!x$HowqbpSr3aZd+ZUXLEb;ZpGAhFy*=) zJkd1oecGYOUBi)MlJNI8pUD3JMb9L-7HeHd&lOdq!7$6#FjqW{W z@5k_RNgCo6F|oMRbdL(Z7q?jLYi~vNX(L-=F#^};=BXMrSrFX1gLc0+H=$mco;Hjt zz8SGV))kkz6D@Ly9pj$EC z%w%nThxm2s$Fic3$5$7*8Uuf0iic{naTygsC~hypFVM=K+aS43rG1=*Ef~lhSZJ&- zL?D(>s@X>w79bo-eOq;z*2l6)<(_kOlmbu_bZ@L*AI8j7ExGELiizpx=7zEMBtoM%$hAw|g?N=HY@=S96o1}a1G^BZaOsZ5O=mURU48(aD`G9qHK4G%Slu+!52029I2re3bVY_bMs z;ssC=E$@52Qj@eO5@Hx!##u;WxUY94k$i6JwpDhA0;+lYcz>Nj!E?x~Gh@O@&e(ZA zJ?D=nSd*JwFNAI^2<;@}ZqZi444VPe(JJ&tiyv&b4k+Kjou8sEQ{LYWgftJ|Q$ z*S(jTs>CTukV64u*hYg|6$+;wYk9fo8I%nUzK?tPv`9$TOZrCK{o2r78C)^uBn?Hh zZ&@CB(&Cy_bY@|@y~EpzGCZt#dLBMRnMhXbVa?6w&yvUK)e&t9Azw{Tcn23gj}2c^ zCQwG~=m7u$%%|4zc-{VlkgDWaL#q*Cs(MPetZOQq zk44oKf<3x+&W}b{5-I`;p)1+>6oz!N7>HHqt*)x9dTi^;Cm)oqib5qx4U`b0D-ygz zDj4A2&nFh_rOCJfZPLzMas5v%{B!QL==kAi!y-rtlZcQGETcd6jhas4Zu;tyRXHiW z@gSiCLNEHdx8m+^X$!%8RK`@gk<*KHGi>vGS2m&AB9ay%!5FogM@EchLbl=LTzgHF$11Mw-~tpft~^|DbSgs-k}?hKHC7$g z_ntq^X}^w8WKw~}gIL=C0QJz9?K?86WDpn)m^dZXeY`qx)CE9dT`z_4+04Z+s_gGb}|78ei|z~->9*R=G$#V+I% zZp+{B)8t|j3ba(#*_e1PNjZNmW>TKa|`*&AA^R3Gg)pl0^cpYx-rqvTj%-c zSxzSBvrB8+PzcMnh~HcD(~uG@4T!lOf6A%y1orEw{8T{cVSAo>QkH9>H@giO{{TLa zj4~N{tAW(tq-}e4-hp&~nxt++6(H-{zco87h6{4m&}*tUTAn}FV0b|7zQ2x`B+vpf zwwih$^i-kBk;$BGi7CXl=RVAShVqcBBw!TONuiAD8s439o>5rpLBHbuu8Sh=a5Wch z-8IiDwt}S`hv@cMznHC9*2S_GE9M%hG|56S)0l-k+{WB=4U@>Rr#-Ex=N^X@F)7&R zF}ookUqI)WU!f*F&E9iG&a`FULrBNu;)BBBHCA@6sasFE>=OKDDQh~?mT?u zYEvT#K5`R^#zo1#K9?0X7>STdY(%2*->_d9CVxUSteQrW%5k=rQDsyX$E)&^c=^`v zrV7NbpVIW(gU9gCD`46VM5F}EX1;dBH4ZYC1+J|bxhtu+qY_g^f(92XZENT}d%ppE zO{M-#2lqUG1OP&$lHk}Ig%)B-^!|lVp%jN?P*aIi+}qzbr~Oq37a}pK8Xax+o6zz2 zbb@X<{%-fxWye;R1pwGxx$mI8u`Sj|?bb2Z0CerW{{WMfIIba>juq9s*S}PhvKc|i ztXoY**k8Zn{TUDi+U!9F+~-?*K0F*LuMsx$Q89$AgEle%VC2e!5Ym00B8lb9$EqmMP3+s>b=o;4;o032zUc%=b_>pSR zBSr6N^7EThBax$Pp}1qjImx#kHO>Gyf(|Clf0wJpnuMZgrjGI~(Z4>)IHB(pq7S#Gx_a+ zRpLg==u(QYY}ITpU_L|N;;I@Dz<{=}09x8#$B&a&9pTRG2p=uf99xU+(4v(;vi@(g zM3KoOfQruScvVH;F6}J6U4)t9cxE?`XF24@m3Y)NlI$_E^1H3(=6`ql3uH;9o;BIi z79>cdX^tm1W|El~!84(XSm#LG10XIMITv+|dq;O*o>`;yJKs|uRZ}K5R1%bV900vPVfKkUp68F1rAX#L zqFGsh05P^iFH4ilcglU|`7$FMX=t1leteG=yMc8Rln_V<+QYw4nG;)16jbygl~`4wZ`xbK_Wh#vQIpF3pO2OAgSPs3n<>2Pz!PvqhkJ`Wzn9~-wvihU%1aF_4xeI&Ye;rQ0t1UV8J=^E zy`}QCEEB4;u{;BZD_n+P-i1DSUN~QJY!=G4vC_lVcYSA*q(_4i=O<*kEya-m$+3-i zdd{wkL3^;&Xf*z9YS@8|?r&l^tPdpKnk59-)-w<|fJ)?;j>0niUyqefBug1v<)V|vc!tTD&}92;TwWd$sRU6wv0z`FE- z^t=2IU9%v0BzIkRoM+F1tZKh+Kh6GxxD5ye>{YFi*BsAWyS3@Z)yGu87Ce?Pabg9# zbm`3b?)-lzY@4#}6YO%YGe;_k0^hBTx~W_KUuo`m;r7;jra>I=S)p}qO~stJ;(;1v z{{VO6Ts)`M*OC`zXf|nfh_{iD8*p@=1~}B%f2se>2zE_9(VuS{WUY#)`x; zdRx@RKUv{x;CoVoc@i{{;=;{?*|%*^c;k{5S(G$PIF*Ig!_vx>k*dhj3*`2=Cx?fl zf_Xa4u4j%0WydWPWzp38K0k(QQ6|$Y5GVUQOB=|5yOj#amclr`u_CV?S1h+iwZ}j6 ze0Vl5#kP&6+f;-r9XY^vKDJDEQG2cD{>1HD6Gb-dx;b10)s9>dw$Xz4{%$K#17UmI z*}6T!>2z18)2waY;bINm9h$p(DwIs90mK$_s6Eym!g{PcdAK##g2B2v<7=LdE0mkM z*8r7C3C%!m)g$KVV_C`F4azjgc zNoJ5S5h=>n410f}6m;#s;d1;58B|~ZAPk@DAGx2C@oyyUuE^O2jQRV$?)E;1yRI_| z>vCS9jybqfej?`iceiU2dVW4uheaTQb?196o4d+1mE=Jq1AAuLTo$#|boyRDgP6F( zA#ld(y!Nr_V88l3?^$Z%Hn=fs<-=jkt@8eVkHkC1L?nPMaJY+`oOt-EO`=tfP?usy z4aKnycTTaS_j&l$@hi-Wim`UH0d-SlJY8dZHEV)-C6?=`7a1Rr(>^ zv7ynIOZhjKCjDQFbjNUtA{hc2?0AM&^y(swmCf+DV!X?B?xynu?{d^&v&P7*3Fir- z>iyVws^2_~al@6OERf1sjk>kyu}+TjFQ;CH5hI3Zz~R-e#nMXiHdHcCvs*02ML-!F zc}XPhf8F(c?{*EE86+}F;1ydS>9EA-yv1qDQ^K3RLl6Qk#n#-|n+}6kWHTCFPe>kn zcU$w7M6yJ$g5M)+1}ZYSNkOi-PE+0;cU~HvNHT60XdM}Nc=S4=! z9`X)i_P;K_0!-kQFPP+r)|8| zYME0i6*wRU_u}WI$+OCVnY^Pah3C3-<>RkPGODCoQ(SCcm;km_E!*@dU!%x)NCr%D z2blb-eO)#P@|GdpT(ZACn0IRVHbxC~U4IznP{kTbkmD z_Hcxg&DcEp(IUnP%TS`@w$wAq6&X|)a4ZeE^HBamuQU$ zu$VG}eCfxB+unYvhitmEMGm3BI(89*)V~QAi)LM>#F!mj1;X(Yx#>4^cO_S}BS;XF z(K!q0#b)$rzB~_QCe8%yIU1AzeKFJe6z6gX;tB`*w{D)XZv zIFo20M8?J>$dyoW+4npdyu6G&xAcSH-A}jIpyK zs7T2@}Uz`gPl(-1E%y4W))84cgnAhg_nCGW0U0A(BFxaYoGG2FEQh-!?x^ zf#r?1wP30vj45mCTcwUK>C+Z*6p|2xs}OVxzz$s;r0XW}AD>^8XMa@rJ<7H71Od)P z+lkZu?|`3+l$Z{1q}t;~1T*IJ<2c2EBwvL=DlDzgl2`uab+0(q#@9B~aI3F-zgnAv zWvDy0V7&ogpqy;nO1QNQZFf)?p5T+jNF6`RrT}+`7CfM~7QYDTNUimhuBt%Tl*rhU z+~Dx{mOa<+D|1ZNL9<3B-2zDrM!Z3nyZnzSO%XtdY}{x{f_%XDDm|rwZEJ}+*=|8O zb7F~#cj~2l*r;r-EvWCWDcr4Dc%oRYD4`61N^)VAg4{oEgQrf~5_B$O88ZPQ1FJBP z{o|DGXO$)?;%LND4m7)~42AE($!4znribim1)gR23FU^96=XOG3?EUYJJ)L0NMR1SW#K0dG2@2J$yCoBM4 zgT@aqJuG*f54)n^W=n&`jhO>0mH~*mi+Z&3bPmcMH&9KuR`xkE^cfTFe^)@U%FNAe zZWxU%Vlb1@@A3H5B88dRrX(t{86}GmY`GlpPVEv%w#}W}c=yn#3PgRU$}iK-sBG|6 z%O-P**jVyoDC2bWjzJl3TDNhCWMW+2xygSeryp?E&^;%OrN`a+<$= zT4jd~?~T(aODD_RF}k zM;s1&sF}z;4q9f5_h0o$?f|@m+T}&@H#|GV^3`wXb?qGJ&hZok2NP%!-eumIX@qZ6KUCVyj#QFAmRx0+kxMc{bfI!i^j#*~xJ(OpE+ z7)6k181URunL|X#P06-^>%x66ZmZe~mA(XXixZF^c-SQ(H7JiOCDbo4u;W3z;rDO) z+SKqWV8S+ZVhVssZryR+VdMO$iK59$hAIiYtPSn$cPc1e@n-n3(BGe*+46M6n2*19fuOMA{J)`B>={hqBsnT8JdB{%BE-IcK;nXU)qP(`$D zPl~HKj6JbdDZ_lE&vpG}V&k~qcAruL^K|4<2oYYTVzI6Kts1 z7B}8NZ-ONL>bzKk7Fg0%4oK0CClklNj#+x?e5U7ly^r2@{UmL(+Yrdi;08Q7Nn0>U3bNua8eiths;tY&W2n4=N zIvMplwCgF=FK{-Os5afW$03Ub09fx(i;trpl=`_UD83nzMP-H~C{cajx8-ivKORrb zEA5*bJKYXexWtZ5PW*Cl{{XdbJ!k6b)elvE8HU*6+1leP3$LZg@KPV)K`Q{!qp9W{ ztM>ezdGYOvfFmi~w@NY)$LV#B&}|8u`n~@E)%qc{Z9C#PLUsUx+&^oq```4mU;o$e zZze5^1CI+Ub1-jVX>TKj<0Tvcs6UZTNzzUNi;HxK1)l!TL-n0qaS$mtjhrpTiMwQa zrf*`RRC21em&wk6976%s=Z|u)iP@2kEf~4ai z%* z8*9v3%hU1xvu#0d%v68@Tj_f@rRy=Rxzlz>1|o!uCuOzN=S~WPY5=Gi1~;}kW8Fll zwq$h-bOOTTQs2p|N;7Q`3kACXShouA2044m#o@WiX(LroWJ2w(gQUM*>u4^6qB+3= z`jPZ<=vS)SJ_h+9WxhjJV0(zYZ%(B;?m?EZ?jkf zsmuZZ=CBoXyCViH&51x)m|yu1C#OLiML9PkFfXCey$(~358YO{XG!;C9R;$~p3Z)a zVI&gVzzu%|XppV!4>hqm>OZH2acHFGa4a;v_4X@3a>!hR^B#Yb$Bt|z0@&Xp&>--S zSh9UTSxdE?_T$ZK75fiOSlU!1fVlQ+cKn?4qIdunY+Un#HTG$@IcA57{hof}(w;ZC z)E#Z%`V=~ZI$WvL{MmH93Sw4BWe7kvB(mFHE+ml~>pXdxEcqiTCgiUVK16v;{W_O= zV_}{6g^y_}c$3G1CS5Ev#OOT7?NLhSm6ewgG5`#P{M{J6iY=dLVJ{Yp_qz=bH%01o z5tJ8T3Ds;tIQlsH&mTLWi9=?l)*Q!<|B8jZUf;k5rc(uk&-Wl$t`p!%E zIYEtL5FK|{R=-K?%VXXCUsF$J3hYFpNfeb*>%>2~C)Ig6Q=>-L)UeT&z*A@GeP^tX z#&nn#Ns=URp@3;+7Rem&tdVT;HutOg$o#L)ZLK3%&|TygAamZ#*Yp~_93oi4=J9A5 z)eA9L&MKUYtkLhkS5e9FT*!=*h*ENo=o?2e)_^HvGROS8S~T z-K>ELN8z%zY|2+AnV0?EKPA1J^FCQ^@V&fw_E&)cWGshE+TaHEe4LaFUz{9!wG z{VMmi-K%I>G{s{hB>)E!Ysx&5KJDEm{{V;G@?NWX-SVVO<=p^`-uT<3SHFkq`8p)3 z-dH3oid2@=0LM-)IP&78E!4(Z+MP!>JXrq#E}n~ZBGw_UM>7IC%#)7_nSf1CIypa# zq`exNJ`>5d^^hcFs4g;GkF>u|qMWX)iAGL~dREJ&^|e`50z@OqVhF!PsxokNDrX!; zB)eFP`OkOvemwM43j@0Z?dK$(%DO0691%;Rf?mgoK*fJAIVcqbdrVq)Y{1xRybV`u z#SpuLmTWF$2z4ijIp|B#=x5{PcUHN1bBs7~BKqH)oqa!IfdOoR7QTb< z`&CT4JZp8kSi2|yphgU#!G82diVA@KOdC*F|R1KP(vxO?+@uws#{)b zn^^wfRD8Cf4#0pp?ffWYxlnr&=kOZfep=9Nf0y_gqH^_jBO{iBpProy zGRgovQgDKFwf_KjN*7Kw1R>OcuhaSuMIkIh6#$caTT6WoE(+Gu(O9AmLD2=lVQysa zBT2`{`pP)sQ+u_$#Ey{5{oQ>JABspMEnv)f_)k9%OAPAu;Q%s=1teQsZ{NwkE`tIj zi5PIJ05!F=0DT_5r4w){#FZD~9XmHkG3xpha<~|&82o@6u=e^py!?AYNsVK~29{HH zKnS=UTzGrOrXfaTB!a{$YUG=S3)AEF-}iJ8OD1gG_PYQ8TYFs}LYpELNYh4Vum(eA zBdp7hb;0kvEs3R&aGWV9vSgs;)d=Pp3&6 z%oudB;?!J#2l6XXoesqmW3HNcto&Ujqm7vz4YY0+1EZ4n0_ z&P|4(Ykhh=OVjc5)GU!17QiVY-~ca|rCdo^(_EcZ!woIOyq|ZWMG;AtAPv-Uo8MF4 z{{Xkv{{SA`yReQVFOWDW;&Y?tyG83g--`_@+4i_oWuIl0S<$h0${b3_im{_JFR+PC zi{SiF?OSkdoF`JC45X39^Ly_h*}C1|hLxe31u;ye7!D;iy{x>a52x%^Gwm}?96QE4 zwOzp;OMEfK_*MAlf=?;B>iIe;k2%z?%n%(f$B!^4xSlymAX9d{!@9}`F%?eH2odDm z=U2B5J=9!2?xb7a7R7ZMw+iv_td7mlN@d_*Edw%7n%BQzwNvnPW__+boq3w!!6#jv zgM4Kh>muHL=I)+82imub8$2=0>v0o59k{((Ngwk60Lu6j{1FQ*r5kFOqEuzOPKO+? zxt4A}BjHPz6ij)B?b7V~#rQp>A|_4Q9V{gO0EYP9(=K~Noex>{zRaHpd>gpgpx#Vb zu0T~@9LUMHZ|dr2ZQ(36Bc|=vxE(1So%eWpzEAx9?V075;I}(MxHym)b*`XtQ^y`0 zjGtRr^|t zC4E1^I~3bIP9@Uj>bib)@%5GA{2jH=1cb%ak_?cr?zj9zQb(TVPv3sur-o#a1EXY( z(P7MJ%Fzzf<{h;$**0zXoTf9W%Xp-Mf4|B2^HLAuHH00W2iTKj+4hD8WhK;D5X4w}K1!i~g5_suVolUs%OkLT z51|@zpS!Hs)UO_BmQ{Om-pUwrlf^5=k0sOb=ZSVXg9R*ZHC%c0_Mf5S@cky=9?2Wy z0*vj(RgR75tWC%#?>*n2$a$q)U6Yg7hQ{OqtJCk+rM@FJPbt@DluAjmr~m=dYXw-< zb8n+fWOC}vZn*`=(LF8qkM%6U&1CNLcTLCmlrLmeWspc1UxXEO(;d#!3un6beQy4Z z)~1b28W77Ri(1|0BKdc?I?d&7)Nyp-B1>Xj$Z;Zdy${`O9!{hnV8KBIEzQUPE~Nz$^AN?X=RKaD?gLH zfF-~3@!IH=Dh!#NqVZ*5Sg#)S5j>BR@_*aEOg%q?ph>n#Hc3t{kGxh>sAuA4U#gRN zAB`yZHa(|j+L{>wRAeL$cw_)WeOy&&-|p{eB+rdG(YBXFE7CwX$%otOp6t-8MT!PxYVan_AO@ zvqmxjbO3?Uw|^--e`1>wEEp*xLOyzVMBQ9Evvv>Nz;u+EBW~9AP^VC900|Y>{K_>#=SWrv|GfpAt2iKAU76H z**BD^HjqS$_Ry(U9Qj)u6{1z@^YWESuLMJArCA#njR;_^sC##1Ut{{e$HmE!?DyUh zas!Uca_c^uQ!1BZ+(6KB@Efa#r^l z(j)>v&{{V7~$aD%wvDK_G zPrP4atoL<_(HlO}1YkpT60pkG=#H@`uUaV!V_}W2!Si%e(DC@bc?w57u##|Q=JRAX zi@V0XJ)TU~H99i@t~tFvSL=B?>0>haSsaSdBM=Hk=hNoe_38a z&Npa4=h1Co6yG-8!PBO;(J4?A#ueM30)W>5eCDxYQ#jDLZFq%(VbN9m zIl%m!s`3);5k{NMpeh%zX)@C-*Hr%ibsbcahk!^4aCKJm--$n+K`aRBLs(rI3H&`O z^7ckZeh@GM;w_~CRXrTWTi3;xljPy@f9S}ro}x%3$*sD>d;YQ-we(z)$6(@M4SzavaRmrwyde>G+nCf4(SKc!emRJHlLPPDdSbwhRRx1ZXr zT~AES{GQJsvE~P|rj!C%lwZsTK3y8A8C42GvuY1;uH%zoACO%K8LOR z?H1b9nro_sP^QM#B=h~J^y_0s3y`}2LFcv98tfgMe1iA>OH=o%pe?z^qvWc2R-^b* zHkV_S8p@-HkVg;$k36!%jTkHlQp}+SC$!jb<6X5!9I;HJUE(t2x)LqJ((;uy%Of3v z6QC@BVzwiPlj;46t>g7kul_ilF4#Z%*MS5I$ePw49uOw(r zQm91@uMrr2M_DJ;?<=5f)ELc^NY`f(;6j(;WZ&;Rm;Jg>+ENhst9)WXj5jM~Svf7@ z`G1Uck@b77R=UP0OQ4PP0LaS1=EzZTSmDF^+C8XFgvP+6*(nLqEV4X)QGAyTDV2g^ zss^9{6!W%Aj#GIqA04Ju&PqXRfncl#ea+TfqxHO$9HniR$R>Fl#f)tj&CG_u23xgG zA1x$_ChNt(t2Eo$2*{3esq}rmKOZw$*eAupjv^(RLwgqsx{4Tlrsf}A{{R)VZ>8j; z=UT1fHM5tstj2YFfjpNMkhvgZZ5ab`i~_^I3pJ-LpKpNzPIC+ zm$IWXHf%-$#9G9Y$HnUDvlT^=p8YIGpNK9WPsb@vx#4={=gNwH5p}je*_45N{k|Hj zc&fpaSc{OPdB2}i#^rX1@!>nW3yo|#pXX8UB`!0$^slJ?YFAH&6y+Nb&zCtM-?Fmo|Dw>APf z&CuKofHo(g49Q!xKqQh-j4&vYK?eZxL4W=z#c5P zysc*Au47yEUV0MHsvhOk5;`FdL+$wa7cwc97aEpUQoWgZJgN2vX_W!jAUVA(q<8$f z8Px;Cqd*Bdopki8a`Mfb1n{~zvB?Vgx+l9ptPJ+UQK0!d-aKB}_6gQBX=AYGO;?Kj zHre2f_tdrcIoXV{1B&Eiz1hL;en*tuZ$Gh_Pi)zJ1{VQ|YuSS>mxQ-PxUu%yi#1b* z7iAX&8*vUD-w&|#?Vn@Xina-zu5P7oh1w^5-(S{8LcayfTPAQWc$6!Wsz*hzUO7G7 zRBs*mI=Xq=;iZTGv1Y$Fvn5coX>i#_tO-yEzMhw>pqFEk2Hnc4qyk0QU59bEdH$+T z{6aVz0L6lWKsV@%`+OXAt~Ado(AgD9%I3gd6VKK)DVg2(nh}#+XfY|l$=61p;cFJT z4A#+@7G9QJ`k~=17HJb;c_v^3E@!6EK3!!SjT)0|!B;4?2NfE~7TpHrvwCtu#491i z3poUHt=AvU{{UIK{3+R1{S?g<0_BKABfZNrarEk@f+)?_WoXE-lrct+FC{c<$;qFU z-hI_`x0dH9X4#e&jzZhU60;W9AQ&$t7H!Al@Q76#+yTTDD(1>`<77CGSn^LR zH>XRPV_05I1H`~8o9bJlo3E!&WyclW;;0n{L@BksnwM&cx2uJel0a5GOKv^pem}66 zlYAAiO6syIv0IU594&uG>iE2N+97ufnJg3%%r#}phMTg(u)>eLa!EHuLZ^$N#h!jY zLzR*?&g@rFj*A#Rvw7t_eiJOY;bS;UDzS`^036pnANuMosoL6Ia~QTDd9irD-FkC8 zMX?BeLI#dUq`sezoh`GCrIrLzjLl^>?|aOAL***Q{H&wSi3TJMZ>cAn?4@WnVAlt> z*S|N>tG11VY|`N>o7fuyEPcJ}a8hwqegqnGi;MpNBDdVnYd$XfrI(qVd9BW#P2=M4 z(MNlg1gKC+><98`s!f0+q7H=k@%lD73rBFaAlldABcoW|YDi|{HdMQlEzUrwA$#%3 zD>tW~j?P%fJW?q<%vk>bE|;`NG}}7J4uMAbt`mR-k3HkMZ#-`PRNH1O5+nevim}T^ zUgt(iBTr&IqhkQ9fu_B;aLmG0&!fD_)BT#Kd2mOCSCTDprfi+<%rk$>)Z6G#<1@Ct zP8GIY-Z|&wsMu?%vl)p&Wb1@F*e-9<&B73Z*Vwz9sd9y zkI~t|lB{T%CinrB>^ zyy1d^JkA3F=(i6+{qKS6Z2tfSmQCi;H%JA@4W3>+e7l z8*>xZ^s|@lx}WD0LgG?1-Sb}j_;R0|{{ZX%0K8EuPVf@S#Hh!!qJE2G=vJHRrNyx% zjc;RI_6&?k1m5I}D&LEDkCyTHazLSmLgPY7BfIF~pC{wznqr{PAtJ;!W4otIhtQ`m z$g@ZcjNR}`wY|92sBk-#b-oA`WILsYl%WcBC9F=I;NR)f1q`7u0UGLiq+LCXeF~aI zb;J{%qRdYdTf0q_+QDRS7v)?=Ue0)a@|Sih2~usX>}-E0$MD>&R=kcb7a-CN53;@G zt3EcUkhnxSGP8GElg3W7-s9u4E#^D7EpSk&*2i84kE85XWCSsSbwF+C=yK!s_jFma zA4L00DPW^E4s;4+DyJ4kkD-E8*`WtA4f%mCgk*i5PA_>+*@`qpvz6q3v%VgL(IfN|SHipBn6*#Z@?m;{e8U zMk3&g6}RMYQSBQsjgA2B*+?Ttu^fB8FT3&00ez){7m(t>*_*RWyVcLMZ3G6(5~H%< z`aSw~D4iKp^Kh{mTyXbk+cR;nVgR-HzmmQ9{{T5tjj~V}J)a|TNYrVMj!uYe5t!st zbazq-UV{sDUhb8nnBt1Bd_g##l8n2Q1&9^_UmnnLBj{sPOBy?5F)G0_RnsPQ?!zCJ ziT&M8F3U7X7_H*D8R*M#(&JAb!1{p#o>^E1^9~{XT_!>l=-Cc?_RyrVAR%1X_j&F; z9xB7ddydUJD$!93xKp}ecDGB%1yL07`dg<{t=EK}j|7q7QN*fdW#$AOFNyv}RYuCP zj03V@#@gb;+uwKiDqlmlgreJ`j79HoEPLL^4Oy{AvrD!GoE#khjn|!JCmbK$K3-Rn z{t$=@&d91P-5gE_+K+Q5PsiY<5|NR-IYJ7C98r>ixk%N$XOr=($ z(e*^$M88MUnm={xJx5v96dJU-O`QsOnJ!~1OEV3P_Rk5e3_oznU0SR>kaB7yg$|c zm6kUl^a^lqFvw-o=u|2yeVymi{M}@mIeh>w#@0mb(Y-Gp#alGb z9L$Jp%Egx4T*Hr&_mxV|l##gy66<2RdS8)))0FoXH_|a1tVOS(2wh0;W9d@Gj?|i5 zI2=8ft3%&qkzKA4tg21PvrEbQ$M>EL4I2ofm@wA=03i&$$|0xi76}_IwIl*RC;BwE znq+ODTmOy~MHeg5hYPekjBBEh+)s&N=JXz=M`V??%sV2=NpqmvZhfMPKUgyWPk)wQ5 zAlTqB<~+poB$uHV-tu*g?$-6FlWQ%!ED39sYo3-xcWuS?D7Iqlv&;hKps6H*htZqL zedKl9XyRMwuA!9tm-c&=-v_BZ@`n>X^aHnPsYa5J+Kb6Buy8yvY3l4vm(Bkyzg1DrCLlq<)asx`#bwC}%YAm`I z7Vk2(U+`>_9iC;6QlYdUo)*K6IDCHGRaTQ`UaX4=BJ^J*!-h{hRYaa-ZbbQOEWEqDtNfRE}OONh2brl8B)vv zHsH5m@QMCl~Rl6grWF5}!L2w2t+@3OZuF?7)PkA@1`4be#Vg=Shgh+Fi+Max| zgVlZC`sj6evt;4~V#5439WCEi!QjRK!{kdNxNz#A&dT6{ z$BGukVd3QW)noZAp>D`y2!WMa*{QKC*6U-F$^BJT6TIwNSq>iUWJZ}W zDxnin&2A!b!Ux`E>gf!z*o=m&kjqYU*Dnt%PJt05XqcQq7iAr%7vAF{c|Xq4XTg`5 zzeKtQY$a8DlDr_HEq&(?(9XJ8(Q{Nn7YvEWl2jZWv-)*LB&m^wkVsv5zzb%H@J^ox zkCO?>7-GQKD;+qy+uy3H3WE@)&t(z;Vt#HMu&OsLWlYtS3$X^|j=zZ;=u(Lqg&_k8QBi(N~NNP`mqG>j`&^seu z=;{g8c%OM6CCPa`?}PNaZpRkX6v=SNZilcNdao%(raoJ1P^IEtD1kep+iUfiw{yqN zm{L5HvNXWiN&-<@>;*V=sWoR zk0n>7&fdqt{{RLs~19)7tYey_`h2wo)73HMdzF?wK2%%93*= z0$2=kxap*B-TwfmK|Jw7fu=eWV7%qet3>bk{26KTlRBAK5yHJ5!p-Xbf_|^|HE|j+ z%JygZSaI~die$y4nWZadCj8wH$A&ogW2CO_n1*dTZDGuQmyg*ePa8;x>P)!1HQ3mU zhA~Hvp~F$G%Oc0TkYfG2J_3f;vT zc?GO3s%z3_w@S^zt z09hCSIL{An%~_3lGi>-faUd@;oVcv3jEpH4bEDW*nS8;cJ$fcCJ)iYe?@0xrBZH@F_gRkNz#N)snI zdkbJsb@9)Rlp!o`rp=%XLAJa+uI*`&UJO=-XoH2s;dW9BG>;+TlX>aVx6=l~#c=0T z01kD-m0W6(vP~Fh-ZqXXj_u^sE<=MDnbVIvI@q~C-6~~4X2h0D*nHoX^+8&fV`FcY z=eD1)xHCxy$<667J;iFBVgVVtL3K9sdp!#1n;9WI=PTc~$J6|}R-Z|+ZCC*j4CkIL zdS09g@8y28{Vy9%%GTvnA2HFFyT|RI%CT3{W`|l z)7J>CW@T0*qIv7f_)^a`t#XPj-PHS3gqmnc2<(LpsCbv!e0>kA{{TlzwT-4VM^z4_ z+}^;e2IEKmufQpf8fa8n&CR0Y-99-wk&v>vAaGbW;q<=6S+p36;_bz9*9t=b-5ug( z`&$q3`#<@9`Z>k&s3fe|FKzCeO*aa(j`MfH#F>UiN3Z7j^<6lWYx zE~>tpcjz;4$`!K2SPmRT#}YF2f1CdRQqKdK0^J>YIT6+I`2ITAws!W@XWJu@(U;>0 zFPlPc19?Qpk6LS8Z>D$l3RIKwLhKECU<#B9j`R$r~cVz= zW0aF*eEhso^1CExr;m6|4r1FusX|a)H??k*9OS1KT zqjkHJ=&9pyP%e%L0FrKQaP(Kl>ChV3>(VSWxw#hT{zWy^ZX}bK;Mlf<77E;dVyBU0 zg|||!4c@kg$B&@IDJ*OOxL`qGFQMVAd55w79S+X-lYkNjJ1D-O@PAq3pE@h2c-*!rV!9T~n%?i* zuM8u6VWJu{WzViEc(`)tbU7;ZH>^O(g@L_=zn7jB1mg|5IaFN^TDF3~WhQIW@%Xim z8-p)P5(aaMF9{bnI$zOx>EmULWD15MNg6gJ^0NJA30kgPUiuwP>@Sy3 zNC3N{YA>Rf$q|!>j^+9zVfjBdBrwU>wLcXWFstQGDFVhe-kw|Y-mEex@h zWs))*A&DZ~=&|JAd6)RVE^P8hkZm~F*ky6WhK2!W{3iYY3Nwc5(E&HAIM_*B~bB1Sf?u1HAQIc~dS7 zGuU*GKR%YRquFz5PKCI`4SF^eJZx`*rdexWo) z&WybB_ZLEMWf#@HT5ND%jJFX74)6f`3xy&nqnbGGptAtXI|sK`Gn7ypOOQ0LE%U1H z>BxfCRXC_2Lu<`i7UIWGkaK=jU?gO3W$nLZyR^0}Z(=d#8c}Q}g;i}k4wPAebt2=m zdFox_qXBV=H#*<5hYIhE!aGj9-z`#yY{pp2Y=mAYw2DGU9Axrq!T$g))O+5xkv@O? zf=ZV~jU7`_W4u@LEz|!1Xo?=Lsg6W9RZ;>CS%^QAkD*H1qRcBQ&hFNxmfu@;fz9&wt8I$hoQc#Al7>B^83IcfAW5Q=aBts95xNp3bdr&!fceEY zfLLaEEC#7_bD-Dx^vrGm(&JJoW$$Bh*0H{r{hHXaXC&$`Zfefj85TA=X@A2)0Aw++ z$S~r+F$af=n&4ePABMHKd5y2nYIHV-WrVKIf@aCm>RZzhmWC-?>tjl0VUC^Mf1~#3 zjwQy0jCpiUg?=t&-JAE?5vc{tv@;dpG7U=!$N4EgZh=4Xv8~Y(q(oLc8_i&`1Ixap%Cuog0 zlmUscCf~~T=`9F(mQ~asxdeqIc|Htt@%c1JtisGf$TeW6n4c#~#4OHS*<(#F&vxm4 zLRRKN!uMcE2?EJ`5_!vUcWT4C*|uXlD*R_^B~nXqBmi=amyah?2SRQQ!08npmlhfM zKAm>kwXBS+>Pp)bPIXmYmqmSu{B!P6VmXf2nk!`r7s3n9pzE!%!?W4T!J z!4Ohcn&m;pJ6JDqoAsU!&9o{7Wk3t`RgVvu=CNFpqg8xe znHdPNU`76ML$gOLyF{)4V6A{v=35U=sS%utk~#GUOIxMmz%Ey2P7-XpJb}j^q*&V8 zk=**alWBrk&m|cyZuVo+AnqS5zgJ$wF)E9JVSQC^{o0WgS7?h{u`;i1g7D4I$J6Be zXycXM)tqVK#zo^SGP7fIuRT#DE@Wk~2ICTR78PTBAGu*Val#x)AR|;>J?x6!RxITE zuiI>jl8qxWAs1Wo8~6CX#Qy*;m#I9*Ifc{2aEft}mm_ji7QOA#;PsWvh_(QLaj*yp z*4c9M$j$n?r)ec6jbc?Ljg>lg@3%%qPX6zAT4op{NgBs}T!DQq$o!5TJ!hTnU8$tr zHHcPZP(U7ULyzcJLjrV|6WfRvzoAr!rxHkA#D@%#R{OSA=t&x)`nKLs?BY~>rf1=` zgK&*{T#Kz&E;y@GW(}S;U_$0v#GN`^UR$SAZ4qv_Qd-uM&GHKLa$-wFj1G)4vWtBm zW5>%T-V1E|!*wdzI^UR<>Hh#KjB>j-9v;WRp=H_QMo`VOK7qy9DIAXD(B$|$X}f$1 zlKf~&k-4(nucBF(EEBHj{q1?%2G+>SBI}iwMF5+RXWNIz$?NSO1?-cvy9QT~-I-gZ zrVGOP)pfA)PT#cP(5$O8V!+668DaHn)$#tWmX5`7k<~4JW@aTx$BvI^_&!aOXj_2r zm9UYlsC}7}F0=K2z)Wzpna)8?hzteq>>_n}`13ptK)XUmot>XD3l>$s(%w4ne11OO zLn>X7G$e+w-ZALz-IB6xmuC*LWK}@L!S8)v+w!}*Nla9o5e}YaU46gnD_jm+@p%z< z>Ma3sqnMnp$|mc__px?~gS26a#x<}wsmqh)BTBMar;!pO7X|1dV4a*;%Ohvq7&8n(;x@FK z$lJ-`p6AclY*VI{86M0IVL9JC3P3{r0?!L zeBVr8XO0zZ1ZLP37Xss_9tz1=w^gUezFRt~l{?-X@64AS7f)aPU-DY$?zv@+?`38t z*^GG8^_6?PNh1)_Iw<1I27WI801A0ec~T=L(IA2g9F4WVKTnmHj!y5@)w}R@^=MZL zS&7VXNKZ2p>_^8}>{0HSc8#Q}Mxc?KaVA{`-BIluy_t7W92G{aI$55U`c_{bv&B0@ zv+6@RxxR#cQ}_IQ#N9WC(QRjc9CGaNnVjp5LnK1`dB8TKjnpvfDN&UjkIhrFvJ&g=@{ zLW=@*=i|dFaN0J-CM1xKaz*c8eYT~lzC$8e@2733xmsjvCOYp0!51Cm^fK;KNIZG5 zCDz-{lYPAD5ydLJnSqP*o_qSYTB0=6a9-af>lLUNpqy_D1t+2yXc&DOSlMx64ThE- zpkBRyi{uw+hd@_~KQTYy`1vQ0DLU#$XM6H%H!BR1q*&B3IvexcUdN%DQl1ba0hG$X zfpWKFdZDbX(!>hoB*s{&ux<>Ex%w~3dUVTvR-^yX@a#q`R1td&t}pDiK2nO$U;$Oe z_8pe*t5a3tc3`AFUSb)w>3v=mtc=|&5q=SFnCPz;tgPv1Tct89Zr!8+3zKX`xH&$A zd=)doz>rW74c5lMdV9=jz{pPN*r~7pISQnYCA{}Z>w3@T>cqzAgy2Q3&JU%BPMDKQ zQ0&OOD!`IhYDN2=pS$&c#!}ltjFT4&rwQ!(73oV7HI!)HY(dBl`}qF=IQ}l7nehht z>%)z@Tq&JO?FibXTdS_Et6#V5bSRcZBF}4SW*k4vr!+?^HYc>9?@!lnf5_plAy_652hD8KT!uJ>w@(kJonpq( zOuC0gd~rZC?>cU#_UX)kuInX`%5^GqL>Nis_YkIkSwpn6jhkV_STt;ml5Tfp7dbaA z>iuOfSe3+03ZOQ4iwhBEE9BCWCCSKeLN(NBowVrH!PVrKh>VG7)$M#ww<4#H(JsuR zdeMWTkigqjYS2mn4T%@X0dF;6kEFJ!{{YT@%Ba%n;2iGY4ELQaIC_0<9sbVIKI~dd!v;~27aXEIFF+(a zfWeil6qUab#P?SpYt~f-)a@COlQyHe*xQuP4Eu@x@!&+pK(Yo#1ly!RT`$%qK32G# zF64mV@W{Ou$aG|jhKZjs$tY;p=uw5iJ>ABE+J@_ws#^;PF+Z`uy0wTHHi?<(A@MlbokzQ&_{5jX}ujL z1B(WJofuHuMDu7*XX^MXiv%H9fsq2l8$Y|~b(9?{B_)V9fH5ecmg=_P!^iq~TU>jY z9Z(bs17Y!l?5f9T z1s5hDu~68LPapd}uCeWo!CA=0_aAZ#NzQI`@Q4Sh!M(x z1=h)H3&ui@OAdsa%sp?srI^Te3RLeT01Ij_-u_>X&0hj}dV3Ag5;o#)7_u6zfNb&m zZ%c<2**2KTC9Lp~t*-66-$X7K(J58{Viw-AKJVms^rc6aDdWJAF>QId zj$dzg?zgqWA8y%tDp1PNC?e-38fuMy)_S_=Q4GwWDyy+eTZB6=#v$JO$R%*H6c8WsQ0Sg-r>#t2? z_jNpxw0Pt@12wxZ#rwLk*_6nPr;HX-Ev8(XZ@(kQVn#FADbQjhQV$*@xsTAM3N3?% zH)z`uSfgug0q^?CzGco>+3p>9O9EB6c;S^Dc>b=D@O+y!=8f?}TP1hXg!xQ)K7~FF zl)lp|mLL+QKQ=e}&QHhD#q!iowL*=v%=&_igs3-G3|ag9bpqH(RCAqH`U|PfKk&Zt zy-lJg-R(DU>$*d)vd_vdj#9k+q6Sz}W-0@o6bFvh>G=Nu%UIEumHQfHvq!rGi@lBQ zh{={OLavqADyhFj2Dui$q>Cm|?a8xP3eR_UBg>AP)TKky_o@)d@x@;?* z$L~5!riq(pQYA(faO0n)#C@I$7eFA$j$zV0C3MT6=|rw1l(Gg3VU}YSIOCT`d%di1 zr6wIB7&4W=G@LAr=zZmMj&&vo&Tei6ik@OY?xhO?w&%F!HL5P4bMTx%*HU+@^ys4n zn5owyEQiu3g^|2}W|(EeYhn02)xhGU1CMAw(^}ozq{p!h%HKd&?QmBqTM{ln?b8wV znvESQj(hDkV5fK(9R@3(Dan0`Y*H=@ZUqZwvEsG+jNNBP^_Aw_D~1{mH$Fd;IY0p zh8Anthnqygy&r@!NX5DW>51VumwXeY?)-K`C{&Ot9a7|zd$x-vS#f<`Ya6z5nj=gW zz>+jKJ+$c8D9x0q7i)uNE76lYYXb0Z2(|!P;}C9NpwvhZsRUTEk&q{!q@5i5^`>MT zf+B^iSsW^Z&FFONO9i_=YzZeM=f!aJyO>hc-)BS^lKdq`t$rUT-1mMjw^WYB%8VW_tk+p8vwXA;G0yE3iKh#g|*Bl>&K ztf6FPom(3&pp(R>9Fv^Dx!FSp~&7a$b^!rFPXuCpwe!3F5^u z@_ZHaQZr4lHp;8Zk-!bOm#laZ{cIktzH?}b%oVS{jfMxb1@bEwliqoK^=lf4%slilkn%H&vKB5h>34mYu+Ocmh#m782E8m)h5-3E^IG2w za$>S-2)VzW-q%K(6og8m}`pQq#GFNJ22Hg$?SgE1;tb}#+juDnLm5<>?X1p`&?;gjH=KNv`sSqq>IaLiWh zZf+^98zgK@>yQ^g(j448AL9Q2q*F#l&=@JxoLqC;9pxIfB}1rNn~~`S;^Xi1Jb9ZZ zqdum5U=l$?orh_#KA#OJ%_by}){NTr;q>=XdcRJMpp>X=1a=q7*pEsr`ZIY?SJzM` z+az*1Am<{BjuuC^r^ANp)$eK)TiBd(gcclE16+UWras8wQdV({z?9c8SJmdfx$r?o z&n?0z7wIjy)p|c;^P9S$G)l%eg^+V4%W)2|_`C9b5A(OBr7&%gNf_8RmSu@Z-A&$% zuP3|OGwY!J-|Q{F;L)>2hk6)3gG(Fq_szB!ZyG`)al~(_U`C?d z)cSrtLvW>moeaD=EK1n(_vG@QdG~bijI$h+25<@nHX3!DOULcU9UIBA@AorDoOcbp zj@!sd_mg21`!3PtBl2Fm(bTNj!YG(!k;>z-h|n(z#QZdCvr1LnNGh#hs%+rDDIRIx zkKR@C!?MKtsCj7FKvO3RES>&Kkd{hV5`mS*@ouBInvU-eF8Fx#t4@em#_a=>2%HNN zbQbjci7WW|DlXFpj$ko~L&alv#W-GWF6Mu(rOO#^8{8NuNEB1 zGdi-^+DOB5d*p71)bgL$AV#up4aiw+RfCOLTif>2l&a9}5I-Q~aFLdo3t^WjDhmNzc#3J^ ze(tr~#nPE>y?2R`2vVx7K695!le8#!i?d|0ux3MUb;Y?vT3}GibY6h5H_YYI?)^U< zij^h47-O1%bhep)S>xsB9Sk`Z$=jcYLt=fN8WW1UV;)LWAkg!UQ=^w3kHN6Trh`k~ z=TGHOupzK4E~6u3-s0%@`nnd9T~5&Fq7BNDy|NwaV;^z%amP}~&Kyplt75Tl<@DoG{BBsNpb#1s0zAKp-rr8W<6yz6Ug{K-6i4?W6ndUe(Dwrw*- zbH|7P6~4Er)ml)>TKDKTb1P>X@%a7*T^NNn8rg4cN80jSHAc+|<`%WXSnF>dK1|MU zhOZG6aqVuP!hitGFCgQ$li*%Yl>N`>3ASyd$&|LlfH)5j2aX`(Eh3)eGA3~Ul!H(^KzK#Z?+4`S(;>^D)#SCxW9AAgC ztF4YVBwSnNJoK36jzzpKBmnEpk{?#_^^w1ip0t!(nB{KX+T4=fSl=85L<;L~!SYcU zhOFGelG{tGGO!mK>IZMy@%sbgQXBT$un#>}<1JdtQL!7{)Dk}t9?cHn5-rPI1<8`y zUi~%wWgi-X!=weT;i(cxN|Hg4#tFxmlZ$d|_A48chP7$UlMo7kMTVW1rA?0Y7$`28 zTK(P}IDa`rC549x)Br;bLAYES@+k&jUGd^8!Iw6@IT!3ve3&WJv&V_OiDUS>6*I=m zwMQdkiYXT6>2YLvcEXmAlcwn3A`rjPMbXIExEg>L`Z11}R-f(#Q zZ1Q~{6eMvF0-i1{F&4=eAE8kQmDWdKSh?pMJshyAn`%S7Mq?wTi>Xn91qQATJCTly%-s1C{PeBQkd{d!bz^sMw_+7RCftT^pNQ$(Gyur5o7f%_IoKdyiu)WH z$z3!`lIEZ*FEzF8Zc0Sq@ygW@&auO~H{K}c>}%zg>+1Y`nzx2@R?JjR;Xny6tnkx%x zy8aR~uQO)6kVeFi**?9Ea?<|*RUKU{yRc1(vas#c+y0t@2|_^vAmUpyvD4G&#~xmv z-uudN@kM$OBNZW(6<~7LoE{pDWr``r#V{N;5WhOIvE-tEc~rqGmB1*;Z_nDk->d!K z@Owc7%Z%z6ix9wAsT^DqEmG@?;dyx5aOBlb4=JUT18Z zmjcFFavHtGgE{o%PK1AV{*(j>8b=r`S*$?mI4iwh=JJ)k?itrL5Z+={f}PtEq}-8y zn*9o+XN}nr$i$n6X1Couf8o_p>@38*s0R*!zPCSCzOTob`!>f3h0-ZP7(36^Lig*y!jQ`c{>Rg zi3!&tJeS|6k2xf$%?k{ZeoBzKbXG|nCgJ6GSWDH`vCEyaapCI}47jibPdGfryp5w^ z-2ugt6LrZ}%kEo;p;v6-ST-|A19KK5g!(*DR?V=;*X5j+76Gi$-_w>1WBvE}?T3dX zKx)j&K!yy*l-uV%j1nK_a%9FFu{=5==i=P>r01a_M3K zH)T9J;GR4dW%8_fa_w$IFi--uf`UGk7snZ+Rv={BlE*>^r&LeyVR6F2K-6Vx9`5fb z`f`0<_v<|J&9TUZRn+qU2(daIo*7qJaISV_V3gn;ET;DC?CE9W{HO~XDrC&E9F$~k zHOG?ruUAAiNQ`_O%XSXDM313CfH)NaOJj-GRtGO0FPoI~)Ra&}Tq@>ONYfuU;G@|> z#WTcFOOm2CVUp{WS02G6dooCwFm(*dMlzv3>)8C|0_@V3d2ndJA+wNf;{iL;Z?LR) zHKWo6P~g}Z4xvB*4=08{La%`nY1kDcQmpvKsK`cviB`Xq;`elew@$_X9PC=E99VWRK z^tia&r%93*#f%Z%bOnbJkEG3BwA&ub7A4Cvy|_ug(vi1qA16mV+c=290`g|ZHN+sX z_M12H_-*L%`~UX?%YT+_)X7CfPQW7|9!PS%DnjkE6$17-T?7weNLqVSDs`ojPxX zB@Pjbo12Cgux^K08h8G){FEo!4dVvtNdZO5-^uly9c5hTaJysK!~X!ZN-Uv5qZ4BO zllcDtQQ9{OW0dR0^nv26saKoUh4HsQa~#=N<9ltsY%WaS|8{CtR zHHI=TdmkL1%*^Mc>a%BEy1Xr$R>X$aMW97nd8`2H^0w)mfAU7)p; zVjZ+ROm}#`?>e7k+G2`5r)b(GW_aF3F3d_O7`73{xF_oVujS2<%Jg5>nqZIr5X+5umOhw z_K$Y0OBU>}7_p9R)9INZE$ogqdYWyxB*HW-LJ`BREVkmyznV!*n{0wzt70;XuHj|g zUkAH}ZkK)KMCz-~mSEd-+gvUP02M|YVR#|sl^4m$Hz(_VnocRXaR1Wq^JXQ?Yzk?M zHKYDdDdX{uPVbvBAe*iN_WE*k>(O;-jF_Nzjg*{5t&qzsy^KHY`Uo4rCRZ*wBdLBMi?_{`QaADH7IXCI?)Ww~RhdCpg zLPxbrf0Ua73w+o1t5Y3|PnAOg6Qz8ZX>0mR6Wi}($WzpcJ z$s9$%zLyF`nl=~jyqyU*bRuR{R~Q1ZF(FW!*aq~dLs`Qq#>YamMgzjZ?0W&%Uq{FK ztx=W9Ve{O1{jRI^lUTGl0QJKuLo5-8`;p2N}9FYf4y zC{pTjY^`DDwplGoEr`1=P817b09_b&eJmueSK!UuS)#_V%X{Vqt~h+BS;<#~Cp`pg z9ABgk2kq702h7Q|a)h?@1t(TzZzs`T-S0ZeaP!E^@bLPmjqP)9pP%hjvD}Mxnu}*q zZq0jn@Aj&XjO_5Rj!98uwm}IQ;ycS1Sx!vi4}~I?;a5URp}Vii>pteZVy4#|;`bxY zf4^Rs_+Z$P4*>@PTc7J!9uL-kZ=Meo>=a($3zc5}IWLdsbFR3?*fs59!%lzZ(5MHn zs8T_=u-B#a$ws#tj$x=j&d-H;Paj9MI)kS%AoGFU#017q zW#ja}&Mk0!+~1P44Gq{baIgdeYxDYEt}h=jknOt``Z+DZBWeNB9lbAgpO2HsZRASZ zEQ(dtbym_ENM14qdKMNVn^nnUX047)6(wg1*fu#522q|%qvbtaK3#@Ltt4EuNpeb# zJa6R_l>JpI0<7jXwV2&W9&^RS+!WqQ&l~mGCqjc~I1#1O99-KhJsCMY z-B4dGj|CSeAR74wUuTP_{{VCFbJdm7UC+pqmgF_UVWTh>IWnxh+LvjK)yW!}mg{{; zYjA%fdE=S8K$D(66*z}LG6bF^>2dh{NZVtbcR;n=pTacc`aR0&Jn4(!hDNX~dI$!b z$1LB+iz*POX9dNq<6hiLCUvbL5e-RfQG*SAG<@=-*$_#yy4V7(dmlx*ZPM}kbjh1E z*iQ?dCv__A&yholFXSb9@yQ5a#c~C6T?Y<5Z}Y3djJZh3)lxt}0@g07{U3XkdQpJI zP}$T52FyEn-@LV)n#GbiX5p~j#>JHA!E$G5f0xx%vIwTTuqPPX6&f3&{{XYM@@ev^ z4+zBHRcnwySqCSFljS5?KR5YE*`m@=7R+KQFWIks+tb6xcE`f3+3;@NVI*z|We#yh1-QPo_&HApz(mRmfp8cd4W8q7$6r9M&j}5n z7FJcqkI>#PkL|PhSkF&3U7}5)Q5FS>BPWOx-p8!VkGbQsrHaomKqNX9ARBYX!_v;J z+TJvRVI7METs+~+Nt(aduma(sE!%K%q|yt4llM{{ZSdc0nKuiBd5C81MKLP^qs$4g{638})T!pMQO(#7qe{4Qj)PL($UWNh9GMa|F#t%u zN-V^V<1e@F>c0z9R-#L{V3XO^A>ty`v5)BtV5)>(Uql&yYD#EP~e;AXr%QbNKPhYEEmV?$_c3 z54Bm>EpLQdM0c7%6o5ftbAP9gkMlfBK-Oi%1-<%QJV@VHyx^gb47iPkxR6bkrGom; zABdoQLx+2I$}4;PL0GDGZp{;LEEa7bSh)+wtb+)5K(Ca*PVF z(UkLzD+?R^`X8Q~wHb`3d_k2!1g}QjDDnKSt^O(L4Rs!np{D_3kkC`mSgty0d(tmp zLvxq&Aw6AC)d<@=t0)?(XkgG*^q*) zIn|eqJa{x!<;k+kv&yow224ut4`3y@W?xxTm34@qRsn8qI7wcX?JMY{B2ZaHpHn1# zmu{S~{{Z#n{Zzc2M9`zgQ~(a1am;H|FP3!iaFyg@$!o$?_IJK96K>zf;Ma6# zRH?RHWxEm(Hip?dG2{AttJg~eY-B4LP{*Zy5ph?riKWaffYcI8SP#hY`eZxCB#Q&C zAm|U)VZr0^<9xNJsCF6ll+T6IKpYF=#J^9Yz?-kue&1Rs#Bn4*+#m~c@^oq+f`YEf ztUK4pcU7;0#S;3y3iJ~9KZEpm@!2H%X3Qsuj@9BHHDEiLsxq6FT}{XtE*wwM{Ou8RO*dl_P4Ty9(95fT z9%W}BJHmn2UJ^pzyN-LvK?97i=qKklf4uryX2@`D^kg9`IuK%E(j2hs5uNV_X{M#QfB>Vt#ROK6)@g z?B}7QRM^uE$T*9?V7;=XDY&s3_R%S}hLAlOY{JLQ+6^v5xfyev1OEVvO}K7-8_(yi zx#qT>Z#_xo&lQLbdjoJm?Wp3}i|sEXliCXd(mXtp`n)`RzT3jGDFj>}KRVZuw810~ zo~1h&9*Mt>wwTSo*;GNIWh%uFFn^HxR2x8&MLb+wz>NXdy;H}wD*&?avs?z|NB;l{ zT6B$g*r>axH&M=h$M^pL(&uZsYlQJDU<$ZL1Y=+rS$J6@R`rzF(-_HDyD@9Hvf+Gq zF{si~vmmvFhDPCa=w zJR7w4ZLc9DFdFDV=Jcv&Io*R_6L7ce_CAFf4zI>6#czji7aWoCFOTOeWJQ4*oqYEE zn$((N(2S4H4>x9=+9>XlfJQ@|So5jzgKoiV7uOHUdsbexHw$pNYq1Am2+8J9VOAYVMIBK&mogNH!Gc(Nut%T!O>Re*HN%X!v_oY<_!F zBw!984Se+WYFkCmd`HjAGi+ZAnSzWY$1jvV5WnkV@!9UQ*XZ*DWB6)D_&z))2`Smn_;O(mC zhXo+1^Z7p>4G!$&n_Ewwxm`EQB+FlFz)+NBW0(f^08qTAc%TwgT$5{9_j*(%)UF8@hYCh6R!Jdl#38R$K+M%ha=Gt!FjevUZ2-ft-YL+&Dly30 zEJgR2FOMns-IF^E5UBa?Y%6v>s@puw016nb`G*ezzZGV{Bs(l+NLOG^O~iPavh*uf z=t8B8ouhU*h~~bXneo;{M$UJWa*bj@zX(0r=9Ny>BNOKzh+|U9j#l7iZY(Z4{3xmE z%>F)IMNNrR1uu%^NhI@%Y^PQcvi=)tYBY=$BFxUsU{rKK>o;4;U*iem=AsDJb9mzl zW8!-pq>;M4?(fO)Q(%Uxfn_A|3$C4?-|_Ppkui$efI~I=aXkj0+{#=9RxCxd0^?pT zwWEjmf92>$v7#Q4%R)f|7ViDs2`1d4O_`cC36+T?U2z}YJw$!#A$hmDw_0}vh>@zUSv_S6;Wme$Ub{<;h(o!#_%c2ChkQHs_NX} z4lNEDvwGUSYz42P*Ph`5BZFY6YhSeVyR}Rt;4Rw+Qb?$qAh5~WgHeGZj+3A07D#@jPD0T*meqg=9Ir{m)8` zG2LI8PqW{ymi^py{H}g7w!EvjK#0Xql2+d@j~$1zp+*gCTbS4z8upiGg>?u4`KnE# zWjQLzit_~v-aeeKnql{Ki;kx)Ot(@<`|c*hk)UjK_N$v!vh!O?RNG!mHMc~PKOZJM-~#BuiYaRg-spUI=XF;PYhhf&hZ-p&#NW3nuMnDFPnVjZ z_&aNkKtotNl1;gZ2OrC?NPtC;HFU?rmR?F^Hn;%%l&%~0f;$Zr+7dfsJ)UZD5y%HI z2bj}2ULI`Q6k!@jn4vc%Pek9*@%~Cvyy%%42Kyt>4y00mhF0{-*c@$=ON$2%@@DTpxZm8{KwC%g52@i$fKFTwi* zY9kCTO|EgfdQoLo1prucQ#=b$Y|uP$5bwNe-z&3yo>iWgw{Juz1uUng0V+WB(-uds zM4$)wP&y+mn#3N%s5NM9^3zDMV{xNu0UDbMM)Vdc4VbGE&gG5Ql-Q>3fSay z3U_i}6<%~bC&!D&@R)oF5K}DE-se#o-09FnXKv*)Ud`Lh>5-7c1_Uz>AX|^2#mB`| zyWuU@TXy`OJbAtCo8Y~c2@Hpna!xsOE&WHskdk&PDK`M`(sIwad9UmB21}^!LA>kM8`~+hrY|U~;oW%cA9~^Dgs;+x4`S zGBd6~pt%GaYA@r*HZQ~GR@xoUCXAc4bgB4#<OCC2e0cJMDE@`~t_xbWvotxoE z<=aVcB2ce)0di|#$ec^t5-fW#}k*avF#!~idmhq zH)bSs)RUJ-^^}Re5mep0Zp^YH*k0oJ_v^=oNn?%BF;L*Gj(3T0uPWA z#gFC5S=jE`+f44;L68yyGLwFg!Fp|34%AiTVopF18&EU8dpw+S){QNu> z+$(Z4=M8Qp*V6I+PU|~S7Si}Qy@?|K0A%l%^K$s~NbNJ2o*cMkc9Hz{g?DO8`z{xluLDqq_whOMuWXnM;t)9!j6t^Ra(w5JW<2Q zZ}y1&-qRl?>uX*XD`!weiZHV{%hks`OVi@zVUv45l-&949ib}vHUUeFfVm`exHq0s zN<(XFFk&>fZnbXgBwMz)C@eccBfY%xd(R&vwnA*6gJNt!%R$oU)}W<>C5`n0+Vc@i zHzxMI{{SD_p(@SB686(Bb9L!-Ynv;;v&S@k$ph%)ztC=q^ZB2 z$DdleX4#k$?}!VjC!=%F-pSDX&-xUxWW_qywZfsu`d*rEJboPad>tSS%ngl)H*Ht5 z!MsLjH{v!sQBAaH6@ov4>to@DHqS>O8 zeA0tuW{+G>BcZ&9cQ*TbpDho&lChp#U!5{7q*>0;7zUA5SO&$Btf{npIVIir?2*EN z7;3{W(b4*KT&o;$D#FPjy^bhbQS_I26-K}w7ElhN!%R)RQ*xBycN0M^i>1~;O|QhK z>^*V29E)m2?`(L_Poqm5EVgcQuRYh$R{I`39TNd{?FU=)Yr|bMOepGXTEy$2(^>~| zq!Pe`%zFo={K{js_GBW&4xW}aU#In+I~Il|$1Wr+FKg*zyy90{qCFcWx`0KlJo@yO zorn$&MGBkW{{R}9WLX;{ir?iCpJT^@<&I^w&Aw3J*m}-BIzep+xx@l_?fmZ^02SvJ zV8m#CanPe|$t=WOdB7I@xjMY-gRpcs+~1fR+zyXQf=om!9M<6JZ7-!6k7rjHT!YWt zvtQ7yLm8NWZO*#zg$*!b3FaMGYd zr$pn41Cx~`Uam?;%ylDAKDXxkHRx=}#Ggl)fvWo(ym5LVoW1>%Y}nOzNLhj@Y?}DJ zxk(p$D`xUZmz1wH6;5OzP~^3_%HNdfKOc|9p#c^ej&MH>8eE%pUfOf|RhyREn1xM< zVQY@xm&cDb1(rLs069~09Gq7NiT86;MTVoKZVkn5fe5z0gk|(qJbwdcKux6ujH>Yw zahq{*ePYvcx~V{}t9#?GKYtG&+Ckf1A_?cdtNBU5lsl+djw-0+U9ll}d%4H$`1wN8 z03#!NXnR4+iMc+7VGJ1P3i?;EHuU15VyVR117Iz4#Ckk2uS3VqM#PX-_5`)`10PO{ zQ9DygvF5A-A-t$YtiI23EIfRi12Td@EJj*r4fvM_Zc}&lv^!QKXO}q^Fi2hF0tZX| z-Q(pS#`orzW@JfGSe{nT88_Xr2L*%H{a)4n@BA}wFe(Jd0U?!v1Z8}c!1NdTx{`l` z(Z=x`buiSBI9U2pe~H#r>aDI!WIB-63flK@UQPUd9RlLo6Ww5Vlpg1y{vB#nr-oKY zZm(rL!<>RWDE^&IBteV00JXs@dwe->j))bw_6^ zF(THr{>C`BkL4E6CSZi(aMlUsG52xw`5#l|DX>G5OP4GF?NB-@-PGcxNixd0}2x7evpe%f+(|>ijq)7{I^8lRzAYEBa{{SLiABBp#cr}cpWMnJ> z=36Ms7_#nC_Hx-32ON~%<;3IQ!8yEio5`swrU=@!W0L9K$zDvse2YO!E%=ioQ^L-M2kB`{K zKp{W?&N3fSkdqVN&(QJm5^!k{s5?BWF(D)bmKc^8;L*iirpc2;k}+-TfHWhbrDIM? zy!XXYLCioL#`n=%KUbX4H;&7oDzSj1=l}p5vdlLk?!Ltt_&9lX+H4t_g`5C^fh-(0 zd1RlP@vP4Tv81Kl1GroP$#u#%iwE`9>98^Fc-%fhG8lsxEI2aa-(iIY*3y`#8?`V7 zAiGGLz6@GM?oGVo@?YzTmx$G#-bAvzHhfqrw;j)$lKKx)ODZ_jSw)SzapK|1y*z|v z1E>-6*lMj4@JYh_ruIHYZc>6>i6ftNh%mDYc7jhI6ZH7&w60bdk_Fju2IM7!4r7tS zl=1#??F+ktMaZ$TsBUmb>v;X$5~F-w86Tr^E;;exmuKPnKCk??XWs1N9(Lm9u(h#t z!^iefSc)U4VW9^_Jj^QAONF+V9K>aN@U)9lz!iuEg5ueID;!(x@%`tk^ZG5bhE%pl z!z@ocWOS3IRr0ZU@vCmpRAS?tAMe&my`Wg!+#fv*Zp8xR=%8t)Lodql`w11A`+2fLSL(x(Wyg-X}XzV7OnqGyqLSEo)?KD&)qkU+3v<^j1EPaAu#miiqR<@D%Wiz&xINFa0C+Wo6W!ar1NhwZlxI_u}+-hH|@8_qQxW28Gxs(ZDL zAeOPPBjw3+$LUjcd2+X5w!cfG#pC?sG@Z9C2V)!D5~lvHOnT_w-v0om65=9JK`YLq zx6-1;7>nIbm$#w$OUG5Fz|R?zXO(3--7SNE=i_#9rZf)Y~?Z@$@U-8w4*HeWENTZm=t1m$kGkTpRxYUAg}NN@0Z~0Z~FA*2*n& z`^t73u_zlnYN${s#4B{Tcr*LXva19FY;H$C$wL~rxjGZ(qs2wdt$u%&g!6XagqkUL zQ4WZ^5--`GNvI|xxKi_ot{;#m~N>yogwzDo0__#HY?5L7TFB-Z25 z-c*YCRb_iJt=&eJvC#QMcH^%ho^_052-ewLt(&5~Uhl{0^`T9sGB#E+Q>mhEp_w*Nsy?HpTX#gb_)(i)8>Z?m5DrCG(r~}XB zlZV^%>9cHd?Cf2v!?lIkIkC<@k5}X5O2RXA%ORmtpuSz~eN{I`ZK)U+%e!F9erF@2 z9dxK9Lx2WEI@sLW+@60=tF4pb_QVT1k_pwa7wE&oT)BW4lIi@l0-&xN;hK_|NL^rmEZmN=reN=`?<&1`f2UnL)bToWX&BIvhubkv@{ zKhAQyxj(C>Hp;HswpQm2pthe&rj=Wv+qh@JY927`2y$%sbF&XcD$S*sn_j?k50;k` z#|&8>Va$CdV7O!S{CttT@;sbeD~Tj_i*&MD$C}jQEJ+nsQ^kV z_Z~#zEua#Qi9G7we?=Sm+H3qBGsV3?B#PLp3z2Sy^Dig6@$yFT%M93+G02v>9?$Qm zkDl*?_Tw2bmna{kqeax=igc5d{OxOh49m`{U-p(zsw~L1Wdz)^^^e?{Bx^ z;BKalMh5uetS%1p~CT-JJZ^C__X$zsb4PXETvbJB$quLzZT~kHoBy;H93K!W~4UN>_P%q8w zVzjBV4|8mlY`wboTO60i=*5!V#_`LIQsY}anBK*fI}!n5>oMyn4^4wEqTJlW=+-l` z%76k0*2LcAWuNCfe5&3O*%P2Q?dHFty~zB0yON7#QXI#OfSlEX19QV}I{Ib!RZb^# z2w(^yt{_Vg}4j})Tg6{TzppAT>NA87!r7LYOGwjr_hmEZS&>UQIJu=iQuqwu`0 z9kzIQoPvz}zyXVR@Ns?^olql%cRqbG`w3Y#%PcUYO%?*!mSLCO@NXmDJxbd&s0qP> zTFk`=0DWVpCFAf@3_calKbNnwd?>R=iI0n1p5b)qZX(L7U8UJFR>&%cR=>-@oIS*^ zoJdBXuNRq*M*jdu4m$R}3x-goaV(>l;%wf}9ue0*rV@B>46X9?5tyd!6vV0k0Ctb1 z_G*Y@GJ$)Wca3a28wW1N!Mn9!Pyhp%R&m1HyT*=L$;5K1uQ075xYt}>-L>0?B+~54 zOE}fMbZgIhq4OWmewBMPiFQ?x+#FABa?hu}adgK$ab9FME^-#Ix$PG8Y0LqP*xUiA z>=aG1%&tm;RE-Eae>RlC(MT#l<~kqcS03B&>$zTL*sm_k(V#3{(J=#FlE~kyN}}p) zZFD4!M`+=%Ph{Bu#!^m(>b5KSd3gSfFS5YrCO}Q^eroGFcFFv#T;#8vo=hc(E%*`1 z7uOpq!x!r6FO9On6R;cBatfQ9pyJ6s-y`Esq1y(^g`(Wuh`C{F+tBHJVLvyzuNjsm zd{GKz(tQcHpuhI2lUy%6r$wCX+)N|886^jthbzfp?dM+N;)tbADgX=fcPq(84#DKp zJW0~)e0+H4wrMR5#+q4AHy3uZpEfl(iDGpU=6WQvX{Si#TRb9hz^u&2qkV1C&E2RW zh+f#_6le~k@@kyQVxd%yVosT5?fE)=go+=CX3Jr_$P&WK(b@N&ugA}R#S~Hn9Tp(xSGXA-Ee9yy&8cFPmN{xz0#Q_6{;{&?QOLm%+oNk6^oCwhJD+#p=Vwj}DOHoQ z_W)_>a&6=M{N{~=uR4}D{W(UpWZQj^V51??vCBpALo$vY@5A@v0KJ7 zPL+r@H|#d;`cy((t_9eOT=Vw#8j#sEfZwFI$PRpR{{Yudi6xNoXh}T~2NAbVhoM^1 zK0KJ&FPjhm8XJ#By~c+mLgSQM8;)?p+wuLg>>C!~XD+gyVa^bGK0orYA&Hn(MxOGYY0OR(AWeBf{&;e|I_JKAXq!L69ZGLbqcJ8OU zrX99$q1IUxH!&f2k98k`x^!Rsw$N7Pu-6a>0Ni`dFT3REl6_j;pW%C3m1tEQ3yX7U zttaCq1_?FF8-;$*M{-(^W83GFV7J~OI6JWjp@)-;hn3nr&xg&cX<19MY7$Mvi`dy? z%O+l@$Fp0mXX0i~b$<)iV6rEi9%_+XrOKTY@hRyoNe{vcv#ak2+@MgJjEfMi`2C4| zc8T~-E*U{lCO|k-e4}E~t7YHAuJOfNFF(vYar~>!)T+YH2sY6CzJpU%c;461xvYHk z^LtOiP`ukaAmoIF-4|1!-KW^&r2hc(6KzQz@4nq^TWb=EMq)y7E7+((61f%AEJp9_bf{w6XNpkh*9)0T zp7Se3B)`h5y6=JAA5kyD$-gNbun@GN+e@DzyZ2Q;{{S?HWE(<(nF_BHY05h!ODcQ5 zD*0nGu~oTYs;NGTV~0=Ju4xIjEqxC-^YE@-Kgr!^zIu;qZ}>_!GEEa?B9)My$xdz;l^Q1+B-1MWBv*>!w)i2>+h7Qv>cvlIK9(hnRXkJQqv4sM_|npJ7-OH?rU=E3 zM$8R~J-0nCyz%mN)H&0lYzN8xKOchJ_#4@AyDUYop|K~=bKmj)dc?Yj$;Crl0nO=F zs1l$6YTY0QHIF3gTG7~t(&|Y*ZFnl&KrM^n2CdoKuR^g?hQmuXr{uqP(Br1yLAkk* zYi`@OLxp^Q4nR*awZXYMSoCtEYLg~R$VWazTjak^ny4}+`U{XfFGKlsYB2;6(idB6 zeK@#$e{V*h#=OLjKjl%}v=JKRpD;^=2h-_4*rlOe<^x44a<}EvnB3e8fy@cG*XU9T z5X@D9Su{(HR^8gA+BPyBt_Kp005-kPr%RBqNeQqk&PSN~G%8A~79%T*{I(zUR+d*Q zWXLJd*dENfdD5J&F)wp+Kq`Fahj{+nb<(ysQDQ*iJ&Y;XAh~NWW1@~6SoFT%kC{@& zSYe6CXf^JRK0cf9c==7upn?oo9*gv^YnyQ;WM%MoI^Uu%eFsnZ71C(R!xd03#c(gt zY#)^UWp;rqIqh=CPw4Px?<-!D|I`T;RK8>(Ch{w58-w4LE-mcbJe=KU@R_z5SwrmX zT~A6f++UaFKKGT~)nqufI$uQ4k;B8%lgGRB?1v(9B?$l)04fW8gu*Ybs#dh|@_63F z8^&}aMhEDBa<18>aNHA>%-9ooH*Ud`e(}3K|p-ilbN0rOtGVecAkS02xS*MnIv1unWvA*&9ulVwuQvTid9!terEQi-Ea`+Q{3zqgD*jE@OfgQ`6sNQpoNY6?@*}Bm&2}*m(D! zHgNOD@Av!qm&mN-;cu%k=@UD4F$C3?=^9CQZ8|Ba8!pBV1^*A^&HmHOeES%B1nPMwj^;`c{y7M6U!&v?`ny5IAhsLp;=k&X$yeM8+CX` zPvb*5scX=*bmfZL@6?4o^LqT!p@|EK3(`O)+ z6LTJ9z69AG*uHshciqo-%GbI9Adz*oi9dssmwC`gxV&~U`XB%jNbTmez|(`J=x*q6 z2Uc4(^kHr$Qmq_|k8vJ}86^eLgJX0gWr#iG!y3g+*%HuB$j?<>kXG@CP=mq)qYsk#t*)$|BRwzu=QblcMPem@2^n)gwQ zUi$O>YIS^{kC$5-crp<90O>#$zp3IR-AfC)k9Q>{jD;RCp}%gRgVN&mE770e&>f5~ zd`0&oh=pU9`#$fK`DUig>X-{wic>qBF&3iGv}29u{2l{)hFIuVMq?b;jfXTer~k{`36Zf?|viTZa|b zjB9Uc>Y_-}_aSl?zZbCSNm=2-!o-p;3XsJ5OD659rb!PH^pfOT%s9LFy;TzTk%2r) z4x}Ac{Cyg-J)rK{8G?hMTyJsGuQ{61gU(fm>;Z*(I33fqfh%Uh+}iR>(#`(>eMK=L zk*-4yPZAeRN`c97dVW4{M`N1t6&FIm0Jl4MN`aA4@+!tYU~Ftnx^d-B*^_$slehKPaHchiL-%)&qhb80v-V_*>L$_x!C)27glsWUV0B!Ssevv4LcA;)0Ay|TVPD6{OAnTXPK^W;AzfJ&rKC#tHuL2F&sAR@#Ewyp5TIS&M)$J7M87|hSXSF`83Rh?553q515WR z%MD3Tn>oa4ayo&hB$RQovKt%SZhp^BggV%gicR!x_P#uJX6`m{5Rutqr%SBjr)DT% z>H#|5__sd^JZqQMQr&>CBmizNZqH_bGojvVZ4HjP1BZ_mAVZQ=b81GXL&dymc3fF5 zE$1M7*P$9Q4Qv=#Tbp0&IIBP`bg&wc_$UqUW!BzqZ^*3x%*NpJi|cPVu0}Sv?AG17 z9IHyNHU3L~(O3bsw7=p69t>KL0tQ&~4$qpkP5HQVvH3yi*FL1&+s}T>bm&)8^BZ61 zqK!TO0D-;|Rg7%Xuq3cpgL)@7la#M%)>zi|X`$Cra+)#hGk?A8{{UIa{{Rn%1oMqA z=F`WI!|nVZ@T`VNR#{m{y58BLUgR^vyQl8*TSPpxUi-5)HI81Z=o5FeAl?OM;JF$asEvi-kqgr=J4E76oj*qxz7C4jco9~RF3j`yMC^rCH=L|o0D zCgznm<3HS^O(10#xlzM=ACiM7Spz3IIU94E@^wQ-JHO@Q$IiA+e(n2wc>W?aSQQRB z4`cEJRczAAkCJNPNwM#!;*lyC!5+yYP%V$-)q5;?$f!;^ok+Q{^rKpn+J%yPNKL`i zAr`Twv~16PF|AY!sORSVUOaZPttV+%dvDAuvRjq7EHxvvR9jBcLrDXn3P1=!vwKo~ zUyq(dN3_U|lM?G~NA&8Ww#HGG3N90oM zJ85}a=k21%i&;qo*pNpE^_+jz&_g6K%$P{IDg2@K>Ot^a7&yx0fq-o~EUnisy52`# zfVE~cr-MGuB0w#RUd25n^uF_~se~)N#aHlw*lXw{f*F7=2?Lm}kNHx1y!kH|33hFo z8yK#_*4{nx_LHF;arC`qQtB%k*HPGe%}i9atZktL zp3-<|WE>?&Kbyy_tlmuFB+gxhmdmIs^f^aV`_CSD5#6LM&(bHipNopAoa#VgInh*h zfzy_wEoI^DTXund_7$lZ%BYRWu(;IfO_=ci0NtzBRmSDHfCtQ{qTiR}&EzSgB zT-%)g0Nkn&t<#t>a4oH_-42hj>wnU*O0da{w#XZIM{SQ!>H6BCSYGS+J3f|G=WRqR znhR(MJlzwINqWrv-}+GuxifPC<`?#Sc+)y+UUaaM1;D+zs0@Q2o3qbdt%%SQ%uck4 z6BZqV@qayhe!;4FeAiK@K8K_9U1&zrrOK87jv;@e%6iWqf@C~71{ViF2)>>=x*v^O zrZ}`_B@yC zf1c>3EOLz5yBD2U^t-hV<-1PHv+ZevjaQ<>M?>jG_4E5)!8XpszY-p4287$aDt}pC z+HDBun$9nX>2q-Mc=B?`udK({)5(~uvshT_D!|+d&eNjS%JuNgXJIZ zI4jkLc;rmb-8)Blew|OY{1OOegNS#d9QQn35254D#~bTwkQGs34b+fsy&oO~b9^4aUICCD-=4Q++Y(TyZn9~@d&CZ$RcMeK8`zHb z@_y^b+P7tIjHBl8aHA{Q+8)pW?tN#EmD@eoDuH~77CiYcck(OK+I9?^IP8oqlG^CN zdOofFh!UC_V!(!))sh^lZ^4O z=D(+pkL*%xi(O854yh7GFZ@W@5Ci>>~#cz2(?uN!LEr`hq6Ygjh2so*1` zpI7$ltjcA_hdV3@WDOe-F6g7rZ%U;}U1e~+A*IjAcz>NtV|d1*!&`&q=>3Y#nL|wN zAQsB*TFRJ_>oVScYyCCiGkL3Hh+BFP0xlg#XdXVYvM0OZ7$XulU}_j2rO@bAYcekF z3jw?rKslIuuOH;Ct^|=wC9+dy(DM=9!nt=h+o{3AJ)mgl8s9ccO$jj;C~flUDm|NM zcwq}V7654L#C1Je`?|Irt`|(hT?r})(^MJm-Dyx-ziyf(;$V z`P0KZw6czR;>*o?Q%0)-h;)?P@UZ^v zHa)Rg!cmFzhNJ$9FKy%S(>td=UTiX?R&kWTDN)(O(oPzldD2G?Vd5gn#OeU_>)F~i zRdK=sxaMFqKisc2U5jZa0UHo}*2l-*ecf}o+4XdmUpj=!6tE=UwA(-k$GYEHS-jd^ zN~m4f^MGw{W814o!TUony5cRy%t5*KFy2u<#j?T2JH5wht7H1T?K2#{HT_E5e?RF$ zVk+RCCu`@A571?(q-ltIN=aP@v-bZ0S6-~`1Z)-v&hd3D#Oaq8wLjf*bclHtGH%hB zwfKh!Cmt350M}Z^Sy{Gd8#d0m z?b~l2Jl{S40M0MQcdRVCwNi?D1&cJ2@pYJaKN!KZZSsw+_LOceie;|~@T9K3vz77P zyP4d3n|X>kcY&_5Fjls>$ip?u*!aCh{{RVW=&BXVfX3=IU=(>jt*;UGmUz%Ja3bOb z`bVoq?)N&sclrTO3wgk?1Dl)dc>M*dz3puu3=#xLqC~JFOWlS+CFExII+-am{4F44 zmuF(7pdpm^WM;-JJv=vc#oU5wjmNKAe{G4=2Mdd1%%avBtv8Ju2Y9aqF$lyx&D? zUDaazCpx@j5WBe~l~&p-OUg^q^f$G`YT6JRNNh4Fkh%W= zqm!bLt^^KsU2obw2KpFLVD73F@0HruwTSt7PWQ1uWJlv-MftxE15yLo0M_9_0E~Zm zTZ29$xWe~6fq33W=+ULRX`7S+_BJD%LJeqk0N8MieskUXzaKs-qe^6E;jtV%LocP{ z`vC%-WescE=Jo`4*Qea8gx&5a0Ju_a56|ANE<-8>wYo|Gwp}LA>DJ|{f83#dOW83Ok`w>P!%A4|vg zR<1Y5PJpe8bghp@PA@_HdFfGw6pNlB7q^`GC8bDDX9>5?Y`ii504|#e1`MPo_WC`n z$;9fM6ql2L>IrUZen-TETk84=TJ@r7AyzVjgmZ#VO5eP$Mg<52gQkaGbHLx1N{xXI80bi~ z`*i;Rqx?UY3a|$H*@j??eGI&x>nPARM&dn|Hz4!qKO?8(&xTC}?9A5Y8iw?~+EAg* zfF}80QQKkY`@bI|vmhI=u)XcVN7&=4M0s^$qeGDG=2fj(ILfK)uwOM!SlI#8+WK?a zh`R|vBp36GWuX3Ni1(Ftk+~~jam~&Jm+LtFioo}eK74hxzaii0(q!m`Ixq&|wf_JP zjTLaS9?eHR?aF1+@%tSCNJwlC2vc!$X2|&D{9GVwGq5Kv=qA}8C(-@wY?snvc$=7W zUAxYUi~j)I{&q%DP>?J+%Nt#-=y`q5yWP|1oX)~Exwg3EKbMVjs^gf)cDm!Q75$GX zLLEe8vAHLorS$rlciewh{{T*&V_|cafMU&a(A{Cjpy zlO%w;k_GxcQwn5cfUSF5L7`>`rr9I1ao0+kdDXfin6E0s$wU*f$F@r&%An-T6<_B{8z2I3SYtD#LhqvEa@3Us*r2*|;J1e5dw}v;uKrF2><)tK;3?`rZ#7VRzA;NjT_I#D=$}(C^j|s2e~6 zumc-k9-6c9q<^m8dcT(oD=S*pU^(?4TgK43Z5~cPQBVZqXO+y|OOs$?Acr=^H}8j& z^1iF~LnL#+0DV!oth`Y>e_Q+C;pmX$jnLDyAPCXv#qpS(C-}~O@2=~2)v(600wISW z7#w3^fh2J%@3}PJdyU@b`Oe0kNLa(=84$Q6aO6?I9He+Pa-@JB1YT>!b7t>hzRRqt z-(w9RSfvRl7HAv1>Tt>oy4d_l?8?f@*vO_byyT1NewPkWBwqJ~RDUY?Rv| z53BUk)lsyA1_{a9!4sX2yq;$Mg*HiKE&)YwKpZ`ce4H|JPao9!?yjeoW(q19vpR1t~*_{Xt#g%}vUd_p)rW?4VdQ?SK zi7I!d2P|-@$x0uI{?nh7cpXT3OKUjin!&5JMb0DC}|~mkz~z?va#r65g=Bo}Q=&Bw(Zsh!Ymy0AOEw&u4i-E*mdBD->v;Ol z7d2+y$Qm<{0J0tyQay}+pH9z~KF%6-QZdf@k3D1{#?52_ z`b$|UjFpvg^W=5@MIxN6g)+T}wbXpH4R^9IIdg{7h<7U@D#Q+TV02a2Q^1KGdOJVC zjO>j#-Nr3szbf}V2|W+re+u0l5wPrMXe`C867Z5WYx=12uWM#Vkwj7RCt&F*9W1@? zkC@|y6>&vXLD--TZQtVm0J;4pg%`4tn4L;kT(_)Te16_`?YT6-Fa?OVg8J#8;iO$9 zVoP2f0a$@{cU~@0{F;AP`a5o7(6S={&vAQ>Zu)v2KOd1Kh8seqgRYMD-IH#H5vz1Q z$KG!F@~zUDo_HowqRNC47W_c^2*2>@UXYL$;UTU8C(oZ%{Mz#&n;A-y_SWlrxnel@ z@Z>k2nx(|=6t6sH@e^YKzCe#leybnm<^KR84ZW9A2M_>XQ_KOZ7g9Spz&6(P0NP@f zBP5F=!zsY$b}Hl0PgUdSy|Y$ETZ^1s16&3Jk)L_*Jy#`G&c!5gv}QL8ZpA$uyq@pB zyQ!id6NQW172+eO8kRe*N57lp%X_+MHw=OBkw9hE+zrO59#Q&#?({M=NYTkB3!$)P z3>1%7D=#O#o5@QwTR$z;?|US3O5ogaIk>s#`Y711v01c6S&j6`aB#OED)O8$KP5|y z^hQ9hI=JYy#x40CH+_ckFo>%0NUOjtk@^xoad1DB_g|$iCv8VKZ88RH5OcnzPqLV9 zI$c#=8|?eL?vVzWdu zIRb5afE85mEU6PC?d0ZsXIEZQ%{1>+d zaKBKJ`1yS8nf9opFzol&P*f{6EHmWfeht^X-u>63?CiVXR7g({Hw}Bg{{W2snt52{ z`p=_wsH z&OP3ImgnzdtJdvcdwGxkRihR{H6@MBwMz~b@z;L}T?-#H+cdi*gw4W_qPmh(PLCFDRS(SlPbC4=W{S~P$ zm+cohUfq~ui)HE39As9BazaRDDEAiv-h~j_pj1_uGg{%A%51mjbbOo8N_UZIV%)&q z!@Bu!_dIZb)=91;pmQG+r{1dx+Ne= zz}pr#yJ{{+6D=$*NbcPm=s@Qe$E#fGfaW1=1?_K?Y6p|Y%6iIP%=pAkwrimx*Cn{R zKCXeKaHT_JtE#X%`z)MXA6G*R)2vp@lU65pU~a^6n6oB+KOLcI##S(du*Fv3TNcT> z=g-R0;=%4OW>d~U(Q6h?OB1QJ)aj>ommP}*xJ~8Mlh_-e-aplXhUL{&#feZ<`a|te z0mmZbUUf!4$D zP&AWIe|RP0paLZ ziLyTKTQ_L4`TZoXSsvWrTjlWdSnG{j4D!Wu8x;iVhogTf$u&7PmyOw=v^n&`>a2Zc z>P`>4sRf7!v#^awt-d1U*dAm3I?GyGij$#SaN0%KkT_)YvyaG2`alV?W;QN`Sl>KEDf ze_P(SzIB&w81*=rp!XuR8FL`s>Y)Xq6Pk z8-R;q!y#(sWH^OfM`*ca`#w^wtr%vRg^Q8R^tgIFr?L9VUQ9`5V$6BipBPa^9%oq= zFh-2pscYkd@g$LdSxR-miksx5hBpNdpzz~C5nweu$3Q&Q*;o>EAC9;Tg^suWGy?h# z(WU(UJbbbkBKI7|h&7J9-G-H^<|5tOdAr^`4Gnc9Yl!9=pUJ5lq~GN8dMDYg8UugB zL$BeZ9M?4`4R!6I4r4c64xry!FF3aiWwkdv^&(o@h0@1CZ5P<7<+wwQ$NF*A45vZO z#^izQ_bFDh51UK=Rb4j9!?1z#e>=zdN@Ba?eP2dn zeocG%`hRy?j#qnOYa*wgjIFQI{{VMvy}0S>)r1JaUz5g>I|bQY;8J zw?D{z9H^N=UM(y|i2j{J6EkW!gKBxWU#6?O_sQeWT0TO{Z11kT&ECroD2d18Olj6# zyyhT2E;;NCx;VEezCB;)Tj_l?0R0i{{e-LLfI1s!-1p|z-3p9H%m}gPu^w-=b7@jaMr+b?&47fIk{OS_kt@wiyNYRfmx2^vG z8>{uZzVx|zk+2D^`tyUdZVO#Vy_)=D z-kvxn)Azi*p7XY`Hi4gofw+}Mmg(~Gbkc*rmzT8sA2!cpY@@!6I~aR(-JfrfIA%F7 zrEPwUp0mf{^_FH9Nq{k)i?0>&?|0+-x|?bEJ7<`efr2sfj!lodlZ$ut6H_n9*143} z<>3SoBz=T)ul*kHz;X~BB~YwXUJ46=f@ zA%?7{pHD~A@%TZP&H=uF^RG93dzCe?iCh7DkaPr_YC2!bqAsLr0bbBFsG2DwI-3D* z-jMF}s~2i(THue63lD9VruArde63Bha@LTO8)~k3`T7s&NnV1;jy1CaH9X*tYtu^b zH<)d+)JIvR+8rb~PmtiMYwP zn6mEvvM(+5ohWDcCR;XL?lh56tjcn3E*?T{=d9|hQz*s8k+6Iq`#hl}W4R-n+97cM zMSTQQLn5<7Dwaz&L!0PvaBsYP{{S-kE9r1n_9#k`aehzT?|$?Aso8d>i<5Op7c7iz z!nt48u2MtBoTOb`pSH^xd5aro747X(q_P(X5=#x`yUpqTMI>LC>HBH>R4~CBJZ0IX z8LU{6>W$p;QX*&TX)(u_6PsVV=K6dU%EwqB_O>7rd58Dt(Z>)szM7tvUQ(KfhT#(% z=vdt8`CdO}9iGa@_rAFNmaLRzioMi~uV%iA-)G6oL|qNb#sTcG(&mNKsq^2S`hPmJ zRzOq{Z^7fo;Ltf7gNZ?JkpBRFxr*b;uyrQ**q`K5GIhrGu;;ir4{5lo3&?H_wdX6P z=+QkR*|Nwm9iagzzsaYC4sVbnwCHLTkSiR}J^d^H03q_Ht3t$DJP%;z{#F(UK|WR0&Btr_!V#C~rqDZs09HP2|${{YIYWEUA&e41rc%S%E#X4FPI=52c{ z57Ewmqspb8K+HRg$0PJswRd`pwuZ!LeX3_BIn?~-)b}kXjLiwWExY*_Mf*w6qFD%h zq)_%Z;ov5mLbqrZ=f0MpaEJ>ugj(d<*0343OW31pOV{C8k34QI&tZ_G{{ZCnGtSmO zD%X`ERKD?uHp;}?hojt~GA;z>sG`7}Y!@TDk2iMpo;>>QXHn@%@Pty5fmNdh&%8Ji zzOn?xvr6dU*kuf3lHl_K#Y)?y?bnGm+Z0QXVCn`yjX?w~oSS$0b&t1f*}A#|y9jS_ z$GpQQkJyhozeTw{zNde)!x#@DfCYuk^zPE++xB!|r1ZKr=VNT}MfwY0Dt-}w1_+iU z6%C;x;~3=2-g!Ed{35~Pkwz=C4v%B9b&@v^SyyvAkExg8QIU=i6|}i+TaPWo%k#Au_(DzAC;)x)PIBf7410*Uey*UChBRF*V1Qp< z6aD(Zb0X?}B6ZT;@2ifX-A@}+ot!{UirHj&?{6ts%vcMOS;*oh=ZC#WEEonh#DZ)z z*4X&*eMT2S&JhbS{{YIMGWi!SE^YYiaG@ENHv^i$e4o(hQy3_@?Y+*O;lcd+6vQYX z>P~ud@bswIU;&NJh_)8_M+{ZRB~qO1bRSElL{!?!FLQI+E!<`Ln&q*&vAUqW#l|^x zoz(IA4WNrISOK9St;|o+@#O1)SGS$>Y&kY|$0E zMCi6G&3*>oZs_Kj1v+B3vBtpIS)bk_yHQ5Q!{tjho$iWo(Qeeah;?c>Te@ET7^2De zKK79#HtP|o%F4E}ADiC(vadWLA##KgM+nDGdp8#aGkjb~b-k^pdfmkCc#^O=QcxQ- zdMH2M>B-0H{{Zqrw+~B-b=qh_mOxSsRsracD)& z^&=laJN3IhkvlK|$_j(eir@S6h}#(M#@L$x0sga-fX0UMMhfg z4xxa(z~|{um);PBWLrxSanIw&Hn82CND9E@LE*LjpDW|l^X~5QjBdn?fCEoQi8{yS ze0kM;nhq#C8F{a#4e-nP6sCz}=yH^Dk<7u!dmKn!+*!C&$LlSu$|K1S zcmtpZ=wre9+RXE`K-EAzSe`L!Sd6$he5EzfNR2=MwZOK5&DJjS{vL(KIT3ZcUqizf z>3=Q$v-_-xL%0m!E+p%r$I4;D%KmOh{{V|`v`nzb$Q=P>7F>{?BhYNJt;L1)xdO#+ zYcn3}dbFzJaiB#|hQUIT%nJ;cdBB{1TOa*1Ns<;;;p|nhBzwl|e>LD6se>kkg~=Xp zc&t4L=x-T&@pTcmi3>j{@2F| zkpj|;Zl}(+Z7t>V>$9TT0^-(Hb?-67G4!zAFIz^j!hj5#OuMPV8vg)Gp%r4SqSKfW zkZXDQJ@1e6E%ri;Kp=x*tCnjIr;BxeS4So!0HIuhA5t}6cUdHIN~N<0D`d9P!==5G zA5X`bIjcUUK<tE!uzo)=F(6z~aOM!r&VbW%8}n@KX*T18hLlI%{va z@|3vWXihPdp`5l%jIfGTV;ttu-ZUZ?PcLN9sSa;Deq;E zl&PeGcM8@qa zzVK=^OqxFS_s8UFNG4E{J}kU!MbTY7RE-bypZ=rIueADVe*XY}zBMM)MmB;~B;ALU zi5Qc@*l~Pfe0;n)ZzYGy%Th%MX(T2^V!)m0=-V!_@#Crr?5c&L2LvFMCg(!k>^VL9 zZuw2$tlm<0@S(|4J?Xn*bGBFVd%RAly7^|oD*Z0Z8`-(50p?<>RkxcnZw z*XGol@=n9DBnlXWGi#fV+{Yt&zOJs0HnS-t3-q*8uab?|{{Ura78;DNVxW#aS8@uZEO^2uYGuYe;As5GAu_pUSauuhJU-IZIY``U>)!p zk;AW`@i&j)hOtgDVJ3hRlV#H5SXZILS=ie)t*y++(2$`)2k}U^)^n7bO0c|?h2b*CC7u4{nxyqtvV*(NbGsyv15Cxk70f1S?>Pwx!oC<-&7YM+m(o) z>w3Cmvd0>;TFY=i(^WkQB**&7k+!JJw?^G!G>{A7ko9c5ryV15duu-GK@>=>3lLRQ z;OM@y)YG+2VQ_>TyNtK#ufD0$Nxm*l&@58cQZYwE-`*+B$7vEyIpeA?s4)@m-v*8iIOytTZsghScy{D6UuRW&mSY^B1mzWF$0s4 zLB&8mqD*7t{{XE09%Rlobeb#_5LDj4F;zG|TzzLL)$+f0M7PF`IwenxO4r6T3zOV| z&7T7{Zy!tRE75`=AVpau5gscGt2O$rF19%QeBM^i2GFu)HPj18$P}=HQQK`f`;@>y0ZF$wvs(Ia@hvQ& zQ>BXh;M?ugh7G$pyxjC96xg5vmi9Lw>N>ww_1n+%8|m`-_0ouQ6LHOQPK&2Y=v0cE zWZC5en?%^yN;qtNX6bo9@qe4-Z95j{jzN$-(g0Lg4lWhy{oQ-YNH?~>wR@YlezLH} zpyrOsz>Js4$iAj{_>;$!>}9c=i{^qdc3;jd!bJk-ZAK(%@|}CN?d=-?AT&HDy%#}P z`;F%(_jTmS8^*<5P6{dRn%4?_y3{uaoPC}D}nbWd$9T&N%tIqhvKnU?Eu=k56iSy4EV7>5=@d=YEi<7Y*{>-IXfUuzxV-9$~F9i7Lu9zdg^Jk17brG6imlCfBuWY{2R@9M;yr zp6~w3QSnc@k0HftmS1e~tI{&a18WjUC+TrlMp)b(Qr1)F@%NRxwg*na3Xak$MOxzY zK=W1rKwxw?H~H%?WAkfJ5a;W0j0Q)N29?kem-wFHU9u*>u;sG z`w!u%DU)T|AqAayRN08tPqpLyMlcxph}PHsYf;6wPZ+jYAnak5o`!vA@;}LiV4mB6 zBADu8CxQN=Unk>}^9Zxg#i%v513)yOB!(!!h8*g7{CN6I8jU%qAp~<1_&aDA3*3uq zb8nKFVk||p`9GyfCaZya}Ow)e^*Qx5(UPh;CnqEw+|ITpmqu|^ZfeIfNx!JG&+2z&p|T(0EV=W#X+Ki zPQ2#-0Q9c&^>jgj0FGlyb0{NU!&K}t#;0_{7j9?t@#wBy$BL)sI`dGqi2Wdr!`*6= z(UW7@X$&f>P4$q&D8nq7{+q?h= zn#6xXg%Ac7I<|oa7DV)<;9gR!U4v{yIdcS%WTzv~%Ck z{EDQgt@v}s^(j}P0UWkb^E8iU$BDwe0Ds-Oh~aO;YmJv-R9tFB1R&g-7` z`PID~zjxLD07cG-Kq11Oh`97l?Wt0X4WFC~@Sbn9-@EW(rlWSR4f_Go{*Ddf_-81F z>2c3+4lVSnw~E0Tjz#(`Ev3d?@9QX)MXU!6?lsVQO~=O{)=B=&{4EU|y^dtZXh32j z+J8HD{bWDsB>w;g7{>c%JH$N1bKBeX)j>wJ`K~%Z5nwypzwr4#;fg0h3~YF<(eHhZ zIHjYQEcr;ASW*&u+-M(|j~+f-0|yxf4Svzg*7cOO5YDA#0T|rmfRUHJ+2rM7`?{av z+e`;}3*J7$ykh>1DqB3O0bM{M!u=DbzVFrjUx6I^B!cBmL$n^1R~xedya9a{R=Th~ z+@F-E+e;Dtl{KxrVyrQ_vk_|%VeE0!5$&UPXfMz7D*1L)lU|E<9?!{ojpLRwQG$_u zG{ugd-7MK+#F#pQ4-g<5i~j&3o3!TwN3nv$$^2U>hl$sr;o&6e3()?fVSXw=^#hK#zWVOy2G7sfWj zsCFQHg5PSXj#&0M%FAIPJq0_)Z>1lTspi^VTU@L$o+V3;i1n3Jk)46#voW`ZtXpszTOz}(@jfouCc z`az89NDXSa+!4=jlbrtmCV?>;wxsh4X4eF!&#r+%J(|((BxA-E2e1Xl)2BS~0+%Ar zN~ty-60zY@KfiC_kT*T{3*pt!bJ+ecB-pH-<;mmw#5otQJ_CH zX-q|jIRFmPmrYaKt5J)U`8B4Q#_imZ8sLgfmf&UnZ3a!pb}S#B{&DG0t@0MrEp^t` zu50e{2HNGXnDiG9-cq*ZTigS}ExTWnzx=u@Tmo9xu&~!l((rUr4u584#gWgkh0mOKBjSHfxgt90d)-Ow z;p6+gu$ZX!lXbsB0<$&V-F#mI&V27uba{{SfIQ&#s8^RpWw$0_@yG z99Kh&8)WVJx;v3Lwn0xaki=`w2T7@TqTIt(#DCF$tA`mjn$2{w-)mS?s4wu zVUV%{LBBD`9%Io(?s3(>_CRL13c}fHK45oi`4o6rtkyXZYx=e6Ht*@t!~k{;s;L_3 zd~Jmsj8y@*5`x)$ozLb`a`b}eB*d}dCtY>n=zd-lGTcPOt%w%`n;VR{@%lB&SSTH* zOLT}peU+`eWm#W{+V{Sv)&6hYf1$Id8zhUTI+A}6c9Iw>HM>E&^MUHRK0k|9F0A_6 z)@B1$4bV?dq1RGkERkZPgpv)k1&4Q|!O8ganWXCgNoKi0VWs+9Tr}mA5ZuJx*0q<2 za=-g&0}~QEo7@l<$NRi;c={sDak~IC*UU0LI=$z-t6L%h%9rA88ukQ@dR2kAt!+Rg z{JzF@tUzm zg}$JZa0ZN#rEm4kgdW1)0*9j!^7PEww9a6WmOpl7PtWOo|X%T z2alx0IGC0?01b)e`dDz&n$k=O<{FXCbjh3P&_}bC%~-ZfY^KFSD)-@)75d5mL1`J3 zf}++uFMf(Ievdv7dd#qTSZ>C>gj4rpj7ExYEDDZRKcmL8HOusu>s1T$(!tc zYK(}C>m+tokDrHE$IA4#Rc}-B_F66lL{Iy8i$V`O18XBPnnIu`0Y)_jC0>td6KyjOHgH$v3kwP*7Xd`kG*~ffR&l z3umF>Tw90A^1qy+Cq!4Uwu2%~^j@+(o61M;Jf~S(k>0bAro&zyVh>%0UO$@Dpk*p6 z+cjH)7;ub^UzU^r7DY>pZgTkm{)H5B1jGg_&RZkG_w?V@{ZwDj*+QdI#XOW+R- zoL)IU{hzn`zZMxAlE_?z8ivyM;O~7O9lU^x z1;hXjOnAvQ%f!A8r-4}3XTCZcb}iU_HuUSIIIcYCTHxC*ZsrtmA$ALbyD=rVWmtY+ z7a`;3&8{YmbKc|+aZe}3!pj`$y*T`>Gsc=;6y+8mgXrh?^H6UWMJAJoUZm-uSp&@l>G8;a%9?Qf*dF5~P;C1(Es z?m#Odwq`PWdR-;xdQ`cuk2b7EKA|7 zkhxxOM-A_f(B2T@j{(FIH&I~2iE%t!7-7m&0b<23lH6kn)D_{&Kk+G2SZJ%l2ra}I zpHD8cp!7$Dcc6g0=Hl9qwOG8TX6qyCWy@QO;@&)YwqhB|?($m(2)sG4>2!D~Gfm-k z8XE^aELpl9j`mOelmiwa8$|-auhBQ^i!C)v=VVz`c}N6po7Go`oYK zfUCbt$=bz_tmL}CcUnm@7He6oRDi=P_o3nO`x;;KPy2ODa};GPncJdKRgjMdlv7=@?Y5g0X%&{%Chd}6 zlWaH4*~3I-g;}G{=@9}qT$CL>WPL>6tE3=|v}&R?WdYY5ofYKiq1#;dmNq0TSX%biW@e9CpTpZczQdegYRCb42a5t}*@b&0eB{;l7+hmSOb4$uy3yD{^~TG_L&C!Zz9 z)&Bt8`DWU-mc+Io;4xNZ2I@}^?W5&l^;F&4WLQ`m5Xx+h*X2BUUnTSQQ5FLccIkuA zEzwzb`7hN;412$<^m*sMJNvWwS+VDDoFJUW5UUVl>b~o{ zj}y=6FA3s}gRlxz5*V|bH)q#S2mq5Jj#0_cx$-$xTcWZ=l!9IRN zb;|zVOo3~8zD<6Ob(e0K-WV(&v;aC@ZtY99&a4_B%DMpLjWBO31N#)sqC$XL*J67O z$Q>l{{{WAg>`xax`S0EI5m9|Qsim=9=AaR!n)k5uc>9#n6*%k|jSG&|zF0@)C=e=~ zoCZ^(=5^Pjva2TNd0g?u2@S&-?;D(T=pwl|XH4rtd%81#8r;jE8hd`lCk0X$ zb}ns+2eh2H8^|a-QOZ@o0T8Gi;B+Gu>(RRWJb3^U zLF~S+iCpaQBa*N?LW{)GHnrl>n);rCG}Rc%$QRkHNwo%fS_ z&N+I|(dfoyYzq;n<^{4n9QX1!CjRrvSBR0m&V+_(Av)g9m$wh0yZ->!JY`&DyhAJ5Q>iPh?P71;^?#SvN&bYPLC9EK>*K#(L_Y}a5$&bb7RJ|vbPp(O{}fk zaqT|sGoV-F#GBZG&55;}rNsIDYJLYp+tX}+tEocr zgI;|Y>CP+Zcvumoho4%HX!trAm^eF^9rXaO4;TAFJ958U)_+G^XSL(vNGcANwYj;D zJt|9y8BN8(IO*~Z3a@B_SY+b>79<82vA3a@?xf|vg0E|-Eot)-7S=l+!#Z^8vYC!8NK6lU2I*>Haqb%8eR6`p?@)wC8C;Zb|vPs3I0 zL|;%ZZr{nNZf7e^n81!T8u+(1K{JCBP7tc-fmev0ZV=wF)~fiNjJc>+2el8$==kL7Rge+pG}z>1H%2S-RF7Sja<{|q z9UJl~GDgDcK~rKi)1T3+7jBR!E`fpOVWsOIZFbNq+st_X0G)+ePdymOr9m-1gdGXS z{4#NHa=x#urZ76&v+-82Rq4;+A2jI zPYsQz7-~hRuAl>R<|;*q^jSEl4!Yj=)5bL-+zV;=Xh6Sc(_MMK)focUCd=nPCyyPz zcF<3pz_U+(M6 zl^qx?g3W7`9Ow^0Umq#s<`OTc=-WzMyQ|uLzRJ%TD#ARJ-ozZYcvngAps@~>aQ++vS}?2#!$)jp~owDhM2ko8nQx|KN$N}hpxAWZRQ!r9*YlG&lxX^Q2#D6#Wm1%CAHi;}7 zBcUulMq|8YU$;hfnO(^(qg=2&+)KM`YEvf;o_kScu7^*Tup#)xsQK_u*(=?In;A7H@#fyyTXzX zS~3-}(<}5L(NF;Y z0J%~|BobQB8RuhWwj_7>bNafaCUC6?yPJYeIfY+)8uF>U@$z4(EXaUd96-N6^Qns> zGJ@bRW#O^M?-^j!JbaAVUt%yQk=Yy&7@%t|T~j{!f7&#XrZj|;XKnb}dbN!xEH{GX=4(5BGLeuQw^?Nd-GM;XS zc0b6dB=|_NaTs7EKm!A!vn08HE=a4-EFJ84*jY)vv0HGb*us?A-vBjHs*Zjm%6_h$ ziou?L8ElDUO| zzc|&G`l`=x-8NaJl;rFUhzo58tdT!CN*6i13y@285=(z_hCsoB0AApdN2TfI-d$A#@ukYxu^~%oYkwk*fHIRX#F7aGPp}3%+7&rCIwry_M^H_CD#pC)i6Uc7LZZTj- zm?xp1kL@>pC1JyRegHW5{C|*W-4ID*#4UWU#=YMkuw0jsR^%I%IcPQajH;eLB++sL z(&S563zAK_`Z*;s!C6UdP}~r2pc%i92P6S5c&fSX&?h*yX}_zoSM+7C77> z#MocnXT$G3KlBL$Ne%Jj)^00v;xB+)cs_*;(r1j8BO}kE#t7)i(%_N&e2$Q@ z#Y!8MP`WAB$~>!kbhgN3;1()FAQrItG(te6Iuc65Ft(Hh-yZz$B&xj2hrhS#~V?l1ECs#I-d4c)*FJn9oI@&r zc&u0u2ikmn2;$@qah4X4nl<6kJxcjg1y`oC|-%5u9Ui#bwGhyW<|=%(?X2z-5<+310zDZ?rR zEqqqM4t|Q^`npP$-jKmw-HbB+NAD_zDDgaqEL2-_a6s1#u*-Yn@>87O-v^qTP8cKUw<) z@g`)rr@P)(%POd_Rlb=EVe~pX&&~CfthRwHiX^#B6c7S*@_#6eAN8NSv?}IC1T}{>`M#I8S8dWz(!6@IgRSltAR5$4 zgaG7#TT3?+3wm44=cP!xD%o1*;M|iPiR!s0@%(62#w6tdVA5DIWz>k%qshr9tp5PF z&uK=;3zdk*WQlozV!HdRD+y@S<%nW#v8R}N`9G`h0i{c`YbydcLs?jNQ;_aQ;$JaH z-c>s( zF@-M@Iy->O&Ag9;9$F=iIn!oMn?6(4`8d5OjfT$<3U+oERVqUaRk*ivNRVR`-7VKc z*Ka7cX=S>Q<(WoATp0i-lxbFVc-}AhZQoZf=}f0$S(6N+c{rHA3e4k&eaNRFEWdH@ zeo9%Qk%r4AgxP>V*(9fiD4@w=-TY^(n|&->C##dt5Y|>waA42UNgSI>Jw7k%dbudK zitMCuIR?X`&3;cUszma9UOOg|Inxwy6kr$v7Zm{RVVgLbP2@a0hU;kEe(7F4~jBWUZTOBoEn6q7m zIp)jBc>E4n=9mYa@I>PVN!)Gt06b$d#e8c_BWijDe8Ao z{&SOVR%Lw4`ZITVbbCx{phB-8VoioOyLTnccfLOu5|Yea<_RiTE$phyev0w#Jg2Rx z_LwCz6w7$(LV~tXmiGLQ-hZsP0_9#zRGa6kD<5fi`}Ijbf4|@V0E%zgRstxY*x~a+8DipIuxO!f$MIxL0Q5)2rUa)>b=b=Ti}-#Z}pG z98sa@zF$YvrUqvuJYyHNtSvWu9Dq``fl|?qcO8HwZmj`k>d}2bw6wda` znHNs?N?qph{H<uR2Hqg)eo`F*dkUo)JVt3;u% zbtIL@`iSI@1o8P+Y@1||6%3>*@jPVdi5YN9FIW4yKW60kZe{6K*34PeTLMQB>2Y6q zy8FEPx|og?zY!YW&S=Stg`;eTF-8SsR_XGvc~7tK0GpBrJ^t6^FHiM#Skqd>Vn&1T z6}=32TZEenf*SdKHr`ii4;F{A{y$gXlB*Im7xs;Szjrw(B38Y@2L&BI&r5eer04?J zw>t~mQIXe{1uKA&&L>CHiV5=DV)TV5S6mlY-A0>BbMwo7{_ z{dDqCvgS88Jofq?g<`CDG0}%OI_fBF&_42i+5C|U|6v_7hCjIyPhNZSL}16D2a1&HBdFs4q%jPD9-v6$d)%YxzFo* z%su5+s754%VnNGJpR)cxu`tZ6LC8syXhFa=hY~Ey>HXjKKKk?7miHVws6K}tdyY?( z@Uu-e7Sn`|YjC&yuOBds0my~0vDDdYyj+`iJqkQ-x3K_RHkM=P^s|mGmo4i50OCEN z;TVE78f(lu+#mM;0MbE)q(W8Pu*{VXEHNj6S>KeBo~%i?HIFs%8vg)^QKO>YT zR~l!jH?V6A7Z3S=Sx6+>z)6z#uDm$gpsW5L-gTV&8t(uSF_z7ck+fiLy$m`OXLc?w z7@^Z!;>V-)Z}q>M^C^+qLKL$w)RABc+p3fNN`c$QusT>=-Y$@uE&i!in1;UtWk zTY6u1^)gM10NNaH!g4mZ50AY0`B@3HM^kM(HSWVB`G5U&BaBar2P0@lJU37*r`9TR zf7Vnl@O7oUS3Jd#--K~4aA)Xy&-AS%Ym%VYbE&>a203J@=y>@avaG|lIdj4UM~`Hfvu&KFb{IX$GU1rj2iANAi4rTkg3k@@Cj1 z+3_Op8fsZas)N?<_rI&&$6A&RS+p(_C1AJ`P3_Ed^l4z&5s9qg&qPI*qEHL7*_G_9~A-LxfaFW67ws2V$Kj;F_(g-OPW`k%roPQ+*im%U>P515UA z4R2-q+E(-0`i{{`2?JBx%syJ(EIUn&>qxz=;82sV$<0z0TK?KAjec4SX|HEJaR38r z>ry)W*5?#x>bA%Ot?lF6@#dcT`TW#b4&HuYuDo0{C#ej}nFCC%<~&{h0Oh4o6RVZs zxNe9h++0guEN)eUFJ`~aT6%3oHg=42fo1z$Wlc);Zgf2%VXtnSJsNVGjIYct2Ydej zRX}!?0rBRE)d1$=PTfD5 zOU$UYJ)V2w)xeAMkzzFS-__93g8bw0^ITV4>DyW_mkz1ast9$HFP7SM<|aI zs{Ln;V92IupEa`)pjPO=rjk=7XyrG8HdZ>kV^tm9^i+fjSO;Q;`ivN}FQ$v*`@WPr zCd)R@c$}8G3ofeaeO!qi@AGvhn{p4zH1kocmeA4dDJse%l!C~%1$)csG@i5TDt&K% zlkikkD~8l-T1f}b^7aayN$k>%ijk@3{i1|!Yi@s&Po9|^#R08HHNiu7o;`|?zNqGvRIU>jn*48$*Toq|-(&qfy^O4W=v#F){ zK_Z)Xe1d_f)B}I;>(APDviC^zT?r#Y&EBZar{&OoDs;um;MS7oHq^5c#PgPJonUJl z9S1!w_Fvgurk6U8=8h_K!G`xX=Dq6!YJC1Ppu2t(Yf%%Sw_xY~bqHasaXq}&uD$;N zK)nv6oj;E%yd z&~MHxM!)pdf(EzjzP-KL;6T@~^Izrh^5PdVr-M&5_KuOKlB&C$?)lfgv9bAh4l?5yvT zyq>iG01ZhyWXLy)mH-|Q;OD;L^yB2@@n#=}h*v2Zpt^;y@`tQ>KC8(1w}zl!kDl86 z2B~TCXwM(Euvn+=GvUF{bhE^ zwMhdyApo0_J*!~2Qekcjoi2OC_Hed6UytO97T*vqYz@bvJ)EeAllF;ac-3C|0s+*V z0pOu~zCM#OjWp)(P&;LggxmmD_ct9cr(CuyNYKVf(_KixIDCI?R2G*EF2$QO#mq>@ z3~MiG2g`DN+79Y^s^#Lfg^3*Jv(w$KDMcD_f~rYDa(Rz8qrsn#vH|e1z#R)-Lw}bo z-{&hc4D9S|OMzkNCn(wDT!3I)+Tz&;(-Z0E{aq~`c|`=+aWhx}r?kz%LNF-2nR)G} zbJ11&ss&&s%J#rq{GH633g#rDm29pCL;=#b>p!dhaHfI42b4A>@ZV1JcRl2O0~1x9 zq`qoE)FtjkQ3dzxT_It90xwru6ZPLpQJUo2Y?B?jrVp)yY5vLD9 zGWvcz=2KAWs+Er%oQvFCGK&yx)%vL< z`l?ixjSB$Wri{&wq>m@X$Mrr1FQcZ(!o1xrEzO|d!SVDlr3S{^vtY*K8}x^>?a?v1 zhFhGp>0eK4?0FgjdHlzn&kdA!gizQ>bAK6@dJBX zi}RE&_{N}ik$!O2UKR1+Vha-^sSHhp#eh!p_w(LWG4hx>S!_k_)&%*!&!J9`sA8z3 zf~7(Z8~#tMso-l&a%*FLZ|26>uaDw%*)SOu5m$tHD)BC@?s?Je>p^=qtN?QwR z5mYRBN$&mCE>{fpz3g=&BV&w=NaC`Ad_liQj0$s>!B!^g}j?AV|tL}Sj{S-mK@e16?Y!cr~^u>>7$ZCjuB ze)5!(E-XM9Xb2>db>LnaL-Nu1IYD+fq${?Q0y1AKo|6@$%HpyM>GAt?Di%pt z5>n?@cC$EB&|V$O{&#m&?Jkzd*);)iYh=Rymrw6Ean^0-M2(bYU34~EG1s%D!Q=Qp z5S&N=3&KtLMv=Nd;!tAJJ}Gw$dOl^pvxHxJ%GNtsn#;xs!HL8#(N6rFKlc31LTN0A zFdGxasmmW%!@-?uu_Q1#WuOFN2pt@w@8Lwa*+vY5hRk_}oII?tq*8tF`V@;dNtVF3 zJFw>ig?M@%=aj775?U58-GNq9lBKRJ4=cmu=n@oY1c|E_H!+))OZS#ehuuj|VC7dP z04M~tnVEZ>TfPoe=+$F&-INOv?gGrYbhr7cVZZv$N-BG}MWG?MPFb z803~Ghlej?+ueSum5ga;2NbyrAW-&s^C97U9q^f<83?yEL! z6ypJksZviD73{w2KVr8%oT`Uz18@z?pRA5PsVEy0t%EVPBp!9Q3p-2wTy)r$GZM3? zXKszO949nfnS1ZnZe#xdEqI)vLm^mWDjiT9)LW#-%0CxZyY+wSv{k5+kK2*y%{ z?JKSf-kW>rJ&Neth>*oqx>${E>2jGqZ_tjGimC?oQ85FB0Ld-rq_S>b>o6&d= zp$Ih>BP;ZGi6#;M09{VF3S{3n$}&|6AcEG6+sUkNHxDAzfL%p537@X_f>0+FGes(?zXr(86Q?+-QTOO z<4Cu71y=8Lu5a%snkH#ffszw&#>0hs?dtknUn#txpEGiX>NNnKZ?xI+l6fvpkAO+s zjvRHd1^9)>dxC=zlLDr~-3}c(RiWiBz(`0SjxH;DUEWU*A0Vif+Q_($eND3TaOg=T zZ$xXo3`r)|C9vl}e4g>C=0sEqqd5pUur>?mBx=8{q%^kxWs|RTdS62w-__HJWZBfT zIF>9dImNuRnRw3CGNYXPE zL3Fl6optXbdUQ_^363ny-WSCR`YjY|%JKW}^SIi7|IzXbxmPVAEv1$+M?XR2_3`-m zmdP3fd&#b;s4)UOwm9&8O;QR<`@5vFGU|(oF4y3O4n5!9f3odcp0nj_(7w*aalytV z;RF=7*DRAqyZ2j9CGI?_Q5N_@2wevhBOZv(Ul_c6^XjvLUMnJ~+`YN(F zbCaqhk|La_n`dV^tbv4!D_K{ElaG`AU1TF=DjqGY$Fys^N5o@dtI9ew{XePYeNUw) zAlf#2#`&rS4tnp^$efl;C4N=@um1oLmu}f+onCe^wT9)@cYB!!dvx2~=Fj&30Lb=zADj@#gbOH9 z#d40!d;3ltoT>GkwA_^$+6gz1uQ6ihLiX=(BmV&Sd1`1$Tz6$;1^BKQN(&B4EuSpX z2Yyy9=jwY&-NE>|Rt$#NaLFNd;Z&Qb>vH8eZmaa-%)MC5*u>EEppBgCi30pr&v`7} zIXM#Y1XWYxDkXug|oH!^Q1WC59%Xmb3- z$!jgUoSO6_&a!z-k1Lq|EqgvXvMU~z?}7P0-FZJ{d-%hv^G9;cd%jK3-bsXin(>Y` z@&3Rqu^BHSx{S4EeQ@+Rz2!e<_&#Sqn@okh*8s=kI!XrwhWM2BF!KIBZO-ZB$1cWpbvcnJAjVb6RAy$!XLz=l`|$8@ zX>%h%7(=s|Sa8VwEl1RMAFT1Op(OFf6qiG@KMTvLD%~$d9#|v)0Emxw$w}6Hk;$~| z#NH&csmFO4Snu@GK0|1S%I`Lhi|+krnc&$FGup_&D}JBQsx~s>RmH7sE;*HN z-cMuJREKJnpE(-r1~=yb4kVM_biVSKCY4HS++81tVgCy(tX6`Hh|m@hPBn>go; zGWm6lI;kZ0R5Ql0g}KpyE0EwM^b#qPp%qUY%B~DVf(Uh65N6{3&al# zwVNI1yz<|w&)0nI9HXf47l76r2eJq+evMdw~jmZRqz(3w4OuPB=b9ilV zp5(iC!4t>0fmdfOu4OOh0D3vz96IlkR+wT?RW6VJ`w ztd32dX%$_Z1HibEkXvQKyq729PE(k9H3UY=(W$UuWd^ZvZdhq+@R9f{Q~{_rzJsuP zRJ0=6YI`lkf3-vi2KOZP4%>f2$KlS4lLGeu6$a$5J*oS~$Z4lCuSGx|U-Rn`cySo`oEOTop{37G)db}#o1G-428{NjCk2btcv<}LSt#C10+;gS989H>;3>W~V zRG!hr#nI$js^G4DomJZQ3I)Nzqz3)mm&9~i>V$D@>ta`##l(%>2g*OO7RtH}5NY%1lsg89b_+vkM%H8)<$xvBP=pylB*IqDcd+L_?7QOALVNyl>R{AMD@O zc5bCh8zUm{g*O3@8CADK=zFgv`)Lxou6lqmzL(bcCnr%&Cf2twqeN6_%;}D!p^q#7 z0CI@5g(D3f?$!*;Eu4luWDKb{mE|P*zE5@8lCk!^ofHl4FtW9^#GZ@4Sn<~E`$|p) zi3DkP?wk5;$NnGfl{Q_Ruw)Fz#^t$IVtd(0-b39-{#D0}YzYb(cyFsSU;4>ERr?wF z?6d~HChR%bTT}MFg>wOIbRC&?c|J}1{VuAvQ=o1xYpshjd0F2b)(`vN_3oW=yqG0{+d5e4|m|TTPt5tu7KSAJ}(_*BncvMyF%^G z0nv5gnrxT#U#z21BR>idE)CB%N6>WO=y^!@o>rvNHNDPVilmRXhrIFf9fhn0J-QC{ zA#VIPI@=~v1C!I=dJ-qTh($NI?XM&xzp9M(7N^u6UOLPCoS zlvv`|weRTD1-qmvAbFcs6~~142M1!-ad?epA)|0MjIyF$I+3E^gN( z-=!8I z?@Ozb*7-l!!U=3i#Z>v*0!J1RaLpt;zbh|=LfGdkVJ5l)JlglXf2-~a-3Vjb^Vg7t z5m|1eYTYdXIPoNy`2Dp6Gv>?eh-FBnJpB$fWd5zr10?j9<#vk z_J}5dv6K}jqIuQ8?C@NIJrpU{O|Dd zNSAo3>#BgD9v*MWxBk3To|bgQIQzX&*~^NfzYJ|7>z zt)Eu!`cA0yCp7 zgwVuE)o*Zaxb{;dVbfv6!%@s`I#dGfaj53fvABMFT+u5Ff#(e6|NE(!Bo zkwmt=?Qnl(xae$0;1#}WY0pw{9Ke1vS?430pYT?JrHybpgV=vyq`RL^YSYbr&dLyV z0=o7F{hCw<+hBV&p(4Qf>#TA&^Iz~*_R^bfap&x;AeOM~tp`$hv=xeNbI;pS205AgC%3Ys^>$6DGdnwBo8nJ7#Zh>FlTDLq8~7tO1?1u#s-7}3btIXu<_Q}S@o*L8aFZ7@8OHNaH}N72W+-%hGq+bu{VO?k&l zE;{+3TE#|ybNr&K8+OpvQr8^79)62O`?|DdZ?D+4G{8$RNfK3DRhLvJCuMxHsUc^?#mPyhf{FP8cD|<^Y@uWwK!N5l z5sMbdTR)6U=<&z?@4MwX-Pe@%tsdon7}@shHh9hrsS3)iY(0r{&+hBcff*uIV4=XV zz+#v7Bd;KB8z#}SHs&?p1=X*(=!o`=x3RABE=FIgrLV%9P)RkiKn6}1F=k!&UF=8h zmOxoYmw+|^F^_3ynBt*@nh%Y_CJl;HcrpT7niDovrzD#YGk0(`p z9Sj>raySJewSeInce7>x01|?gq(mWIfFx;t@b+siIrOpZAIYy$cs>t?FtQQ}8WM7m z?d4TE68p;X+YsjilAQzNj$e1=>N{?ozk|<1ujBl50#IGn)+k1x-<)OR#e#V;g^l#J z&4>o~Rox_>H+5(jScCciKhmVM%W|UE9OKNl=zd*z?Y6ioQu-(a6MyCS{=^-j@?(~- z&VFt$9xUmbo7I;NT= z$^I3DjX@os^wP<{G0<~YT)eZW0%%s{vK+$tS7&86()a%W zrj+Jw3BI+Yme=PWqJh^#+DHfSc=6lIkZX1ec6sTDeCng+6i4O!b(a9+*fgxMAP=7O zRNSsj^wP31k~25RZ(0{MxYt~NO*6@b0@)HR{F=^LXz0I4^8zhQjN^L^KOTdMxr2I*VBij^_!OA(_>o}EOZ=9N1T?%j_YL8QbV6>4{ESoVK%S^ z_CCls@aDc#$I61rv4T9sdC9%c66NFjG0nR=P*fFcuZcR3qx6220bnEJu>j+)y7@P( zByR7hQF1`@H`AqXv&S9uVhdloql@zBk!e=&D&wKI3H$<+E!Dxe0f5kHeScZw@B_C> z_AQfN@#DvqK0iUr(nZ$8MkCE`o_@!VmXn>WImQXYs-9we=f{rCk%z^?k~9FD@nYT2 zw?jTR;}>h&5^c}i;%4#sGEyo`t;~n@i4I@ zf?G{5L#7HDta;7&M~~J~hs|7;&x1qe{)s6Kka>?J1^PVqh9TnC+FMACK%igD6$18I$IT+i~ zspF6sZUM0!x3b%(#=Xq{06F~Pl1>}kfuXs-XBXamUxme-roh050^}e(`e=AxrCAAN ziIiUuKwFJ{Key29JRDn*BO8?@SP*UWvg+uAAR1oe>YVrcD_4X8DJvVUOrYGBI_7=n ztgT#Bb7UBnKFXYWIC)8FG8M?EIJaO^ac-99CUyMnG!;0D4J^b9u~X4q&)EWYbx~M1rj{LEqj7ZzKW-~!n~Odf&+;4do64JLY*Z;6T97TMzoRd!@_sa9yJDioJ>M|x zYw_XA{{UAX{V3WR*%nZs>JGMRcNu(tZH(m$Oa_1_qAoIH(_SBLu(GUHZrwn~pHFYl zp#2Ek3ZcYkBYl%XJ)?n%2416ugshl)nRJ@1dBy&)kuk*EOk`M@3OGKV0c zS=zyt$Cw{y)1h@moGJmkd_}arN;srt%9jn%M#@SQXTOP8t0=cFiXl;o`>?9p$%ob-VoMs`cKAry;K7g4VMGj-($& z8$H^3&DCRfdN8QE=EboB<3o1Wy02Y**n6?pO8~r0+}=G3WF%tWm1AL_qv_}CJO-p& zx%Wz^oOJDrFH4}gKJ&-TyXW*t(Z)v;r#U>>{S;qrpAc(1mM?&?P%$@g=|8HeuWpFu6EyT=wF2s}E#$Gr5*YvvTC%Gp4%DolHdrZ6Wd?R_Ck#FDPpbvD%J&B4<5Y3nI!t#EB*2Z$bhJfbaRo4jrqIOfjQ1$(cD zt@WNSksc$94DF#}P+RS{CiLrsdnqgwa1P9X0I#F;jz;q*8NL8rTzIZXZJzt7SkVRt1%;t~3XUheLb+0J1oItxE8^$O~dY7DZFeNXgq2 zPFMFo$5>gdRb5~mZU7yFn%{4WcT|%jO^*if8!gu$rrm3^apyq=)=ZN?m%WBaPeY@Z z>wRbGt5uDevyGUJivq+hSmnvP@|pPT&R3l%QGg?Rbzt&$&&aRE4p%c=hpRnof7hj zi3F{6AP`4MLGAPxt$!am%HT8HNJ8jjVgzgUW9jDgkP@sL*!bEFkHW)i9Hcn#nF2{605-|4By;rnPD}HBU*)l4GGx}`kC?b;S9>-K$|v7@1a73@i)h)bPiZ2UShK3e0^Yy@=k6>%u9ktbFOg%Y3WrPmYw&pGK2|K4 zwpL}ng{_^m>aoQ-$}F2g1_>^`VZGVZa|`z}>im8VDC{nmAe#a%&-Iwm#wEO$$XN7P z1<2nl;f6mS!bjdn*n&)i>3|G=^LK{t>pTlwnq%M&QY#XdTY+U}`W~vdt2>NXU`2s8 zTYD^$bSYr8QUqqY<6{_nPu_Kr#{U4iu6UK~?_#)%mUUn|snF|W`o2^6X)0aVTOc`N zd5zNkhYys(kdEkuqc&@RE;ajI{{SD07_zfnl}emSpvR-^KPlt!^e<)rDZ86lh1rXL zKCLoy86BE49}@w21=i$ayxuzW=`tf5-pT;s`TMP%dONZ1ubX87D$Gaq{(BAJViHzrXLpDHwOM$%$Z~0cwgY%O`?)$=sNaw+b zIH+Dzo1@C|5oh_&3X{yW+*TLc%8 z9ghmwI4gW&Sk^yx*6)AQc7>WZDJt7rFt9AjdJDbj;av{Typ#o;LkQ_Rle!x5v~tI~ zOegXjv|o4Me%TZ|Jn{TrY2Pur;7}ZI#k}9e)xY(AQ_oxVh=58Zf&?}=IslI%`}um$ z>nT6&tn9)uE?OK$?Axcuvo|3pkHb^#0o?@RB^n#q%#YvS{APcO?7!_9Zz11D$x<6? zP|d%JLbfI)4jqWUO{C=^lkR^tv{rks^6M%olTWs7L3hc6qJV{hqqpBhLW_al%@D!E z2XezH7R;@?o%wC%b;yv0CQKeoZWZ}EF$(CZyp160y*^e+KSCsq8*d&aVdiphPeyJC z`AW^pEqMO`Wx8TXvl)qqg2ZaMDAbcx!*LXoDwb2GE5fCF+vHy5OgB-pD<<|m2{*;Z zdyl1q% zfms%8RJ4Gx!L7zq$!{X)bUd9h*=+A2iKGZE(z3?H zACtlNl~(#>DI4SadZ1yT`h!^n640ChzSg;d`LKrXopG~NDAU+SwnzJ>Y*Mpi{|fD#Fi zuglTNdJ)r}Y_ME`;7}UjmmHP3FIMUAYB>?s&x5!LV;dbo0^}7uH|XZ8zTo%Vr}TBQQZZZ!%D{08?$5Kz`yM`A zRtejC^I3W)rO@{2^6r@zcDDclXF0XI^G`Rf-t~T`XWi1k1T%0YhNYg`dycbtFXuTw zS5Dwe|;#LzH>2MpqIg0(hI#iK6x-y;OTFM6zT>9C#y50{Sj&(P^lDFWZsDy6r2`CE? z$})HRf35wT@?xez8!;>h&25s~EMiY1R7&p{w1fZ+)8=E`=<QUY7~~`IMIEIF)Zld1}JO+(|4u+0GAYa zcZf}a7ubo|dr0s-#XRGA@0KH@pl^_cnfOWOF zc>BJEUBB}q@KzWmmN&OGk58f1QRH>ki0DA+>(f}?lu_AaKK5t=@-g!S70A)OSG3~Xp)9HBYKL~|6 z1`1arV9IoFa_Q6FQVg)QtOi=KzCurVyY-dEG*NMITZCJ)znXh>sMoMl*=>B?FTCD( z_OKrD?tK3MVHN!jv({9cSfd<`!688c%CvX8A4|1>1n~NQ)>Z+_2C!jrxbXpX;lnJs zUs=j>pT%G*76e@Ll14WBX`FBRxjk>ZiNKN0Y#Y9(+4xG?0G0$;ljz^p zeODJPLZq7r$kdPl1+oW;cWswV{CuNwrR*$iV55WP6TBS38b`lg7S1SEJV|se#Cw+zyVt#K(^eJbE&EBMB@~@7T+uAG4hz zeVQqcDy_^-$4R@Krd)L{tk+YMlW~oKzel3D{{V?q#kM0HBs~1yaC#i4tEl-yHo+un zS~4V$pCPFL^mu%KY`cI!;sEz8q1V{){hvevmdJ40{{V}>(!)O=hFPGA$r46Uj=842 zA6X20ztmLC#o)^Wvsy8ynCLVf)gVDF^+G`D5HF#xYHgj6UF=gp9Gs*2G3^|d>C^uJ zvHq(UR=PVWTZKLG^l<(6fW~z5W7(Y#IAtQ*S4p#N3OTYI_Xhe7m)_T;p>zNeJmdk{ z(UXfh%P1JB2K?!N$@_GNnF{0@{4dxoVdKZm$YcYBZZBh}dc6k(9CZ0f*PEr@AqVcnNjnC-l)v`w!C9*NLyjJ|4=f&jw=Jt>S zSe8&i_tT2r)8R&_(Ij23e>XRqy48U34`%f200YtIzt7O6lmq89N&!L$p?ddk>dOR*Oj>+sp;)bte9A;X+6x>8aNE-V4It@P;tbb;sdeI6#IGq$yVWxs6}CHg$z-|!mK8CwIb{M0P~ldqiE zU-?v83tRDg)IvVrms46%h4BD^&)cW7R^ruWoc#RyR%k2+gPNLCP`2#(jjI`03lG9| zJl>PzM$U>|0P}&b;iA5uiuKM8L}Ep`^uP0H4aYxit=ULx6Ro^jy7xah7Ol9-#A$tL z!T<=rHSd1yssw{!Va;nFnA5g-q%y;DNJ|9vLhm?MfWvNUk@5{Rtle70PiJWD^n8D# zxD;%AYs8WbNLXDEb}LJ|k1u0_KOMf0!CP66X4~w+4iV#X*f}Qg;CO-(Pe=^3ApZb1 zpC0lhF{v1_3`qd=L&e|v-yMa~mGA9RY{|52`>}*d>z5VrTsRUX@^O2*2zKq}#_>pF zNV0n_Dle}mB76S;A2)>LyJXLooRf2N(r+K?3T(C-^W4^2?yqKFg-HcKE=9r?)O4He z?zda79LSU5M>1YDvW2h*(&EYg0K=~|$bZOrs|fwnqf$8Dv$e$Tea7Fg=cilCoK zdJXvv+B+;N;V>J52W74J zs|n#F&&~XrAf$ZPxj%xq1LoG;-#Mvi*yu?d_aCN-fY#&8e*ml}@ab5@{{Uz6Xia0L zo-bW;SX)lpe1e0!Vdk-}9PCeE7WwZ|Fp*>Cxgb{EPi==VA3Z6!g~08u%Z&Wv_En%^ zVz(zzYYM>2i0sq>OJZ>Y=M{&YZOxTBM=il4v-n_Ix;v!PtCQLO?YbXC(`Puj{#Gjw!9Kng)O?Ov60 zFi;2yo^^|Dtw_+>81(>>NhhC&hK6#40rJ~VlZv}5 z>waRyEv|X->GV8)Ep;Xq3_-tX=HcV?d;UGDc}!w(vAF}Zo_6kkBBzm#8t{^Aarr$O zVvbthx$>Z>~l$24<#FBgV?yn3?$qf}T&o)F6p&5>XtpzN!@5 z`H$H0b%N#8TGl7Cq@%vEc|RY3<)PPpz+~&igvA-_vnFBD>2UD+blA#>RjrA3)w;%Yw|njX0NnPF=xh^KB<{{Tssg$!XR zZZmQ;0gk$KrpdZ|eN`;SE(NX78My}*x!_m&bUI~lMTsTZ^gi#b%MK47XD%tTn*(Pj ziHIs}J-yE-4t0xkiY%;DC=9?08x1t{`p@wzfdv!>Dx=M>G3j8`3>evCRtj}?wVkxf z9tQp2C%dAHh*@!bMTr8ycX3qvb**rU+DP88PyjssbL^#RlaaW{+lh}vdvT(Azj<9D zy4{6{)B?9=zpLc?pZ-k=JojQxq0#)jw|~3;0IOb;RY{sVMm6HLnyBf>&t8CFsv8^$ zW^e(&3Jv1TH?HyceCSnLP@n_YZ;(sy^^$rXK6{IT1YiIkt6d2goVvC0b79qZIM31{ zxEKtG;RjxJ7xq02y4o`oTkxdaN|MUQz1@~Pyss(Nd#}a{q!lFGS(K^ZA+lEK$i2_i z{{T6qfd)mzLbj^F;&0RA>-SZ7vq=yfurJMieHiJ((C^Vk3d#k^7aZCyygeUxL4k?! zbGs*)0e19}_nw5F>#Y5(fB;yf*?{LF!rkTRVf)Xk^Ka5Si$}JrsAHf~nYY z^Ubi3T;xZXD8KXb{{R!}>a!b5A~J_nQFbI2iyuQAUQdj8`HcSn7wDtRnO!EvQN~>5 z^4}J2jd~1SaW(@%k6}QW zK|hmyhRb#6yj7 zU>Q#kH`fvY;!N6{?Tid?A=^`P-Rkno%Wr9WIjh_O6gA2Ig zA2wK`KSO82x!*b^#GuG6dyo_Wp_dzb>Q2xvFBYXnHq_jTay6J})&L2Z9u!-bvze@}Xi5s8-Sk11N zaj4|i>Bse3ui#375yeOhfxywO%52`Bj!u`nC>f4B&eIvjOnNV#(!SjqI}C#mNEDQO0s&Sm$ew+u}MHES<_~u{Ks~ z^A$og*N45q(4_|F{~UNU;Le z3UXaMmh08hqDPUN5Vp0zVW`KC1y9ZL)Xk|~vXePNAP8RS0(-ukdFT1+D3ET$5ys;Y zag>;P8EpQa>iut{jq4PVq9uy3u>!yoam~Ex)jKSFie_16W68PHs@onMf8pQCc)qkq zk{Lys`a0YNOqH-q@^A5eu8u3mxy4*5A#Kde51_a{37@S009P!&?*ww07By0yBmn$ce@oC&DjCX9v?Ry?PaT#TYlkxZ+Chcy~f?CAghX~8k#Zqkl0Ch=YyNydQ z<7=z0Xi?C^H1Gca-1f1L2(3C&3v&=tgx|QoRpaM;ObTRmFQToF(C+;&yroGZur~$L z;_ARCRwF3+L~mE)%x=Ia=V71=8R+HJ?>wiE$Eu2uMeX6Dz zEUw7XD=mp-HaPkbwM!V0oRJ=bLJh{9-2J+9NQ4ZL>%Ys<@C`gL!n?qXb2ZFHmPyZbyn)+tyE~(62O(>VRjmHJgWHp zxZ)|KTcKrcVu6u#?qchM{;TfmL@ao9H@UyzKJ&Rn0z~e2j>xJqHW{h6a<)SOOgD!T zbkh3$w%%_3_FwERk>QNHi4aOG5vHy|1(~izO~v(6pn!cZyTYWA9^?o` ziMpu*`1E(OcRYSNH_{uB#aXp1bWft)M4xvZBWMbuH4HLsLmO`K_l?$9JYoVgoO`t> z$3|uvFdw_5w5-P6CTnYZ37ZxVk`xTcAR7re+05Jrp{^fMiSFn;8GZI;i zkc9VnAeGNHsa4b{GYqo{{ZskmHXa*f4|?t0@0g{7~%^dU~+!OcHtNca9eYJOkttN+vR zrJ3JnYY^4|5PDxv+p8lBhyXFL0A$!)pLySj{Cto~$KvV(2VTIShC^;V({&5~09Efi z?uAjdbXgFOHVK{A(?{Osoexo2o>8a@j)IYi7aB7g=C8$<)1z69?xZZ7#qY-VJbS5n ze4qWVoetz^R>{nSlnsdDj%;G-(+foSuu-2dIWDKxd!N=-aF?A8GDRgBW$BBC3$7Yx zjQ+Ey%6h6C8b`C{pk!c-Dh z-pS>&_msqW12y77faMxHTHQF)dt15}NYze5L_nl5(=FuQKg^m4k|sde*o@RLSk#iN z@q6#O{O6qRvgi;=bSz;ETw8eD(c7aeP{)jQpBIkcBNURuL)g6;d1d`vbkgx)?Y!O~ zA#SQd-MgHo>XJ6Nf>>QhU=(}2olE&mX5(OBZ~d$bMd>U zrpF30hF}RcH(_;hIPl%_o<9_lZ4udA6Km)@=G;5U`=3`r9SVmkcYMwPxo;^QHd{RC za{_j$eczk_kX*o1k+*+tzg%{3u7BGj=va=7v8D^wc#F(^@12fIJ2_D8kCR^~$LhM@+vcXv zDk90WGb=r@Agl6cg&KHLPmFQ;uLA9#XOJG8fF^~Jn*a|NePrPJsO0#!miZ0q+)j^V zj539oAtac?l5FR>v+a>%`tS6cYS>geZryVvD@F~?+pCr=y6Cg(tJITccUz>R8(63V z1_R6Qn|IfFNmKK5nP3b=WGOmgh1E)q{^NGL$z=W4yzG0Sx_G;?s$+fY5@htT|Ow*KUXY%v%%X*Bd-d%)EJ={l@`l?jC;QyrZ!m?*#TqO&3pYXAHuk{ znD-?m3QSgIY?&pK*ot_6H`VfAvr_z-M64Q3m-xQO;nPE-pogMEs0D1 z@*0wD*ozf=p&*rI0P!lm^ORl$`o2z?1dU-6DJ1|3*TA7;$z$=HA&c)f(R=9oYsr-P z2}CzuD}bZYVp)2AuaoO5Ry?@nUGt!^X1c}2!MlzG@_$a3B$G=ZiX|JxCm|VJMkj~j z@6~`B-RE7&`5zaV zrG`0*XT6LLteoxXCinH!&C$tKfpi)+j7S#hlNRcCe}DW`&9TDlBQ^_e5)?A2{zrfF zeP_4Ndg9H0N_WPk$Y7<#xhzHZY)=o?ZsO|wU+o9lrQXOCFlj6#B-%t#*73Kk@|9GZ zWKu>-0;gVL8rgf;rQr8_&EKn!<%n1g6m~2zE=s>YDc$-X_dLwkfP8yO;?Z3!6p&OF zv2Rbl{;I5qmyK8(GHCIGYgm6z$IIlxBaau+WL-iKS%S zk4Flf;a+h45PBmXJ_#Rpo^`Yac0+_%6$e=EZV>woO=EpR>U3Vt%y((r@%WP_7Sw7j-Qtt}-}VU#r&h28+*}s`j{{Ye> zO4*1iRH(;J5r6r8-&XXnZHPUGOAu63InWCLTIS^3;!mo|NAcJCBXCZmOXvlzq+3T``7Dx@DNs@UpynY42qE%5thkgcSqN7~{lWiBdJfZ_WPzuSp8Bf)5bp zxzv772j_p*8cQZh!Cydv_EJh;%1izw8L+jEmIuwJrQ`iecIXIeA!0NouD+C+@!+4U zm-Yr%EvYI_HftsQ2^V?0nPQQ{K*%uee)c&xJG!R{sVvZv0pS{u-hWBGXS=H-+Tvkm zu($^N{9g8=Ra2-JRRZeAjQhloc$xjozy8hU99}Xk;Y)MkFK~OC%qH3SPu~8A%G=a2 zx{}bFsWxD3=HU>oh%rqMDW z5#8f$S~el#Khul)uhod%AaF1{xh;L$wkUP#Qg2uF{dcX7wJSCzM>-x5Eqi&tzgvyY zrq|Zpj2qLfedG?&w0O9Mk_WXX`Ms1VXjOA1kEgNb`7XsZ5 zPsZ-Jh?rzS-VX58HT}Nj0damVeqmw!`W%6|wXe@^-=&X_;_9>CkI0%N#{9(dk3Xlw z?a;gg0v6dt;+Zhsy-STMdB~P>~bRF{kr(^_?;)gKoMd^*I%Vx+CsMH z(wfn&hN6wbr-GjbM~#?dO7_(oKf$nwW(=T~wl~zJhjP6ViZm6fyaW!~j-F=KvqEH3 zf{c8?4$o$%kKpJYB3*dyY~8lKdK*CEtcVOh4N2!8gCNk@>CLI6$?$AK_}qNA*It!- zWksrYQLRHy5q2E148-YheX4QUSyt-C;CAXi@T-+II#HDtrE0u><33%{3?&}|1^ACi(|kP~n(UeSzOn$;YOsjE?&+<+@y_>bX4xpuE@ z$p+0&Dm9_SYVCu0e3B7C;s+jWHPgFGF@xz3J-pRPQO&LXaatq!2jOKO(yO)}{F83h z-QjNAUb6v~gphkk)~pb@^L|jH5I;F({B(B3-%vDZU=^FpZYwuJ=TmYx39 zjlH2=kt)|5$C}r-+n{1>0X8F0G`Q2Hhxm1g9S13JLtKYd=kdW<gw$o|okCa=*O&j*}xcL2o_RQDS;IZq|X--c&^E3k6>_ z&(fv{vE8R`L^1Q6zd1zlW3{g}^2v?tH14A>PsiDJq!uioT$eY|2^Jvwcu_Y3i?9OS z9)5n8C*%79V1sBT*LNh@NM*S^U0uvR(rzzSl(i(~s3ci*(BAz8Zu~wzcbtk6O_Y*UFtUI;&#jNvNgZb; z>@v292M|>kIE_!NP3P|EIL6Wv21SS^by36P$73dJq_|+-V`I_i^k2vDLt%Wae@CBN zQ;&oO;&$v!NdR-$_tKgjEyA`^0b(0PK8KI>k=inBgtGIUI|p*AuqX&oU@Tp^xw@#j zHLEy?mvse_=EU@c{Qm&U!mJU1ihx@s$*rpZJMvUblC~o8>tHXIH|XP65f}jr3ju(> zOj!4sH(#6ap`WB%gan&uYvjNx+t5hq>b6+GQI|=KV~#UyqUiOT&nx2dnQSo^LP_rM zyL4ODKaa_!l%3pTDGVc17#u{5UyWt|0Kfb(eHfNRIUZmM=f=thqr7fg>noL^5#o#t zY0_(SweGDgitv>H0&W8~M;~j}I(XhwBW*7jR6q#0wXb#dWY5Rg>8=1IQow*puSftf z948EwX+VR0EG?<@m)fRC$h=tP=5cig`EUJS_9;ckLB-zFw4Y1I`7s@{D3Vjq*@#K?hQL3i4Dil8nTzJ+=gfYn~i;eF!~eQcR8slnXgj10py_{H!b2 zez*HEl;r~3MfZiWA#B&xl?lTc{byB7Vy8e2pHgkaPp!>(=}ohhWKx6`=d+94^b%Zm zoq=ZjlPR@PwOeLjH7VJG1`~KFVTvV zBE}|bDF*ku8)AOdH)c@RT(MU=SQg^97t`^)`FdZ%lx-2l6cU`703ak-5IFO0pDWf> ztQUlk3Xm?KoktgPtdYhgkhu+V{MlHX_m7vX^^os=x6zf)5XSAMnhdS%X3ZR)KMry5 zmTb(cvbH^%mP;Cl2e=$Yhs{y-O<1sBHTq?Ml5{5w)1!BN4)m`0P4%5 zO2Z*?G{yBJql}$xXV^Md;uh>fAEj~K;_?02_DEzKA&&Om+~bZ-W#PhxMVmr6i-k&QckH@Xv2;#eA&gM&uEKTwD+KFOO3o*sA z?Nia|a5R3a-BBb&rYAcE(64=u8?$GJC;6Z6ErRQb7XsjbMx(jwB&_duL})NV*n-T# zk$XBZrt#PJRfr>IbT_+@fHjJSe=3|j5?36ICjjKJyJ9b=p&zP+C`-7E5VljHCgJQ) zrH&PHE1j$(g`ACU`swX&J!KLUOR_7f5N^QVg!xDJlrj@wa|*0LQDK%Nhl6>`ap6FW zMaa+_+*^y;&KX;F{{UOO{{SLR5t}(KZotWN&(YNJe6@td_YFExN>&gHhia#W$u>P zbV(*nb#}IwOSP}j-caDIGQ=u@W@Xb_Gb)C>rH#5V>l=GKS@3Xp^MJ=`0aPL|Iw87` zXO{Jz^T6Cx=v9N+mxLh~z3rDni`%S(6p@RN0*qJv z11OcFQH|}!=PNb&ITwX^IR5gB>@US*+3<3T!-=kUih%WUO^=qB-O?IUU?E;>0szzO zeQ%GFnN~$Snt(?Rs-ow=x`sXf0QK9i*2^TLY>~0Wl#M~Y1&7hbE$6$gSk3eQrMC+Z6x@t$$Csy_P=aMvmC0-50*-YT#kzm2ldJ1L2so2= zaSE&sg5$7o{f{3u_==6*BAFS&IRjpjTiDB#c`G8tqm|5IxwmFOabh~n=JSt^x@qK4 zZyajqNhxJuw)PR^{Oq6T$lZ#DwtK^;TIW&1i<-mG_dCtxba;sqsw$DXFsV1NJleN= z^>vbBs=Umk+!iH+Sbni6nlcNc3xM4`flH2#Ib-_I_1xuA%BbQTH%lIFBu?j#%Gfp} zj(A(2F^Zy#1mYXcS|3$W$nFUTLOJ^d`ZapId!%4RwOd;k0eHDz`Hq(->i+h?^*U@vm-Uq58n44*K5orD9px}Gs$3o>=2-g&yms;YXI41dpPNe)dayG_ z6<;89B_dlezpRzh^?T2}rk13PRzQN!rSWZa?#}H>Lfp?L>=Ij03pd+erPa~LXIBcB zW+LVOOuA3dtwJM5RvWWzPL>wx_ju>}+QI~=*pOa8O|n+v$oG7BBOqw33l#ti!w!}G zJUpEkSu!mfw3gAECAGRux@lzPI(|F3CP?8-Z=JiqHb8niu1d6p2FVOYXrK z`n%JeVY_4joA1mZ_*qoT&?W*KY zCo6dW03(K3x7e1m<%nW<3F!5HQm9#1c16iN_cjZA^?e?tkW>m zDwAuU-FW;Hkcvi7ExDg`#H#DWoICH`{b#d=NYMj9y|BK9*8Xpg*xB{}0F_t&)cL&2 zWW3-DDxFzuJ!8|Oi5@GH)u#!_*>d(5lz;xei%qT>5q5KIyuf{qQ|l_I9D>ajMS}VD zU^|P$-|_Nu@j<&JF|E|fiETlBS2M$c@bo?95|UddjS8OfHdh0OAGhP>EYn#5C|nHV z>eq{;s*fKj_g|g#u|i987S2G7*p5oFI z*`nIV8G%kfolZ*s08JaF_;r3$_{p=*sBDXqGHOQ~1AZ;qD*pi7kM>g@CCY|vo>g1K zW@Bgd-sU{|xPCDtG5-KZ7Ui$Iq00MgJ4yzhX9+PD42K7ArzCYFY~(YYow7`9RIp)x zZ{5e%RN`%-5K)Gzrs&G4fMdZV+&%vQBf0BoE?WX05Z1bdTVLJCv2}mE{v5Z|p7C40 z&1IhE;a6Ta2J9{}!cTX)`>IpK(kd$wBS3hV40yY{EKvH&+hX=<35^qr2FKdyO`h^U ztKL$+KFD{0b2Kvc2Ro2Q&FsJZp3kD z6rRIWtg_CmhzmG2)ai)mad2{!Qb17(xet3NBcq>BclIlDbID2su0&TZsd5tio$7YX zEF{X`k0=HQZp2e$}_kL&ZSE9)o|qa^6}osicA4pq;hHrJ+{J5Q|~u? zpXRB$(5JT7v$R%2i(H*Wlql${`u_mD{{TrRVrAY~ae!^LI}xZCmgN*^%M16r&s>bysbM<&SV1;q_x)zIuba3 zGaRhAMZwm3%8tTPq2$LM2+2x|?;FbF_i=9apS$XVRV9Qhi!4|;VR3g$TaE{XS9{0z zT({`d!cRv^q>XZLg+m58sXRFY-R^Jj-OwtRed}GB$jNW1#PNMSsx}az5?~o7S&!f9C$v={{ZtD$IJK?MxKOphr zw6=LtGW;q?Qq0axv?}qU#+$9@8g)Kb`NmgMiM1}*JR3EUcv#qr3)r7ZBGsy0t*DSmBv_`7 z2MYtBH_5qQ@cyc}M%Z8uK(1N9gT_kES%BpZ+rTc^OZqZwu|kll`ryT5rKA1O}2XMB0BW7x*tmmiPQi+qLbtIw5! zoj%fGhrHfWiUqPLq%@rDx*l`rH|Hg$8HpI1S*~*g>4@S?JyXyZQb-3**E$cg-}Rn3 z^E+JKRP%*aTUc~B{XEloy6U9$o@ouVuovdGhTtvxsY3c*_>6}V9&itpP2WPdIPHS4 z3dHNKmhydHCm+21Tz&9n9Qh4JQpOf%-Ja~;CQ(}4BzTU?v`MCc8r)sUwlX!{F zY*qM(b?NGQyiiN20_&*bnYkIBI`#-QnP z0-JGLgC>*usg&B!!KMM6tre z^IMgMFY#&H-A79h0MUmx+Kbci=N+ z8HTn!+9Mj3=-(hg1Y)-DO54BHQUe?c`OZfA>*;ZL{u*R!-sbnY)PaAZv`i}R#6~qN zs;pWgaG-)MZY~cTte>N!ga82s#k97LN7gS_od$1^Y>oR(=tGg)`+2<aQ3WktnxBV%(x~i7y9}sN`5Td&u0WIy_I)A5|jtvC^5YbvNvx7V+c9?9|53 zZh3Vj^wu?+*J)i^bFS92TH5M^fi%T(DPo8bOrpg}I2_;`YfoLFJl3c|9&a8zdAlt> zPHj2q*wW)7n&9hDeuwvZx;Pn{B#Dgpl>??t*6{4{_(^mfxqkc^Ay{+d~3^nXopw>=HD zqiaIVEz$inIVYp~Xh^QLI~w{7GWs-GX>c{r4Srf!0^cd@{*@8Xfo%x{j?=C+&pfpo z!yZ|X-Z8Ku#E-_<($Y+vt?VpsVdfe5@oXoa{Pxn?{{ZLrbP$ZB^XXzd`qvpauII2g z5-cyL&5uQC5?ru4SoyiH?9+6!D7kCth}7SSnfPPz`VbUm2VrZSZO!TU{{SNrp#_<{ z0qGul`#4hH5{H~!hPURh#+eZTnT~*7#OO8l8^`feiy^hm^~iIbeU1iv|DayVn z$KpB?c~TijIs!Ger>o)pWojWpRPkwvwTT+v(fZW&wBl8)0W5UqvGA2H%-lI*I$+!o zMeWDK#fQ|@LfF`Wj--ahB*)Y5%>F+ClqF7hiszXo+#o9H!PZbpvH%6{!r2bpeKxI; zjz!rO1+>ul__{QyNdo0bVtK&SvGR(Y%GNW4Q5;1q6cb~9b4_MeB#_RkEpkA&vj#Oj zVDng;bJ+e*kMyLpxYz|y7eyRJeZOM0L466;lHr)04xIk(s*>3`wejeSo-XH0$M&iS zgu6&=vQ!Q23k@xIX zeWg)c6pX7bg;;{C`8;&kLYPaiKuXy!e0lPIkCb{u+PG!R0jCg9SNU9}QVfc#Wo0Tf zHtqVqdF3~LOY}aZ38c%3O~ib}gLS#?QVNlEumot<{*zV8Uwb5sqZ5+a-7E2L9~}Px z#`Rv$od-zi9owJ4$X^P+~h+o zxLfJV3;q{1~v;7geE42bz{5cW`&q2Gz-Fx}{ z=k#j5vTQ-cR9~I9m&>7sX#2jm^PAdosgj2vgvvI(?4?iT{`1C@WL8UuA#Mpa(*CoQ z%Z{GKxPzz$Fi!%Ic)POi8;?KE!~11c0!6thG~yOwK9o<4QdN(HwY!9=3I>MbA}!X^ z&Cg63K))t*$I)NCqmXfva!I%+ovsH>apb<@I!t4fQlUvhd!sr!{cJOgp8|Wjr3fbyZ1&*`!=3z|Ua|aqtK_$?xU`ITm<i%5-6TCSS7%Y}IRW{;B^);K#1w4U+7D0CwVm#OGuYQ(Tv}Hlo>xm-Qw?*=zl|aa= zTNTMqM_c!D@%$5$I5`e`3yvQiKMzUcYAr2-5pYN$1f ziw$>#fJhjkp6|C=lb`+?QFchORVmI!L?HPrSxP2g0(H3;w!H}Q^wNQwln1n) z1__DB(BP~VX(NrfuDFG}waYI{hYF#W7^P4d8b)UT-u8^(vGjbWtME1B^%^a8YnHv% z)&nc+lB}KiXron=f=44l!-R9)aO8jG{mmjb4IyR)p5!X2#5;LS?)B4mhmGHDNaQ8e zaV3S;B8L8^KPT3|tCQ}f?5Z0|`?SASoXVqfqKt5Ce18#_d^TI7<)*Q=xCZGa^{b~S zfMB+J-xphAZ;u8t`#w7rE=6*7Xwd730!D9JmiGkl_^#g?X(W#u5yC@TMkJAscMazq zY)x}*9Y;EwBijC(RSKx5eAjC(m$P$)?&G`ruO<6eD=KtR3~V?_)$;mr$@|UU_V0e6 z607dQ$Vk@rVQYR*kKp)2vB)*9H3Zyn_x`ihc~9Vl0oiv$TWF;4k4qnUS8Sp{1Z7IO z8^=btBHRf4FBWpwJno<_W3E_P?r{;_F;96fPn4)qNcV-&N%Je5t@(9qi?k78q<6_n zWViK+(|^}}%}NnW%NU48Z-@f-E!}J$oHe32-dAQWipmu89CRExtbGatp}S%UUc<%q znu*>MDIpq5jwuPm;&`);6;4JEtfdfvp8v>B8G9gutHuaeFF#iDds=6sMkjydwS%Wf+ z28@l%qI-U?^_86e0A3V8@o38#PgflRqZsjIYbZL46Mv$eI)2jx_&Sgo zO^T>sjgaZ*lJDL`^?nsQ7?ild(PT?tpg3>4cq5nZD|4K7Dyd&B*$u!3O13K6a%(%v z{SD?lT|u;lI)(%j&z4s7xhnLdNeseCn97U7g$ne$@o#1I8G5h!quQZh#TGJgY;pj! zmVT}|a72%%3C5k<*EpiMD`{SbZvN3hM}r6B`f| z^W}Esf2{G?oVHz^ipbeWG$1CQZztuveV2Q*#wr=2;|m(z>wP#!6qx#SY9&--v$0eo zRwowIy~S7zl7lOzdJGhEdu*ILxan+?$skg}3mYPm3BR@KYC<};Ge{gRw!?*wwI{bt zzCB!oFFy%r*8RoL_jFAb(+CPAkf%QcW<~vC(uOCDyJ|R%!t;4e!mssJB``F`>bpZs zS3~aozEj8X(qj?gTRefx+jfjTzy0CqOBUCW_{^9@2Q~L>PjmkOvi|3L*8aw-Fice> zkP&diC|2}Qm1OHbJN>-!{Ru|lLt%0V88IxP!kWzNeYAYlh)`nR!g>Zi)tI zT-%@9yKqJ1Jbp~hagxsy79?=J)cXm$;PJ{+Jc@GhqO?rf8(&4lcj(;_h9coF(~*O# zdQw3JOeBf`%Rn^0*1s9%QgRpiOUFjP7Lye=GGDx1kG#icRCN7)I{eazjGBn|d$_HF z$z>Muo2b1%S-Y*wg(XUnkOOzjPfI_i_57VoY|N(1BYCkWpVHlOiR0~6BxtsY-Bgk} zA`0&dbb7vBJaScZRA&zF5Y&(e?K4Ob3%$R zE*{65L=JrjFliC+>bm+90}Ptj<8BvQ5tqG-$Mk7urCZpfMCr3&8W6y?r4FSJf0oOXi`Dr}-zO|ydG>I3hW7vicGrP> z^5DHcyWM?dXDw_AqX9U8FJp3h2|m}cL%C9c1@dlkMwT9v`2NpISfO-_wuQyzF#9g+(IPMKUTg_|Z^ zzc)gX@~4A@g1eRNQZnlw^ObDB2kha)tGs~Tz!YNG{kXcXE%ANt_NqsPp(UgjM&Yox z0pWCq8m4>sNyEyn!V`8;sRHLi-Iko2PSRU;!Ayd*0D#KmIK8`*r;&cLE<^n0j5!;?rA|W^v9Too09RU;Yj6P- z0|v(+&87PY@@-8Z%E~xM1n$rr2fW80PO3$X*>{Myahsd;Q94^k4jp9!4U8hnrMPcw z2;tX(u$g%3?mhZ7&_bj|F>~&l6f77K#v+Ch(h~+f8WqyyL z4s=-JEK@SPN-Vg5dP}?BM)G0f+14r=S#b*-X}bc$H37n(DZ|RYtophnL}^wj8B{S) zcUKD{la6xxd#VO-?6)>J+bMV$_cuIG-230{!ZEaAH|$gf7Z(7Yw}%-bPtI3Dwno{b zLQ6WcvE`Y%O&<>GX!^aPC~^DFwzYPtrc%Y0LfRK7F3bnD)uaje%87T6Az3jjx-zyT z@QiTYNZvnIFi8vWJKWyZ9GOe`95G0IemvZE*LIH^Zqb7WMUn(0AdnVOk49J|U-pfj zN|IOE^^zh%)b1OOK(0AQO{2|ffZj%J(#3qvi5 z3`Xd_Jb7*=SpM^`hOG;tb8D(DrH?V^y4AyJSB*lI*5<^L;~t!T#YHkSQYZ<#ku@sX zNps(5rNip~0Ia4-7TQ34&Eiv*QLVxoclUegBm2+w&7MYuB1O3}f?0_Ku7Pdn`gE<* zZp^{RU^r6O%T?fdEtQf$(7zIqabt@pzxV2>)-9eUD%K8H3^f-ZaGyhhs+KKwPH* zT`XAk>Hs$TFIK9zO=NO@ay0W_`}7X1vT~d(nlN6^PanIc-EGB?gp@u0)Dg)!r z0R!hBMwu>$bYA8?$^edh%WO_QOOMs={CU7g6(!*qSOvB0gp}THqi?F^ePuC(jUgEu zo8CsdlKzZpyUyitFc@XQgn&kJ(BvD1I;GG0zj^O7mZ$Yrf6D7BhtqzVN!2$-2B(m@@PQ3jT=fZ9 zT9E^3qMMLc>#ATzTox+fPXSEB^p?yRy2l znzXjt^H8zV=B&M8Z<6%rZQYf(u%Nj4zDnZSR5mi-!h-dL))Fbw+f8NbE+|V>47~|7 zIZzj*Y|Oc-W$QAflTg_5xS(TLD~&Ae)s@xM7hbx$x}@sr>dJ$wZ&_VY73*^U07?ql zK(AU3v>H~_v}kQrv>Jqv1##W!hu$ri>RQfQ>XE+ zC(eHra2i$-Lqs7Zf@#DwfiZfOXK%lVG0pO*#J$6=Fl6ksRQS@nY8?`Yxs7Dc{ zPaqoAAZbaFYAJx{qPCTwQK-_SPCC(*u&u{hWv`|uo1aTk6T0YY&6IH(_L@B_%9RDr z&rq=w3Q?+EX)Ls}VFtR0iomrp;3>A!g;!dB8c;|))U!`3p`{T)&_Q9$NcoMc0^hS* zD^`n+V`JwURkNoe1AE+j!m)!`5JAIIJ)SJGrpC;tAZ$l5?H=0Fl5L_Hnn2t{^WwVX zFp>NHWpT^Pc<85AL4~YIwXN94ES82S*<}nt;nMw=zeQ;B48fchx#GFvWAwObrHQye z0mxV=KfIHBImez~Q(9Mq_`?@W=fB0Q4af z6$Zx@0hx}r%ZI;*k5}7xipImxX1O|Za{+VR{E7)KO5t+HKn9I|mkUz@krNBv_P88D zPfHB@^bD--pn|~K`i0PR^L)Rn{{S)puyh#QD&EL|9RaO<0058y(3|SrE>|Byg@+iw zIoF?{zK8wK1GU|Yf(D^GqwP0MBk}Xau`qb#E3McAr*G4u1QxeB$S4Cr==GJ69!L^# zxyu3FKAZA*N4)Y|@BFYeiMs%*E0I%Y zd|Op*Mwn|%RN+q0LtNXij!r(GDTQmmVgVdWd$f(%ZRgARl|{nbyJA>yg|=Knhr6Ep z`l?9)$f<f+nofqYZx?!0`omg%bZbRkOJi0L9jo^aMOFp{8XMVi2Xz+`>hE=4#1 z$}eJaQ)bJ_aeX>!#$_!lWqbCJn1SiZC-rf1v@8#wfL>ApjTKZ`iDu<`&L1b%%e(&oE+N&MAV6!CB%4@t z*M}PC?G%oJU<2mYxePc@vpkhs$vl|_jAeV0oC><~S>faOLK#S!}-7S7tn$%!~mlse8}__xcq)PUC0?0?>PzsW1-U!57B;{ zwH$=+bQfmTiwu37mkP2a7Lc*;{Ffg?H>l+YdPfb{v|shSQdaTuRBVDw-B!vRo4vJV zUYF-8l5)Nx@vVk`APvcT%AaFaB(MP3*phW#^2ZYKb(#Cm?=nX+8xjhk)QqrCWpVzU zV=u%2P^l>!Fjdu=bjkWo^vzXEGj#FDo4;?59P=|PGUb{jiN&&O>-1ys`L$!AF_J(_ ztD`@bnj*^p%A z1=uaHdm$d*A6tjIhx2so#fCQz&2j@`0OI9&KDR$t_ta(NfRV&5pku^WiMg@K^8IIv z%(ux(lf!!vqjPj>jS!K11+;Rc{pRzOsZ^DF9ZzY!zbdtlfVWGp18Zw7z*HEnA6d|# zKwe_<7ATsj{{Y9iOlMuns{mI207Q-+hYM0+0~gbPkOMGfI{7z9^4?Z`6%pK?8(fRI zVazXzZsGp`5`+?Db{A3wiNq@&k4q0LC2|BXvk)xph&RU~Iy^G8T%BqvEW}){Hdzo@ zvlGz5T&o`-xlSND*>V=TfzRwz$z_v6a<H6pSLZfm^nk>%{5C^w|Jd zqbzU*aUd2^ryA!XM#-O#rOZb{z_J$^6+AeS233|udz&t zm0s3R!&{t{?S6}{d!5I>idjCggJ>wcSd8zZ17HEm&{-4npXpD;jae4ZW3)L8^zolXM4_gB@gl-@~Qk_Q;wbs>Sax^%fbKI(yS z5r8d;*2}A3+saSf*1aGuYtT@MvJK^8mM? zKn24tg2VmHWo{$J!)?#kGPAS$*NBTEdk$J?b63i5!0+8cSjhvNL-CrtkUWl+d61or67 z=;g`&>b3bJ_Em)rSj)Vz0g!;UK2wLvQ`Ci5xMIB>N%=|d72MO=C z&&zZ^v+n%#7Ku~^EF9d0W!C)}9kcd6ulsprz57s^@Qsdp*sjHcbi|asnPO&gq!Db0 z8d%@fw^i#oDk-9uc4QVHTITE|=EutLerCI8Xq7-n)=YIF?2JkFo=V*OBB~J>m6U)D z?_tnp^>Jg?(x%zuWDGFbkZYR%05)ykw_L1P=BtQdY;GR-{#T{rwRph6cZzWb5ZQN+ zlieh8XZNz!CiZo~jc)a>Y}Xwg>isv~c=={^DV8BsFW5V6pWCKpZc2q6gqtzfxsF@7 z{{TDwZjn@>j48E&Vxu+<>F)A$l1%;dU*6QZv+9v9K?b)Z+V8>vZy%hhpjC`As77`p zB0(i}==uqttf_>G9|EM2+RK+0vxJ+vYJSk$t;&)WU#qn zG_meSjSNW6%2dcm98x=Ci=%qH9<}D#qX5~K`tYz{%SFzAAB$IawAr26&>LNf#;ah@ z)^43YIl8CPW5&YHUps4%7S`QA%j*1Ii6RlG-vM%><<{K(?L2=DF~lXsU5brfBW|-OrLrZKiLmnsJdJt1!F%xqHN9Gi6eM15?3S^h?Lj9Fb|avMek0b~xs zd_`aX*7=Ko!!1`pPc`l8>q+9cQV1uW%D&AlLEb54*RTs){@pT5kV}D)y@qx+_R_Jg zLjtiDUSNAQI{GwHvo|>A=RLZ5P3hL64c=~-13=cjo1?{pW|Wx?xO)d~x_(+GkJ9K0 z17`%C~!FX>@kjEtqoQN#vH!ue(IUAoHiA*k9xrykv+$tgci6W>bWVf_Wd_e$P94<+MBxWfAr@&+;nCRqR+Yi21lW z78vkv8WNwh+rdLAay#P&gZdQII;`G&8un0!_tWxp;l5^ zuO8KUvh6T4mq{$(Sjx+uC%GQ?r@D`-{pCo~vnKXeVk6>L2H3xmUDuE8&7N(T3(FBQ zC?u6-RRii-GybwlS=ptT8+PGh7kA8sQGT!f++Gxz%y`(+Lo*YQ(AW#F4xJu9uB|f@ zaT=8kpiy&i=v9sIUe>D0rC!!DC}6F_(9St<=y>^9v9h#gNW@a39NIQp;8&A%f$&ei zj&4@Byl=SH;z2V_Xe)1~>%92F^DybFADNu;lC9y2_)2W+%$-gJ+s?}Bp zbrfu&a}RzLWedU9@~_j*fDtPk(RV3!C^R=dkJEo1W=4+77$sO-s~U1Gceaf&r!x)%=lzoRzomOY+7sqlHPl8G)z zH&PC1m$T{3&Kva<7EG(!_M&rFKm`9(Ft6`fyQv}eS(kWs+YJfES6d+Q?g_y=PrAx`L zo~p4T!(W=09Br8(9v+uqU#xfJ;-+7gb$H=GRpPk5TJ!gf46%hqESRSVzcugdCvrb~ zM9!*l1xsmr8>lCuo7nf9ttImcR}iafuv@d6Wx&!puLrFC<#WpK@Zaim!(}R&R{+Xv zOqda;pzHjaagAc=U14hp6cfYf$-z^!r2^tbl}1dqal9No-k;}x$)6aW;Jkd91hFFH zPM<^R(T=RvmD+_}E~dbn-9|0;U(kOkAA++s)}@r->JCM*#j*82LFuU^Zg^ZVJ(uG9 zUO&%WGg}k#S4oCC;g0W?i;ylcZqNuT_B{&QS9aFdNi`U?)uOVxl-k0=;GHE2&Mr2{qM*q{3)Pt)VprUb6Mo)Q4WO^|gki$`e~aX$_$!u$s%%8(2+c z=uJvNi&q;^3fhjFrFF#(bxT67tgfjxvbMT_%IezcODk(@tx#oca1=mPoCQ%5=}T4l zJb3ZvqBW^QRcON2s?mj^w<^RWB@)k1SrpzBUkLRb$CDoWt`P%~bl0YoST zZR$Gt!hNOASoF=Db*1Z>u;M!jQ05wSI z#mzyf%GYD1G^KVGXciWiIO1P5R1i|;OS1UwTFz0aP?M+Oy=pl+Xbtohu&gW6`6y6W z)-}Uj#v!zsNK3AoT#Y`(U$e>ZYp57BCiwvw(A(IbK8IOeUtl%M-8SY2@zHVNzUQO0 zCdptnVAzq7G7w8K?o`1=#DEgVpigDGTe{?a5pi*T+w<{yG*0B>z%0EHp!-<({@iN# zTa-IQZ5iHJkS;E#B7T>lSaZ4`uEMR zE93ABE>d3=9&$+H9WEXu<$QeElbf^v2z(fksVWE@UB+dmoKc1NtSm3j&07pCg`Pwr z<4X*=&RDDTFr#3+N1Wcr&c8yNX2B6d9%HI9h<5h023f#OTdrt3s*y*6o z+##CTGor6Ts)FHJjh6ZidoA5sTyU}49f%E~w=lK+h4B}arbHPLffx{OZU|t$I=>k< z?r9AyH-sA#h`xZ5JGd_dm{t`c6$IqhCrdt@I5*aQ?u`P7!l2`Q1^NE5vE$_HBn^O7 z7G+&5Zo}J|dR#y$iVEX`uX8kQMj60^wK*eZ)< z`Apug3aTQk6@vGMg55jtFSqL@tz|w;1j0??usUB{bkhALtt>MHlAw|co+FQcIQ>-_ z%T^$V$PsNh`p%vi?>$`ujt>P@EX>8DFT-L-6L z1YjW=U!!ntdV99w&(-+iO@Wm!V8Y~!-#{CWkNiqy)xSksvjPx+05WqAXQNuq>WY~g z=dL2dn-tI6@$;Eh2!k6fg0pLDsKdIy;r{?}PO?VgF^V@m&ZV;Nen*cn#k&Coa{vq2 zXh75Hw~z9sMPP?eqROJfM{DOM^XMf@8l$)h04GbKvgmQokButH(?!dZfj{-r8HAD9Ul{qB*jYncqKoBdV*z)Y%N~lX zR=Y9oa`^6(XN&&;-al)lVpQ)6DFWH48jyY-^jqim`~Cj_Wjdvv@jf>K;P$zV;}PsfZ+KkA?Dg;A`~X2>sSxQ~eYNOppOA8p9`T zdK|6~!Q%m}+975lFrws;aY z+H7?}#N(R&?zfMMong#!xE$c<2ei%5O0EbK7~BpeYpA|fKs{&ErNteDpqmVAqV^0J z^@|%Uby6sdsz#v}>K5k3v*}Cax_JHH^YfNmB4;-@4U(y|5!qwY^^|j-GUSG1a(yPc z`XAQOnDjC?3x>rAxaSMd$~->zmj1;usVoR9ik(9!KS}q=`?-J5GWhx^<&x0j!U4Hc ze?Pizt}otF9fgX5w{_PQ$LY5De1959lt*{rwSmb101l4!Ok|&V_fhWYJWMUl2^y;f zTuJC|+9&yRkc z&C;=AMZ5h3^ypYB+}TtFI%HYcko`OssEAW2d)-xZ;ynKVk?GUq%)OoSC%Yrz&+NtN z)XD*t#^FKG*quFFO5Uf(vqLPZRbxWvY*kofbn=U{EJ?}z?N`byFcUK}DakDwuA`sI zKY#D3vKv|^<2|akwnvLR=~0QkFy=HQm0m1X5OP18i|+T|-h7BSx+RVBWnw=sgFmg_ z?)}t$k$k{VSeZ;^IL}}_%s8`n`W`x$I3Vy5+r{=jdHoR_ zQ)Cw&b~zA9?&SVI&rG6iQ_kk|uj{#WEy5Ry8J7q_2*e0|*% zx;AQVNXH+e^s3qh&aBb^>^QEh!%@TI$7RAeuC|XLqRE1;@!P%kp7MeXk(xY}Fh&N! zfz7=&-9^5e@3|W7pJ!E4@?E02$phE&=$56@8@)VWdhL#nyV*J1WTi!PN_g)$CVzOPsndRXmkgIR2 z#r%0X+@GpOip+HzvFV0AzU`^euFD&fsnDqnuD+E^F3_zZF`)&m&T_xD%Yt&0(GDwi zBspDfF*x_uwPn6;l9=|O6e`HqV_=yPr+ayg8(^$JZS0o?#4?`tKHTC=pgRGBaBeFNRVPF9)w$p- zX2F8ARyH=`DggC-pL5ny(b!@$Sd!X;dHps?f5ff9vLxi?MYsWEPVWO_$K!kQq{SMR zkRf)k0J0k)KSMYun)X^JA&X=VJR^iKx7l7(3N?3>Tt%JqV4y9+@l$qKT*?^msa_`p z1XTLW5{5$|B?|^qe!>aIp!>*Q_uHTB&9M<9#QLyB1B8|Zu~qIz6VXP1t2rx&HvrU| z{{T*%7Rh-71pyQsq{&)FEzn+9$@|Eo-Tt!$Pb=g=K!^bV=F3rS@_SA$poQD_gt0ayPu*OM`bzn1GnOTPRSLAxiYUPrIA%a~? zj!bLH<+{zJ!i=v8Gw#{dLjV+UsOa#IE9;-?ztp6anmZ{)8U+Y!N7JUy0Ff5Vv6Uv} zf#{xtaO1oE*C|Te#F!u@nMY!SmHKVHkK=W2SrAvJeA!ymCpl>-D0S%@G< zrN9LGRMLp38p(~Xotnc|UNlwCUs>Z+GhR@k?aWxHY#-Zr<>39@YI#(o3(;V1+5n`R z-@C(A5)o89m|Mx-?RDrbRWoeIrJg$5Uh4hcbli$-BOqa|fETd!ZyxvcpZMyd;Z2G_ zyUr;XkOjz42YP>1T;CMuLRFO%$H`tyq2j7YqKM8^Cw6KJfXjx~y1-txKcpJbMK%0e1h0x)22+h%3%{_=)>lqwdD7+BuL zvhC@n?x=ZoAftF?QUKw_K)LmK^@&rJTiR6z8<5&70C;^bAF-DHLcjml`9%x~BwTlo zVyXI7rpy>DVYSt-t-YS*asUBbNj&7S;q-Ie*E4I_H~x+UfaTKZ6jqSU2(>ECK?|NWV9# zj=dn8JAfN#X>0>xfHp*q?|*ilTD+*9HDxkfj&h^VP%=_R?miAj z_tlK%R`EsKRUmPTMUfy7q1ZRnrY(n$#x9780H*9dPSO_s&jgi3QbHq85D?n20~8JEJG1A&ombvdc$b7iLv%HC zx6lm>dHtSFx2tw^v;=bdGNuF`e%=$ul`c)DF)kJ@2-D0n=FB>C`|1#ZA?e-Im-xVJg@hrN|_PTW;FLZ|vvLu9*EI++iuZ%O+ThaJV59s2KHr zQd;O0h{-rR${S0cB`mW0b*4!O0{W@2$Oare9!nfQH~zR?oB<)eJi}XG=+^a5ba08o zs_h7-lVXiSY|$9yVb#rf9Rl%hwt1KR+kN@pSMsX&Y`L_=bMp== zCLy`S^zG;TsA&+6ivams%?qf&NBjJ8I}Ln+rb;>UL`T%3{e zpSj%fnr2z6jQg$0NPsMXK~-W`oB$nuj$S{5k_AO8W6T4qTzh%`T2f+yH;FjdF;(FN zWIZf#UQQZgL}3>$!~=+6Gi`d_M?P2eU)bwk=3?Jcs?eL*U(P=(hPJ1TkrhC8kn>~R z1KhV4L&wWPki#X^9&bO|PbcH>bi+B9kGp@&##q?u>g_9ST^MC`ZFMQuuC;A-NtUj) zb!e$N)t0TSsXBz#(3+DCs7GL>(2MC{J4IiG#VO;gZM`88*-0GVB`A;wJzgRK zPL&kNh&=xQHNT}wC7S;Lik5$c7#Hm~>>ac$H1T?w0H$7ov{a4_rFHK?TiH@*?5@3M z$yt3(Ma0#+G35Nm&ua6@s$iEG`R!8L)?hEpZSr>4%}V!nzdg-*l5Hz;0lBv@ww2+# z-J74B)}&)Bhka?I$B!S|Fgp3^>U5=DpUD;DRa01nZF6*aZqZB}DD6wtk7`R=u{C_J@80>+( z=HDe+(2dNN`G0K##s0{ob`Wjz)^yh1cVM9%pN5Q%8XH)To~aJThRO3AQCw-SHllzH z2>ELdI5c*xn(lKL>1*rr((Ez4x$k<@VnMDv)OMccn}q}$5%`a@S#lh69@D2c(5+L+ zkHz?_FT^zLz3AnV}7}dquggT7`l67z-VF#b>jFGF;N^m;)SmuQ?@)7WO}v zM$NOjVHzsjbk^+R{)Y3)SDdaM>dSHmIne!zIOJ5+17Y3<*!>DF6!)tv3X}n}0@_?# zPLA)PzNISfxI-H@o_6NN=n_Md9Y$c}dCg#%)MiGj*mUL{>=$`N#;KH4KmK31$V##h zY*kIQ6q&YfienK}%(i3!dkZqfo>T*Cir}-hhWIJO6W)(_k??r?RY|tPCM9E3TjF&t z_UTKEzgJc7_SWT}3ATV$0;1dzj)R{bl>*0onBt*}YM=~573fn~!rDv-QPBZ_$aHc^ z`n0E8V6aq1C=^>s?C9G)oEu(J1bxZp@|&v>7mllac^GG55)NyVIlY|s2MM$seI zXP?oLG-^?#D!Iv_wuG;L`E`sjg;u}@U01NLPL_ZKW^=kEu)CXaaeF;K%vTpyMZ|-k z0}w`59DJONnf+&1^_6SF1ZvB^j7fq%rLIn%t~6g)*R~{-TRdHOD zZvi$0+dr~v-vij>I3Z91lGZ0iQGaSq zvPmg2T~$v#i31~HlX_9~Dwd6vIEg`*2y>9Xw$b@d)l}1q01mq77Wot7-N#vKm=RnB z#JE9h$tE&;v7%4{jLsBXk`4-S#E(jw7-1Oa3dv)rw!TN}Bk}Sm6tPi_!5JMbVcpJO zAHeyDPSRxJK(bv;)ZCA(%dD1>v%0j5&8ye|Es*|IF@&MGLbbWqD<7n}KklodK#EQz zG>2REdp#f3b-%-2zc6TE6p00i>0iHYk3yLfBF(|7iSi)t-%uXiPEKRvpm)c`Ae zz*46YT(RC3DpxFgf0EaQ&q28;NKBm0yPPJSeM}zAPx=FS6vyMz{`Y!;M{!(zr>}U2-4P22N58Upq{Th zUwb5;K30W{xCDkaJQ(AqEy0(M?G!E%%fl!aRXj=>+FN(T-8kv9N0Zn!m6!z*5T%cg zDJ4&`q-JngOM!daqDu5}%Xigy_FmvLjf`rU3&|sukm5ZEqE!C7`wYCGM+wX{W&Md!3{Al*uA0Q-lzwhBR$Y(8rYe*`%c)MI>)j9&kpQU!fdcx7L4I)>BC&=7)66 zWCH7vw;zA^b$dlDvVwA3spsbmeI7oakHwNb*_zog=Dw`8yi7sfl=KPK;kd9)gx_}@ zJG#_1DupG9Ig7N3M}v^QS+HM&vy6tb#h4h{uu08(>9>bXtaul-#yR{#|<9PCxK?Bi|hS?=p{ zYNV*K@U@0o3)^09w|7lU8H%AU;2XY+VQ!qh>ZDr*4UOCzUiSR&kNUj`cWEky0m_0) zH!uvjVJAyJbK@;#^+vSHsxS)|HZpF&uo&@IZv(gsBQ{4@)GdI?wv0L%`mYx^k%12; z(`y}97UTBuk{+w`zt|ECvC}KugL`A%UJgwY$Kc|kV~}ARw3c3QwmjX$r$s1O@Df_r zIItrd{SFB*q#O)^3NRe_vA5)Z?*9NIg(1bU7GelyETe<-oV9eaXgH8i4dWB300s)5 z-;z%sA;-MZRelBGWo=Y?OG0Fg^1hAEF$VVZu~zO-6&TrDGC9=nwY!{ob)R=P_MVMa zH|r4rTo*Ss0d6hjk-E=Wyrs({xkV~%V4;a7I}V2m_w%kIQfwYXwSiZ18%`cB4ok=2 z)_Kt|GKkQdC5Uxq>1CQXkJ*{}A?)HbQZu`qI3wcMqxls(nEo+~qJ?q^L2@zga(+*v z)zy-hBNbrE2woG8w&=2QQhs(QzPHj_q(Z^DIL@pQPM=w|NZIw3OU(|BOiN+ zM^rvbi^uTj-LAYW%$jy+$BI4do>QYj#Z;aETG}Y$Z%cU-P1XK-UGMC`QCo|XkPz%7 z$mM>PY@5zfMEbApo{^Joq;RZMg*UPD`mFj7?!E!%M ze5b1X-E#desI2?U^k zMn!cDY(^*fl_^HOh07Mu6RUbD8)I3Q<*Pmi2F$W3Wh7ssFKe6neahK3OOO?EvEnDs zOdt1C$2F?>|2;Lp+ENzLc_WcTw`5tMf7U zQ*GR3qgK$ByQ;jdAEeK`;Mu%>4u5UC#N+(HqM(H?*FtqRHXhtmas-Gpaov;;{atf~ znq3W;j&Xh~4{yI#NFFfKrcMxneFLJNvS}iZH2S&vx|(_7nlQmmb|l-FA4%&x@0iMV zSHBS_5oI^wBK$kLARaToZ+i=Q$?|jvy{gb|Dbg~%Lj z)8pXIX)vbJ=tv+~n>y$ZIRN_4VxBqaTRco+3{vH~L=Z6}*2366@~FkN(^1b|FxSz` zDamRpEvcpfLvJ2De-4Z=3sH90Tbj1o)EQa@b*pVsWvHtyL@h#SqO4jnDL|#LCe{}K z+>elVX{MIes?EVKutT&msz4J}AhUs!Chvt|kGaZIB$9V*rD)2NlBYG;9QS-4ugXyH z*z1a&bnRpN`9F6_DAFvU!2xbpA%Py&O5JDY>&=!oa&!Cr{{8wfv&gfQxe@@+sAfF( zBQ8(uR;t9P_arkF10ZhK=qD9VAd)^E;UocjTv?aw{{UJ3Of7iJoo#Ekj+^`M1=CaH zdl)Gluys??EILnrMSoXRObQ_krEUo3y|3JVSG=S|HitX+ayN$~kUZAO7C87m>v#Tu z&^ZLNGG%4d20$;DEZzSA@i%{GQ@`UEJ0|>^Onan5X?*=pNY3hGlpuC5ieAe!M^YcEo1tga@q^$oDlky=`Xn@T`U z12to{D77r?%&BbEm!UN#*H%{Aq{`~kSy5%Hty(K;O;K3XU6dw}z$<{F0<;Q-*=o=$ zZR5v};y_d_1*>gEYTHwxuv)s*Ru$H(OJTQ<9zTax*IJ^&D?q>bS6lS+Qx!qxx$|~Z zwkXp6Z=R}(9cZD1k}NN#y}!+?qJqzBAH-_dlVA}5m#kIe6>BQuE`JC2aj8Xv{{WR| zZkjMJoAJ;;HR!GZJl6A4U4^uum>R%qSki0HNiShQT+~htuC?rZ8>p`dNxgf21`xVd za`4tZUcIg>3em02a02!m_08CvG^qB7_)U6v=~HbH@EUe3c?z0*+||t(1I>D%M_lpP zQo;rnHPmh_#(W@#DqwVVM$n$KWE23z7#g6_(Kc3f0h`mg2Q12e_bZD_rPD{xoUP*U$VZ zOi63gW9=S1c=^kU?lE*-bsWS{xKgLhYQRm8;H@WsTg__EY3@w1l1?L=TT`P_BN|_d z_Wn&`gOd<-1M$-pk8YgAR;)eEmiUW&{{VpWYXAlzps=|&t+>>Dpiq)R6$7@h*t^{l z;JhZpj*$NAQLL_ns{rY3FWLDu)&)WRG$P!?OuuNWwA|LmAOV?F`8mAx#>*=CnArB-{*`01Rd=|VABy)k)nqpscdJ#{EsqX|m;j>w zg>C=>eA+PcP_poXG$;1`s=H5nh3-ii>B31i%f0me@|!H6V~NJ=tDp}q{W?&C!(&RuV*9X=}XdDs%V00;gw z?q&=DP%KI2Hv2VRn^UXwisd{ONdsoq0@`Wocb}`E#wW5fU2kg*4X?*Q%*?rH*q(E1 z>s%xYt<8_hb*x-(qu&Iv{9QnigEh^yH^lnPxDu3!GM6I3T-<`WF}HIE;G;5TGf|E7 zH{#~>Wa;?)LL?U>A{(CHH?aBDCS0!QWU!eTSF5qOV{Xpsdvr{}*FD2eX2D9r(`~J8;vr{Rhxe)(vP}S}t$W}piSiu;9&%DRo z)+1tWfNtz=NxJ3fFQG?fbo9OV-Twflm1d>$ z03wT)vA$l^J^H|b%EnG4fZlfd6lomHtX4jcNE@qtsPW!OOpYaDWOY2k?86&<2LAxr zMY@)$C?i5Pg?Sm+lfpCsi0FMfIcC}zlF}g;u^h(T9n^n$>gj0`$2yLZp6d<^ddAxO zpPT;xBb*f_lV`>uZdO~!p2qKV+!TBL0YOn)0HtIG#IU|g+>RW&z1Oipd~DsT|o;x2#ZP&Mrwe(VPChU0kDrc2G?SLs zK&(MGJ>Oa5tKIJ{nxbjsLaL`B32PO#&Yn>oT)y(QI5t*p25DQe_7R^$r$q1>72ICd zx+%Ih+eTkQ)c*kF*wnNjM!};4&H^bu##`xsFsydpifylgIGXSldgMjDiK*cg6HKbJS{ycg}%9i*`pqHQwJ^-8x)r za#AG`fI%z?YhT#tebp0{mCiAVRZWY5am=~;zaF=>YY_m5fOFt?f((?R4P&H82!sof~JboCb zj!c!E!Bb$jH)Qq~-PIwaF|%cG4Y%fKOdzoBw?+aTd^0o`h8STRD3VJ1}qNq7G7WnIro`TEDGi3JpTZJdLKfH2_uZ7TuCMD6o3<6_Y(5oou?_s zg;gfoCMx^JE2$B2dHe2|P)9b=c^PrQ+?;H!`FQ4(^>THRbYKbKHe}1w!X8s~7rU+u zj>UwTlfp5&yx#tEdHi*$XW>u((eZK$ zTLW~sJjIRvg=uDWV`Mh@IFHd<^#c4H@n?(2vxm0z7D;ATK)ke@S0x-BKO?{1{{S=& z!90f%k=z*D4z=W1mQ;<*tHTpbK zXVCYO*PE$@-;BBc05{!cxP*)cj>t>nyp)VV>Aa2WWbyMZY%W>kE>|VjT*C%tmPGfm zVY{Y9vdy%)5g?VWZb{*ZdR%h4?tZh%`>OV~KZ2o&jJrj)na!=Gk!{fUyi#ue0R2wY z_fb|R6{(KiM$NG86M=m|0h(NRShp7@zE)p%UYh&~4ne0TW+Y=Bf&e|e&OD#HuTe#g zCB{e50Af9ysijdufDR4XO_ZN_&Bx>M^pUxYN+OOh6kH41#t6tvcvO|kg$7(HHqq!o?&ALdY1%xG$#391i1L$avxUJ~ z?iDbAZg@Ad%-nwnPZv`hJyccXxy{{}oXZm-Crx@CA7j=30H>q2ZE(3nhGLAoGVp?& z`DNdamwzMv^ZaI2iEX?iEl5jar+K!0pRJGYI^Wyd7{yMJGb)rC0d-N&r6O&erj=Qh zc6V{6nQ7~%s?O7dySR|b&t?}pzL4H;8Szcl^VLb!`r4jtu1Rc`UVb13NbPmU4`YYw z{oOV|f=H%=lXbdiGE;Ot?iZ}yWqjC{JS1BG0C!%A+H*Eek?BT7)wEsOw-)P^`W&B< z_F!=&Y+hE zIl(nrK@?!Jt|L;+X_)w8rHA<9739DKs2cphNKVd-3gXxcWgkaX?l^!R99 zlXZq^C6|i)B-rwMw&kOeFInY2s+N)!Jd%uCq*ROL#mYg}a({PPnZ_)?D#Sz~0D*pu zxv$=TcSZmL!<)@Yj!tsqfZRZ^G39?pE)UH~FZR35*>R`K?N>}Jv&)`a_;q(>caI)F z&#P=Yy1TNnl-lBJJ8QMQxyj}OQcuxZ6OCZ;)iL(C{aT$I{{YJIS%ByM z$^yXXH9sb`^|0()G!faNjZ_P{JY3yo>E%uI{{X4_uSC(}T_S*nE11O)J|olUVEadr z@nh}2hyEV_0EzLqNX>{=EPJ3E(YC?1sJlSmdC0dg%g2j-Tun4fgA))VIVB`WKs_#1 zd-=Dh`km)m+9Qpp63EvI2ty~e!NcTu{{TBLSIBaiX7%sGQ}}uE_}7i&Y-Cc6&I4(s z2R(hB(Xz#~zbK*H09AEhpdToWyWHpX``^v$`q z^C`yp9{&K?q`$z^g~}-wOY;IZWrwDj*V7ouj9gi7ZF}lEIOXKur^^2TrZTS-fMndr zhTJEGKJGssM?YQQWF#`%n&Q?S=iW1?cf`Dk z%6r=i8UXWp$~hITT%yfb*4LKa{lS{tw$^0Hh8< z=E~9E==@`MS9U)!c4hHVn(DQ_4BNp2a!Eak$j0861eH9n%L&F*WIoWh^12>B9WuuI zJ(*8biXba3fAaC;_Pruzumo#uFOTw@tfT`6=N-CSRd7HUb6ob=@agDdO&cNSWd_F9 z73jgUg-~t!JAaWyu}Nu57#aXrkCxS->NMse#`YDmGX~V%hnrFMcxgq=&zsL#Z^DLq z0tJBiDqDcp70oASKg)a4x&B|mJb3(38_DKBH7%GQmbtDpAdkU92al4Z!_ED*!E^ChJ4~Liu@t#h=tUAXm8VfdZ#*TaBvr8ewL<_6OypjcE6$&$MNZ9O$mAM?PYsu%v69d)QQ1#2*-P`96&--eM(V4&;S@=%fnhfg=jMD5R0NdEwanH=q7um?}US$-cs z%}8hB^HW!=`72suW=C>7{%<{K=#ks5wb3G#2YDFnvCAR!xhf!>ByXYD+HdeG;i)gA zdt&l-<{B!Lw5A+YL0g?X z)V3U;6`MIymPLuK8h$EdqR;p!7-&9fG+W#A*HJjLE%dk)9qKfl-W7+Qx3pGv(ww~1 zeTIc-{z^0`Pe0Sv031}o zVo#iNPT$4x5I-yyjzsnNAGiqqOWRsVu_%x^`0;xL6O8 zmYw0}q9pVDm1yjy=J~m(_q+-@k1;eST`Lq0W6f2HQ^~Ua9WP6eX|JF0QIT}>*w8mH zA1xYdxM)C71*`?=F&f-t3Ap*_mrsxoXly|vv-VM=G8Bw^2n1_?mWViIRszfpq-o7d zWn15$wz@I_BhO+!N;_iiF(t4o%tg)qPePhx%w^sr1ARkYkO0Wvn9++|kHh%5t2Y?m zaovHe4yXSB3Nf`!8O0@VM`1kH^gMr*?)>F+zc=DqP+&;Dq@7J?v9A9Bs{a6tnOHNh zXXhn=Vt*^SNAZal6B!7t!=Sb8&{@Ho1P)&9z1@AB!?_OHaZ*XE1l`XGx>Uu zIAyyGS*}gUeQ)H`p&f40v78lIw}!`7vBLiV)=A{+&jdRZIu_ytbK`C*`cWu!Z_QOr5*?4ohUR`_aSt+p791=p1dE(b>E+5P2t zQAaYdb!7xz+IF6kBlVT&ZCiX8lvW{tDh7w?d0o~I>nq4EJpBCUz5f7TkF$<$)>e8m zDGa9EE8i<-8v4I(rgdQ1E@}nt@6CJnQ^)og8^{dIX41fyu;R0Z-GhpIy{}0L3B_gHy5$;Y>ny7VnPB!@#)y zZkmQLvK%oFYp$Z+u0AcHJbrGHUQ}TuVTJ9`sJ&7Js(w8 zzq*Y~UJ?fgxye|JmdDki&;foFW|bp4T-YuB520ML$isnfWKoVF5zoga^%)(HNU`eJ)Cy;JbOeV9sn88@MCe_B~x-nbk38I(gSC-@obj{94FgQc(@gN*h~~?KW=m zc=p6!e2($9lE-{(K+P!!6pw{d&$snKyfTuD^#QJ&9cbH!yDg~VivukF<6 z9l*_)F02?EmiL%&`2JHE>?1~2W7t@qvy#S;?*9N=RWQaNssk8Ujmm=9DP|br-%h2M zYISoPaJ+gbTOYFioiMZpGlP%~uW@U4aegVkIppeT8z$YNl`EqYrN^Yr$Avh=)pZCg z#2k<8T5isfGb)!=cPEZPo&I#7oH1fKh$AC^Z-TXAs+er& zhDl7)J*|XzsQOv->m^xHhB60Hpi;RX$nH=@wa9>>x}L+r;@*O9H^024=Q(D4h+ftf z=05L@BdltMzG^H+D)$8i@yKy>9=m7SAw??jE9>;foPZnL8|I+egJH4McIWqhKBMplL>*sjn{{VNJ_n)?SW|A2< z(GwG{q11mrE}b`J+pRM`+Rq@ph+Cc_I@gP+m&f6$_K4$zD;8uAeN}I-lasjwi65xl z`_GeGWlVEzQby@CaRj#Yp(iNaI!!E4xBIl7?mzWT&$WCI*Warq^P+)@L@&^aec#aX z-|I1tVws{xKNsExRypB{p~@Vgk!7A3Wa;jAf8{zwv_~#mFrm_aVIX1P&__gH!0P=C9fJ*>CR5P5IBIU!qidBWTvHogb;O(wC_Qg@W z$6+S&7`aE0K+7HO%bG>9fBVee&&|e}luxzKHq|phusb$9673-08}+bkUitjLi`TIH z7@H&j4x4kHW34x2*`Y&{VvwzpB}2lg={`!BOE2elp0mlh>=ej(C{$~4#>3FwNx`$n zC+9zoX)T`;Zgc}8z0vPyo4JK~n`Y83V~mHD?BQRagZ^`A_euW%3$yuoKSTD}mPfn9q`Q!|O)>I4 zRGv;K`G1T506XQj$q}S;72f2t7Z|hamqY$lJ@5K-#T0Ql2Q8tu_oO18C#nATU7p*vRVu>ltb5;1J?7#~=jOdnm>Xtev$QA+ zWY!`00v5*^8hLkH-j9<108-DJz^S%*+D5_J!B*cXFf&FDMe-oqhllRI_pPjUa?VyxVo1_McGZFA z9nq#u>{W;qqLFesOMa5y+O0^!C{fHWZ<78zb`h@F^yKEbWz^q}1KLREJm=3^x>j6x z@#FSzZEJd;5zl{75nQ!|)~>XTh*8iEdODWEX=aV656MW|g@?&llh}q4mQZY`n_qgB zpn!RSHordz2fxGmzZn{$mI|wPT?hkQTcEGw{PBN2*y)C(<~f%3$mC6wkUIxg!ks!6 zzy>Ue!c74^vJM_o)o^l^)6Nz%yIY>w1^Lf%ESr|~e3W10JGmr*SOa3Cu#JlFdl;#Xz6xb_;qrtGX+3dnAubgBTd=%X|{*m#J4vA zo4TJ_{vTOZtg%4mEtdrVT3O^^uP;enIR60UBX0>rpn<-%Guu61oPV!%iH%oXY^4cZet3}?AA6#98L$6 z?ruHgvdOsqG8hhyq&6`s*cLe!$I5u?XbVWhD8@A)v2N0f$H=2~KV1I+cSeZzORBuQ z2dXsKvhApwcA$Lq}>h_ZavrP z{{W+@!tz|PFgTz9i`w@%_ddpTJb6noR#4EVN&96Zp$ri({4{UtI@+#C!zmJQMs0Oa zZuxf_PP29EE49f;j}vSaRC9&sXXKJN{7K4gn%dKzJ#xLA3+t&0f9tMo;;Seni-FE< z&(o1TT^J~Mmx)g^5pizjJsszdz>T5IP_r&f3A;HA*2}t_Imzy}bpHT|@~~D>%`{1? zwj$RoM@4(t`o5JxJ59B#Zpn*hxXG`|PM~^SJ)RGFOd?4xBto&N0D=UwvmW%Vv+p9{ z_CB-Y9*bRa$eW0OwkNjT&gOA{S^6SWQ6WSTt_vQh4bgqvvc#QlbH_6zS;z`pEN4ks zI>UaoGx9O}-m;p3gE%J_Jrw2(=(_n%g%}8zW2Pd;{?8lmc>dH-s$x}DFjN>J>`vzC zVf9o?6mg+u1mB!?vA#4`AQxq`bCJ)_-CrN(tl82`iv$A1Hw(Ev9oM?4FlC!aoe|G* zY)IhU$NvCwjc!GUcyYMVm)J}^e*t+4p>?rg&&Ao%i_`nR7_Kq{fX)j=sX#FqboZP- zPCp+dy0r{tAOf-x5YdSl$jrSI-Z#Ac-5gOS8`lL*@@8#ao5znjx3&PPoKAsseJ#^f zU;CB7P9)5(Dl7pfg+cEkviB2|no6S6P7<@b@@rrYJS3v=am(tweCnN*R0zS4sU*g9 zQnvDt{{ViN^NC$rJemaQNKWlbfA#o(+8_S_-F@uUD-sUyAw&!eg$``iKk>84z>n^s z-^+k;n&!pNT zXxj${+K@ZHmE-tznabsQeZHPd%t6%N-=?O$_3b>=QWncmZ=cUgu93p|75@N7FrpgJ z4tjKJk+7Ax@h13Ts7)u;%&y0O_G(V^eWrO^1&{sH#~8C;Mq+f!SRn z+?r8pbk5YfTMJsmb{e2bsQI<8J)Zddl;wN=Dy&kgGazdQD{BzSasL3qsM&Ue69p15 zxzIMClU`A656jPbe#G3m+Ti@W*85ag)XSLrNZRTtqne?DOMzC7YT;g~Q%k3s{gvm* zAq?IvS8YrOn!H)%IS@4SlUl|xokO(9H?6O2O3X<4D`sM3ZryLDx>B#F@UFUWx0iNM zo@y@26G8M7&sb04t?jHl^ovdn$P@DObJDMW;iWM_DijNr=DoW;9IKu(r-*Cj)A}_t zwNY~c*-Ulqq{c`Bsj*;dH_N0gMKB7*8c~5kf&z?Cfkr53Rff2l7HfQ!D&{ntwKb0Z z1tni2{halw9PUG6ZhZATjBVq`kHp_Uo6Sj(X-s01&S}fgrP`RIhP2EO(0>XXfCl%X zAd3nB4z&`R$Ncd?Ou_HXdtK1#AJJ-Wr>ng*m&$kvlQwmy0`6fuoWA3Yiu$wX^JdA|j572M*Y70}ThYgY?j z$B!S6ZKVo;jc9GHGq!`alMEo?^4w6JNuoUu!$Lb~(h(dW&|ja=%|tpU^j33|ZrU`q z<84np3UiS}+F#|MQ)azdGMX$dKMA5;T9gfoe4x`1^iR!MuC^n_E3(M$B!O9OGmiKq{2ow(#>zeF{Ee~{%cbf z*Ei-MnhR<8YVAG8W^4_vY%xA+IEzgvi<#IPAEFvspNIw5o3fn~6q#Ifmi9k5toFkx zs=`15oQvuf{Vg)hs8RTW#y~x)dwnXGXFIvkhNV9jGkEc<6GPdTI0mtBjY_HT5u}4g za>0Nm%ZcEjU8TYp)q-LAz0ab7j>F#0?{yRE(&#&~(gBp=@|kK4H`5Cnl#Yl|UEx6+4s z7a!eLbypQOOaZ-v>7WB&L-H$HLP=DSvKqD28ggvXJ=ldh0c0ztx$=9SS^RzlNRl+A z6&C5l6|@HaS$Lw}Q}30c9p)ZWKu)-M=*C=H1|(Ss7~l|-fL$;ho*$(qJkNduo zSOP|uAZktY{I=lx-}R=GmR4uPMjl4B?8g59ZZA#hyC+Uy7^ouxIW;G~vP;1XycHx# zBRN?38=B(UFO=Row+D~IM2J=v-V-M=9BUHu{{Z$8o^Kz3y4VX$S-4~l?ap*E(*TN=)YIIb>NHE{`1W! zR%8rT*T&>qm|D65UF-$#-AxW(PKYuSckQ{l@MH<{e(IYOT0R_A9N!YhCOzJvH@|z; ze@VUw9uXTT2n!ul4R~*pHXX_U(7T-D2!x$cc}^ZAem`!SS@>**=Tp*U*VXZO{j7#r zFuoco0YcL^5 z+y`PW#9CisPgD82S&L{SYZ)v@{;37>600_AQ zjF#?ahXri1G+}c>y4xJ@Z@W$oQn<8(iBu~0S&)8C&dI^!@hHUTjL3M8MNbRcqn=Mz z?5Pj zoS*(aRn=nx_I(_!`dCWYzvk;VYq@yZX%*Jn#JbImxVRbjdz}9O zHR|N^sa64)-jxHhEy{p+xmaZB=k7n(p7(B<5SWlron`~KSo!^qhXn(ekpX+0JT5=)AOWV zrBP$SjtE41yQY}0#_)LCb-DV^QfgEnM=WnVXK8{##0kml2U{dvRGY6=?&zRhq{jT3 z?&3FQT^jxMk9WyohbhZ$e!7qWVrALLEiIQ7;m*fD_L%TUwtvq1-V#$4LIzUdNe&oU;mmT9Y}oK_AFKVuzJiu)@BieYF zWLF7P^EK6$OX2aHCFGL~`AwUDyV@wVEnT`un|rg~Mhs(;A;MnW--~*F@;YwQwv>P7wmY^F?3)yROO^Lr zmuAA5Z0l!b0ZR*9GbQ+@-DRKqP91LgD0fnRMz_N{Mt}ZZ1p^il1sqnp-_XSq{{Uqo zZGwHABw`^m2i=!T7GQZ~WQqK=81-FUeGCzzC>PO`nAJ(PC(6qsHug#x_HCV#G;jz_ z7!63cZO^AB%`W%KZ2YBDC8x`5+Y=$)Jc`|PfZXh;J-4Sm(peRl{VcZ5;EHIFuFVFfQ9N_A19xb@x?X> z?QuS;M@-v?+VTE*#%kK-l;4$tXj<5{?roo=Z12Ttt%j#nH%Hn1Sri&qT|761S={2C7)K5oPVU1yG5 zFP<7l$y%b-*3yYZIy-Qq?4bg*(1K`BXf#m6l>S9=>S2$s>c{%sdRE!XgY7LCW{}8ZHc;B=f9S4QxIz$qQKt@P ztoomOT@Kd3=mGNra^t=-H%Gpo`zqo9Y{3@+ZWM!bL!7bY=y7=MyJg1q{v~qN)$oj+ zngv`7l+laiUvf;dV-dz6xs{wJDPYM2B&+V^QW%x7%S$2`q zBN#~YpDKyg{Z(egv0&n70o$GcdN#lN+}~UO01Fpc{ayP}a0F1BNcUHm48$S2&D>dU zq?^RQF7xVN5p?=f&QtvxsxgXVkrySU#0#jC4*5`O^t)9@? z`3^oxFTIoGB#m~^@v5+j<gTQ60*%d3A|R7C~YFtZ}?-w?WAjwn%loPJk%*M7o|9Ef(g za4p1gP`<+>KJ$In?>rX9!q;0YYI#D}=#G;otnu|up{xq6z(7&m=hwQ`rI<2?1iNUb zOcUFzVC9PhU;`c>O5G`)y+^z9y=Tq=_==DS1DNTLl-^4oTe$Xrt>r&;B-pRM}36jn>SkWuWmW2rur%-uh$@^or5W?)IafQ>D4#+;{>@w@YrD&_-m zup<%Z_kAB8>#|IQ5=n3i0!Z}!^X~k4$g%_ra&LVAyLw+@mUkCU9D{1O_nX#w$HT|S zOOXXog2dhJb8CICXUA-q8bFD{WRq37>|aIbhYRC;&987RsP1L^s?~6MwlWUaI;r$< z_1||U=iX6^fOHC5pIefBC-MGjD5)=IIQmMuc6mqob=C!KV(LynWpk&r{{Tmlq9J3w z&n9nmJ*NDpPtc)}*GuRPg7(Nk?yerk=KW`)CrUBm9Wr^ru)jxz{{UA?&5Ky;iPdjq zO*kWq{AnwBHzY0B9YT$9xb&N#@BaX=+AYg~N#ah#U0YuBKL?M|5G$j`R=}o;6*L5a zOo|xnY@|0nN3-a96j%cZRMN+vp55#gnAbN@ZP|X^MMySKPJprPw!cD=B-}+1g;B$m zt`<*aDBEa{Ra=(%6_|QB{9}`c^|g-HYXO|V3n?~4zrAlM_j}LjLDCX|WYDm%#@8(U z{olNyX=D+Wd=$9>RGVtY99jIO5+EvW%3l%->58upQ5Jj=aQN-UQl%7eDdH^Mo!(FB z)I&0|+XH(XNgXzBE*ROlIKIE?;;k@+jE4l9T!M75?nu5H7ZQoRg4|bIR|*EnGt;X`%VkP(DXOGT_Rj$OS1-T0Ev1U`#n^z z7g_50Ix~F;Y=8jVjdJuhi66>m{{Rmhe213ucTiI_m|uu%W#QC*@5d*)=7Siyl)RS7 zMdFOL%7x&|*nfFjoQ}$jmPH_dM?-jVNhMn5g^>wF&RgiK%_Dg|=6FVsMURyK0L_ta zcWhmVm4V5DurfCe3A46*5&0wbA7+H;yzf@XIlk|0CV*OHn?Uw8LDraGZU?loD+ z0Ng5`w;q-r=`c>S%2OHP+1Z!{-nkby;w#ZzL~LMz!C%S-A6dTyoWB)ea*4oweo6{3LcC!*S@> z(QR=lI+8EVYg_qh_4C_oJHBuU*RtKd&QaE|bLveXEpg_(2{-ZM$LQm09V;ze9Ggv3 z=hw28Y9UbAuT*Sz-izzzq=YTm{5*L4K9P{{RM}@YLqLt&Lf%_ptK{^_Q4e zowA@%D)lxM=gZ66hlG1#vkL=i@OG$(B!;;f`RnNCu@&R(Nzd~33f3%jBj?x@~IO2gD7q?DGm4)GJ!n)38$` zZ9Lv{OPJ(uQeiex@>QEXm>P$iReLs?dHJZZtZyc$=krkkMyaniKW!crhfslBO{u%S z(WR*}XlPZ8_N26e$Fs+e9N{DOQVZ!@k&BaaxiyfUB7e$|8;j3OjO+cisT}gBsnxGb zlbW{X?Mq{IA0=8lWidK50*GiT17Iu&zO~A?y<$BsLh3T zR>ZfFO&MJ%p-p`D@#Dp-;7Q7~8hNQRrOMx18h(IMoQ6M!xyM`JJ8I>#>D|_)-CiX% zYmB31^NIj~Mg>t>v^zksUz*e|mWfR^al`W(=25x>@c8lmTGO2<$7~Otu!|E_i1J1D zO{q^GD#|sg!L-CyxUe5BYY*FMWP>2RmL0SogvXwW)r5HQ5fzk2vVd9Ejd`d@@#Dwv(MhiDrP0vZqZ*v_cIKc?dS$8?$SZ|!ZFLAe>Mq0M z$B)20>M>HGptiKaC~d7KH@k0^gbhtFiV!sv4Y&;}0M?4t*0D4#h54y=)|$gg3u{wS zvWZThe6`lC9QA2!u1$yOu3>Fxdk5j7VW6W+Zqpli=s>aklu9(LBGhSZhn)PMvbk-Q zyDLc7@LsTVzfC$jX*brmTpH7EZcdzbOZvmPO6~02LDe);z|G8ksb&I1C5}oE~~5HXkJ{paV{G zSP2HyXzbN{9VqvL_|gYG`)f|TL+7mAOs252Ty}GI*imoYNd%INIySodOwqRmC@?*` z*0v)0kUZD5XSJ!3>o8@uDgYenOnG0684^F%Q9HJUVh?K-0Pu}GnpBqxa$8(;8ZF`G zxdX6&)mXS4!BZjh{7GE%cCUEz*!xw1#lX47 z{MY_cc>e$^Ts1UgX7Yg&WKdhu)7$b<)?I{{APteq#1`Tn*Q@dKIc1GwYeN`d1VqY76_*qjbT z*l1YsYl?q~<5g!mkk0B>{WcfV=|{Z&MKs{K2RT1D2b=2Sk-x30;1=!yC{%;yARit- z1rWmfvT$`Rel)U%?)0fDqgDW4(R`H>{;!Hm{C*azu|U$w%7apOo)t$EA#G+@J`GNnsJ}3I^~Zdd=M{cMA@sV^MZu&T{%FyS}mUrbJr7 zUY6b(ebX?=n<(9YJshO+RgBZh2I?r?1;JBnu6Srf;~i4#18fF6FVM)qHN@IGnzo=^5uX7ejEv_sYvs^L{koH7%~*()WNvu?DpU`z$?*H1^tYlzvyqmf zF^OpP#~5eMFN{ZA`j^CN6@)wK0-Cdf32jw&>3o0Ht~d%wH=Z=S>TRcsg}KmZiR z03Maxv+kq%6&BL36uON|UiTMf`UzJ?WKEG%cgO_?dV9=5?>9N&8}7chbtNVCSSMw2 z$qOJhz=B+NcbnPwP|5fI0C(wXU1fLpMQ+NEBu$XZEdUtU>y{Wlq2)YXFSBhdrpg-E zuxwxTl*qXh5wwe;B|uY^{kb=Lx8U*SU{#V*Om;w}n`lr_MW+uOa&GfaA0fNDufmhb zvIGiM3kF5ho@Q>6!Jh|r?*9O8BvUMd6p-IcNGAUPdX?yH8?m!5D60}J@=)g>4-4$; zFB}>B&F}Wh6SnT>Xbi2Ys>pQuZ$9so@$=IwajE~*^5j1kQ5qCX!nm3;z{?X3EH~UA z{>xLqXW7e3v&(qT36~44iQ@R?&eayxw3*~GIKXL)eCng}DfhjEFyC1eZ*iRe6O`|1cHmMF?0Ngc^% zNaPJB7U)LtGeOG#0Q*Px6p?IL7g=5uH^g9bM<VF3s(1%OlWtN3{{Zdm`_J#PNP&vMb5KJ_iM)0w=cc?- z!tWz3-l6I>i2#=OJdq38!I7@Ht>}(u;sGyynpdK=>0z>y^H5g$kOL&HbRYS??N)6 zkE6WCHr;>vc|Y!SsVKZb7l%gK_KX`OeyQNvxgWd#0BPTUO4io&tt3+ls=(}yi=2`i zkgP|!j|6)+{{UOpbw4knn_9&hmvYid;{?XVvg+b|`Wzms`pz#SceQN5Ngi-UM1Ugl z_}RGPc(RZEpFjAnx-Th4`&#Wyv&LM@IuOWFhZnew+v4#_ESPmt$FoZxtdq-QEIu2V zz1^MD0L_kS4uBPmdVDS9e|t&JIU}2B+MgI1NGfz)+i>)KY>CNzT%V2;wKTG%I5G(c z56&Ih*H7@D2|OdTioph*J7Ps@}V+>P*&XL*BXsDv+@3aSku5Wnj&IYLghIF8U19K zN5R(eaZj@K{HG->T;6uou|+lZ)FJce$ctoWK1OJJpI^xRM4hjfJB*&T@U$5ZrH7dgLxw@2Ct9_(TybZJp?--RH@8b1Jr5r*VTmEQk7ht+`Ce1U+3EIj__O|J z3)I&4KfWLnPQVq>TTUe!pGzvB1KY3peHsu4jXZe#CIxiI4&OD$Jo~*i)Ie8TL0T(p z#|rB0E4y{^t3l&TVATcK< z9S%uT)9XB28D}TB>4#4&uZ+L5&T{UAM7ep#H}d}gsqi;+;wzSBvnUD$vdg>s%J%ku zf#X4*(=cUTkWqrIy*WC3r`6Q6Y%jDDc4*sX2K;VHuIPWnZrlB$9j07R6p;Aur6S`tm1%i=fP)C%F6Y?w5{{Z2q z4I|F+#})M!GH80Q_I?hYezW}Pi(uH`xQuVCY|=_YMIM~2pYppjar^D>_VJtdd)+24 z9u1dj+FTV1LEtj6$rHPTzt!PhuEp?n%Lr1S?Ep2iHhUF3e-+u?cR!cPM_zi|~eju_j;D%l(S@ z!YHzJN_17@Vpo)v-)UyTPakFEobwgPMl=M6koGsa^|E2=`nr_NvSYhmdw}i2%kQ7n zdcRvh@=D|(C0gJW>byqL1ikcIJtQ_J?8!4<)!7NyPFyi>B;drd#*%>fwvkbzZmRY#=cYHmrjfVM0tX>Cp5h+A!;8eo8*8 z=P40GyR%>=BQ13dM-Iyp;U?#SH@;F|ap;RV2*FB#Es6t_WqzT#SCjhQSLTw51Z7?E z1r!B9z~aOnQNE@s&FeO3U0UMWXsQU{MfU#ys-+F0K+T2Oms73Acir-vtRah&@=<@T zOCljKlf07IZj#GldNS=VjvFjVmzVu_RZ&o7NYKc)y9GBR700{5>HE0pki?(~WGr8# zN#Gum@UI8>bvQ`_2_#VRFGk|gKu1e;n|``(Tde;8c8`$wf7|(SlM=RT5N;CO_v0I%x5H&S*%7V5T8Phs6&A1{WB&ke)>KMSpvl>Z)Zd?r*LdXUgrqit z`15)F8sleNfw1QcbW`ot^R23%T*TsUkz(V zU=Ca>*uCTLU#CDpW+L}(FMHeA&FCXgLKV2M(FpWUrCl6oTsEYRbK9qsT&o4j-0@pb z2MXowS7b=cwkS_OH|Ei)K$6xB#DQUKb+I`w^C)8qq?E842a-)fVtF{BUt9aS28;w= zWdz)*He-H@c>e%axoFp-E`x;_-4y*VW9WaB3X~N*ku7~Jd4T#poAs63wj`^*Oi0GS z1yTod?0EC?I-;vB*`I|-=aIT^xc+~XWErv5O0$xC!f)oeF_M7EXLyb4FeMZ z2toJ0aJe2=4}AXs_P$|xBaOFFjHK8u}*v$KC$>%Jsg?gkfSZOirAuH!)=k9DEJdedT82MizBa_}eBW`MN0^ z_ay$Vx;9dnD`Al#&@OuSchhll-_`xy$i)&2-I$;qj8PW+8_rLUdE}$(CYSl5bgV}U za*d${k2^0L;{ttej>SzRj@Tnf$?GM+MrLlNJlClD-&4G!0!DBKL5E1TE3Zu=c`=bb zGxdC*@dy|&4Wex5Nyr0~jyhYfUA%s)lI4BhClvB6wj==+o1!?#(m0fiw8vzAC%bS$muXr)stw=#zboZ@G3{GG4(EX+1R@4&1P6)Z-f>4PykGwSke8^3 zdGT+0(d{!NWuj9ojc}$!1~Ds08~vSkOYxbej!nnJcN~o0RUEPw3JKv6h7(zLBG{&< zn+?W(`_}zsQSb)VM5AVxM^JY}_J!hja_{bw$C@cipfP|qvR+A85X*6v8#H)t6LGNb zIcc}OfAWx$$pgc>u$5nW6!)Y+-SzTYMqIY<-Pde9Ib|PlB4jUNuQ{&^;Ye{)Z791R5+E!?mMLIR?8@g2PF3bVzP3%z3Y_zKtPP&mE2Qw>8ar+d2gRTTeAf0@&%<@>}w&(1=^J z_zLHhcIm~j3tFBDsgbSYN?~(rcGnv8%fqBR{k7*$t?yO4-lCZE*0Ezc^QKYwYB&5< zY@^4XqmWqq_0~9FB^m=pXekb@S{*AzLuV=f07KhbK&Fj)$-G7v@DDW>G|X$Qt#hu?*o#yfMux5=dFlB)4Vhk{UYCWBF=B|5RN_+}2C(YYgw}zzT zEuhwrP20~@NJXkZc>Y?U$x@Ea$`=+Ds>97)DFkUpv*{Fq_o51LwY03Y?^H(`<9pos ztD~2?&E8g0E=yGG6%VS-C)7RS$}OC&$+x!b)&&yV&u3;x?jYAYN;u4OCFOl0I0?aoN8TE>Vq zqPw!Unv~5gsH`Ha*|vS_Fw|IFSQh-Q8bq;_l&P^J%|yy?M^{@ZdBrl?&?vhBMsz>f zOE<_XO=oRsp#byIuRQ_htzOzsp&Aa_KRpEs4uo@j)wPu&S9aD4>I01psEMJZ!CJD_ zqOkMSnQ9}=LTIiiOkEADUeqM<51zHp=rtFAWo5g5YSE|z@Kj(hXao_ZtC_^73WpYTx~KyQ+TK)>v*H~#=C7T9zHw)LQ0hs#-A+6_8u zZ8p6SC$r5#-qy5vt2P-#1&zS@>lLhQ+Sem$K=W%r+8WM|#jWf>x0;A+aem8zL2;?0 zTM`Ht(@Ly6a5*SE*ZJsMoB>7_#DUHQs7|8iQ~0YoYGlPNm6V(LsjIYEOlIhwVa@q8 zwYg!9@6HVnkY3~;ma}ncWXx8e%Xwr)2pso!UoH!R4!9OA+9Cx@ryDUr!21prlL z8k`1JJ;nExRKEp$C9`884`JmaJJO$6k@D&Lzu`9eYu%Y8Jk}u)u=;VTOqh7~^Oh%u zKGMb3_Uf57m=xWS-3FQ2ax6!a+4y&RzgfNe zuS7#EshS*3jj!D{>CGEq8e#X}*kkjh^hP$R{xUHFE0Cp7Vz{yB<4)S-cy_s=%UIgV z+}?{TNb6l1tc)3?Qp=6*btOtK4tQ{SzZSo3*(96c3Efb#AyW3WhYIxge1A7rDK4z^ z*KZd6VkqO8wq>2=C5()mN3zk${JrHSVpYOm&KR6S+V?)I8oF5vY-$$D`_?KzzcRa6gWR87VFB?vB$h)95+eQ`Pg4COLo3f3b1MyEfkIyP+L15 zi{R2lzVeg=Y?%chw2sOEkbdz6)+HP5hj({#rKCLBo3b! z1Tt5h7R!^cgk&S>7qz{fA5~Z7Z3LUe=EQ-49`}DAPxJo(*x2J%Or2W<7Z(>nZZ5Bm zNjHC6FT0D<&NMtK&ebx8A0V9UDr;AciPWL3msz# zXGt3Cy~%oV{Z$G=$%|Nk=Ku-bZ#|ikC*SAybiuO|k%HyaS);j0RC=U#P2c$btIe?D z>l$yEZt=EDU^tFhChk76l^)p~uF{Mll|k0otNN}%{dZN#R>TUd;Fi#O6QE#4$?3_2QSBVByNhwX_15KUO1)KELwmymFf;0)BBn=aZLX*|}Z{I&DByf33mPdJ?b`(h=>fi$D8jmlxvhaWS z+b4pEz7Ujyy2O=xvtex1@_bOc^`1|4C;Hsb9yoYlg}9wy!=ewY6M^Yx>h z!%g?R-LG~^MADWdUPiUjGv$6Q4167TCBOBbr4decdCD#1b|7PY6+q~HZ|BqF}?HJjV0*`!w7dgllOyrZ>@$z|FRxrqSTw6YtV{^Mb7K(4T~qVF{SiuReV#Zn4E9X<&4xMY3%&D0PhP`S*>&fqgdiE7FQG%%TXg zP6Z#jNZ{{Z3sG4<(wMm0$7{JpYodo4L4 zSn`rIW_K1cGY##S@jm|ms)PRk$Jt{403!+us~fS^O|*Ozdmgq=`<2S|GQQpCPZ{EW zf4*GOMiCQI(lc(LP!{dwWs}{*>wRTe1W1J#Ak)8n20!`9Q$-AM7Lr>@GXfpkMG87Q zB0BF<7?ad~{%U(l%g?KkWm^Ga-3-;*KIc^*>iu3h{M*;)<>GT~Y{agXRXzMy#FEln zyE7;UBW?vbHe5`<8gC?+FI6ASwk&cfeOmst@aog){{Wc~Eo9(c+-w2$lH}DS z%@D};0K&-~iiRUM-gbPKCof=nd*-CvKNcGDW{^Eq-DfT~A0Ey*tB5qkID@_sD9@ub?P+D?mP zn?M)`CfvZOK{N*(^2`AEU@WQv}AzXvq z$~hiwo-faNPq$=y2(qh~(lqGD8egTulJA?Y_x=QtY}-65#0fgF=i=tcC6-Q}^9Snc zm2x61ss`%Ja=vZp@3Q{@m`+Zqi(!l~Y@44|ka3Z2KCPFO{oH(3%;E=dcF5o%6;V&{3u}rJkx@JTwP=T0EpTf9^`5H**8&CiipoL7g%t4tlXgw z9A5D@Z(a4CQjK<4&=9<<-zGQ5lt1wjy46ss;sa_}Umjy_ini~`$Mv7+ZovRV8ec*H zwf&5L_5T176w1nuw+t=_TX%X(AJE=bDuAZgqcO>u2N$)3pQ`=e$i==S0i$-E2hjDO ztDq6y&4?s(6WHFLPNdA_%wnOew#bX?(Z?z|3Jh1As5sj|r@g~kg4Pz$1IBr-Iyfyh zIK_!8t%)?Le3HF6cvvFniVw2#v>b?h0VUz5-5z7RSF2bmq4S~r7>er z+6&*b{GHaV3nlNMBSCHv}lu9Pk8bOcz~*$EvvamqU35VhA+aj52D#e^&Dc;x#e)@#Zc`YJ#u2c$a2 zqO(|-&&Z%@mA;J&Gk`8tYzZpH$&K;lU#H`)^Qu`K*dQgvzyO6+2Zwor@}F1x&pd^O z?TZD;`xx_FbazxSJ&?xai7rl7zd2pxdH3E*<0Pdp151F<@_9)kMB}4SkuV4r%n2or z5|5+D$$0ugV&2YN+3b1ysymD5NBi1QV=FC=Crq*&N*jlRA6et&Cv2=<_cpdywkO!} zFYru)R4s{AQ>Jz|9#N!s?Vmk~>RoixLgTa&8CnjP zt0RO8(UXy9G$uuOYE;4<~NEP8|31?*9PY zfBT!Bx5^^dFG{{R=uv-tU^5rEA4%II)HtX4Es^rq`KUP^iXx+thw@xiQ2MM~`8uwxqZ za#-WTkI8&ztByXamgM7HuO{V2B_nMX!xc8_Jd>3YsL`{xDYQVJzOQ3#kub~~p%+uxiSI5o@`~U}8A%oZWn<0X z@!`D4;F-eJLBxuoe9_Q)*W^iV{hA1N_pw-ZOu~3 zM!zjFoT^E_N{!7|D-M3%&oxHmej{BYg{2xAB1I0Bpvuu*UHU2mTT_=xK=FH8nw<45 zt%O>=2uI?mAsUOnHm14l z*0mh+yd!h3j~+h~sg12J)zQQxQNnnrO>}0>N>`8o>7nhUGi&Fqcd+x)3sjWC%a*pc zZ8A^|z#tDf=Kf7h2{)z_yf)DJvC97dlU9bqeKaHRR;jV2wm{h5Pnx%y$IZ_{Jaf}W zH6$G=fR17-HMr>;QB>@$p&g&1pmV4mYJ3g=H#Awnmy;)ca5Sk9-Hja1N!RdFVmg}C z@Ga$HwRERSP;Gp8@%V|&*0Zio_#s8u+qS(3;$7b{sGwY&iuK`87T34nrz@p~Jd}y3 zs&{&R!&!*9s#6jz+g!4)_+t`ot;{Q8h;Ksyt-0!gmln0Ck!*Rns~0XECKHF5@Vq6O zG6y*V{k1^Zz}eY!xgU#4u6XjpKx}(O5nB)sZkpCETxqFOi#J4rYRkduQFM?I^u@8Iv zhN}2iW(_21&2}2wFXXLyTXxTv!L1VTkfT%bR=(2x-_6c+xzGB0asH>r2+r)V<_6Rx zaD(R3y|}g&I~(EwBy;)gPz{v)Q!)s}$UHiYE41oPUVG8JeqGOR;BA{I^X3Qiit_#! zfW@~FZji5|_$hJOR{mjLXT$J=EOM`&#<=5^@4r`mr%HW}r<`9o8U~ap(^_`X9ugHC z*V4K){{UH_P{artyyrncX0-inF_k~VNO9lmQ+ zy|wh82W)rPreTT56}SfFM{-1bIY{5t*UvO~z9iUr>p11RXTi%3*)RBg|rw(u^iS6y6?X5KjW36Lbnz$Z%#~-GnBA9kPqiD&XP~|O>f&=k) zw|)8z%&8_aBSv|+l+_ypgjtx#slY6xt>kO)ZsL+THzsjD{eOPClG;nNX?X1x~sUMS4H{%VnGF z(7Sd%3a|OdpLyf)RKNcKF&123W>)}e3zom*S3zp zTHi3QWNrTd2g)rxPy$ILZXHFh$*;yDedj~m)-c!^Uvf4@irnXX(x0Qn9Fcv_`muK} zGb}&Nm)*hIbayu{IejtwPZ}}{74wdICQ?5IdTH=}%*TA3q$+sVW)9QG-N$jY{{Zj* z0PEMPyAQzeSeVSOY_MYOJ&*qYLk0cU`B}|hGe7BKVU_bcf2s270eb=g>^cj5sca#{J87e+V+2fryGV?;&G}f#&@IJ!)9&k$Mu^%c_-jqs^mgaU^#`J7W|#A z6o>Wy0Oh`%=w-%vXVJ_2y`7)-I!>Ng_{MdA>0;~~5g?7ZZam64>(|*^aDg^Ch6{4a z8>G$UA1@Ktg(t!LYHHg%>)N17zy9R=M~~4vJexV<*JqC@k~0{e~SmJrEuzT zj$L2jv;P1G+I2e&o8a4m(rm%O=6J|p-9t8KBsM%uPJ$dnbDe$tfRN7l9Gz!?Y@58{apX%x9^eh@*XZar+jo9=(_{{YSZ0QW!pe4hpFaG`R7SU?1ylZUCQhrv5*yA)R4 zn-tzZPawXZAKq7~viuxVW>iVGi7o+U$RY>nv%tGbB++sB^I5V@ENIEKi82R-lxGS^YAvi|=7f4?PW%?u&Jn3^d7xMg9Xwp=vnXO9vK5Jxd6 ztlsmZTpxfL5oIS!<&*2v-AS|bIQ?fWTFtYs4+?>6#zzn%rt7;@-uU}m^1HqnDj*(A zqXEIWxwVD3GI*EDdz$nFX(+tAVq^PEE%$5L1AIV9`|>i8Ty{H`b9KvG(vWlBE+(QEF1Fhl26UlgAAdV z2VLnvMl3Ea{_pzQdEW$$V+ErMU4tFuVQ<4Uv9|?6u5w5dO?cMCv2KnsJblXaS8DjG zmD&(6w+1H+vEtl%dygMU*wJ$BQ*4DXvW6K5<8vbM4x8-q24<{9QVkVL9wr3n{1((}yo;9m>IRX^lOY;TwnfU(Kb!t6SaO=L=1|dN?+Qxf3Ye zxz(*~b?GJx$L}{k=*lc!RU^vsvPl>wEVA93`m8*skDbC`;ZdYhAOPm1##N7@GChXM|?X7Q#q!K8F%Xv+#KIHT4F#nk5Xe ziBR2DTch`0$L}cNo;6Ur$GUZqxDOffyToVY({=gUB@nL9VQ86&WxGkn^W)FMj7K3VzNdaik z9RoF?R#^r-P0DYP`mRg!)TXph$Bdgh$eopSRfI`w~fk<9C1|mwhfmiDE31X!udL?Be85#%kMDH za}*t9TdrH{s))IkS5UHWN{gyu5!JmNufMy)s-=}O=OeiG!g25v1wI)TTz*UA6LmiJZ+S`F?DIt#5)#Egb>db=9#Jyuc>d?_IqQdKks~U{ zdjgeJ=CoNHxmHeBX5;+NyzMm%d1T$UjmojoelDiwEPz%V>SpG5R=Z==2Cc@J4HpEam-Zt%ph*b>}L zFG-O&daCxec-l>!Qe`{$YcK}_k@VtOJzuTeCs{V@)z_7)($o{{T}b_qq1+ z*3BxS>n^R*T!p|so*%oSSuJ6G;Q)}t>6g5$3YOK@#jIIdHa*AsAH4pHRTCnhXio^R z?DD)+=#q*ZA@Hkc`;Pfi)nT}_kR5}sQTonX*BUfc! z_IdK0ZBk&2t@SQPI`eqo@#i3+Nj9;)>`69Wm-K5a_5#6H{QbQB?mX*D0x(rBFJo(q z6Ya>HjOVQ70AGJdk6l?T-o4vZH&C(^w&=+&E&mhWyE@$GxMCdaRlPNm;K7eB*>(a z0p?X0<8N6jZ$9g){>eD3*kxi(uZT9{^+@$1{2E(lDILI_i^{CbfZxerp0o9}>4_K) zdjaWje8{ECsBDtUk& ze%b>dV#3!6p(oEm3o%;`+m>s6l}{g>p*tdwRn67d90rQSabqr7Je;Q~_j|6ZrYCIj zJh3FDK+Imsb5K2sYXGQ8<{I0xN)%L$Y@i+SBUGc2i~-*HSM+Wr-u3 zUUnn0v4HOMdrz2AktOZ`VSXX!{I9V~o%cux1*DLKR2Mu=(aK1=sQi5LV7R+nyI#Y# z*W#>TPZJZbo10Gc60Ql@2o@&#eT1&MmI_Zey8OPE=ISM*Gd0<`S3(bItYtREE=9Fs zZI`i4-IK3BXEwh|g1{R9OMJXS!@S-<%0=U2XS2z0K(Nq({L0MS09PhbP4C<3?j!Ye z&P;mR`jT7edOcQu)>cOu2wFQm^gQ6()u&m>RE9xV+V(&N)Qd5>9aVBbSy?0&SGJn# zZazPESR%x65t0sZaH9P$ADpd2b$kwfU=Fr!80xEgS$K8y-lvuERzRf# z9Cd5vPxW-cERC968E!yT94U`U)<$Ueo5!q8+o_Uu86Fl4avE^)V!Vge)Qg9r*R{#B zG3MaMl*{#W;?e6}KkiZy@L4e+k#O3M z!^a;a+?Mak&&k6VlJ+K_Zh$Dql}!wcJ8{SToh9V|09n4?edh$3wC!RC0dn3G!id-5 z=&DVZ2Pfm?^CX?5k(n5(FF+|P7{4bj&Yk_{)1NCpyrbK8hMq$+sY~$20St8Mx+ep_ z9`k?6a$R52#90Z#G*Cof$nJ*QGB+0TXMPFyspS6vsp)D!283}GY)})NqTwSvPWv=& z{d95id%qV`qZLcJ+}o4WKu=C&i~7oJu@wPkSCB26lx!5|(Z(J))1&0~-*MF_uFz!; zfk_NQVx2n!yNP$N)p-00F=J6LCBP&z)`{CuPDJ?HuYMo?-^+Qg5aqMuHQBPWQh*|qtBrUPI*&iW#QZ8|&5 zxcq*NxwmJH<(xv6$&G+xWYaCy_wqb;vIj+%J)uOXji~h7qsgmscxPD@Wm8pB%t%$Z z1RfbyPq7MCI|fop0C|8n7BA55@ObR6We^`&>)zNyh#E03{zQ{R*25*eeIb*s}9;Wy8beE7Q`?b5-L0Bd4p~kKqcf z&pp0jda4x)Z*QE9FGx@?f5GP0$B#WR4?T3RJX;aD{im|2LM$#X=~!!FNQi5H^ZcrC zQo|y9bsXdQ$y}?xCbXFW&TY}`)7q}j{HP82v;(uYfVi~{>`dQ|rjd_r6E*60ZFJWD zdSqiNRKJKG-pY~h%sGyL0(pf|A<&IM=eIexr&GPJuD>X)$16%a^_8bzHDw{y)z#Fd zi+_-(C++Q|flNyS=k3;@^db30`zzFgP&up0gG*^%uFlr+ZEK!R-8Fz|YDB|BR0DsM z(xp;w^H(RM)6A8C9QEdH1Z7@~`d>=&<~rA0TrwokQ5qW4n+TyA8aq%M3Q(BoO7Zg3 zEN%W&vJI|Ao_lFhG9$(ND!5Z78g}-7)8qI@8L|1R(AfdV6$1C|=l*q9!jp8Jsv?qT zEu|olZEyH1)c8U$SRFo7V_thMntqzhyK3XA)#m=%Ab6{gS$yZuTt^nP5H~-XnHeA( zcF++_QKd+^E`I>hl?l{}rPludi=L4XvuoMo$BSxs7s;ZSR_S{UD-}Tp_zGZ*6g{6U zul$v)m>4yEbQSAE(JQI*>t1|miN8O~%~eF;HIHp)E?dKeHnoPd;Ns%Kk(QUP52&to zJIkd=GXBcbVAi!RHYC`8WfNUkU!Tuh>!hPwH{-5#;s$~SyC0IC-2B(9)}Z00TFs|z zE)p+HW5f++wGi!ahuR@NW~*i#1y5+T0{~;n0e@0f!R1){{Ri__`{eS-1IvnSo|X6v3dUh4O#HECuU*VHL>lkw*BbgFp~_wH_DZ|UmJ^a z`ZetAn>JBI>In1NrP%giByE0Sac`IH*Sqk1c8WA0WC!5;HLUXEhB`eyd3m1?L1ac= zU_S+Sz|oe3n_P1X3cKOxOKigtsO%KKfMOyR45Sg~wZD;_V&(q;hUwY=0E^@KUaW9p z4Ztg7&96Ot&%&E9k84sv3jLoiv0qqufED5VF&Oy7HW#~FPRrJMd0CYGTk7L3K0G}A zvtm>Q@j9JI(2Y7T>{p)@SHO^17PWx+d$skA0Fnzejlll^$13yoeUXp_3uO*XuX_)o zwY++qIOoylj}_-^n>Q`Ap&ErY`96bG`xGqGi-To1I$O5BgJ;>m+9hiXI_jOi)?QKS zZIfqZpCfx4=z2W=0Nkxy`TZIDFCUt3iv`dIflN?c=IL_GWn0qUtZC2rLE4Q z`d80T3J!m?Li>;*^n<>)gn&+~#V{hR@ZhzTXnOS&UuqMRn0jf4uM7UFf zH%GBXo%uVIjCJ$Yfo20uYguQOQ&#W(S5BUH3vM5aw?3c$0IP@Ds4~n63hECzWn#LI zZk#?4+ZBVAjz<~>I5XSp9e1_nkOBNO!)vLdbKq|+{w%-MlhudeCwUd8f-(k=CcEGO z?mRVG?Quxunm|#2UChPt7Cj}))>nrdyRl%P-&>#hYj?$k@E-N79GNE!@QyM6j`aIokr%%4lozFaf$NazY zc0Q;}y^UG+SmeuAauOr}^f?`?c`K!E`dJVN_E_37N+@{FCA}=;;P|)y0A{{5l5Dqu ziGNSMf2O>3cEz@&dHj4$I=Sl=i}Mk>CC4W{n*2+uT%X`q}t&48YCM` znWiIB-YhsE*mNY*`(jVh96Z9{_>a`aVWsu{%Q+@98 zCP-r4zyAQJ{@S16dCDEkJ_OMQIzI$@4k?^^qu{{V&R z%kaM594SclRBIUJkyxYO$h-dl!%e*Pzv{X6kA%^c=Gr8TlAtQC{wBY7{fl*m`^ndn zvwSVJZpv){Z2*EKRu{PG@%&$_=>Gu7?H;p#A9oYsx!PNz5eZyGZbOyy|Y3Nj2k8xD&Kox@ZS9iBH9<@am}BSrAs+@RT=KQt)jGxMyXQvJQ-L@4hKKpZ=roYId2HP-9{gmdK@}Es)DM z_{6TAJ`X9$@_eUzj#|`_-t;=ty7p&v{vZ5r`uZY}OvKE}>_IG`qbNqc62-dx--}A| zE}$;Gh%PX>T>ggktZl1q#TG|)=O7VeT|L6`vwO|%>ZKcA)as!mBo4;|=;ORrAW9?s8eoVq^=&h+B+6^rOvv4|!G%p61Mqq4>EwTT{86 zg~rnf4vC@az>T6S(nX#Pq2uU(`CB&g_T9Hzx9)$JwO-xpc<=ZAKeKtruqH_wJKv(L zZI7h8&YYA5A(3tgQo4XhG5rS5j=E{oI6{pQvm9^AC(}qQ zHPugleHPH0*!P}3Oi6sP0;6=fU;{Te7s%_@?)|^b)w@2@n#j?Y003MO6Dhw2^&;+6 zY|nJv$_Y!7g|Hl4U2L1J=Z)9PWLqB52;Q-gH=OMpW>M=pa5S^a^_1&;pUMB%`N=kj zC2Ih3CfKmDfqpI?uamCY9swfZc@$hY7ggaLmhMkg>nYm|Q%NL}Fv2^K&!V5!=fU+< z?W1P_OQU3`pIt59-Q~(h`KnXSGGh=33~b(kv2UzSpVh}jvh16}OrW3(aL?u){J8_c>m(l&@A-BneakrEbSv;4Hq3lJ2JZxcOh>BpZ~7Fd)0>orU*h z`Z#YmH`KddlLGcXAlKf>R_yBeDVU=)yy+mb%YLuo>m+dQVCufN$$KJx3}RPxAZL<7 zhA=GF9eTx%eGhvSf7-yDVH8U$xOtxH$!vsU(w8q!$nN^O(1U3Mi~yxjvAZ37t{$g( z$<~B;AD5V*Ry(z@X*BY$v&X!__kOqk03~tAUTC9Wk(-wmV!3du$s*yq&-H)Jk?#b% zJ2G0uNe0* zIKQScY|h^|j|b7S&ubgJFIQy^pN>9GC6XbvX&y!o7nP3i)Dk z-M?3zK~#>h?C3DmHxz}gvPspp9#TJh5P9I=PS?xNaulQz)U zfRVB29Go9p@As2Q%6iY@)3ognI>aMf7+sx#E-X0HXLj6449J~dDZTggL!+jTYgpb% zr36dSXwM9--jYc@WO?;IN7oWmf@0kTZ52 zUePvNt)FI--^V8MPvI6_v-aCcHc98=L~_da&NHgNr7BaJ z6m(=+H8xi&}z9kF`V@g6iA2R#H`3&W-CW_l&OH4+*OHg@ zwyayz#)0!j1m@qbwzSesgi!I8jYbykm_Gu~(ux%IWOq(y^9Z7N(k60D~a zyLn&Y%Y<9c{hQvfUkAaXP2!>fs4Q-{zeQqvcR#)IbY48QIWJ?3C|MPW;RGAts2h#s ze{X2~4N`@Pp~rN@(5B2y(>Svm0Ae z{QAw}UxX5%>)3MoUw2zPb3}}B)E)vKOLwwK7p?KUsW*z;V~q0y^83tsA5}(*h~#!{ z0RZ&bj;BnOI5;mZ{ww>+YNT$(nRzrE>cxp;=rwaiAT|V) zZEs-Q-4w?zpWXYeNP`A^DW{cn=_Z_OAPPd5@T70*a|EyrD!QU(dcM>*yT?R`HdNSVlxWN|I3 z1r`tVo%(;6{{Uz5p0(EB2*4i7#N%*fwe+@oxbt}8eunFRC+?$?)Z}gHT-@5&`KgLj z51UI@mp{UA#)%|~$&k9ndU*FG-86g8>i+${?2_o%v7Prm~POm9jhElK0%5vxQG3owpm~K=HaFPKx?H$e)}N$63B7?)Y)3aYrHu&7>nwZEn4 z`2PS+H+2k70Y&oWC)4}7SaaER16DS;KFYIje~aYkiD6=Sk2yd2P>od$a(TIc9NhJc zs&B(@nCtdwV!)E1mig~{{{S-&tF96@15Y?7yTjhfqB4kA6VJ|VrwuDsaN+m6jpr@v zJ3QgEWec+9y?p=!h+mEO9nR+~yqlBy&Ps{0^b>80NtqD|MbJN_!_Z%E$8*#H{{YcQ4wgL3 zM|kP<{>1=hjaielf#&}Ju8&io)lWG-Iv>3KqRfYnUowy_&FYb;AsPbG3$tHCVm%+X zS_6}s5C)e6LviHSxSSqWtmP?47|3#KG0~5h_xO=y!^bH`s;o>wt{5FiHekG1N96OO zrl8(FQWDl)4sO%c?$f^?S3*ZBWR~&9gjug_it#LO>Y}1h(lLyiA^~u`99~hlc{J(z z$-g<(t7*q14jEBcE!f{bM<4M#uj~*(Km#Kjxxukujeb4G&EMmsIs{o{my=L=S8kU- z2e$tJ)%h=zX4;~KN(aH0pf$@*O`jOKPj}*6e7T&;!+1t^x}o7zW6)HEAE)KHP51UP zB@ZeTAhsp=m05UWj@Typ7V^8Aj3)hrTPZt~3`~~cy}Ob5Sv-8nl0&q!rL&Y+f?NcQ z4i=TOZoBiJRaF@zjGdkU1lAaZIqt9ZiN_}wC;85=q=dmE3uucL@~+mqTcHGgO(Kh- z^}nNz@HXiHlRh8=l8Y-Gd0ym_N8{vaXrhid_*vS(v$eQas_1Xyf)6Rx$Fyx-XC>YN zIavUMS#>Vr%AViF_$A%!iTdAn)<&SrWRxb)5}5SG-b2LBI^Iw9arm{HCd3uYHY|)R zR9HJX9HeaR=z7mp`Vt{=9`ZwXn4C$v4cX&v_8*!50A$?mzuIK&<|sijv2;3MF1*7I zj~Ctl0EN$9cwP>j(qvc`MrdOsjp^jshXj*Eo~wEO@4WSuCgq)^WlWJrDA&pG99%Ca z)&BsqLrrANmL)@uV&^^|CA-<-=tcGa0POWSE&xDna%>n`LmYUQ+mp$1e4niJ*eMd| zwJ5fTuw4YWhGOTWnL0i z>x4=5(o=THMADt1zy?f!L1oaK(|aU+X7Wat$3nk@;|n2dx27~=02VCFyjwQ+9p3V6 zB9Hpb{{UG{kRU?|Wy$i6LKHef zgpQjp#DC%wq1Jh5zc#bT#f+i4M@3+4nSJ8t(@)NFT%Ts;)~gv~EVB5kf@`54!|eF~ z04m^5Jx8(KxhCN%4hT-T27V;*?|zS|5&^A=`3E(6W?P$hGU@i^r4b+lp5${=$+nlj z%l#TiW(SE^@M}qyaJ>2x%nxtTtZLUy$qbTCZ`o5c!<>)7Q%y5!NH!HKNpo`FE~hPj zCbcng%QcTNwZ~^?nvl}ve?ORL4Nhb>HaZ#t$F{YmZKjrAp8cOFpeE;Ceo710QUrVo z?JFz1c=7&Gb!}%gXs9;SYyOlVR)RtOB9Oio9K?L%vb`7{!o!=kyoj6d8n!tP%~;Ah zVRZH5aYZYPb?2yKlzHlgVUC>j#?D@umk&Q_ZgXB#ob~AKEHxG8y>!LLKO~yqFI;O> z8|kmhT4CJO4XL^Obl|r(I((*>A8e%MO<`%K5Q)O%G zZhm1-eU%wBOZIlwd*gkVULH|Xai#6@X;tjRt6JTcwK=v`02C}d)V5qDwe=%g+MelV zhi#nIlY+ogu%1SMo^x|*)Z9+f-_N_rw?6Kw#ZM^i3q)eLoFFAUO&nSu4wigi> zqAF}b9Nhl^$))@o0NLZkMTJtL;*A!-+OXE9_R@8=L$^b`mKNs`waGPM=e;a|&NTko zqP+0v)O3DI!xAVbQ$$LK9LBPD02^~t?QyTk6n6N@3!!1>rmPAqbsCX=dW|FlX$JQ` zYV>X`TgrH29S085FKq$6dE0)>B#{^k_FVJVq>XY%WiC(&7qRnIDza3o8D7qBH7Ozi z+fH-O*;PHV^Ix0~?$S<;@*l%nm5U?i=Go*MLm<~)-!Q8FEs_sE$yCh|_{d-{r=F<5 z*mi!O=GOM(Fy_360Q*1$zk`q(Tb{dPh+uwQHLpf!_$|+Go|Us8Rbi;swFSqQm$)KU zTXSEUmSJoOwZ2OA8ZY*ueB!aTY+`?x=Kl2;(=#709@If^*fpjY&c7u+c76cB5pjR~ zYC20Y1-I>?QK!vT?U68zTHOBtCbcwZO|<8tfLmWR2_3#_rbtT08poegYDIy+%R$qW zaM05+0ef#kY1z?cuzs4;Yu6D`645S%{In^0J6^L;FJ?a`m7raL^H!Bxo-TiquxZR{ zFb=xbgG9jG*8#0bqgz&sf=A-4(zi=cSP1~u<)_a=ZYzg1N!K;lsjzCx>qIp+HA`oM zrx@cL)*rTsM+seQPiLNh-?z_P2Ns8)vWw4?H#{`o=f6Lmw1I0!&E_?v8(f2LJ*aW9 z(Axa;De@(e!q}ZW-IQWTu`2O{at>Z=dJ+0`z1uZ_=RDpQ$MB&b+1AY~mJB#Wirk-T zk9p!s#$B_|{Vv$`dcD1!*>-*&+`j_~NTgy)YiQ6B^t##=RR!`nHb4M*nXlulstyb+ zk;l6?9CshgmEK>}$w_EcSBf@VPZ|<=t}jJB{}guC9DOD09wiv9*=o=a8Bo0{wG9&ENgRM0KQ{fTcg1XH?`|3*8yy8 zv|p@0GM*kTE;$q1@#e2>#c?1lqtLS1Wf=~R^S=J`tEjeJjz<=a(H!$Ba<8Xzg4Uky(l6S4!9l_e{3_`VZc(pb8>pLZ1CMQ4fp znzgNa@W&R>?dMHMB*l)p`wHjFt*VNcajnKrap7kcX#GhK*}5^ z3wl^_;aLkd8k@4{;lQ|4`4sormD7_2^0Bpsqqw|(tKj_SV2Rtd(qt;4BZ|P$=(v7SyQ2l9vu%sGwPC^9IY2ujjwBC+tJ|bKJH4_qnL7esVNrW*355WI?n1t zLkhCd%EERe#>8DlIw z^5ep}{{Z#(vF|xA)_&XQWGY?cl82FtmjQ+O@$My85swT2e2E&H*eE>>-;~{c^0p|- zn}^VgYC@Cf^0GG}7mu2VQr1mxUpyOR`$XZ-Kx`_=4HZ00iw zDmb|&Y*t9w{{Y@FN!yf%c%NGz?bM(M$-OZ!U|g2s*Bg6unb0JLCpIh9x6UVV?C!d*x(`VjMq_a5 zV+9>=ik>Y8O&i@+Uk4B*WEL>WrJquzN(SD?6n}5`Qr57NAYpwh!QmkBMZhuKBTu3=wg@MVSzbGbqc2 z{`2a+RI+y@N+^wzSkwm6$b%(YaOIc0{{R-f&7M)ROd)kAXPWXtPP>)^J)>uv*z%Fr zj-*?3=MVwiggB;6%E-K`$Dd1+ht}31cyV^k+6OMswb>XpJghiLoU&k>&Q3{T{{UzF zR#R|Gs`}{V$*jN*!Us3me*SdU#ETmeJ9`Cu*2nOqF|tC)$sUJ*=76x?tlK}7-OraR zv-t}24-Vu%1`uqZc5X0jgGR)^f!?Kp&vP|S5_Jk^NwnmPB6*F zgHx84q&C`1S|u2bbk~Q|XX@bl%FfzdCkVJ=E$0A+s3dlMNLD@RIk5xX`A4O^(z#4j zw1gc1=#S<@`|s%JEgqq4nlVvuLG_aDJH+wpDB5j-b1a6<%^YKd-=(~wi+6?fH#flt zIqCF%&yLwuUhYNwLTGe4o1P8Kt-?mLV48+$!KWzK1R4dflc^ zC+)fdd6h%Fk7g$w;NdJvu6%wsUs=y4^`nL2MMcn<<>xmejWYE2JI@^a+h9n8?J8|6 z71Y^XM=c~To>D{cc|BI^`;l!MT+*1Ih7}<0LAvMK)R^C`#L=*_lcbtR&p#s7Sw`& zol0>n$hQuGR^ynAs~32Mw_sTF`SkP>f_vRX?~%2vf1&7oWfn|q&0+?N&wKqUob4oB zgJ3i{jCJ&1htTT0S#7p5lFAfoYslEDk@s)A&p)jGjFcAU48>c`%x&Yz)-KZ!B@Yqi z#18MK)#uOfmhLj^c?U2H(qqtV)VS{c^T~V8$fMfj2fXdmW_I1j9z6$#gO+&m{{V`Y zm6OQTCd1o#*y^jybRF;I@Fgh7ScZ83W8J<$pLN`?Gk)bm$+RMpZ@X1?mNZ^ZpuCil zKdqiU+`sVt*9o3xRdmH0pC-JG7k_lECe86p_uNPCBy$)JOaB1f`3>~{ z083GtFON-=By90%iWw~>^65txk5fn@2m-WuNiM9#_%2(MT%Xl`wn{h>byBz_n=W~R z&bzYXgE!Cpw`s%3bw0Cq=aO4Q!yp3YZK>if?Xtq0lleabTAqg9@4>rUGN_c@&~qNs z>m{#LHT5^J^K*YHzF{`A8*WxSAOtpSc~;Yjcgex;H+#SH^&w=DhV;9r)+A=)=UzFW zhpqi&KJU$%)OyNt)G=_p;CcPFqZ-|>sL%n~f9cW@u7niB5_AWc`&cIL^OV2^gEqcm zG{4ctv08Hgqvq!AuHFsXqZ_l`bv8u@=upchAlU0}&28<{4AH`tlEwRqdE_n8$Ntp_ z-SWbjnHd!U2nf85_ddl#we8WDW|fV~l7pLf!^826e6*gb53BjnD>KeEnR-joyKu%l zoUb3->Z1DDB$Kgo%!utFiCt}A(*^j$-Ct?epmkhv9tH7?o94c&}^`>6Rw;ish4bTr{m9(fgom(i6kyAOAXv8I;b{X-y-`& zxOp!jM1ezwl$j!1w6Zw^Bk1GGJkJz-5Poyj*4Zf|iar}~s+Ka5X8P$qpOLQM1gfDD zNVbLzF8=pv^)dLDmQA8WOh-j-QqXJU0Nq4wDQA8mY}`UPl+&n6;d z8inH3*d{nDFgFA-#fOuREPvtO zt?HuQ^;^#Hw0MCdo4+ttBaOYJPFHsqyX${eT>~CgYm>l83Q!9wC_J9c>oAUiB5CKv zO_U41SK#_mFCQ<^@t}qwSU#+}0!Sa1;$Q6O<;Voc0j=+qj+|MnYrXrw_>Qm2(hQkn z1i~1y4u>$s3ExR33DDn9`i<0IH0W%SZ8yY`WpUCW`7gWmpLbmH%p-Wy6TN(yfE&$u zHH(w$H=afqzuDmBdd}bo>S~$K$%G`16{5%+Pf1K+%I(;)+vxYl-OKajl$DYc1 z*5dNKResW`vyG}aR7tWaD}7k2N>i(ETl>G;7X2W3V}-soCP1sGvBVBqF87pu%zLg= z%G37vUu9S#ceIiNJEDVVRP`|Qvwy7jQqz2*qqU`Bz*mHVIBjg^MfT#K$C^`-T&-kM z0-k$6PsgjOLEb=Y2nt1qO6@lVsa-H}@;gmP9O1I6jV_l}8~g+JS$N-}zLhBrUE$J)S&&4zfJC zU6ABox8^h!6{HK+R+(*eVDaO}_UNfHxK`S@3X`o_Zyr2<0<~ckCQ&H84wT5jCZR?a zfwVqr_SGz)CI0|TNcX)U_>bbMRlf(%QNnJ2ZCa`iH#eBpa;}(7sfCUC^c7xKOVlZK z>;kC_H+Q)F70SEg;U3aKCZKAsXw`}SH6x2wv1200rXt+=k7aVMM#G+|i4FOW$ZEGk z8M<1h7##KKBpMDHi=!%sYa98G_iNY!;9S@qVW7Qo>ScI$>0>!9ARcOV{41kO4f0nbHbm`@; z2Jdz{{yG;xKsGw|(z3f9Z~Q!X`5$gf7Spg_U*e3tVX z1QsXdrF-!8*v|=txJm@&yu|Oug}@? zSD_dtKa^K2ov>zhk-Ke84hxRUj?QY6E3>f{B90a;_x+bOi#T+bm8wA`d8z7jsUtk0 z@#Dwfp-#50bGyxaX+_J1L7x9M_|a>sFT7!ht$;<`j{m(;#Hzz(CJ%ypO%RbGZ1T1Gih?f!PM#U z5zR&{etPLbt${ye<&V}lBh719i4)a@LYFpan5b6OxepOfV3<- z4>fs6rp!GaabH+tW0GrYYDK@7UPjHcp%MZy*P8zT%3cp2u%_aGLK$v3yFB!U02^QD zuWg}Qe8Zcc`Ay^a_`6DC=dU!J8OvZQL~@{Zfble#cyr{%t^AAA4-+@Jb3ZuY)beuU<_u# zvl@FXAFRXw0A*)8!H_e}Hetyu=T)}44wfzC=%Z$ksfs|P8<3#nfb_b2p7-9<9%zu{ zmo0D?h-NhM366hR4(o9rt)UykD3cOZ4a;V-llAKlCl}S%FJ3>?(_~p)PVtV1Q)S}5 zJUgfAz3p;FjiDM}P&LQ9zr)a=5;I+i7_K>-Lo$QW!{a=o%%6Yx&U1@?>cv@BHPkX$ zS#ds}Pn64*+4&gh)zMkArp#j~#K$aR>P4|w&!>e~7T2=F5rvql+!E=MEzgzS^x~2K z09jr*R0OPcYh02EA8(^0(^6>_0|o>&!vjti^iNCRulR;NzkR(@vW2Je^5wh?6lOSp4^AJ>{U8gl&{}XHYIk#TU_{ z-R_V|7_uq15#m$oofeg~&Dg8bvB`kb+2EoYEuo;dB_>6J2Z2==^@%SnHeEl zMwr+!$Fr+?^TW3hBm!;D5A=E!L|b;%6lE4mDdM{em1Dz_R)AzNsVY^3pkm5g&I;~P(^^^M)auwyUWhfIp-$qPxCx}jU=R*0md(oacdiQx_&=EhaT!Qd)c2X zY(=cSE$sqC^0kpkW;RbTJ^ZU+7&0vY>?%rIuKCAVeth88c+Y3XPMBl;yw70e- zg_olr)>m%KJ2E&6oaW&Au^T5x)%v=xW!hi>tibJ(F)Ym;##iaD#nL!r;Qs)a@V>NT zWoSUzitRSA-MJbi`1iaJZ)Atn{NQ+?eyi(0L>orO5Za1D3&IdNItPF^7WQ4Ak(Yzj z{_c)tmN8`owsZ?I3`(o_7Flt4M=L$d8nb&f(= zOIW9!x?BGMZTFr|w`P&qnh~I0B#*3eqo;>Yd%@~?X#W6W<%5>w4n?l&pq4BLli}rM zk}oOcdauBx2L=qyq}p%`g~r}DKI)ex*qNfuSESyT2<0W>>f+(#y=0VPLWISt*x~~K z8CiIpB$wa%-we%}GBV7g2uQ)-({EKA9$Rbi^OnguQf`7fouh_5q?=ORs)cP`T}4=RcG zQA^0h?O4$S#7iajgyD{%QrqtLK{z_kS>H~D_f14Rq~N~xPJHMM3dOMt@Kfpin!Pl!2EgsdP%;jFEsxE85juLAX$kn zMT~>1SRjq~;@*Dpx#_*EGjTDGH(rh?al=jsciy+wRO5Cdc=TJGZGH=o*tfj+zga)5 znu})IVR=|crqSGt0erFx!@#_hpI6IC>wPB7*dm~Wm0-%NVZE78DbmMvvh3bYe!FS< znK^{%1aSu4b&LY4t*jRKEBl@1?vw#nX*Pzssk)(Sf zz6)$!+5OM&{{TmwCfTsSqBP!v&vyaDYi3TX{{XXCcH^vZaU5D8&NNhJ&3OE@k>|Vj zp7XcNNEnyQJ2Eik6kKreNtRxMeO~UjAlb>B#LKliNxB@AN9F!RF+=4%TO^baienMQ zAS8v+kI?sYTcwggymD?V#Yn`;-Sv-;fj>X>U5fywNg}GqhQMoiPm}eWpRV3+zq){B zB3>C(^XGNpb@by(VA{5D)?{X0dtY3RJID6=A5}!jvcf|?4UslL8Odfn`Lo?i7SHZ_ z&w1L`+2Vwxj_Vw1bPSD{?{Zu`rhom8Px>`lVB2Jzjg?>G=s*gfT$Z;%pVU=QZ5u2I zRby|AZ0lm`0R_EWf0fpFK6H}B(26A*lNq-VR_da3SL^+D@>zXdP+M@gM$Nn`8QsCu z`o2+pKl^^wR4pekPRv@x7&@$EVhKF7e)FeaPqOV%N^E&2=Mk|OmFaIFeuUrE*T^g~ z-zHmQ==1u_+rQqb6KWGg(l#p|i*Nz4wB> z#4z&Z8gPm-(3P>vceEn+@U9y}v_=B4ENpI6?{>|h?a|}ecX`A0lSi}v01O`(mR3_R z-W3A`LoAUZ@T@zXo8uAe`TD6P{{Z%nCrT*K4=7QBf=R*eufF=v-O+qw7{?}sPZ0#3 zW#P(1_n)n*Uu2zFMzTj1ftg&cWctZAIAc$&<$R~|l=~)ptFo>dnHgM1^^vQ)N4@&^ zywDnKnHuByJ5w}gJLce8(# zF8dhOrCAsbAQa6f(2Hqp_RINk#ZN4mv`Daw8)%PcF4kaBYi^GRPD+kf%$8**9YT?j zgb|FA{&ROmyG{8 zB;=9a=HW57dCv~})&Bs8UqiC|A0$WpqEPNwGh)7O9o&BFtNh(~j8-|*XpAaNkYST9 z1#OQUYxkXB{$EEN@yfT2go?Jkug-lKrTK?ts8UR95!31avig3iY;8-$vDT_ZPL%xq*de@wvD0k;*gL+TUN2}W9)z2 zuX25xINLB~M(%*QU+J>+HrG!n!PU~^8~Hb`Nw=6{IJ>(#KP9v5kaauhKI_&U zmPqbn!#IW2F=NtU-_NZ2&a2gYrG{2Q%N$I+%=zxkfmH{|M$OQgUn@!?N!=G^$-*ER zW5XL?l-;M4!`65#XvC~W4vvK62^g+8r2D1EuFtOBZ@jPAH-teQ)>6$WuvG^W>F)VX ztJZB(fBJ7ai&(@Gi5PFDag=Q!EvY?wr+?gja)M@cIlP8L&g^Pv(&DZ=u(*XJ9Y7+@VZfaw=z6+XQlnO~{HE>szgh01 z{{TrC(yNffT(!$zo1L-TXVm`yS^EIBw8tl7qT44o;mhki?KgKKFb=!C$9QKl#y{ou zKlptA0Q(!itMt!!k>oaIXu|_`$TXPS9#?3`p(m-2cm3oPkxCs7YrJ8X2N;4bgH6Nb zrL5akAV}hk*dsA)NVbQiv?qQ^T@(7hDMrfiA;FAoDB?^p4G{|au1UEk_kX$GaZc+X zbuu>*uvTMriF}tuyp}u7*Sh?s3N4uKq>a@=Fyai_<6+9~qVMkdkFEU$B(1baT4o`q zYZhUy>~Tu@eOKMbysRn$5<@YNE(*QK22HbPylmr>YU_XePG8IL{!E;z~ILr6LfJLP)W5$neTY~miy~IP06-0%45wB!Miw(B-&UN z-*go7q*vtf`rqBtot3U`N6AGlq~vZ-Y3KWXKgkHaot`{+)4nrN;m<3TwVcPA)DkuS z01Ywh4abijK14b3Sldxv zb82*}c83`J{KWJ9A0O))jJ2f-BW4%o^9uGN0nU`9;AXqUOcId{5x~!CodFz(mi4aV~BYkRj3)y-(8;bV7 zj~+SD%uhXc#yq)GBLudehQ_N$D8r?vVYNXUVtY0DY3lmk!}isbIAmtgA>M)pH`hau ztK**P2WQLAURKb=xF=6A&sjREqib2hDR5O4t=C2|tk8N0H z20cF*t~aLiW7ObpDPi+eJ3tJ=g{EutfgdFd(h{QGmAo+3bLw?B!gpI5j!hb>knvxi zBzo+sr(8{RWZTXxRanS-7m~^tq~iIG$Wdrx{7#Fp}w`wZWs|lzrXbz)rfjhrU0YG9Lbv`^G1Lbm z&sXe`S`(%3+9=-#T&F}N+<*z}xAQ5g4tD1?P>gedtpg_*BSBfQy^>;5Zbvt5E#etm zXnft3U_t&=sSsS$f~mkXxubA#6iA$(WeFG1^VERvW#;DXsk6B>*-efhi~O}Hz0aI} zDw3$%u$tKjKQGBs2eGmIbVY`o!krD8!o%dCjmA_?zz;RWLXz2t?WaR;L7UC`YIb3t z&sfrpHMhuV4*X!%rL|YSos}-shafDZ5#FW-x#zL2n1&;WY3CJBBKUdcFc>f&iir-M zdV7M0R#Fc(y8i$+q*vo%dhI!JyciT;n$@kXXe8Kyr)>d0r8KgC z0O{tpJy8s$Tg^?JGcPYmlHe6QCqmtnmV_*WTl-x1r!LY;ix9-=uk>-OaFnO4K* zub_3Z$Bihn@PtNpn)4C(>Ss_Gk1?-VA<7kdTHN*Lv}DLn z%dP2i^w?nxN6SR8-JC7>eoNKR?Bek?hy>dGl=wDP-%G=&(6zMsb(&<`#BpQuP+wcs zTU>Zb-fPsLYfM5s^@LFlYRZ<_x*9V=IqOLtdd)j+Xra=z4Gjni>kw~UrD&{{k);3^ zJ-?2Pt!-^s7TSsoaZzq+5)J6imDU%jqRI2=Puba;=(#yiSSs|a2X(5ygZ>R0V0JIN zMIZtIH$QmM;p6b!J_ME`Hw_Ch7HJzakYCDfceCuXZvN_vC5=pL#^lS8*fAOe00Vn+ zjlNF0hG=7FMGSbl0S--#y$>b#mFypYGmVI0lA9{DW;G5DFS|&$*hTkSO*y^@jtJ4q zo%871r5Kg+v3zP*t!8Mnxx;sMHnpc9J~zZy`>7S{sWNOy6+ zO5+rE9VD0T`nvikc0}1$VzXrmFc`BcxMS0gYL^~Y$?iO-@q8OXv5_~o-4PqBE<-07 zD(2XAb-BGB*NL-i(90B?M$wWYSStaOT=&}_ClzXJ8$Qqh6UDSWxx)o(j*kSqrF$o1 zc;S#hs&gg5Tk#trE3|FN{%TF#c|ZJK*5@gpNYrze3KSmZ?-?TL(Yv;$y_g{QS|9^q zB3qa!N=(tSt}?(c&Hz3zdaq>JZy~B6+d&s-K5KW`)OiiFsf4MFEW69Ak(2K2)bBi{Vr*!w#xeG>x&BppsqoBz z@$B)q?J>ST$l6jG^yAt#b`fV8h|Dgx0~IH#j#pn@&ixX1s}@L3ktax8q> z+tJ}#M0=&loRzWRUU}Up7T&NQAccpUtViidl^+mA)YKjb~n)dtoM?4KC_glWo6V0 z=uOEy=hDWIWLwP0gpv(|F*f>7P7p$^e1YPzX5&@UJnJHK+1Lh4IReZMUA^Vur$`kf z0>zs7gl#h)SI7GACDqiJ^KRqGEKY<2t-ZHO2;`1wLb70?PFVB19+a|wS4S}{09CLS zy}=Hi%%9>?Eh06?MUBWFaqY5r{jxejyGvjQRT_pC!x6)~NY~0zcBpoA0pOD{d`LX^ z^^_RQ(TNp#^aV!|x4T~-z+;X;a|^_zDyY@6KI^QhHZ9~*EQrQA2B3?Aqv=tGYvvC+fXypLtQ|b_)E( zA}T{?lQ3(r_EseQps(7` zYKz=*zumTU5a(t)_z}i|ZnlJfBbM>EmG9 zN}D`W3AHLp^WQa^80Gk^+rR$5CH?V~&#^i~v?Cbrn!6wuBxA@fUlri%CD{BsKQU6I z`&3e-?TetYker#|;O_N&6;7K}Pr3{m{Fslp}!6l8$g=KOXax0V% z6U6>cH1gm0vIxXHGqRSD;T($0u99CQ!+MZz_I?R{ZqK~k^<8Z;{aMmgJ4EgPTgf9J zNeIvYyoU4c*1d$9Bnpud<`JtK+U$8rHg9bQm}p}Uu98_-Yjq5lBL$i-J>i9w1+ zIh%vNmi-?ow12bv&sX)A*0>6GB9U0+9#AmYvk zL|?7hAo6c7k^JnHgDr|Y94V35juIYI-{d!&xYLtvQfZ^#ey-9xvr1%=Hi)4mhzA5* zS!CX0#`EFzzUEIVTi0@2p6@93my0sdDwCSB%&OufBZG6g z+kB_I{{Yi2v?{%$Yr#R8Ra~xV3A-#wY&LkKUt2d{{{XVQt)pn4Y9WtwViv=P+xmaI zt`20$n(`6ONhGkpS-i)R@%ROV$^#{>Z7>$U_?|w~01%Z>x4&)2VWtq}vT`izV}52A z;(zfeQHN&4*@^RU5`U4(eY&D3Cw3r?5^Qt|quZCiNE&ZJ!X#|i5t$dn4&**R%>XFm z>`l$J94z4UE7FQF8O_cP4-PzR!8ov!KC5{BWd2UL%d@q-6xHhd}VA z^FP)2(t_+{P=qvx$^6~vyvGMr@>FxpWb3t7Y~Z$M8LoSxW>2v)hN!z5?PdMpyOY4?x_kQP_Bin#>#|hNGpa(J8bZ0mhT4;sb6UgmR6hSwtqx8mT$s~aM8btYd(?MKJUs?rP`q0C7na4z6DFY z{WOkE-uLsBKgM>M*s3Ix0N1c<3wwS_KhMutXhUR)2nsG2RwG+A@!cd4axi%O4c-l% zad%ywJ6{29LuOClDk(`AKGZePnOzs+o4FWf(k>rLsGcc$F8+PbSaF zNA+}~tsnuV-hyuXKF`n#=%ilCsk@4>s3rRmQ7WR*0#N`}TXFESftk&-k$-NrTO$Ft2M zS=KYoMYC9eJv`MnKN$C(L-?NIK?T%0SPn^&<@jfZGO&~K8Sd6bRz9ZI2 z8*$-He}i7??B-F;l8zO>LvbIV_3Q^@Eue)ECYh2Q{hm)`>c0N~Yu@|6GfM3lW@S+Q zsTyw|5#*$wX_uz|0CoM2#-29oX4%^ll(K_ng$g!^zQYt;-nK5xe2Ck*Px4 z{nimfhwFZ_N}ZJzR3mAngAQU(DzV|pZv1@xXND{4HV=eh+Ze}=5`Y*-p}P9Wo5%KA zuNyZDb6-x5wef>Bm#ivK=yi#2dQ&9`1%f>!HQ1&H=yoJbx=Ym)l+v z$?$H?DOFV|V|xp z>0mnzU8`>`wq2nD1lh}BqJ_Jt^rVx&B}*aNV{^Pl&50N?*>U=QNk6Qwp$epyj0^MW zp+81~?hVU2>(6_s{Er{fJ5O`x$~2gyU}obn74Of`pKAXA>+JskyseXQ8mXK zss8d9{`0?f)w0IXvGX|0AjL`o$@d=ml_Zhhc|38VslMPImAtmu?>|=+Dt(%GoFt6V zp7W7>iT?nQi5gy?x$69`N?pfOaUHHB+Mx{XVtOOiSX&mFV=ZZvW!CU#&)zAK+lXa z$avTI^GacmF^vUjb~XPx|3Fr5ZPusny31I?F%sVt&Q4JWG_I1~wb9#m(N9e<>uMtE-Qb{1k06-E`;k z*EMLoaP*+ zwF3isuXldRREh`6dqwJ*D}hQ6oCD^jsJCDwQnEqWW9PkTa0i0Jzldm)$^cC;I^Hd^t;eG)`@ZLJpOHbc>Eot)6aS3*Ri5z1Z!`>T&RTf zt!-^W*1n&nq|!~Tdyh9gGBPx!+KqG-4XH=PuiI6m;HEYv{g$sf_vX@@6odi@zu?pP zHJclmXyPx=`0J!%M+hg+r9LR~vl{?B{TgP_l3Kw51GR2HqgmT}wr%CgZ*SpC#^;*e zU~0p&w>cIar5#<4=xz8f zPh=OTZ0Tm^x654aZnLM$%$6K-8jl{0Midj-Z#QjwTU1~%W;hG z)ncd5_^V1>^IoVyzz1)duqiG3Xznu?B3zCeRXaQvzJQLtWr(rai%_nm``JRw*(faU=HEF;GD@EHez)T=Up@sj@h*B-_tH4C#@* zv_5Y&NLI(>uS5Xrb<|s0ns~?;?5QMsicye)1@BIA@?NJX`kXZL_-fQ3(bUf>hwuonE4pzN6%?Qm>AWi_O=iTqxF9zI3w z3(5n|Pd+xa0qB9uK1yR1Typ`a;Rmk?wx$4Y=LWHnfz((U zM!|9`ZYx=AD**kJShnpCIs9}hd-hQfEnQ)u=A%bq$d8tW3{Qb;cg{y zrOK;)s_@9>zhy9UY0X)|-^OO^P0{G6S9zVFG`w!1SRU}s_* zom&=L*TqLPk(MZI8>9#pBzL%~i2f3I$P9$o-st5U@;aCOT~asSx6^!IYi+G%I4F%s zMkgS1JWcU_)~9qzq7#_|w$i559uRuO@M6WsRveh)w40krJO zlgB<{61c}^$BTnLD(01Tcv<6lBh=m2{9-wY-lj^*6t-Hb z`%4RcvMy4M$y<`xpJw=FR_iJPq?7nzV9Y(gUKrNC zVbl3(T*XbVsU4i(%3mi%QC86dL~<8s$hxsSL4J>MJzYdq!-igQ#0{^frqtMh3djk; zw9?l7W_Wn}LW5Gy1|SiMJpS{BsEvwekjA6{stz(8Mm|#&_0L+{xp^&ZpO!iKdhPSyF=d9k|J=$=a+bvXErP5xVNzEjiG>sZaSScWd;Zl3`+rZPTy2e<8 zVosR)6VhD3E=Aa$j=$CT{?&(Yqm0{_3x9f$pzRD+d&)^!~`N74M>jWF%TH4C$@ z?$;TfnFqSQH=n{n+K>-Fo{l(LR#jVM#3;G*)SUB7+Pjt`n&+B?0>|aAay^y_B6#CP zMT`w)Ix+M!+^0{h-p4QJJ11|03p1=>B#JIzkl44kv%BG+yYe!6%9GDn`^;^n#k}>B z5~WzJ&ta-o29O^ofy9RCW$12tf1LjSMyAWONg<9fBzDkZ4z?bxm-U@pDK5`vdr2Wi zJ*0D=MzpFRb}B@OPZx2o!kB#u_Tcytn#(+wW$8s7gDE2}TWNgYi|*puf3maicEVI| zAP>8?rbm21?#Yh_0PsK$zaO}d?)?!yaEvW>$C$RIn#6S1=v6x=)dU%|j?V$OXm5~{ z&}aOr^_$?`hb)!?aT=^n;JI%pwltY;BlmmF?7y*p{BObWrt(O6KGtH#>_m|6W0Ft4 zt7}glDX$a6q`I9-daxUl-TK*GD*c%N1~}BaO{I;8C*>2z)zQr_!Miy`U79H*xd{?8 zWfyaOwQcb~DgOZD#rgjLNs9_JM&yic42Hyw8ca4VoBsf^{Fe8@tx`96%dA1k7b}Wz zaNCU;>6@_1yg(==RSCv5Pq*uOz3<*s>v+|Oh}!u2LnZ4|?*9NY?|nAJV8|jW%I&K# zQzEL4j8SacH(kfpc@M0tv~*CZ-Lo7N-vFTLmOW2?Pf_aW#Ii>%n>kehY=J}-Em=zlDD_KeyhK_{pAa2a%=83 zYstOV%b;WBGH#ykJU*_d*?`%cOD?A4US&n-uSeeX{{UCItCJGN$ynbKm{=FZM^w|V zkn)mlq4zwN?~QwnoS4>5?=u#%dRfM(zMogs?>%1Yqi@nCc)l?)5W6#hcFbhXgULp{ z1ddF8xA?*o?IK7Nh>)_Qxg2VvmD(YVy-x<~>-SaYPdw2%5fd5`pu1uVN!^c)Q;!dI z{$5_$G*QCP!a>O1NdaXyCypqS6zd8#pC_vZ>Uqu5aH+vHZHbXP0CRn_}1DVHn8c*2kVad>gO2 zpxNS3sg$5IMH&uu=L_AWfNg=cg4&yF^r^&M%Q&#e*e`ca$ok4lBikMo z-V>k=k@Rt2AKL~T80a!kI$eM-dZ^hzN#iV}u32C>a>voftB!$UjGJisqE(|)HvA3T^N14jrg*rK$6MTun` z;~oR-uLTIrRI!MqfCMp659S(tJboCqjk-W_i87k_wmIUr^rUo(q@klE>kuQl&1oC& zH?@+<9?VyGQY8Z90mk0fyQw4_QjB(q{w@f;^1f%M$Zq#j@_(BCD~sZwmX0^t7@b^_ zc)C1ow(>CazPI?pNu>CjkW*u6ivf3JRa^J5$E&Y7nHEA-3^81;7s#eP_~rXm=})yr z(a52aMgZd!S8PcfFSL9o9vd=wPD}20mFGalMG-)un%84v!aSz;D5u@;eb@T*#_+|Q zmA(j}m~k;HZg7p#ZcF!d?tB%0iNZ@-JF!$-0{Y~!OZNK8^LDL@D99|W3bzfIhE{uT zJ>0vFQ>_00*0pRqUd;q??CBHSoG*1QUQQmGyj+|49q0PlO~a25Y&#HSW?`TVdAP68 z_jLHwk_?JLJ%eB6`*k$Cq(HA=s!D}R@eb#``@hzClMp$L-ciSV0UaxzCduXB$NIm$ zu1&5gq<|1BbG3;fMYwaJU@^LoR@yc6KSr0lBP^cak^z1b&(!e#_K3PFU0jQHDt3D~ zasBs!sc5nbn#2MW>96xLtVJL&LQ32K4?o!Y6fj5(M{?v^`T$DXrNNgU-ct#XMykG+QmT4?^82spGJ&%K1TsaNNqc5;Wcs|g{pEQw zGJw$=Fjlz4Xzn-{dpmXCS5?af*&>!=sDzB78LPHO4n3k(eXo@I$S8K5mR7VefeJ-- zR5xLrxcODiJUKp3%1PxuOM??ANgs%b4aL`)ZG2ac-O}V@!0#`6kcfwcN}M>pvX>@= zfPq5vwPJqO_jSAD@_VoIHDeP8KtmwPLO@9r^da5G$osBQzx=8EoN#x|7S04^R!bQV zMLRMaIPSe)sQm3LAk9qO5_rUYHf!)fyz+l{M2b`wlE~y7Sn)r+iG5GK+GYO$wu^af zV5^C316Y!#-|+Epa{jOLlS$^HCRqzM=`Fhk*)E&b92?(vAG`8Z+q?kR0!@kRU2)Et z411;C%sBDM10IMyAGY-zlNlr>j|tS+{{Tw2JhNVc%pnZ0I7MEQ5q4%dcl|nF{{SaA zKoyxsW+#+u`cic1AO#o#3bnxxr`zDo{{TDwh)Z8VlN|~E?~n7}b#v^Vbss=BfHrkf z#&-+E%h0&u!+!l{6or{&W2hGMpXXPcqiA-(HOLn{#>1tL)>kTmq}-`%UiPs!;={-J zYwav^%b|t*-m(3)$~FB0y!6^PVDAoK3oKyVmhWe}@=^Ht85YqlN=_kIf<^LcclcL7 z)%6vr)Q)%fdpupQVW-PUBAOc@MTo~ka8zsPZ#;h-DnHMg`PI30_#Uq2@u{QTFiH)&Z52QjW3&5wKO;a;=F7>u=!T-xz| z^1R1V1%<%#16b+JI`!_+-7R329$}>{+WM}ZYHm24KTnSzuoz!nW7=!csPb;?7g5nr zJl3OF%$(pa(HhoR4$s3@K^nJV)8z5v^h?TN0oLjWUc;y4*0#(!^<6oL=C4*gm|{r2 zzdHW_a+L9fJ)`CD)sE)-E6Xvs8tZOqz@&4FU$U;28xGt48dmSm*-WEAR8lnJK0DO- zQ9}S1UeHAx5Z_X_7x}-FUY5h!@7ykv2N8wz^fu z8?K+G)CXrRVR72#&H8a2h^AFobiZ;?a^ZCU&61*NpZ!{FQt&1zTJK1sc(pYwmfwXo=FHGlD_WPSJBlHjsF7 z7xNB!wsd5TUTlwfNaqHvc{pqiy`0tNOm86<`Rg7?qa;X0zXecP^Z`x0_Q+J~wqg5p ztC2$hMaRv}R*ufu$z<0<=B<$dLcV&)pv8#2hTnjS_6%%UyFA9i{{YgjI#o6!x3@~v zx??H0@`!b~=Oa>kDC1_wU&l|9AP|#pHlNPMrMSk#AJ{9jyLs%YgADNTD={t9_FHvB??*GtnXun}tDc6sW_H~_J? zpR@V39AM@x&3$T}sM;VvxbekV%gb2!RSwVx$T_SuKb2h5Av%yn^#@Z>qj22*4$9FZ zYDxJ!6p{95(_P#%Xg70lVP1T;zc-yE{7Agraem@ksC=EXS@TR?Psy~zOK1O*nsXfmrI=^CyG0JxcYI_3$#jufz61vJt6rr z(Vd|raU_v2m)6APAMz;eYrkt|+R0KTIZ1nFqQlePN$WWLl}ygA%*e|j=H{~oZvOxe ztfAgz!jrp?S0DiWE*$fvy6bRsJja=j+u?f0XDnN{|J3=}9`?nu1ZXY#%2;8{d5<}# zZq?f)0X>fI_v4{_0V9|Wds&G38E|;~QEKz^5~oW6W2dKCc2I3=19DBdhjxNV(AY^!KI5uVk|@V z)9Rnq)UixuZPbLwj!R&JCv156J>OR+=JtQ-dT6XbjggcCa*PGak)M*z9imv;hOuk9Tt`L zmHKrH{4Ko6c36y-;?6^0xAeAXBt)0S^1SV{ZQ5lRZ4xsw>4mV206956FaGc2@geKU zw`@wI2^b4qL_tPbblLv^?|o%>kF>Pz+eG(aVvws6RV~o*;LKX)#Fhrx4r|vmnSi8_ zig;t@*V6jWA4*2-ir*MV3wshlI`dbuubQYhkPch5(Or*0q;Jdd&*N{_E0(!7wTBJ8 zpC9X|;TVfABv*IZqyXN_W#jrqhCYfE$xBTXLKd*Zo)27d~{{Y^< z%I84`qMH@?D}&-0Zur?@cs+2{eo(om8C^3xIFa;>AB%B>w>4 z=|8(P;jsJ`jLN$lCA2Iy2fa%?gY7S50aE6VK%1>>>BV_bOoIXO>#^;^oF z_Fc76TWeyacx}*PJzvJ#J?FjLZ@#C^cFVG)Ya2DUP0&H9$Ys-+aQM%=-}QZEQjIn* zgEn~sVrR6x@&>bFJ$79U?JsEfCCEP)m@-B*hE>F0qr7<`{!I z(((N)GRUzr#HiU^f=L4X7`aF5BmLj(li_x+j>9V?*NRB4qWvp(cW>o$^^}Eyq!}|3 zLA8mn0N2uOyD!39NMj^RAsWQMs;)Q1pnfh+2kSq??S?Ocz7ZrKiK0v0okGZ<4lLu{ zOTCxJRUa~t?D-x9Qp|N#VZ*7fV`SK*R^IJr5;YLUGLe_t{{WW#^?Gpp8t}#fL`XL- z?Tbls#oXlL@z!^SC@$)*z%u(?CMU|L2fO;p##vucxg#JjZ0{pO8ljZjmANd@PP@%Cw-3{0d* z!I^W(Y(NS!yw?9()xN+^{{SmFaerL?A9KXYE> zVK4=j)(4);r`3CPD3B``vZ*!~YnLnBUFKQ%@|1U1lQOGup<=6h`mRwte$JY#?lI>^ zVw=ex-N;mhUt#xHzVqqUBB{8+V+Ty7A)E3){^F_H9PIlrj!9w?Ug2Z`Rz<%lvquzY z{a05XBmP_lA(-rnAPP?csZE;DM4Nod;TaXlvV;Qr zuy|sVXm#L&^P9@EX2WY)H|ZNH9C&h84dpkPHgS0TCu;6acF0k<8JSNvbx9+<{ru{E zqv0`ZET^yq0Vi3f;wje%*+TWbcl{`B%Y^srnO1BUP6k+c;{{VN& ze+#vDJelX(mBS|4x;R*%}U6tD&w_u7bpYz*R z*}3J%76Lbz~bEc-CmPwT2_kULITM&$aI-lRv4Sb^K_PW7>9p z#M+NZq-R*dt((a_4Y`ktm7BoSzK5P)JG3igx>$vL;Ir1(E>+NSu)Nce_L2$@)d+)zUQp!h@nU(`ZyZ#eQhYp zfSV=lrTa2&_8tA*Y#UrVM$8V(+y>WDylSVU_%3n3*80!QC^QfkYan8vX>UKQmy^%# z{G}YV!#*=Bt-(^axZ%*Bbe^4LlQ9h(Bq_Kqv5Z?UZhzkU-_w%AIB~~o^VvZ;d(6MZ z{{S^M%&iLoUC1m&`K2UvFSPP}l|dx)?Ipo-j^#2)vf~+d_&fgZ!G6lp#R(;hW$a6i z9v`pi=S0wM3bPd~;{BJ?%4f>`Ppjnr02EtH(2zpsB%If~#Q?V-R6bwerW^RQoyIM6Urp}z_X;;+f_d*7_ye|pd$$t;e|-X`VSoPAevs8?h~SjIUU zmcKi6GtInz1dWytmdv44>Z%#oIGgWRwl%#!NaY2M@*5`s$f(7-{l_9NG%@o8T#!U>lScvDIVK ztn~Re`rYR9$nv}%cGI#;vasF^V}66&HhA>ghxb*vV>E!EoAX-N{_QqIOi4p!=zvca zMwSMl*>0`P*>UM5ITkgE?7U7zn5kyO9)hR0j>kGiycbF3o@Hx%O3B?^u1k=;3*^_wvo7A*wp1CU%L3N8#m`Alikt1 z9c==K+MI~vXbFDt6-!8KNqtaa$7?kf5iHhANYcB{IKc{Mgs z^_cO@{{UT25Uz|?>@^p${Fd%kE>)?E<2H|s?7K?ugo~-f6Kt33dbzjM;k7Bv%cxYh zBs$m2p@*#9`?~k$F-<$HtC);p#!<;|c}`QS`pBwz{sj0ukt}S{WDWOcDD)fN7~NaP z>D13m<9>Eg3oh|muXgQW+u2pBTncPa zz>utpa#sMOjWm8QOO&DsZhIT+ac|{gS<54iI`YIy$*ik#0Tkj0uokFx$x7*^hd*s% zP1jqLktEYxl?`w?tICywHkZ9VSpyTTN_qbPje4-hPZoLHujJ)YO+8~HV$i}|XEZaJ;<{Ir5TKz|SPXqBvW(DrF+FNUX=s~y1Q4nXs3YE;T^ zV89PCq^#fpuD_y*oSw~m+M4I>-fF^`Z}QeElchMRw>7JgTTjzrTgA&ITO?fCU*xY! zB(rEVsdhW4(@#C>%*g2IYdP)Pr>$*bfEP_DUf@{nY}g~+ZN<>}uANCBy&ONyb5 zDAwbhK_cocSz)L@f~s@|kmAOh8I>H3Kn}E1E~ebyijtAL^7af>yEXaCe>=zQ_SEpv)big$nBbaP zvlRx{`RRehwV=!fy>ipO;ph+x0x6WGVbA__Mbr!an#*#!LfpV!r<9xs{uR>g76nH3?8!0aN%rtZ^ZlJ~f&!5Y^|i;h{eod&;xo0Y)rwIaBkDYy)J znmA7(AnC)=(^@08IeRjT`b1dNsx_@7-9SSa&AnHNr{W`7}7!mqwI%-wXkU$r` zx$C50DoAE*h&Lo?1@B!vl3W$H(LD7c_L##dHomnyE^fNm)lyxfy9*n04rA=pYoiX& zlJ#jSjd;lO^O4O`?He!zh14BvJoS8dfvv$h`RH>FZZsO)sVOy%V8~ z>OlF%v@URPz?*|`bqDh3=9Q%xUoV{8`0?}UnH-^RZ+>oJOF~i2_5GrWSn194^H4be z>7YHAxvDC1$ieJDx13SP$IL!^)R|ewn9)w=!rpy;YBbi*8ZbQf=C?4T+7Bi{vf9>L zX{AmWSZmMCYBo122_GoFpOIF%Tj9AJ7hl6&m^$m1o|is~1@1dFq~7))ww(>0&2f5~ zQ-AEHVsW|Vw?3cIrQw%51^G|+sf^`#9>*MpQDJ31aga6in&lq|lE=^4Ud{=94X@a1&)Xv;o_-wx7Ps4?Vr@Wq?rUXj1;5ILl;ivM&F9Ad0N<}c_GM*P)UfY)$D^P9sw-2Mmm39a zSb#b69ipGl(X$T1Dx~@YRTDq)7H!}Vb4*GMcZGO{{S8euV=6p z7a;k7G*c4oaE5`$KA$CcQl==k8rRgif=Zv+YVfw`Eg%i2Xf^;=_-TGkd9s#P0> z2cG```BEH(AXpu3uQ2zVHDqy|Mxx$-9CfPXmuQLBd@X)$D&&Un3$2LPT{WPdAU1`{ z=goTbVFM0(+neX2wT$KBZIA?%kPhtj(I(29tkWWS^tF0>WB@MlK_k39qu!v22pZqz z@r?@3UVM>anI|y-6)SChuWqB1v3#6b-~n%vy*DhTL{`OEYA$SkpO5nM=dK!!b*{MVI@f-W9z1he zJ7tA-BePGGQ44ckyF;Kn)Gd{Uy7pGscI!uD+d@a0qbNf5i`*NEHDx+6*UxU+z}B_@ z00nOe=c2}+|JC`ok_5P0yN=Gm^r;NYwegTHASemb}ssM6(ql{hY_M?NCm~=J_jiA0=-3K6?;+!=h+6xo82mXd1EnstbzTYArxlw;pcO zO;vMns1){>IUnRvZVka=9QFZk?>~>xZ3ku9H|}ycqH~*G=erX9WqQlBGB9xDAW<0{ z#~)O<^fz_i7g#o40ym3rdCqdl(e!s;7L9p7Tl-KpfU&zCBv4i61Eww1Y}?OI%J{v# ziyCI*0{h7&Wg(bpZ%xtq?oa*S-O{#~$(RMr!MFoO0Ccy_`?_AtC7I;fT#J^&&yD!! z!|M1Wu9ydL6l4PJWw^vD`stpd6BTBBup}J=+S>U~v7LG=f<&>ls(3<Hz=3hwTA{NFg6_m(0OCH)J~*b@9~ffD8(CFy#ZQ-?{^Zr{!?MW`!X$w) z>9fhp6YDmKcYK5|>LT)7l@^^m*ft%Y-Ck7L!M0$W-(OA?=vB$^M!^qS8cSPGFVXLQJ*GDd)}3fv+fJR|63 z{arVB7$dW>9Ktce3v^cm@}55gEjG?bl1-B2-ts((hyrqZ&69rr0PP;@kpBRyA32DZ zVpj~;E0J7PRAao4X5i%hqswxeiF*wfqam4bWn&wGlL9S!aPVc&kJkB0Q*CJul1{GR z4G0}3JUM+l_k68va;x<0l3L)*1-_i0?(%Zkd0&hF0DH=pVhYw0W?i(##d!L2ZedB` zg+1rSwuvM1e5XAm3pUNaLfb6J~VtFh`YaG?jB$(ml zB+9HyLW!~+Hhbz_isf!DqYo+L*|I2hU>%tOE5c2xiY+n}hTzVL|yTbnCnR1A6;@_R3<-9ZNS zPub|}U;q~Fz3=%vb*Tlc0N4S{G_Ay1$6DrV4r`OB7UwjelF5mc@UgjxEo^;UvcFID zU)ugaGp@`5d1F?^bkirreOG^h-k0$R90!aD%K&abYah2>Zq>G#rv@Ogpf@5+OErgB z+)3XtBJxRgCfmG_s@$(EvLtKiayvHj>pnx-G2q&w8m9d70syf3a8t^% z#wEZR7oPko8`Ht*{{R}<_q(0oFQ~^`Ty&$}CTYE|H)xwTV=9SNG*ky5%yP{Whk2ZF z^}L_?I@sFfXyb?zV5eIxt#;zdu{O<>$OyrW`T2!;uP7+%5@t_HfeVHl80`HEKR z^Daaih2Kn}0|KH90;TfHeYEXv@ki@AsJDOc+uA1fw-!sa{{XfZ-7{|2gAE*n9(LT!ZIXgs~%t7Rp*tR$2rjv6VJDunJJU{VV zr>oQ=+rmW4uqd%(nz$Jr9AZt}44phTTi&)Fwttg-toK=7fZK=xBwTsH09>i`wu;wG zn)kTpK5JXKQWXT8mTUgf=dK+zT#VkMccLaMP?9og+M<3m&D| zo``P$0Ois)&8ufKl`K~14h#yF3()-KbcPnZauN%SfG9}FqC;{6X$`^4 zhEH9=$XA5hK2tPE(a6dP45VVO^c`v4Z511s(nr!X2GGWL{RL+ID~^ zRF8Qci+5IF7%{{XbflBYx7$?oE+ z_gtw{B(b9aBv{m)0m(ipnK#~bztz@f#?qC?Y%aMp@iM59!r+BS-fvz>I+;9t!4AW> zY^!(*IAB=D%Vijz1XUL)=>Kx~)oH=Gki@C_|WBAN6e%GIb zwgy9I7O@e;{F1tXmrEDiM<tj0PREINjW64hE>ff3iB>@q@(rrHoTNdLNxE)_`!oLlUG8|B8L%5f4#@;k zh}&l?cDt31NRi{hM};^1hWB|TSisq~DPxWIIEhfM0Fi{rZc~HEZy(2)T5YK|MvgVb zhztn6S#fVRO_Da_gLmtE3ig`@Q~=1$a;R(__2pj|ROo$Q`T5&G+aV1jHyV=Ih6Jkc zq*Cs2>Nv&c*BuEJ(}C0f0O9`tV`17!wuOO?{el*`<(-FnJW zMkkfH?HXG0bV3EtpUd5{KMu>5QNWq~`;R1pzmGFIbfvle~eJG*BABl}8I(l*{=Xr7Bj% z-6KGe^lj2M1y?Nc{onpgjy8rk_J@o-4$Y275bT4Zm~=hYCA!{!CpPH-Mn##|C+x7x zx4MncuTN*#a%`pph=YjMvPz)D6}2~WsBf0tR8iMa~7CvSTmU$Y{MkaI8 zYjDO}xh#!dw#)D|QH6c3OzuUYTZsKEIC{_B(>^X^Kp~7GjaVr=z}*Mg6dj^4iPF+4pVC*$W4b0RL&jjWc*m3i5iyC)B-s{m^a@wTIc z_If=Z5!8|0Mn%D9I#?AT4J}M&9ZDaO~R^`ASfJCNv}V#yyK^58!Hp7 z?RH*{p06L)Q?X)igLZQJGBbu(7e-51oDMvytz0%@oYogS-L*KBi_5$K3t2%WvF_*Z zz1l}kk8B1+No+UQ-TGHvCeIHg{cqOkuIE0_Y}yG$y{+@tg|$7{<7@WYQC`2o2`=rK zPFCzu#6FKszH&?EZ*Y4-(Ek7@3h&X!e7wKJWo3}J3=Ph|D6CtPZ8h`gE3K_j6c;xn zACq71SEf9>QL?~E^PiLcnz>>{&;lFe6`MRwYlC6t*7Z;Xe6(6|$K7;fr#oDHwEK>J zdQd)p#anaQEm(D&L6_6ArgL*dTV8w6F%-|uoUNrY-jZYaFH~UfqKXp(j_#iz{{YUl zx$@`Rm%=GiTjTsG3v>884SFpLFwsfu&|mqqt;?HPNxwCL*ZDPW4Q0OxmiGggRRgeI zhjSvHC2MwQZLis z195O`cug_XyP8uIc7bZOCDU4gbfyRVbY(Vrd1+?%O*!Z{y6eqTmcMOgRi4*Q`m|0h zYmdjRJT=zf*Oh&5MZSAlVFkX&TpobhD-JVoFuajon;)kk(%YB{Y;Wh$jqTzRy=^6DU%%FNfX?(?bo$m^eO-8-~& zlc;lkHNQmZe{0aD$l-uE*z=BSe?O3xbE<|a6m)PR=dcEa2X0v6yz;odI34dl@u|GFQX_>|Bp@A*M zZq02-3ZhU+;XJ_9@3g(cgOd5lP~XW&Xye*aWd(=?^R@BChr6p)cS`jZ7^Gen)3jHW zx2DdgoMb&4)zR%LLdIMYeA+cn@QPV^?mGd!sf#ZUoZ56ZQ-7ZIip)m7YR#t|Kp%t~ z3K&=&2q(;MdhMr)mkbt|u%LFP2Fw;E~lRc#%){{U@f=y|EQ#wJ%~Lx>{!UcH#3 zAyt>7US{kKZO!xl0Mk{+wBVy&5$CTqQtj))$DxhK3H%(@=j}0#U-QhuFT9?%Nt zdHi>(O{z z;#KT=OWHve^Hm!&*{n~_2-2pxW+a=Q)7x4YgU|S>ussG~P1NaLU&06rDL#EZQ&psj zeD@sJJk*V$Ly$?i9h6DSS#usf*%?U|`U7vrSM2M;E;~894I|+dx~lQi=Wk#qPzyG*~D*7J&VVbU&ZZqGFg35D!G9S(Fz z_Bs#6Mw<4EZA(5tTK;c6c`N}8HNQQ2+j}_|=3kSG`Ss&FHTnMl9lxViobvKfABMSf zHtaRW=dH#LnuAbz=po1bYj0t0dg3*=HE0X?dumL|{Q6gOHN|c{ zm%VAyqQb`CJ#n{XaIw%+#fK9(mtC4yY!hNJ3b&~$Eb^S;9b#ARh<7JpGg6eex!yt}@K z1$hx6mQ;y#F2E0#=A0$dIbL4eDY}LjE(ydJ0QSH8Mjua)$7C%ilWAq{sAdD0963yX zkxr4GF_1TR8898^W68v=-|5lE@K)H0BxoM%$C$+Ct;06X@vNKQyz70(nF=-#PSL#y z!l05=0FNhH>posfs@+NC>t>mdvH)F&9@8P_EZnB*FRbeKzni1uO__IqfO-f9?~(_i zx$*M;UU_}r;f#y1c@TwTaJS(8dOOxy`V}cNw(-e2SSkxfo`f*3vE_VUl^xCX8d+69 z2DMl-5g6CEm@m-X*0k@hGLkHBbEl&U)4H_^8$5-8DPT1$dA{8Yt%)Nnt#B=%{F>vm z;s_kWu=8o{FHVf?vbl`OWe!w;G;1;Tzr?3wUkOnin}h}@o6V&?p^+dg!s~uLZ>jG! zdgo!-wriOiOJ2pdOQ7px^1o4DbZIP&gi2YLj~2f3b~v~5Nb7=femwTZp033jY_J58 z7)ZxG++dbwW93_?$#s8sR5WU?E(T;+_KaJ~%j)-Etf7gW(%_h6X5b2PY)5}++9%~a zc}~{i<~2qvjHf1*Se@eg61sH%068A(tNBBu&BGxu#IK3fwZ|YlyleT+OOyU6-xUnf zCT!;5jupi4zV9CWWQ`O~*=+_%j*Q4M8hI=id%Q~QrO{#8BrHv>IatZRM+M~P{NMV> zaz`0(@&V%2E|pPVV(3mz3)% z>gp48V-qtHGFz)TAN_{oeh<}CSK!T|A>QqhsgRM3A@?ex+z8V~n^f7SsoB5%XPwb( zqA0|XENe4ZIT>ZUzE37kB-#D!k^OI)l8v9Tq!6epXOYF)O zI!C8YD15BmK5!X#0mvTwvU{j~$KFuL7D(}suw6-1$CMy> zNhALN3Qz7+C^YpYb&WGT+{y(n8mE=t% zsu4k8%-FnTkxOHf!^?CMPF62@O!z_uiO`7mb6}2a%6Pn|DJ~N2a>V_29`bItdmkrJ z?GZGxhLp3$RPOX_^2NCp+4)If+TrUuzFIGL9dBH4#k_@Qc&{9n-4{FNB;Hv%sapv6 zMOX|fpCV%{)Cm=-{w^M=JD&Haxn;=l^**zHB5_|)M|lNa15>CGH-qs0dG zI`8|unzpwR?F|Z%8j=E)UUAmvxo;@`Q^pN9H?F0Jc&9c%EO~wA6=9MoArfe*(kAxT+kZ_f!FDSiKr~7L zSj!MzNSV44Zt?r5_j|9q-B+y^I}{O#K@vM7GTcZvt;MpWzK6ws-g?n~D4)DXlK&eB+`F#(WWB#BEZFU9wp{pTo~&87{Xs@eRvrEldck}cG1vZoLu zjR^{X{{X|skNiJFtopAfQ0+0ZjDX{BoGu76;P;;Lt<4PgE0K`lnwzwj;NC|Ia(icw zx>poWfOTANVmqo18B5o0{;D1Csq&k)E%u3FCNZ-BT~$E?ODc`rTa@x{k@cSQd0Dh< z(nS+VC4xJHhCHDv8-sn^cj?LQD&5jFMdU)aEt3Pr*q%yGjHwJ#c`ceq{p6msgKWiQ zNu{0TNcynA94<0{j3i6T@P2d3?Gg2DnRz9_v;hK0&$*PdgJrs7mT_-?clo+>%-e#1 zSQ0|6g-CJNadGbwZyUN^!tuO7#*DG%#K&?EDaM&kjuP(Dn09$uT<$OgfQPJd6`L5= zE<7WT-Sxh=vq2(8$vH^!0A*$ejga>+PO_Mh#E~m3ipCyP0=%1^hHm|5BkKOPsdj0%Pt8Q82|tZf5dGa_~hMwZ|i$2Q_I;l zvV<&wvLsk2zVhQ3-dudI?`rf|C0N;H#YwXP(^9R5GDFjkr&2OV0NS>ECEf;Qa2Zq# zuyluK&33Kl4=)r@N$&YRkE9~mekmT%>4frJOv32=BY7@S{{X!AI=)U_-}14RKH8Ni z$(UstJJ_6G5^TYGs!`cCAdNh)yHv8g2m=z==zZt9r!d*(2=Q<+JWi-1lY=#rG~nI7 z^#CP7+P;7GcVeE*L=w6p$^k45*+u-K`q_T2DC7FNu?jXsi6+`1+LaMddOA#(YVq7bH)K1$G&^-jc~^{3I6~wy!x7M_`61J5H`+6(*vDJ5)eycvv>xy z(%V({syD1~TU}&c{{URRN7YCF05QAXRof=&eU!i%3o;GjD*)<%E?IJWpRDzAd0p*! zWo!(v zhkNNNKXY@|L%Xj!BP7zHk;T!n0{So`m5M!{F1l|YGu888i^54`kA1?1#Me%s{ctBN^&P4Gq zC)XsBFSlHJh4|72N4O{k@jy`6MZ@cTY<^e$nyE%cVt^xUlGW`4sX* zkrU9*6%jeX5QF5g-EQ}uH@&UmZEfW3zXt4&2gWu+u_FM*akM+L(tmSacFM~NL#%k`AI$BldfIhP^E3mc!103 zOC3q^hE0$VqmvokMDuiV%Ak-V15#(AWR{O8b)fzQJAuW8E z_I7;;55_i+CJgVXZmJ4EKr=0zOfkP_#jo=raPttY3nd?2&9k7q(w;iP#9V(g>4?)d)z z-Yd|F?;K!@D)fLF>+E{E^F9yR1i2a!9iAv9lg^=I(o%v~ye=p4|dl z^5|ML#hCyAi>oQckI|RMyZJ4YVbH$vO%)o#{++*}Rm%5_F|y=zv9UglMJl^YID+H@ zF#_XXE4j}8gRAm$`7) z>l=8z%2|yWu{L;ES~gX-C3u^4`5#?O6;9E#$F*$C6dco5P&%8mdIQdxjoIWC?*;yQND>FEL_ml*#5 zZC|r1KR$qS_J1~}f-H^k#QE|o{L1yVNN`^{Jm;MM04lu}@bTk~x521@Erz4dTWj`K zODaW*wX=N@Nwxm~=%p;!?F1_DR$KYVunONl6U%jP} zp6)+G$C#Wb;C(K(RItr2^%)x__bO6cBfDBs%i+7)^{{R=V=)vzX3QsnJQRpWtNNL&4tXp3-DdSZ? zDCedZeq>=`r~_ZlJ$0Rk79gJ82-oQ~TZ2g7WRL^8FCWo2{!4iN90cJ-?CeMuxJ^6l z=u^Ui22NXB98N~VywUP_d(KLYq<{=lnDg9ee@B&CW_Ea23~!j6v0YCO&TF$Y!mB1p zU^sY`+Q$9da>IG*=>T9+En;;yR_7fqud3j!(Jh^ z^vK1?b(PiBv&e;MtQEDI>ls6OCZIIzX!x8B4N;ogT7g=}{V9S&@_73-i#F?yT8fL` z)*rAQ~|Hj z2U~$-S^*^G&GPH{6!58qfwQligK_pKR1p3JHx_0@Ft{oPoZmSA0HU@wc%dW^2YUSf z0AiUcT$Mi)c*i7v!=FW8%zWM#~W%2gMaVSw#l7><8Y*0`SkC%tFJvE zlEAHxpP&1+Q3Z$w>tU$}hv@KsjsAAIJ=aezLdC4c;CZwar2t?7*4MvC8nGfFjrq?p z16t{n02j_|Ny&8SbeUc2TXWJG2Nn6KTrswnzh|DPn*#hdx6TjqQaa%i5G+QT`RXi) z7DE^Sk=SpgIb$k>6PL_-HB?(O8y91CLUUbslt^V*Ph?EVVO2R*2fy#2I! z4?PWP6^fqGdRuUdgHK`z^O~K&=Qu!p&(1Y9Jd)bsMZQ2N^0DXrpWUf8 zhU~Gu`z$LtblbzK(aYy2ISh*kVqQ-o_gnh?brG(T&_B`0~pK4VEI zTh$d1YaVK7w>|Xjw>=?jPN!crJw>~1@{b8hl)a$?##H|RrFm_$B1Bu9_Fsm*?=(;#bbL|AJI!Z{M9BVhK!HJSSZIe z6#!f2t-qh}8saNLj%thl*N#w9NaT*;b!!2BexC1Ue<{7J-HAbUIH|D!jMYb!&%eC) zR3UI|CtMYl>R6E*Y01d1(27qVGb4~#I-5jp9v}`k3;y#W=bUo7dZ~Tqx_2olmqkgi zNN%enX4eAz1r1CTY~ZVqr?T4IdRx`6S9L3?cv?1KcyD{#>AL>_Q~psncwYNT802k1 zB`im|eA741x#_*9-rYebRyGUT6nR-1fTBN6Fu?RwbFuo$Xl>@IOJ7srX~=|g+%Xz3 z;WqSqesZLWGz(>j99H%$LHc>g{vA@AG>$Ktr)7(NMKxnA098Is+pGGx{{Ygq=*c#|9g--}ym8t~UdN_w?$x^`Nwjl!do62$cmb~aTK@qw{hCnXd zLr+CL-v_?l{{Rk`PI6bCAllplP?*$YMY)&OVeS61nejF!tGkh*W5T~?I`_7Ho_1pF z5-7ns1vz^$LeIJ2kN#q)_^f1cGqQ^Uh*lWKd!w_i`Tqd@Pb))D8Xtm-hAN<73EoOT zP%hu@>FwjnN#K|gPE{Vjw+IL?A=9NAbT@JO%J)@MoJOni*?%XAX!3h$rSu)OCy_47 z9Hfs6F|Y`P>g)#=+qX;8kKWQ^+4jirG3@b1K{zuDxnt3h_rI#IaOZ0RK3*_buk$K6gaH5SP`JMr-*m-|_SMzXmirI!8QFHq*9DoC3ikg;=yX zaC&}LPgm8~gf=X=c-hVSaa;YE72Ge z+BT5H%^L3w+XdzfJe;21=AE;^^WH|Zah*?Z2ptc9>H4awR@wHaFe;?dYsY;ME-56> zGJa=w;U}Ijm|8(I4h4-@x(;$gj^0X7_>bNCZS%_xGv_ZWa(YS|ryBGki18{$L{H;ELmgKB^*{KD_e_{Unf)zmS#j<*o~}2vJ+v7 zHxhLByNP^{uJZeeotGwP?0Bxk-xk8_)04?>J>Qe^-g~!0X9prND==jPB|Q^k!w==| z`D*fKknIBTF(M-3b>?>v`L*8Gcy zXDDvgRUFoAqv9MpxbaV{o5_9OOH;tUNrl32S}^HqM>z(Ipac38tatl6TM54&Z}J~1p(&AQ)5 zU zEP7Gy)`=ESJ3dc8(X88Akz@?zT%Zb;u?Lu*>Z1F)Q_rHafysgwP9d&6-&>tv0~-x3 zZqDDVTImK++<}VT*5}iWxk@&5!J8}_)+=EmuP8|&S0Td-&nN!?W7T`AcGU~<84JyF z9h_b@&ABd*>b8B|K0Sq^*-R;9lpN)E)k#I-g>Yx(_HXxS)zxku5}`rLTQPi9mHYgZ+^=2k;P-z*NhFx4XsjW-Mtfe)cam?h>ogdeNIGH(z5QC#qf}>P$hZK4{O-DPo2>k#QKoyp1+s;mT2P7#u6yj5EO0L% zkYglXntkt-rWn^}k!41+TF7E`CMF{#UA8&nPP3=`x_t%>tY;*OE<-Nu)p3(`W=v8q zCGR(WQnzOz-+3Hmy&QmBy5;G8XIbC^GHm6LChV&+ldBAwV^+dR8=)M%Z|b(GJHSy9 zUS8H{`Et=ZnK&lAOdop^M{{Zd$>)P4Q&Leq1T!CpH zguR51jLR49{;%++9hz1D0BeO=0^ccPDr6w^RrYDThV#~Xx>=oMkr|^87`=mKyWSAp zTQrLt+s`TMdat5(ow^B-BbitPA$0?Y$azHIUOxW-j>bm>?~y`IcozNae5M^_ zo-R8wD>}FbCoHJ;QDD7pNQ9BA`b2yA}uv>xg1be5VIr<8>pB9mqcmi^8WzLc=_G4q-#3GCnp(fU5=Ml8NvD7 zXeF8kh@ep%lIY-osTz}>J3kLE7Jej>I#^hkBVRQaUCTccV6e2Qi+MfRMIP>?v^6-BoM-82c*BDXsW?e`gt7K;DWb!i0MX>FA zKF-av&Oj>Ute_G(y+0>{RMN^xfaWnPlR&-DPgyfRC-1Q*8N@6qnZYrR6HASnVGNmyH9WC6Snmv`78pI9s8* z$CIx+YuKQNXO?GdNf&_gHwo%x-0z{ob-9~;?|&S5!0te2-L;I40a2jJP11hrd&z%C zdz#qOWavXJnb5_Hg4nht=3cS5(kY!c->Rs}va+jbb9DsSKASat55(l!U`axY5HQb! z5S)w~)#Z1eQBljWN;w%uGHOw+i8l?wA5A2Vey7J5JIiyzF({5jia@x8COfuqg2SzU zL&xFNP_`!ITN2+td+bz-yM_sg4h@4_fP=@y>g4hCu1sZ6Zf)|7Z|$W}tA}frJG7w5 zZ~-tmtXHkoF$P_bCaxs`NhKIuG+R52u=mANM8D}j86SxqO~K{SX2^$E{9bH{%(`_##xCst&#D90eL%*$`&_b3Q8{FoD_h~-Ay=OA8v*rEQG#-)?A6u4wQ{flWN)H zQsIlO%Cd_xx1_kp>A{DV-fn6SgXe$;f(Z*C1*15HvCB7$Wsfde2vf%Epv&o+^wr-o zt)acL=hNG!I5)E~QVsOzpR-V@C4d}6fJ2TWW7f%wW$6BM-SgO5_O_z^D}#G~BmU>| zx=nbj{zbC7uq|d@b0X+lPWyd2Tk~S}PG-mCJ=%J?2FOCO*2hqOhEx(eWEMp#TH_(8 zZ8{sp>OV97S-6>8@~L1I3ylj~OqDt*SXeV+qY&Ko=%V+PQ^zSqioQ62p=>eLwYRg= zhsV;}F#^mPZ+&<;eT+I4jWp&FqT1XiRZl|Ucyggk6Pb99rAHFt23bcP6@xievJ-4Z zS3k+#P{?E57HO4HFT6G^$GWiae((L}T2Nz>C1QbzzY_N1q;*K&yRDLVqgEOuxY-qW zj}Y>Vz9XrFJfVoRF7_&y;}L;C{f$;bF_ zs4+`6(T++ZMNs4mmX_S7D@~VZ-NLULWqx8ieHDv7{{Yja4V|48@)5=tQDTGJoxZo# z`5&yKo;G6%iat&|m5$mo5r;|rH2c|qZsu3zi-tc5fj1k#9rXdeR_kMqU%TaL>d3Y_ zYt`D^WN~+BhR+in3n(Kk(`Y>}>GGXyG@Zkvp=8gUi;Y*C1>EO^dq4Ppq&_$}6MMaY zUS&GCzaP*_s_5pp<5=k#*n@_E(;7R-!CNr!t*7T?PLKcP-6=tA(DbDMJ? z3K)Vd%u3wg?!7P0XHqI85kR2xDIz;)G3j(EjP6mNAfS2A z5b*v^wB{tW?#AFZYWH~K>8DrMQkphm3#eOIt*&uzStnKKR!Qim_^XoW5Vx;S<~{HnL5vuUuuCkLNS!^;+V9E-nLm^ z{hdCu;}vObT2HfJ>a4a`N*Y_GZwiPLyEZ>vE=vNUnq#)PEG87?mll)12k${ zO$uB)YN3xuez7+5gLV1uS4H7H0R7UX$QqNV2J{l7lHybV!0_O2hrMNbNZKT87@baF zu~}>-$JSx${^CFF=%m^db&HB?jk9ZdO(blc4<6MO6md*|i3xTVy9;wK4)cq?{;IKD z-xsiM&dk2ga*5@Yl-aFvPVlco`~Lu0KaIu*5e>mh}W=vnyLl8faduHQ?~5v!B_>qH)uM2 zsZ;fJJZeXrp3}{5W%#Mz16r?A5f?LyY0gO%Jdc4TR3h_T2@Cd?u~_rlQOqtb-pW)H zwDL*Y5{~QJAGGmt)^*E$+KKj%GAEzr){?!H=~d4;{5R5;v4hTe`H24QZbmKksv*n{ zpN6VXMNbld2W?O*c*iWc9qY28uB|d?VMelBw$(l8fvZbugVvWdI48&cHN>|+j~+g2 zM`jr-gP)R+Sl0%qZDPwrW&9N&V_TX309v$rl&+NP$L! zpR&}M(&}_Nj>@T$$IbFxjROyw^vErIj?ocG4em!_JoFI6uGpK8ITh=CDYB^J8hOo* zH};up%!1-oJ)qoty8Rl(nxa>s1PdZqFeGyynxmC*W?O@QH_7yC=pos{*_2`!Xh+NT z7CQ2_y_*SS8WOtLbd>u@zr3o>4BwZS5tx9Wl6ips?Ob5UXQ*HXSqx5!&<0o(mwjW9V&O%ZT=`{# z+mqGMHp*hgO`V4Cn8+>&litxv@FKF1)@@anoo~?Z{&pV)VpIvrjqGwM$XJe#67F|x z$VT#YzA<;InI|I4BQ2qoK{huE{8)Z}{@rtJ7gNcp1b2>Dk5AZ})lxUH66*3~ia1TJ ztMa?8+ZK(UkWNH`OsBQEZL=e;jp4i~*py zJ%^fwNj%3@7HeA(^xyvH`LHW?_Va!ZrH_x(&C!+xg~2V(8phTXz^u)UvH7{J?ABb3 z1+Vg)YXpyMH^LxJo3+W50hT-Vif!I+?*6ajqO`_@DwoI-dolpfGKNqO3Z{iY(ED#e z1OiK#ZzcLp<+798-ZWCf58g=^)WcT8u}v>n!9e1)V)pLxv-kXRo_9lOV~GLaBstWz zuEhNYO34C-#74O+tb>5sEX=A&Q*4qK3KP%bpkpIQF^a!MfzD%ivt_OT$f zfb<{9zr3$s06{V+`aLW99B)DcD6n9Pb;MlieKcl~DAC)+o*3iFn@-GtzXnyq;6Hc1_|E6^KYJY^^MsKzU7dj9bWF13xL0!&JFGh2 z7@9jIK#B`YakjQzC;LzJo=!S~MlZDs+_mgkK}OexG{d>aUajPfWkAJ&Ry)LXg;wW@ zgw<-*G;ZOUyG6(0#Zj+ z)4QBg;nSK8=ac&1nVUw*CY!P_Rl1XrBTrVmA9dwwIku^`a*K+@2z7i&#z}9+&^t59 z{>{+dd0$m0bI{q0>$5boIF+nQ$9@4})n)6S?zeTRXNu=z#*y8aDj{Yz91X~j+)`}o zpIh~Dd0J{B*`zQyg^iV&x$zsN-^b*+SwGqS@|^O{Hn8r1k~tmO0_kqAVc#tK?g>BI zC)xIBrT+ky_tTbzQQk?KXH__zMo*INpWXaZANalRl$YjwC3Eda?ytg#stoRW8^0w6 zepVm(c)Y(`Op9hoqIuwEk~c!im$Nw?1(SKzVdLfV9!W7WI<9I)IoS~lf5qtdit??zlXn~SYwBYD)jSQW12dVzsg15AB#}Ni=N2*g zoViLTH-bh|pr)X97Pj$XjXLh5^>sen$eTk$a9Sg*r;@XIX(VqeCT+>Z96$YuH=UU@ zW<8)rO6s_yC^{(_{9<22l;r-dUaQTmX_rdG$YxGC5slP~4?;MiUWX%FKC7IJWtEDm zxlqNVJ;-A3!xvTc{{W#%%aChus0YvVIQ)NUOw&jmg}@3d&2=gP>Br=E{{a3!{yDx{ zh*91d7~%sc!ouZ&U+X0Gl{~L0aH7US*%B?!^!3u&RW~PbbRV*wU zQL2{vQLA!adH#hl*Ws|{J)bpGu}8d}Z)QljR|MD*8R3iXrrpC+v{I=V09B2b9wNru z-V8kr<$9k71=(a9(B$y5B3kA}D$qci%zC`{n@3mur?RmT0~;xL%Qgh@HM&{L?a9_} zeNO9pNn-1HZ*+3m5p@gd(LAPjFDorINPZweXDr+z;g}wrh?qXkl7Gw1^<8H=rkNv= z5#=gdEaOH3#dwlVjI%}4fA&bf+WpkM&V25mO!$O`HfD4s_ewaI9xsf^>LQary{L>5 z;;kS70m;*ujz@XQefN{1f7hbAuKdhqOGbns_Td+#==ODjX7k@pY@BX%0C?j-CEE5Z>i;K zH#_DMPAwZqV=}FZoJS!K0&gFWt(VJsz3nsVgoPE>WHw%KK()}GHpj%ZNRsJ=l;CX{ zV{I7nd&c{95|qOG!X$eX3&0!L;aNFP&3$E86f-(5>6x1MX9ERE6!LxF_N~<3SC{_9 z6SQp3NL8_lF=b%F;qSbaW%rK%0M*i`Uai<*2*p`$13#?x_kLT}RmK^&HUU$6lV;0z z%gb^xr?b8P0BGNGd%v-)z-;=$0040@%^u?y;N&2CJM z*tM6NchFCBhrFs}l+OT3M(qgRZm)>2;?694pRK7((k!6E%v9rb7f}6fD9zXKb>Erf zxyrK!k*JuIz=0c{7~B^u*79uM^F3znm$SrH zbrNKfvnpgbaSqr{>nE`tY`Z%`>Z;67XeES&(DHJ}4iCw{{5M_6dcN;Zaw2ldPS9T% z*kK+kg1oZk(Z|Yd97Gy2NR#&2t7WU<3z6Gv_m<&IKFr*|w6c^%Ym>pH5cU5#Pn z+@vcik|TEyASgHRytw+>QpQZ-R^(3Cz=g#+Ns~c8BR3_~$v?&SooBi?PYi1z%nbsl z+8uE`JdWdXtMO=Gl>OYbZhJ`unH`PHhlynb@ayF9{{V>fb+Xt5I0>8_gpdGlq#*I_p%t8>Ys7l4QJI(~ z^)Sy5i}3kBtMz$(4#Bf*+dOfki3yWX?t}*tF;$8+j|}7Be_P4^$G~`&jO}VccZLu* zXu+7sr1fy;{{RIak*95YN!nyzi^4*&2#(f5^0RSa^--qd)z`f;?7L3RL9;66qQ?@g z8QsqwUc{eS{hmSDgi~zq5|?q8AWfDam_G9?; zFBTmfrQUJBkHOCuOT$}ahC*a@$b*9c-d5s{TpPyo^h(kuu#Wnjd?c#OAuF^i+G$VBzw;vBUE|I3)^0gJ?(CKXpVKp z=t=CqIUdK=)oHTk14f4n^p6Kiqq|H~r{mOm|u_bw%`1V2>WLZlgOf<&5M^d1MA5JfQi+=KQnO>&Tu9))E2i%Hj`kJTWhWku__5KFrA?wIBTd}OZDuwN ze^oa_b?Ui2=k%!FM%*$sBv@t9wY#rP(%4CbP%s3cz2rfbqt^LKVi<|un@EZ-g=|9G zti#^YW=W%6>M(O_9J(*8oWRfv^glx4v7Wa{N>lyd4o+aH8sOQDI zyne+Eb>R4T&7q zH#hIUyQw=E;#B}M{FnT$4*dZv(VaJ9`eHTriTSGLSqN^VqJg* z<3=0HDC*_kf4boQx=;FoXd5s|B~WoX5TNCC<3xuZ=au`)mUdTEmuNT!@(aZ)G2z*8 zMDn~`r7CTxZ5)W=0l-yT-Q0DOKL+j;kM;d#?jI&DWQsSMLyU&yjAJRcXOiHOR+2rE z1WvnY{{U;U1Ub25%EixE448;aAHY!QBTJ#;U;E#Fap#0@- z!Xb>c(&4eV0o{%~pD(^h;)9Qmmy*2`Hpdn7mn#Fizz1zQd!$aUr|!PH&7O8=%0kYC zR60i1%&9ypmrh;|Ka5EAf>@tG?X*fYP&icp<7VFe^LzDlsCJGC9x?=i;y9MUW9a_? zwCxkfyrakOIZ0;T^_|RDl*qd>Zsn0IcSUfACi1TnuO8Cz^;_NdE95-7I#&B+V`mvo z3zKn(k0u!s+@?>;JT-K0G8IQ#1H>4PB6@jCACstJ+V+U&h{)nrvnxf+k=y=LaUV?) z`F~sMeJf0xPSCBQY~jU{>Lq3i!zS~i(?|M^o;6au~b<_o3i;G;7*1=gO?W_W2 zeF~F}mi#LH&R5=7pDYC=BWDU@bDKh=iJ$&rDWvMX-}}6KzvVsx40es1hoq!p$jAO= zv6KG*U+AhjK=$9wY@--UA{#JFvR1)54;;}dkx+!6mm<~<{;m^k>?NW;lc1J?dFDiArT~_0} zC$YMvBZs)LR$h{ksJ}t}s!C4O)2KMvQFygE7Aw|m8D$)u{!^^}@)(HYGM3%qREaN& zv~SjMPn^lSeyTbAr?}qRwjE12^5NC!bCPv+(H zU;ee;Ww`O>)78}tbrq{ETH8ILuD6dKKZe@6)opgNw)MnTRkbm77&N-u(9|AP*Py)> z4LY|f8q3gxO%3Y_tfWz1y7U*Vqz<~W_0$%cA09k^kxuHay1KH4;WS!2CS)Wm_#+ADfs8 zVyLxlM?9-S+vTG1Wyk$h$0;k( zg|a--q%kJo{+eR1$(PB^OuV4m5pN!fy&Pud`i{$3e|E9i9XshIHLxF#LtPAaYjY0X z+tsEeJV5361?yu$eE}aOg{hdxE=-Ok019!A{hCuki)lm@9(vO{Od}WV(wG6a$xjHL zVM55Gv-DN7qHRlHN6Sts4`mB04$8JHMYcn{ftOo$X3 zd-Kjw_)BJDb6eWPn|Xy*v91w;WAfhB zKf<{fm;+)IkZAVb^1K)$&t#Q}>YGDfOR~bB6Rjd)o5c$;)8>O!ugT{J&RyeSZ0&)Ht?)Y4?bPsG2w&pPpmCoAh?dyPdI zK`$+ZM`u0xsgpyAxcnpjsvVyte$PKR0=*d_%?>uzNj_}7jJ$sg_V8`xT8i}+dehRZ-n5*$>J1)8Z0M}W^$-W=|(U)kDNf<=2k!-<}t7PxIioIO^@|m%^ zvly^rt;sqvHMkiQ_q0#4i*H~;im!05z#8An;QPv4&A)cc7}*85POPoM+-u5yu9~3} zhLTAPX=0#cwz9B28Fps!RV;wTDl(9zbR~I@_uij8uF?!-xE9@jHB}rqZ+QOzSG=W= zhb+`wi{@^~ru}Q;zE8)<`U|7Ps9}FNAkmYZ@(kK`{Ivv9O9K=dEAa(k%&pOr7RC2( z@2cnn%PB34p?j;4QSI$pUnd{jY~SmP8HOgSp1=kUL z3HPZVTlckOF#ram`)r0|MDKlLhG~ zCESziDvjnaD@wvLZdFZ}G2~WBKAl0fZR0Xg175`qxSGSReEQBlstHjTCOgAtwd3@@3)9UvtIxl_EY$ zT*w`q;vE1on}topQRMqPX}j{f{{W23>b)f+1!KC6q?JZIWpG94rcbi(r{Bo`0Pyrl z&egmxLmclRp5taqG0w=Y;Ii?_)?o~B^&M49{{YDRYi$AX8!DAC#e7DGhJC9K&O@^O zU*VH!U886Vu)+U(N9T#mDBfOr+Et!r+wzY}@=TwbzIStnOk2(EOs$Ly~@94u6Q z8xX}=vh-Jz>b~-+c8)QS!uHFsUY2wuolnWl_OYZk(p}ugY+-yI?h7=X=pfI7gUfid z+h@XCgm#TtjrVsNjq&R6RrM*y*1&thYY1FN#OJ@|>po_(!ySDp}e0F(DeL;&Qg=@9^EnGyAzC-Oyv# z1(i!V5vb`({5<~vOUnAX^8OQ$PQhYc1yhK%*^6*IzizzwHkO_8+N3gp%o^LZ9?S`f zc>H;@Z4Bj#k##m8_N#FHi2Qvl{Y&YaYJy01ZLUdP#4a$R<;5N;wPg?g0Fh7HL{|8m zWGR!BC67)=Wd8tUoBsgH&+hBS*tVzH-q8XtaUQ$Jvh?g1x|7Vcto zF@_`foYm2i)n!tkksAxUHM9l2TKayc-PMZwosA^E^<9fP z70BcV57tJLCtXu6(Xql@M+?jt0A!F^ixIwT(tJ8`ZvOz`_I^v0-*TI@^XA&wHgn5y z&dQ|`^^P=c`*yEBYc?%ywRC{Ze_W*oCx>@hABe=-_w=VnDQH!Tj#46`cvc^D$_ zKOP$|tDF8?ckd}iB>1AuHe%qkE|=)>x5e5?_o+VZ_g=0}Pj}@tpsv}Jq+&qat1&K_ z9+I9yxVtv>hL+;Ps#Q8bk?KYLQU^;vjFB*eeb;+0}dL#reu%k8M+tp z__xLUl_ofZ!;-~FJWSduwcB&J%hN>ONi=_%^PGE3^2Sy14=TDg(EJEldK?I;$|vu) z8_!cEq@oF8eVk6=i<0&L6*$Mq@X6>usp|UL4XDo~NdX;a97>Iqz#K5TFYi9?ryEp= z?2(e0v1Sgzq*o=gm9SKfJ?C5g%e(&FU8+c8B|vE%z_YIg?Rb)0I3JdmdDYDyDMq@A4VM^71YY`eIKXC+{hdnU!6VKog-R znL0+pibRj0)2%2TRb046AsiA0gWh2xZ2WwmjA)l-RFEo^H&z4zn&s+vH=M7!{NKx3 zw2H?&D%-g`iCY}0v!FvFiCcwZj4^y%^rz2(_Fdv{Hyh~wvy{}{k>*7zuNXSw z6b){^TMqnGow8VER%rHM5if}!lO*|hUXJT|TMofSST(|^j#*GJ`l z;r_6rg_uOHRicoQ0KJMEbw3?%+-`l{S1wb|ahY7K2vQcnoZ9{R)sk3K5?K)5%%>6> z>&3n9a-C&SHpMJ1!b{5{o*{|9K0H`&BkKHbBHixU+$TnS)b`Zj4d7AEwLd00Cm~8Op%1I~P z&lDc(tc?V;k)4uAQX+QUq1Z~6$5cDymuZ!{8@Qy8b6z&nxJd@=aN6wvBthH0B9dfw zIHb!q@{I^_l&|CH&k0=2ofSm~)W>$aJdac(b z*ix%$*`|eKX-R2qjw!Ad?nyV4bdHc+2_(p52Ei|-$43v|(B5PP!(k&89NbJ0EL+!? zlStzF{#$#=OYN*EmB@Wq9TWhIi0I3VcyMRd(Zxy_M;w9SQy7z?3*r&{q?!7i-7!wg zkjWfpIWaqyw*wXouhQV_J!krM&m6aT7?OtJO@fwgvEhkcPE)7kquzfSM$NJ$iN4ZF z6uJaps_ZdM*Bb7hXyfW+VH*sg{pJ3DZs)4kZL zfMi(5uJUp8ISSN?rFmoCiFY-I{p5Iq{BnvC<*Q6hNy# z#pB06_kHhKy5&7y^G$gNY)FzkaS))W#wQ;q63D9`l+LdBN*EPP?RfT(0_nmfu}%EX zkGuTUc~G-OBDT(!kdneve3?cn(fnF&TX((n@^3j$$@idt=CC7l{w=)nl|RE0 zY-T_ggEiqQc_occ`~LtU+`j(+sfTC&t9Oh|BH|J~20D%A~BgB1gd=-g~}pQNm0zQ;7p06ozZ}8>NSj#%*b4 zHg3&q7z%b;5n*m?c2r173!bgaNah56DsWJA7t@>oY1q-45OKICQO&JQAa{&|!1ucI zbb;?<>pZ0tjEg3&$a%SrB%jsN!$z`|%rma8yv@bK(ej(mA1$?_=GowpuEUfZ&N;!a zHfSRyqm>oSxsP~yOW&_x7}_k>>^m*PueXZ+@|5t!DgXh=PdMiXz0zMPK?EIFE*?Jq>Hh#%`+kN^fRY1003Z+mA5Vj!2Ft{_2>_P zn%8}(2ELY3bnf-De((PP88)ATHtm?Qsu^QETgbx@*7A)PP6-plo{OM3%1jhGuBTX4^&Gho5J-9v? zh-oSq?|lan3-p?v7bEelnCb@$$Y?&93e_XI$&44C3Zj-+6~dNMyH_cBvkW{PFGHrFyQYk}VH{&p(O z$Kk1$$vdQpvoEYz5iyUx*T!GA3^IqLr{{V&J{K1ZlwpJZ0;7$9Ob+K@LT>zH2oRMUf zLT$}+qUHR})m0iTl##n_E;+|PK`MWXLDn!k#w3-D-N~1sUVL$gN|=>o zlxSMu;puW}-<#!rF}vOGy}Ari$nCr#l}1>SRyee7Irvk@tMa9%M9=VnV$p2jr7LjC zqlxv!d0!~Vi0ks#V+DCM`BYNW7>5_maS_gl>Tl$*WJlj`XR5v32boS9e)uq`2L^da3z z$?krm{3B~?d>ba-_sN3d&2=P<2t6r1H}}7rsy2O=Br7}0vN_j!q%3ML!5+=%Cy$y0 zbWw{$8?Z!F$!xhyzBenUf6)H`T81JPl%fj&S<3HmLBS-huBr7~o?H7Yrvd|NQM66T z6|o~W3`)C&UE@ucs^~{amDVIp%OeICU>DInDINVkCsZn|V3@&EY#3NG8+v>1Z~D4P zfga^;9Q2ocCrv$7 zI6|=v=BKkZJ)OYK zSLFWyZzCZs8@ZL=qjtN=mrWLav*Pdd-tXyDZB%gY8G{#0uOw<)oljo)cjJzT2;~tx zA)}OxK~}?)PO?dry5389x~7>nwvm7eLPi!aRHi8XJY$QFFDJb9-E4f9NF!2M$uR<0 zK@13N&~UmJz?_b#nTe*4fTEW>ucrT^3AN95gv%@b*rsgYU@;V%U0a;c2`=v$B!T66U4z(DOF&6aa7I*y5sT<1*p3| zN)uM0?xbVRE=S4PL^--5+}7u!Abi!99QLO}V8YziqpYr_mey9*RGnCBtEvn&m9?Hc zc>e$-GWD0PuBAK2k00f=mDK^M)|#|{X{gkL#Z?mlQ35H*_M*csmtOQPMe2O*&qN#= z&Ye4lWb8IHO3#&zh^5?qADXjyW6$_%(%OL{+QQ@KtYM|e&{B7V`N^XVe%n!|ra)u? zvalfbo_aL3#fh!G#g8}Otri|X>7eY!1D{*hAy95Ky$E4jkfTj(+R$k^)`z1rqVrdD2j~$DY~>p^Y4R+WOX<{F@bm>CNUfn@nO@ znSMeA6{x_{vgzwjBj9M8cG-d~d5>dPuulz-nt>1}HiOMv+TMjYk0))2uwt!aUMwmM zr%gv?eIC%HqYVZ74S3MN>phC%x43R}zt$~k$2)ms5xVUF3*y678h%?>u&@@35VkrT zt>&d9H(|~4=sbA*UV>;W4Xi9aO0;9-_QuRA zFZ~Z6f^CU%kSPM+E|tp9yx#bK|I&Pd-|PRY|`R zOk`V&6?Dr+JuFkdKT4ogm?36jMefVbapd=&?I}Q+6vD3WA`!d%q|8>24jbbcWV7lyqZ7!btQnaDy{&5kRvRZba$Wj{<5vh z!0>j`n=wR;F`bpUQ+vt8GvDXj`oFrWJ_(i(zZY|f2Jw=3WM60Ms#sTNQdL1$LRq^B zM&o_!ryPm;-{}C(HEA|R7zPZNJVfzuF8g;cAH3_R)|TH!Wpoxu+3(G6-=`~78)S&g zH;rXW;7PkX9+%ZaAL}~5oKxX4@e5ix7?ld3u>>4)?s)!N5qEZJZe$;f9`-&{kN*H( z>$RmH#GS@YZE!FE@nhX;C^iVwSq097WN)hX_-eJunD&x&8h%WkKk6uu7r9|;*q(cz zchTVS^JdnhFKOAfX;UK@5)*!F79RRU^ys!!0e6a}652=4h+krj)k_SQMK{nEKX1qR z!zR?rBLul2xBxbfFK$kRn0{y8M_IO=M@pq;3I(~{Se$ZS`BYDL>nRHII2s0IQ@aJi zM;2|$X1zLMO}0wIBCJmc0RtR*UfO5%znk)m#f~#%EIH0ulEu9!{BFJ7JS`?c8V%cq zkP8lOW%T9sa(Irc*|uQ7$-LdFUv$J=p>9ik`ML=skj)e;APFHQ+!L4!-b7gWE(VS` zdY#ty%63+$z$`J~V_szoWnXT6>wIr{?|R83{RYhhBy6z=c5(s^)?D&t>!VgqS2?8R zylIUb$Vn_R9*NU0Bf;dj{{UCY)rVq=*`t&s0DxJP0n=p5j-xo7fdGZ(anEmWhpzW~ zh^o^9Wq6;M1Wqh&6<6z@^^wx{$+n4=LJ)VC+<-!}n?5`Z`QN*#5Uh^PEQB%(08V(4 z*!nVWHjjVSc_{fU-FDfw?E5m9b|}J11)pd-Zp(GMt@kOv^`58Rb}1~Ie7gWJL<3Ah zRh@Innq@775oBGK;Z`M=ei+Vw3P#oIlq!oBL zIJy#aIT}`xuf|6RI9zHisK58=i3E>rNw&67JZ~$|ST=IvD9CrQa7$!~?oDu0rvQYsC{qlybl@BraxuirsV|ZL0BHAo zr&&cXrjgTX5tcj5ZiEaJn`eYyy03B#qh{I0`5Z8k*HE3{WAAh&-TKP#H*(0u=8&|D zY$FKyX~Q-+qC*DmUs=}v+mizs5#-WikvA9gH#CbY~mPXObNL`&!2aH{{S_bb)No*YLjT*F_Cv{&0@W* z$QCat%RUI4th?{}!DI=v1^`H)+U&Sp+;rmMPCtDo)>AEE(e0C^gE=D$1w7Zb?{_4A z>&MACF_@8_%^_oAMVAj-kLvtrV--a0y@1I$Y}`C@Wz+opU#+-K!4=&xx{ehbKrP?*xvRDrm*t zGxU;KdK}d73S4O?d_KGC?n)~ zP3Kkm-z8ow)_qbg@hnZ)Uo3msLmgg_wmqN)2Z*YwsJCou!^`2yPabngC%dAX1(X88 zV2!W3k4@X#lIr<9-`4uStFDJ{lV_Qyi4sOBgBAq}sL7S{i#^}wc~985lNAz3HFgYC zW2Sth*h7xzeY)y7J;(G(`HHC~S5_G6t7bazmmC|sJ@qB`-xez*XMV&+jT-{DZcBAE zeouKouVcMYgka*JaFFliq>;1bIKSQfU+F32WOPJaLkql%#M~#Mk?k9(I@PC?UeUA9 z@N%Lz5puA-Xx0{{rB{$n`Cpc9j-m76o_MPAbCTa+bmK@ikuJB#TWa>yYFjX0T$I6 z_}lE_%7Eac8>k)ma_RSKfH(3>l($4(8MMr=B@_{-{{Umv)RLB&%Pe*v02D6@-^$3z z)1%p;Eht4&>_$pUViykf_tXBblJcA-glf_xB;$ZmQC`i{k^Hz7^9*|5SG(}^Q~wq95LpVoT<1f}G`qXsf~1A<2Y` zaxdRf<7SdDt}d+CZ1`?R`@gr#J!_be-3Z=?zO2|U12S%%x*b2ya-V`%mPM0)7%du` zlEAbi@TM{^q2qf>UpC1M0U>j}k#yXHVH)3*ed}_mp3(X0DKS=P%CwdW2(`m6PFSI4 z@}BmTcDZh3$x@|cEjBZe0Ah+G_^6B&`lE})f zv#3USPCLQavHt+W#n4UOe5dbgieQnm01B{gn-Z-i>v;RU_tt)HKHbEA$+kYsN(hmm zcOmyC(kk%oFq_J5{Fk}p=(er=>Kl1?k%|T(Qpg%MS>{#fboX;^y2z;#%F7&q+a*)b zOD69w9!I~kaBjEWEcjqE?8TK>In##}RYAkP*`@W8b$=S#VBdes`&9`hns~N`ZJJQS zdx(u}+uTnE(=0LaZyeHcm08-F5ZPq19#yP7s&bMedd?5K-N!+)Y>_Z#2Ju~URy=QE zoDxRS?)M!zI?hIuwe5RO=%W;oO5)N1ssw%pl@rt6Zuw93sey%CYLuXm6_wbco8n&E z`Pnk-Gs>riWD>lmg;bmiUt47t@+rtRn4!u@%Esz&RTp3!?;mx;p3lhf%kKXGKy3R$ zEa12Mu$5V{QG$`f1Mw#(AMbyhcuU~;QaIt<#yL~2^>1dydffUDdbui?z7buj7>vPM zLJL_`HtFxVl1vwr-Ny^>C7ZzUwz%JD+9Y_RD;O%w6-irSpY0QFlY!^!H(RP~Y+_=Y z9*!-R(77zETZ`hW6M7$cz5Z=4`~Dyq8_zb?25Nm->yWqELDfk?cf9OVB1W_&pO zp8hY^{GgG`i7&Mki)dn}vPV&q@~Qn@1hORCe$eur{#XK zhBn=Tz>vXgMXzOHy8OO6?a}h3vppP}pqImU%4Nsu{{ZfAf+L0?UdnIGbpY|@ykE`o zf8-XzNX*=_rLJ$G2dBUHiT?m@s@6*ghDTd4=D%^Oc@PI607rOphbNSF-?z8<-cQL% ztveinm7y(^X2BGkHT67_`hGmcicvwa#}izf+tD0E4&(a?RmPlliZWlFnA-@)OO)~c zKw38}5wTV_D{b6gd%Wc;g61f~BK*4Iih|uJ_*a$vf7W&yYr2d=-B4v#BH>uwu^c_> zcv+MwYoo2VO5kGH`8g}glWj@u_l26l~*1Mj;UxgT5rHnBefCcG$jjRKGq<}5kZr|opCE8PZ0+o$1 z0ay+xB$tU36kIyG(<3X69PMDl_8yD6d$0SInqnu0b}G!-g|P!%h!lHFTRNhcGAv1F zYxa@lqi>~G?;&NmxHsk?wW)2h#U~VJjA}Ho(D%4&I%y0$Ff0Vch)c;~#4{&}>Fcz* z%HY}7JVb;GWLt~7>JpPAGsiKJ#vlU7styg844dDl>V4nk#-9m;f4 z2}Pd9+8%6M)yW%s&mSQ?N)Q3e4>0WeJ^uh7f=>@>_$6nQ$q-2H33%J0Ti3(t^YXqg z=8A!A{#N&vY@MpWPuMvE z1Rv$cLR5ya7s~f1U*)ZRM%D0k$2x-1h}dbohmvLVGKJ9~nm->jvHS#>l^atv`vNVw z{;~f6tNzG5M~tbkVCcf&^s5n+-r5n^HPfNruTF1(HiWrHM@wGj@jCsEUM=5`q@(aR z!Z~ReCz)7CseePg?RvlSk|=P?7HdX%(%D|!>>H<7?E7r{C?OBVAXA~&B?pTb{(^Zr za(o2Ls!r3X0JkP=3HJBV{{U6}lple~mpL|;b8hOaPITm7Jfw^By?5+l;Y#fg?U7v= zFMKVU_ z;g(CVNo>)oj-K_EchPwJRpmuN<3h?<-q~?+0e%?qp(Rgr5N~naYYv`uBr?2lZkc&O zsx`^pC$S@E#~+m5eQ(j(Hie`g(-MH#XclOO;j z2Fl%QMTC5uomW}3{%?Cl_%Lwo5U@OCMvq{Kxae0s<4N_Ot)--aqzbvJD`c@q*JhP@ zT%XqPEl-eMO~u4^0J*YyF(%Ek@MpaD-EUv>xvYm5A~#S2g+kI~;ju;9Cg6z^FZpfX z)>ENtB!*?mLZt6WecwouPa*XigzIIXgiVVmL`^`MoVG6QlEnjgNqT>t>nX9yBtT~4 zvlLhml2}G}Z_95UN%xrblK%kRa!jLNj61H+8EYhJGBJzyH*xZN&k}kzr@~`IEb?U{ zzNX=IA)AF`gSmvBOX{Vi2%~_jF;dpCUXd=x)BWT>Sxs?ZSj)kW&PHJ3*eMUp7iLJI36i z0A1?kkIfw$D!6Cet17$Qlr@!0o#Su)Z1R4mlZKlV0Hx#t(6Pzbg^wfF@OSqyd2h|s z?p`sJj0r=GV*9(2Gk4+dzJFQeC7;ZwmCSglw&W9=@}fMB_j3}u#~&O0Z}y37^QvVV zYM}AXPbN=EHfZAVp7LK^LoViGiwP5L3`ZtU3aK4rtiSHtuCk+)%_6F_amKB#H^iKK z%$&5Cd!Gv3n}4&!<~NEU*WODep;p-@;W##l=6_Sv{by!@(UsX|Cv&?pir~5~%`p=ng$t;h(H*&~&WLYxelNE2ya$opc z6XY^^p9iuMMhKyf>0S&KT{?gMN34^~X$CEhFO8SlAUUx)fJaQ8`+BON4TooK+b9k$ z?-){XQ<8A6K|Zpkk@jSkV;E4DS6J-B611Kxt-# z(9s&o(=4v8uBkePiiFgiU0q#LWp!nBA)>Z~R?vlASzS?eWp!mqwW6|!R9jnJT|srL zEk#*U4GE*6H6R+pL}*PQHPzMBhgTZv>Vqq*tEx;iCW!d)K3j~+j>38PJH%i+=9Pyf>q z5^qF%;ydgggA_d!iTtM@lc6fGvoR+;awO|u4@dKKaYZXV;tJy0HS#LQtlm4fbH~cP zt|o#Q+#HA`mIS^zaOZ@J>w5jG`Rsm7SwghwQ z(ff3*;P}ogaf}mUP!vXPJobIpRUEJ5d(7rzzrt3IKv0Ge19u(R_mY2!_1@LAmF~7A z7@YHw(zmPp-|FenM=8mgO!68pJ~;70%~eQ|%$uw?gjqy|xzzsv#BOBW%LM+rll<>$ znPW4+2N4MA!5!h6mbE5 z{(ig1%&8Ja=K*8BL99krIKR{U=iX9eiu@#SjctB!p-nQq75HCf$AnlfafT=CEiwKS z*~OYQX$jWd8+{t_#8Hv}Qc2Uz{VHbdR2)Gf*VNnj)t1+{!c8kIE<|X0QK-J1EaloI zcq)dua>Ru|ZpWVVghwX!1f4+%bIsr0yixe;snM z0TJAs*nLv&w;B0A-*<~Avnq{@$&*Z~RN@xz(N%6!kzaY^B z7Ak3WiPRAm8QGR2CXbht;`Oip01fLa=ch>oo3~aD4ywY&{Vba=PQCr26%xF1hAj35 zXu2u~SH0Qf;DJ83lBEQjX^SUov`iOJ6q2VUn`PR%aBq{xqjteEc{?u5U4VO_Zt3kO z)_zS|b`)s~NX5i<3_##P+=e*3Zu-xetnr53tRj!a6cNLdZuXp8>Nw@`Z$2p<-HJ*>778Skz3}O9G)$d$ zmhobKPHdq@Q5yq1K!*S~CeIA5%5u4NU!1E#*Q-3hyTqz7W-2i^#Cd7ad-LsHx@t=l zF))eDN#2Ex-PuQ{`@dPd-Tc`q0=cY!sk^%Og}4>H=kGkMpDxDotdUtlOr#jWTfV2Q zbjJz3+`pUMb%10%n%!I$3~Wd@8G3b*Lo||Z<3l!Qgme2W@_N1wv&EmZM3{}Cf8qgX z*2=7+{#rO2;_*ef-ANQ(X7jH0*R<`Bp;eB?Fm!=Ig&&9i0QOHOE4Fgqb?Wzh+D<~9 zid~*rBZ;ysjK*h-gn@#XSB?8P9ysp&l-bZAO{{T7fI$fj> zUU@>t5SCOf#*cRSnyErRl z7SDs>)b(9n_mk?p@y$GOt(n5Pu~Z`{%aNXP$jbTtNBNaK4YO%O5R+z>@$~QtwpBvQ z&~Kw&RJ%}MuEQLG$Lldf zL)oHe?j420kOlHfuLRO=`VS|QhJ}K$VoMXQB=ehALpDUIa+b1kI7k3JKrRsNI zd+6rVIxgxt)ufi;-CeN59OI4%_esW^Mtkl`PPRC=m91_w?TB4smPi?ks{~!%t1}#) zc-5b$@1yV>(N1=0A_dG*syk*&U|94c+5Xq&bHaH||9?+Ok^z;MMDd)=31 zw3?_(QP3EGXe_?2{INDhT+d=m3KR+6W4N{svl}ZHq#+9LV;sOAz1+9R89uy ze!7X#$dXZvd#jDkIu*x`B#-KOUQ^0bShUbNKm%x{)MI;jPbawFBJa0*-Kt&Q@2T&O zbHNFHo>QWh-=3t(Z1^c z0CoK%cj0KomGLY<%bc*c0+Wvy*!xOQDDuNLw3G{@$c$2w391AA!`s+&87u|EZgYQlHm6q zSCjI5O(`sRfWghlCyDYM9!q9iq5IFe-c`Ry^PJksqEIZc4iI8%cO~)p%GEQb&S8%x zK(>pAl+QKCc`AAjyxrGCx}>q2mKVlAdMlhWye}V9Hqqq1G1q&(=R7Q5lQO12A$B{p zNY5#|X*1rWo=#n-mEBSs4%`BA^9c-cC?^oDy2Y}V74op}X;SSQJBD4HS_VhcqGKdu z$od=1{B~y17SR?4NU@CoD`pZRq3Obozw-TF@2mA*w4rg{G=vg{*sX~hq~)dEzv}+7 zU&$#edpUOsz`NNPl0r5-Egu(IntrZ=MH{7vD8O=sMfH;Vx553~eoBHOdBTLw_j26c zp6d<{{{Z!WNyw73#T+OZ9l15Nk0;7uhVof|#r=wO6G^i>s<|lQ?rnF<%$z>@i)Y$Y z0AO-uRVq1!fb#Qud|hp86_iMuI0nnNaoxyx zy0~NVoBCXfgR#$(UUvrrgB}ErXHV8m`q=1Zh{%ezD#}O&h(sxE>>RIF{ali7Ci=gl zspSb%pmfEgtOxm*eppb{R7w-1wl@V<-_Z7qzYHCHT2 z9Tp3GRvbH&Nwo8ZQpAg9IGIPg_Wk7s*PHD!2`tirINMgZ>2OEMarsXc8zHabxQ{$O?PQ~3wQI&4E%iMCzhUeUE^9I#k*02YQJNk^?tz~ z^Jw4LM+VMr(i3io=y3WTSKZf~HiIJ#>N3H5vLj<7`Z>*p3HBn7-PgUdO|#1m=V!7T z0uZgR7vqz6E@`w&?>TQJz4i)ZP_PN5U7}m7DZFtmii4vU#y|TQb+uMJSlHW5dqF4K zpx)x8n1x$v%zn>W+q^5U49#M4RXHuS{%`fQir3^WNwX;^B?Z|?v9}MS^yw^awkG!N zrTif1u)T>!lmNK;3VZ%bkH@q-G~#HRFj%F9z{e$St?=T)PEU8^LgU8UI7xCGkQC!e|Xwdd_6UCOCE1D^a@j`VGoe3y@F zGBs*#@Qs2kjkL#`9`E~Cysu+m+j3pgoQi6<=4kuMvU=Fkyh5hs9O_P2)q{W9lgC%Y z=$Mc}98m$2TFT4n`;zr_rhLAlIODNpYwAH>(6`HXg&P1=fqUCS=C3bfo;Fljr8;>)>FkD*2fhq#JX_T zHkl-Z;MQQkX@6(b{_pt61~yjCKGSwn+CamKxsy^rB~LKzUt5DOLWz>Z*j+mV9@h4& z0g+dPjF+6+))glN(@r%e`15o6KSGr}3vj_keb!UN2~-70em5UnZR z+lgZ%cpQ$Whk+hWtAA9L=jUk%;=TpQe`_fBf4htJ$1Hn11Vl&|QN~w;COvNRXY}#^ z0P-dEkoFMmEc^JIb7c}R3}KIq@SY5|hyMU~h6n3Br|@)g%Z5ZOlc8;RO5#VYldOZ^ zc}fkbw@&JVmarBWoJF}vm3qymZ=bE&cCDlC>kG=#ENz5aI;L$BuL`IR==)N2lHn(k z@Er}YVlG5-%+6mH8?k(t^e10fk$CwDpk$CoJh>nRz&E(D#z`GJ`$ikiv#aYXA)ggg z*}XIzn8Y!LJ}&j+pG@}KIl2!|tJk@uJk z?`f>6#qdz3*2>IVL6dJYUmvS!CWy-;WJ|DZ7Dg8*U-)P|idE3;uCf<) z#Q@PdMPG}KNtdUcbWt-7CO5GSeij$&pSk}4Z~5F9MR2k{Dz;p&U|8@=(d8x1t*so= zqL7icP1<`eZ{$>y?E{H`vgjSJjzsWdAASC|n6a=93c1R~f~dm%E>C^!`B^zSlW75% zW)}A8SaNISAi+2!e^>q&G*WF6w(!L@$wj%fl`>U&J(ei1Vy%7&HXv} zc_^XexYkYH(8t7}Dl-9!KYw@C??0>cH7O0wc2YG(TjkizZh|ZOs>qr`FDn;e%$K%D z+kS;0i+A`&q!D$n_B^krGD>B`#>>CS{d+Gypm+8{JFy&fg#liw_c1X~0qJQ}t%y*H=7gyzXy`>|{v8mb* zNQ_$*;%6CW$?n}D!^=zlnqWKF__X2FE%j}E7OA$0C1yMS|Hz&ba8L zA~$zuNmKVx{buveBV5}FcB$m+_x(=m$#=Ie+^g@I7N_dw3%M(iP1Z|L>wIq>(+;8NqL4X&QSCE zJRa%m%a$l&m6Hr?#4%IDAW{DS4KEkoZaH78>|W0<=W1BU>V`nv4dOmrFC>|7A9p9L z{ck?{d1=e5hUeEG!uPH#384qaH}kxI4!wFFI@-$W!$@sqaIQ5cP?KGH5^4gy2{o6j zCXu$+P|*sLYilb-Mb*{S)E62WBS30G6{3RvJb3;+f0m`Q39VD-q>?C+zYR-jD=Pu5 zp{5Z+XpIS}1kjO3Sx5mDgi+eE)gW57w$;|CGS!B*)rOSH6Gd52Wn~qsElL5bG)A(Z z8X;L+DFdPgfKhnydJD>gMMfYf-aJ(Ehd~r`F@9QO#W|`bQ;bCTLXu|xCa^>T7 z%KXvKK2~N#Enq{AeR@%2$I$xPI9uZqU{)~YZMLvfIP%kazur_x-EkO|j`a>9 z)bQvtU)9wcG&~qIgC^!&xXC`6zgbh5AWf|t9zqu=FT(h*{KRa%57v1a^u%tg9_5Mc zvLGwc;rD(f`E7B)Mk1ZaXDY)JfnjcQ$?&(Wrenm-ix36u*2Iu7ZNhrJ?(1<}(GB@W z=5u(1S(f(ct9s{7r<0`w%V^z}OhGFaXL%)p@p5^7cdPY(Hg5OGYb(gCBDw)}(#4qh zUEEa;s;WxZbQ{mYc{v`c+9SZO^a2=hNpc@C>F)cI{Y-Iw@W(xogD~CMZpuht-iGsw z^M9QBI@iK91hw!0cIBrPHPr|jXXXQ+8*zR`CN>NPr&D;Z&u>-!4~x=^kzF2(q+I8PSw z_j}yt@$=IfG=Sx7+=VJRf?p0>_r5hEp?#!B7YcyP0-{`#Z(6JSzs^-@w4*A^h%aCq zTJkXkdUfH&IF@&K7=Q_K_S1y!krRumf2`}hT~V^_3^R&LBRSNNY^%kJrt81?&+lq^ zX819V(F3ypvIYk6_O^k9F2^vwVu$;D#j(1czvVs}c${=$FCr!BQ=w4e!1D-!|APSmP)VIo?&?)ibl;+RR_i`4kN)5#GlBDuNM2~aLC#g|veZ+v+rd#23s z8OSfL3{8V?8=>)BeGmC>zxr6oylF5`G&!YZ?HH!=8Ke5iBa)Z1$_leYBbs*Z zIkbdA(MhfiylkkPX7Sz z{G@!hm-CZ*UY>I7%?y$xvKxV*#N=^4RXmMb88R$sDuyE1->2-dH&4gn^hU+Bz~sL7 z7?`Md)wjkmst+@3+3&S?-bM zY|(?H;Xy;d+Tz-_>MkBrp;k7-weI(?6UQCP7b4bBcs=I_S@m^m?C@-}PSGTapduBG zn9R;`n}cXZlj1OX?oZ<~{%bn+Bk+dM$t1<1-IJJc$&_vs_`jBtUD9>?SLS^i?5Dae zPo0oQin5Z*1?`SHH*+dh))g`%fkLT}5vqcCCCApMlaI>R)NK&$aj0FIMV7|yvIygO z1-h;snn<5{_ci0sGS4yD_Lw(QD(YC5ca1qoTa$(zoBfhM{;uQH&*x?r9%UOzF3({a zq(w*_EsRm)l$d+T zH-G2l$}nQCLa{6{U_!>dEY>zveaCa(d0tPOqariT@)_AgjIG5X99o6yles6Ambpr_ zrPaqKy2z(e#OEL0MME6YBeKG(S5F}ky zZVHJC`AFJ7_g&pBq8pedXtIvPNZPAiJ!Fg4b)HZCC}A=>q8mA|k1#@8(&W{%_q)I9 z^g|a3C`_U1X498E}BaTp=NH zxeC}r57tckzjxMCE;mL{QDnh#2P}9F8SOZ?T{l}7#s2`%qhlxm6t&1L&x~u)Y#SW? zXIs?!&hv<+;1d%V5pGEfIbcgRUDfIL<%{a+YSFxf5-l9rNjYY3hH9<-WmP0dGDN(t zL1$;;;W(_D$oVMuk^8^8-S$S?N~@_<3|EC0Qnt1`>6@&QMDlR>XF+Q0!Z;qqEPPa#A+jLZUD=UW<>No#9puDlVRCh!@Cmiiw6p=q*$m|ZT|qFx}la3J`WEf zn4UY8ibA@^!#*5&NnzS$i~XD4`?~aCo8s|q);!fsnHt$$j%yZ3yrvt^9oeLk5j2vA zEM{>P(MVN$HgPmE%hbW;{wwu+-d0A6%5Qqz^Sw4Sj7PuSZ!UFb z)dL?-F0xmv4$lOTi8p(xiNzILSuKY{b-^Ta`90j<7LlYaup0-sj!J%KrchfHEtT zQ7Zy2mISk3E#CJoQ`T~l`?({lOK1+vssj^?I8upwC}YrfZ$gP?+9TQ|+A`hk4%oC{ z445eV&f>aB8!g%r*_XKtHOr?Sc}xBsHYlSSl3|3dsm~(Bf7#@ZXWpy-0JLub$qLt#Zn}mKdatN0fz9a)|}+nLR0XW&Z$qQ-3y>aq~=V#r8fCTT6ac^{9DcZ{5uSMNN=CmF`#OrikBx8Oq{ z$CH!Oaoh3GOTtljnh0Gp-k`A(uH!*Jl;SX#>2Jn}~` zlaW8(DI0ESD=R2g=mC|6j#U}77f zi+xI*zA-L+X45B@^?$Ni_I<4|Gt62S%OtsV1FV;I_7bjXWtL6o3j*rsti>%OV%T!2 zv%>TIVdnbJ=f$(Y`@MO~5;UsC z*5eYPc#|u#q!W6~BmLzyA~8QBC?Z1W0;Slh^4<1Ck7eGs`^UyNzq%L$70fMy689Rg zHf~AZ%5Nr5`+kjfrk-FAyl|vhivmug8}yd6kW((;=Qh($>~xtGwsK-TM?$50%j3tK z8cA5%_=^#4f3sO=^6>$aHIFs9>^&c+OKF+xMu9n)_7Cs90EqVLdr8)sEVj7$GWoh& zxcF&^cbG&81V`ovu=zBh3dKWMM;|y}mbUMGbW%Bod+B3ZSBEFk{&CEs(O)0fRUE43 z#GMZ{`H8Fk4USgY1I7mKBg#4E#PB*Ny?pD>5(d#P&8xRR(N7=c>wFx=Hc&AoO~5^? za)~De4Lx2<5?I;516@ZScP6ef$6*z&*VGH$eI7OH=|)TArbguHs~hxwhxJr4&kQGb zF}=m`Tv_$+aI|qp_nmL3qf2Ra_dDPLCoow-I;D>*^!mNrRo@3nsAYK^Dn21HIIFiw zlkVyXAA7X?U@is81ypox!-BJURPvrn%6qWb0o1uWv(OJ>nBS+x!TGQBMDq4Gs3ak7 zRZj^$<7_j|yxA4A0O%O%Pd1$|=hX^S-N*nEF;yCKc|IOWIH)BK!yO8n+%V=p_k5I< z&iQnn3T{cnY;4P782rB1l!fbYkRWhu8L=azF+SZbReBHE*fqx8^!k21rj!U17C3>B zwU$E5mhN6Jtm^5g&7M|{K(UP0;`i+yntcx?`M;anSwT`{u*4EiYbf%S;9MhwW2jKk zGTzRtGsBC=@xNkmWey23-P8aJi*aY&`ni4Fl*-MX-k}%~AW(|ED9=|QBd%X>3Tlz2 zBNZa)Z#VrFHXAcWF?>ijjdZn}qXgSLo`;Ry9|21%TU|q6NB44g{_~-;m?_k2=eP6d zA0u{F!TU|FH1yJvCTOpcn+x#)yi!l)N%^`VAt}6e5rd#%=Jc?ibx{mvbaR>_F(&w2 zB!|*$>HEbt2ZZ_!&GfzHMEh(AGv2J4&D79&oxj=VJ+UfQ;D+eA!7l&qYWutceele|ZjO~$&Zwd_0T9Ze?N3l%_&lDpw)4ciVbj67XGRX_T>YLTF{a)7(V z(e=I~rHnf^@_D`IB?jx)dCeg<(xtaqC~UFFjPYaqkMQX)6lpM736Tx%!oZe1p6V~C zPEwHynGK*r0$W2ODoo0qf=+<_^f}{QsgA~| zpKE3sYL+J*a$p5{av3BBTr7$%#BtkA?g<~z>uV)P*LuvrXws50U3pk`c-P8$&Hn)Y zS+gP9S)&DFf-#d8H)aQcBz`UK(?90mhxXcADJ6*%mISvw++QK-Wm-1ZEp=077774T z3E|`3b)288r|p|GP|v$YREoi@qRgalr^a`Zswe*dlliTm&lq@+Fkr7CZk_lv^mmDb zzBTD&TdvFOA(U@-T7;U7c;p9Rc-kOlX!kUQMJiFT#PF5IQ{_ z(@OFF?~>kp<+Hevnl!*7t=1w!*7SoUX~fH<*mftguVM(-DN zUEFg209C!IQjsNXo=IiW5n`;zyiZCvJ>E63?DOrLM1cvMfHPx!Jrqp($$!JDWptII z_}1}Riw;dJfg7VH(~?Q!-SdXkJZOVu*j=NITJGu(7M*NcQsqAb^}D$GpLNN#uuNMm z-JHmR61X8#fswnrxYzz4>nl@e+LwU{5ygASeiS2$eak0V>gstQT&*i{Vz*pba54@? zCroha_~i0X{@W)qF3PzeP^*b$w(zyrjk4h7BTpgAZv8}6TA3@p&_QS3NC2}pLKQqb zr>vRja7zj~ke1qKUJSwDUhWyo{aph_>m6Z}uycqYj>)l_=VN@hS=L=>v zHy_N7jdpnC2P;U!pCSjdjz3PH$tD>wQPNd8FdRHflX5?ox{{}sGO{Daj8KwEkd4+v z{_Z-xh>2uTY(m`b(V4v-vv)4f59{d#c0;UzWFprbO@+rJ49y?)Uw!QAQACW>LfbsB znKC*Q8?95Uo+)3CXx@9enKnF24X_Iq2^4_YSPwo>;O;>et|{{VOQ{!w5ZhF2*hY;rNV{VV_VaW zrh{90NNONz(Acj+YikWf)s@xMhgQ~CUepm$P+V66DkHsh=x!?s6dknCn#`z4q%@ks zO=V~m23*l8wPmY7H6zs1DM?JdX15(#TUk+Mt0=8lMM;*dwH0-$K&-VLt1UC%SWNY>+KS(o+wU0zn(wMwQqcCZe%M`7CQ z{{S)XJe4`5W!b|qWmOiOkWp%1$71ia#Ul|<6Z9@p6CAn9*O5yF_qp$ zf!-m$Yom0e>bE`5TTEA#YdkEwtT{3X2STI{{GRrwMU7;RL?am(Mp;FgP(yDbo70Z5 zqgGLDcV$4_*priSC*}Iw{{TDKrW}((-Ey8gfN;yJ-J-fh$Mv7|ne)sLLHEK~Kx?7L zbVe<@zD>~bk3%tz_A1g7-5OxM4o~W-ZN3pjv>b%My4^s!g*b7<%P#Uy>c2^`a4zJq zVtAAcva+rccRTf)J_=aSM8*dfisuLRX{ezkMv+6H=(7UB53Ou^=@Z`9I(Cght%~Sw zo=NvjOAh$r4aLD4dAeSM)3Bl8-VqzV@U4pvJ0A1y=vF9@8~}`yYy6{!cOSi^tgabx zTc!0urYw4I!TZkwwSopxg!~OYL+|f>NHb{911Pblilxs<^J-MC^7N}Kn0yIh`2sclat;;{`2>A=(edEHAiq` zLJkWU-l+7kX(HLv>-x&7Zp8AL#`VZ-T`z4p`9%B2tDF?rBDB_Oss60SeT`i8y5`b-AD-^S<0UlW2PE%%i zc~_5lmz0nEwFNQX8I?~PMixPGIa`if>p!Zb?;&Y7pqmVoVs2MFmAJij?up=lKyZvk z$znlbO_zp#O*kWyKe^ogn6OaV9J%oWn0LO%kHF>2F+j4VrCs95kOXydAobq+_jLno z1zRQ{_mq>}QR({?YMxMqVh)5@Yw3QmJbC>nk^-L_n#C20z^8$811a{bSIKO?Pj}^d z$zk2_^K_mI*uD$ohS%_c3j};hjslm{t)J8PJKFYk7LRXrl_EviHe+~g&%L)Rjy#_y zANapCJ3hlM(qx-ya+@2Avz&v>8?3`8 z@NeGrd%s^u+8=A#^t=Ty*Cs4P;n?yt;5u^ZVWkAMF zdRg*(3F`a~o|t_e$*{$pg7(Z4$>iDD*%u&O&1q3oamkK+v(M_cRLy~6 zf)?5~hhZi`Y*z|67n%^=MmtAc2bkXxnyn*_|iMPA&#E#$kMwW=uUkx0L&d zYgR=o*ew#u%n&r7sscEf?lb=YXW5&lxR1*9be*~|yka7`u;T2`)#K3K$>ZdoX4@i- zm9b+iYE+hP<^KTPqT%RA{{YzCeOIErs$JxYRM-$KEz!U8{{UD0+?Y1miG(>NcZ9GY zfrdRapFefS_mS4ewVVJ1DGE@(U{>H(`u_lXEiY`5cLQ|!@FKoJag6jks*kUIi5y+F9MKZE?oZ95NQ+Z*E6VB~r#?jzIg&N|3feu&>SW)2+_h9#6q)0SGaSUnc{`Y+actN(BfQ!@( zC&~FvQ}e!iECxPd%vjr-(oB2*08TIMlVQmS;v`(zPO7&&mmBh%-QSg@?NYiwKoZ4^ zTwF4djxX*10NScoNfu0a1|F>^$?*9p9zEyYQl%YxD~1{|kD?0?+tHIHW`27MG8L`mJKZl`^U z=QMBkUanu0Q^ErUN#!y}%0#xk?Ta@?NTdDTX+}b100SB@Sj*()Nhem|jg*BeTeb_K6YJ;kZB9lZ_?Z&G%CJ=r=l= zh_J;$(y5KdI{@mh;Ovj!a>>SHvhJhfUG zp;S|<3*#((U#|C?z3dy_TgkS(Cv7VVu{$=Kie$>7T>?FqDR;jvvfFQ8&L;`pp?l5R6^CH|hGO&qMtBp(!1RB*{1 zqgP%$wvF3B`D{bZsY#|Q9PD~ z(CQ7C*gs~JTd=%D*o|y@M!h9LFASL5orU;^)A9PgloIUi6qaL^$a;6|HvG{qN|daaG8dC*~aE>Cr_Z-1E2^1p?#xIC%USBUrC&IDF*m_GsBw zJx$q&=Q{CMAZ8#1AdbL%c>Op-G1PNud-%i0+6-iCmbSW&&^@a3wrc{S_>y}EoN{}w zZc2ez0BzfHbkg(@?GG7Z7&?$H2p{!fJbpe{y~K_;hmDBN>!%8V>*Rk~FC|cf3K?tz zY6bQ7UHRyPy9NJT9Te6JiDNw-%>4zOGH(5~eIcJOo8fYRLOUbDyk5~tx7AZkM~ zU^&X3i!?9f(kjU0-7-iT>MwEV_jvyRuJ?ScDhRfTn?0NvF|cwif`QO!cBycJJ>ROA z{iFI=?4vl4tmK}VxnD{CFOCl<*M4f51+QOcYp7_Ld6~e`Bp!k9=5sn1Ha0A;^ zTNd^o{{XLEDc-W9Gn+XfP#hN##ha{tNABq1f;lC&&l@;xS0R@-S%xju=j&(s**wnK z($@G^X%G84h(XZE-GB}*m+IljdF~(w!!6Srv`}BSCg?%%iy+ zew}$6XTk)`V^UHi987d2?Yv#*yNaE*d?6fQ2Ua=Qq63RJ{{WYdQwPwKmGypqFKV53 z=_1T9MbMoEhfDK5uk)Wx0C^abF(qFmbvM@k07us0Ub>x(LeoYH%fe)l_=tP?XMVO< z&w0A9A!B139JqbskcSFZhJ)Ak$960~TlNo5~4E#A)^ z^pg6W-tIVjok%SNu1*D2RGhADl%4(WeqG8{Fv+v*tC(yJfiG^d!~Pd8{VfDyWqE8E zt&pf|asL3zX8eBlY4V+B_tHJwF}RE__WZ$y@Qrx7&aZWszy9q*8)(`BhIU&@NtlJ< zEP5M0&4*8ML-Dr_q$FZ7#g8a`y5@M4EKK{!Knh4DiZ4rd8gw_F^LzeP#ex|ufdB)~ zm(q87I4CxU-l|z~w-)uz3Y{)Zex6=C7Yutb0w&Fg={H22j#2L2c;xprPS7HfKvh;p zVa|Bu&@O4SLn~pkedjAn+2WODS$1bbYvikIMh8!3X|Y}B^)mY2)CVP%B#192tV%Yw z-!W|Qi~b@{cUZ8=8wPnW+RB~GYN4bD*j|Sz5=p-G73kuz0fMP;$B4M#Eyp|Wq~$pM z&skAC_uH=@42H?lSlsmE%))uQ3R1~+)XG52^VJ>7bB|R^>wKk*~Q--<&LVh z3OxjkzxLG`$t#P^8BvhrgN#T#l04hqKOcrIskoLDa;^a~3+M}Z{>v~Kru$sAbs`mcMt({`4E zIKcxgp;$UK^Nb?*WsMsL-hO-PJHl$uHqcJ?V;WrW$HRYP{{ZRY`l`r~nF4`jYwrox z>*y!grQlbRRz?Wrg0SlqmwhDj&X4~954_}I&Q0v!p{O5fA$DOS2VgI3G01P>{O7V8=fR)m#v3oPEX?$DIAE3((4z*~rI5ua06829n*b$7|ATh8x7JfTp*8F39`^V&){f~D> zmdtjI6s(F;ZQ`=4^Y0aTJf!sbN-)L3aG0ICayOs9@p#+TSD3>csic{AtIr^249}QZ z`904ELv^|2;QZBla4IvLiFp_+h}c*}t<`c|KY#asYN{vSvsqQ_7W_b8c%2y}M)A6- zmuQAEQY$5m!8-d{eRtMZrqb-YJPc5WCo3deY~{}`>yOgwH@m;%;h5etsu#yqTZ@2C zPDzzy-plnpUz3;78fcqqod9Pk5o=$&bL3W(>l_D_A>WwA(Z8(0y>1>?)=`yaY#Fi$ zx{h>J1J$Sf>3!uJY|!r7eXW%uuvJrQv)e2gqTcSgFTV1h7xbP5j@tJ9mTj9Py19Oh zhIrzs$Mv;d!?P?29Fv@6Ku99YvBzxQ2y^OrKC+>SBMe<&D;#%id71lA1gQi z0JKdny^g5anpw7;`#p(!lp~2cJzO(4S+se6OOT7z_MSMLKq>ZZw5XqDIU`8YKF$tJ zXDuUcy|+p**B1jrt&cdTEDG4gs4+e7Y&k)fD(%zb&y(wSw3(G5meNa>M$8p68{b7a zZRYjUtc%NeA8WTuJ#JSHze~DrA2HchU*z%Q$C$APpjq5>mkdh#GP^v8l6dPYUI|bs zDtLxgwxZSEXN|g#CS4rl-NXX6^cALcMk9E-MKYEtz`Y zX;&53@p_icEkl>#ptv>FAzas13ei$%q7{O+g4i)#&a zVW~J8+fd%N^(NY~)Izq^BJJbHkK!hbiUUvytgbbcrd?U%$B**r>X2(|tA#<uFlj=x5l}c9?I~fauTe7gqPR4SsDsoRdFw$P zv@A)X%9N#EnjWGJO;@W+YJ*eNuTiVhk*hW}DdvRMgI%Q!7Dl|(vMo+Oik8BVQ8mC{ zw$--Qv<+%b)p@Hy9kr!v29*;Q^wp%VJuyh4nN1a==b!}9Ub5T&*4>(Ynh0J<&DD-= ziPaw4xs`si3er+S%^Tz-Q0~U;E;<}u$3lx}h6LK9n5eb7HnK4sO0SOhq1Jl3LdeEP zXxSu5TIEr*a`x!FD9O6}z3eqv&A;LdaGPH4OOQ|k!=ZBAZ6s00(9iX?NZFlY4gg`P z)46Xag-a`ql5RM}frdyG$W)QLnPZwu^__RC^HZ@i64#JQA5t}3nDvmdV=r^ne_2CO zHcPY0h(X*kmCS0#F8g~B{;H<{Ar=ADk{CVyk1GEFx~f_y2*~4sOS;GvjF^rc#E^Qg zf8_mNTWBzMbrFKWtj*G5jcgfx3WB(|Q6oi=DI)@fP4okfIexN_XF$%-0%MJA$fxBz zTdciOeeD|zqslNzwSut(1@3U8e~0yyvYExaE>w#%E*^8b&wZaO7rN`J+8psB<`pax zpI(onD*MWUU=@oCvbf0W+v_u{+`Levse@z)VnvD9)^d`5d-J|{_J;yomcTaOfm(BTI`aRxJkcD@HqNH^DuO3g|rO70l&H$tVbvNT>C#CFu zxBDmFNfU`I9}~pc=6JnE%gO$$mHq>JUdt<@5I2F40L0>h(`}plKOZzuFTq_yF~Ofz$wO{0A{aeh70 z^f_qc_UqYs;Mw4lWoNTUEZC~BKzeH|{H788U*~-i16O7eJG^AgiNpYST#=sfWPexG z)R`4_TtJtiMSyhIg^@q5tF~ee%LI#L#^WP*$TvkB4c}Sg+g$PF!Q^7hfaDU+*q-L~ zw=cV^o$!K?Jd9grJGOQMisU|>?E3h9WqI4V^H|K7QGGPl?vb}?iB<8 z2S;{`mygjt*|f~a5$!O&#EiCkk0%3iem_S~OqOB+GEVU!mz$JR$;XdRkI2`xvHTcL z(F^TN9awK94b8hz!Ls4(Z5yrkuCFU*SD1~O6;_sGaLhWgyi@HGfB4DC*V2`bW(ZK) z%ejcl)hkKGHrFm)bh3FbtlWJZ_%Y`ECGg73QDbD>vx{7nlZS;jczVpf@Agr%#}31@ zNj$5Z&((wlxzAA-etBp2pSrIWEuINB`5tLPh}QU{ap%C8?2NH28?F4G7v6O-^^p3( z>_k)JZ?iO|Wlaw8q4-~DcX;7zga{f~kX-}y zo9|D(UVJBKlV}^lOkf+8=`F`xxc>ml_0?|Evm|okE0t(nmT8oi-gn}L6=^1ud$_}q zm-DnRY_Uv#*hv5}jzU0Zkz6dYXEO3z8gAEj1pB{r^YoJYbFPzPl@8G|Ok^=?v7idh zI(g;Qyj!mxaRufVk_h1&KwAd83|A!Xb$PtrJ0|!?D5Y40nWsEaj!k0#=PYl*9{&J& zJ@lSu--JxYZ;WE|+{U36LD7!EQFz48e_MABg z{bw!KNBM9|6fVyc8$wAOYD}#mIY$>;B5b>)blP4YUoP^E6;bUewL3|(&xtq);bqCA z-RBZZ1bxqqN6L8CeWd|uHf6gifq0dOjosU#nE7Xy^Ht?R9Gk-X;|T;=l<`S^ANFYD zUjG2yQ>Fa76tu3qQAc3YWxSKDiVfFOQz1;Zg8R&`y<37tICl@hgHe>UQ_nw z&Noi}13V;dqp&%y_}J(&JdBumOyG`>c_G6f5XY6RZzwR$z25K1Q)5Fkk2JR-LB1?F zBT>59zuIPa`6u0O{xQpaN1kzupnP!+ZW2ww@_sx5O+&>tBxv@nB1sg2-DY zibmw#X?&eW7EsEJ$U?%`z?=5)rx$*n-4hwU3@pkAJ^Q$~_@CwrvdJTa9UM*=g4h)q zZlsD7ecDOu{4|uL=`xJvz@2eo3k5z`dEb9^Jgy$I`nr-t2{R@( zRnP&ZgP`61zO32S9VWpQoZl0ZDyiYz!i$TiSMR^LZ%D;Lbv6?y2&&@9#r&E}(EGYA zsyN`8l@3v;11nw2TQ|+BIYc)QBRr9-w)R)KEtgj!w(of!@jv0eIVYN}rc~TSrB*8< zD=-I-R`|Co9`meDyUY16nf%lP2&5aN1inBj^s-)){{Sws5eF#GaHB|+E^f{;;mhlO z^U2U4jiX|+0U_Z4t0taSZY(~plI19*k{;^-i=wV9FQ!qXabxAA@}Hje&H6)-2^m$$ zwyc_{T|(yW{ztjH_fh)UC}c5~V2oQZ6DwRB4hh~rrziV2=P6EO4DoOm<{@m{<9XJ@ z$#m*(jAE{cSb{YNqx7-lxldK%?&h@?l*1~hyrU*7Vny;RZQN0`zjwR(x^r!5Cg*2? zhRz@gVq!IrpSG>p7;nmj$8@7mrDvU`gaN@@e-t))y z%+p9EYc?jlB{{Rx5F=v62FiKsP{faKj z9`WPy(&4=8_dnKEvJIVHIJSV*C}owmvABt?$R>-4 zZ(H1tdCSS`cfPJVjzl4(g`*+%R(T{WFT-@)#U31Tr&qn+c{;H+wT;BPE4h_e?@K6* zv@RFo+wG=Y;`qkT4_P0U{j1t`sHKqZPEPHE?ID?@F}U6CekkaVA4{t8zgJp_5u|gf z#Uo+4BXuU2N0V}kGW|ZPI>RHFR}+$nvMNF1e7Af3rQ4^oJqxza8lYuVTnZ>8q_zi> z+IF1dkv5-A>|e@Fz7)XTZp*1J?`>Ouac*G_6%qf zX4w*Ea(11mt}oGWk1EM5ep6@UUO6k!+KfrCNwgz8(-l}utaXwnJGO=%=LUUscnP!w zyxakn_E#*~XX>Fu{P!vTMYU~)MM4-7;DOmE0oZ zbS|nL-ub+7{Wzig%||~L%@YN?mS(ehJ;%!&4yC}AvL;z_=yxOas&=5YFzh^qLrwprH)`|PfJnb$?`t)($gqc48#L% z7f#<;{y(&8GBg0>z}Q%ecAr9%EM!Q=z+CZ*Xr%WlsPmBNe0kSPpKe+iN&twUF_4TD zc7^`{ulnD$Dp+RFUmI(2(xY81^Zc~kn`3vgNWppT6y5p#9Cv%q)z#7LF$HW)fro2b z9FN)lv;3V_c7;`oGXZUGM{Q7oz#TyC_I0Zw*pT*{5=xQ=s(LTHi>u{2&E<3K!3e;% zT3lO&ufc{%S+cb#5i#KkEzUf~-iMA(n#7A**3#sT%l9ipA~fYM7$(+R7H;bm>z`v+VeQ zqIc{s(2Sn$;dbdfn&#K8$|?g;m}S(CrDz5^d|A zK2PeR{{Tusa(U9U`%XqiA`}-9DYRfO574K@vv1TaFf&F?*PQ}AxFz0t30}Ng9?cN4 z#;k5wma{fTy~T`4_h0K)pjeapD}BE-ZRLOQXv)7BU~Sw_G2+8~+<%&@R0L3n0;m>M zT{NRMzcuKs?BXesKow9U4`wCHbebniCMb{Mt1WbF!nu)#>qyv1GdBi{{W}Q&Zc05 zIT!5&_FJRf!|J-Qc;%CxH?Xi#(0^lhMu^DEV~7Jw`(4z=nYBq1 z?JXsNWmC$QXD&`V?dJD`r@Y-s4vvToMMnDc3F z^SpmmBXxeJk7jOwbIIPyd#*vwN5jc^N}A-N$hI8zAoE*=9l9NFkClsdR3!9)11Fh~pv04i~JAoIUk`^xeonT)J%V`HU~+z2?g zZ2th&?EAla@?5g)*eQ)!%Nt{pS=)iQpUKJhyr1UmQw`XoF2-O;kC;Vn3$J%}mAgu^ zP9Ij%TwB`8o*Xo=B$Nl_83iNX42p9*$|$;nsibrbM#FJ~1n+No>|J=nl3? zTiE^Q=tNU(vheKVZ~?r4fpz-Tc;9JxKYJ%t6g!?96J)%LEt&Nj#%>nFy}6I9i|Zw* z5wgZ(nHXPq;~<^Vr=pp^A01mO%QSBqDx!1?x#pFz;7@x%k?vSMnN(iPNV@~+Nhd3t z`rY67kH%SYGKkonRy9D*tj+6l#s2{QAL}YbgqdSq-H^Ksg5Mx5mOM0^+}vK`{Bsk; zF+Iiid46)Lvj-6&1BVe{CUhTJ8b=H}&sit_D~&lu8>ZczGmzN?leNo3v*=sV{!LA-FixrTM(KZ16s&{bXVSe++Tf2+V9Ty zjvz1!Kn)>1DLst+e$oExAH1xd610+r5*KpJm`5qJYq}hpa7XJoPu}@TZpS9i7(3{U zk%N+ve0Y3|{aSLLSyRj$&1p#{AQMy7s&adn)7AUG3&$KRIr=~nc3v1Q5+|Wg~`7-cHXFN9tgYcP$jf8$yTLaM_=hpOj=B zSh2!M_s!@^lqMMCT!G2#OUT_>Jssxdxi|j+*GN{Cm9~{t5XJ;>?=NMyBbE@y{?VVj zll+|`#T;Z6XK3+x6tKRmvBvOh>;BEhv@%EP;Pso%j(15$DH-?Y90rl70y_?lQhUn6 zcBF0OQe^VBC9T4rRPX8gy6HBl@^LDz&VX1;A<0F8^*{5n{{ZayrV=ohCJzue$c#%9 zey!$CN=N?yXphwH{=d<5xKR`(8a&q-0W4W~yIv2ABKIFYwyhDyw=zkTfZ`DPLyK~r zKgtm#^T;I;#~p_gIZ<`L1kDEWDB$;B?J|4a_hYAYk|`9(@;g{GiWDX-!8&aDN4(xi zv--IDulSEl=izN$Sma%IYp!h}5%7c2cdtKv83NsFKOpB+j2@b7*z$YdS^oetK+_53 z2Xrt)XrcqAKl@JEMaV+Yf|X`e3t*<|mKj_IZtfApk^Z;WQ^&Q!|L_BJiA)JNr$ z)k7?Q+YZ$_Nhpa~)l`tuoio!Y@`rcHM;rR@jPCtK2IkFBO(tGx08*%-9%j>lQ`oy)C|set8^J5SvRRB9d7z*W%pIzALBL%3t4kp z;F|e3n+d%?d0I!W##zSF8*#I)u`t#(zX~HKmY3Fgzt}CbZ5di(%(pC#{3j#j5&WzB zudAi3n@-WJY}Qj{Vr+_S%H)#o9%(-6eoy%h*`6?R+N3fvkg={sw`ETEd$T@G=z7@w z-^p+1{*`3DKk8JAv^0Sp60Y}Uw?&tv+y1Izti5$*Nww92wxrQS1qh-P zG|(u3tQt{JU24%>D8Qh|%U00^t+fqkHI!DYwMoFyiU25uL`Hy`+t3qGb@ieKluZ(v zLBP?{RjAa1LsGScDXFKLl+{5?HI>G)^`=}+bqPAxUXUSMY75bhY71*CYbs14fQn@S zsTOERqaudXfSSr8D=kQMb?YryYFZFA*;4DTQqh;K4NAReYLS<$CWiHzQicr}G@2l3 z6+*R8jZvxaDe;<;)oLr!)0&EnH7JRpHPiASQvJHWo^v| zZ32jSs6}n6CIjcA1X6z;Xf$c5U;o$LsWFP{c_gzaZ#%yfvU(Jani#YUmO27SAiH-X z;i`6crHGh7$|4CFUe@(_$h5fPZ;dv|a8Xw{i6s{s^kZ(bll#x?zEZr^qh^o;7dcXG zCwVyq^}arRz$8WYcrvIZ2NMP2^nKlE0@ygiB5^X}8pG(e5_tPVc91Nl!YMcti&b8) zl-{GKPyOGm-Sx72Ilh24sTwgB&~8sw);9w#hw}Q$f?ej#D@Hhst&w*8r`@XXJ$e%b zJ2le5)D;SK=lfkc@gxlC2vrc5C3JRYp56V_nM9GJr&s!!JFW6o?_zWSqO3~Rne58o zKbDWbEKP@VV2cjzmB=BS-Hc=CVaEPevMELk z@-Phx1v+KCPVV+8Qv}{sWKm2zmW-j%h1eH zCQ6IDnKf9?GY?gaoOLs@GlV)bb9wUU)AA%1NeqEcC*ESCgC=xT&#K-2v#YM!%tq3u zNRgMOuGsRMrCk0mFRk^JxGYO7S;uExM>WYATZX2Q3}6>kiwjs^h+m-V9b2ufdjZ?# z_V`ul8%SWpfnjo2E0gW7kMyjH7qhrmF_XZ=30NP^NVMq-;#8mw zLnVQ+Cg0Mg?E6q)V>~M@?sOnX2Xi*^v_eZ~Hz?;I4sOTVIj&x^NnV9p7Aw!4$@VYZ zqh&;36IM4WtVoV4np2|#cWi?leB@gnr&2V$(UXG|4XFU|m8V5E(8O}@e>KRm)acE; z$GL45-JX;=cj>kS0-RoJX~cV#Y@L*w&QY+D0O=crTW>ip%c>Sckx>B|F|bj@&gZGY zd~T`-1ynFSr$R18UGRJe##NO`dwQpqZC&3Ci*bHDZ9-s=CSr~lm14^5?4@^W0yq0SP;9=td)IEd-QNh0+Y(Kq z;RT4Fi#`~%!Ac1wby=lI)86kqtzd=?k7*IZwY)))Z*v28?7=Lax#IZDk$c_MF{Rqj z=M>SjHUcu`jG^S}CfajpcpsJPr2D@no($6}yTvHRLe`bm%$Di4>MT>gTgKlk*!1xd z&chw*ON@vPk-Tfby*`}2yZ->QOWFp@6M)e(-q$@H+hy(Yl5#wr@5$4*(KNBX^$csG z62~ZucV=a2c^LZ5>~Mdy?>AEWxn)V?+6=MEuG6bBO37g&uTKc2UEEu})oRbxkR84p zUJ|VoEZB&0F0rWcvHBAF%6h-YV%==;q#*FHO?R7xXxO`_nwgdWRr zWX-&fkFD{-_1Nb?UNekua3o%4Tdr~MIOf^-AKmY0=a#-13EEp`Ualj689NmLAr#alDeRPEDt&kJZxlDFhPlnXM8Gj6s`;t-D#aeVrSWm+Lt% z)Z3eG*xCnC8$7N}nV5jr9$VJ(`@PQldHt8KEA$d1(tu%wnkRC+I`MS(e)11_Px#A6 z6YnD^VO$!-;gjXL-+Q=!TPGzQPWxWy=d;dCG2Sq3l#!&ITbSd{9OIAw0LSG20GPIc zf`&3S;74_mLcue0gmNT5S@*tI$x$(YxL}@rvNV?JMa|i;#TU4{&nty`5^UkNR8YvV zvpJQ?;tC3GqMz#|NW63T(i4H7YHjS2kQD*ST!8-oiEMd4ipRV1e5F~rUAjcsP#Cwk zB(k8%O9gGCcZTzn>N?LSi}Td|n`(|l z-bxJa8+5>&L#vKPPCuddyqDoEr)V2a%>;_YcAZj3IVMa`9?!(FN98?l=Q|{eWR@A7 zgQFr|Ga@jz81fQ&d)?e#@A8SB^OpK9{Jye&b?Oe7v0C|{I{+7G8S+_-tI{th6ec&)$G!G%F9OlBe94g zk@=o!E{e=Gjkvcf@3`aL{bg+N&9dPhig6b}2Qv;O`hK#7oJ$;QwMZEe$&M%ylr8!Z za-5#-BuKY^zLT+Y%w?1@9hj&Z;J$-R<$khH)%S}gkPErCO>hoDRB4y^bP_8FuH+1E z#j*>DZWUPL@5?XlsF6&P6wyHGUf2Ftl*99MV`GOlWl}&=!a`Iss2y`}`hRyH`#)P* zwrrC7Qb@(3fQ#vVkZgEUtnt;*GdR?RgtnI#X4kzVy-_Y8Aa*e@zXh)LzeZee;QD{) zsM2kweRRZ}5S&eb;`{Gc7g_3kV3869Xps~f3o${`CjD%_p&cu?4J5WHin-wc=sTQG zC@{|`nnl*(X`Cc%P0=w#DG*PCqgAo{XHBN z4Jxr)u0dBGAGf=~A3q)5+{Un(BalfVlGvPv?mj)ASk$mT97W`%g6#T}8R67soS(7d<(f}~<#{=FB!wuXgN`*) z$+U~hydp^S47syhok_NeKcQ3=&9uVyC73W$ro#Nkf-T=Hkda+fmB_io0yP7<)%5EQ&Y(5y3TRa6#e8@Z zzwbR&XgRj+spV%`7&Mm6fc%6{SA}Zq(#UQY$nAUcjE|{K{PMjsWsX@U0Tj6>Hc{;O z@!q>?5#bEQ%D4oyE3Bz=*2L+_{rBd+N}_gOg4T4%8@k@$dDqFUcF(Y)tW(HA4JDZv z@Zgy=$?Ig1DB;?6jhiUrIOHuMRRPp`dG(R|su+>uifQq6V4$nmV|P5B^`vhp>f+@+ zWm0mj79k~YP{yMtjw7f*c;bp|e5^^zOCPS1yqymf?!?^l6L#;VU9`692M=i&*=|lu z&E;k3`p+A^+>HQbQb|*8ZXVq|CRU^z0!IlgaLsH~bXwgmR*8o2hL8x=R>4O^dJp+@ zlEIE`cW;*b-aXuOVY^8>aS| zM`AKJ?qH54#_m(N=Lj)-Lld%^xqy)BmDRAiJJ*7 z*lV4*i5f8J=Klb@uRUm0c289mIfF-1|DZ3)^F=9U=I#lxoC8G)#n)bUN zS-z9W`@P*_FTxX`2$91B5O3YbyPV!v?y80OHCIihMtfL=%SUfo@v85?H(0yA8q>`R zCUB81^|igYV=I2Yi=p+E>BaDLWGhHSfQ}f`fk(Bz>mI93ofR=v3YuL&)B1G6nGSHK z_P>l|recbVVuUD;F^Ua$QSdLb@riu3JgtS8*beHO-nu2mJ7iCu^|n$8KZKfeE!ja> z`Iq$306R^qGwiWHg49F<+o1DvepM-jZO!=TODf#kBW~RX{uR#47zK@iHd1@XrS$0= z%$tKEq};gTTmJwr3N0csfJ=>JItdu!7wBcxc>F06aE4xh)VSAM@o@C%Pc(c{aY`mQObC{{UCWZE0$xjzcr!r!hJq zJs+nnNAQN$BaH0ug&~h9Y zCsFiQ^_40((RWuM<^jZZS$ba>c{+aU3ZeFhu)VBUE!gojy`NY99VD<&_cm5gEC#Hj ztTmRp)i57vqyGh0E~{8952I%kJ+%;*s&pA+Ncbm`%mRJ zom^Y5{L4ttx+27~k`$L3;xp*+$xCnB7u zkC!zWCgs#h{5f#~rz5AT`p+vyk;cl*8ID9Wjt%arI*(SK_jf?Z@j{Wrz)h)QwmYcE z;^N}#AFSFu1#J67aO}oAWVd1)(`M@z4h`NOK2wzGW`*!(T-lB&2)uZ<9d!g|c{YG; zaTUS_);q7KhgTof&}|YcK1-BJR}eYa-=`MLvGPgvyEj?or1TKZFad0E7jBiTcyZmy z*Wv1;{{Y8G-JnPqNin)gHBfDl9!(RcyRO-`V(b-^10~2TNgr?1-O#F{y_!48dld&D z8HRhk+-n~kf0_Eu+9sPDOS}OMm4cBZpLE<1#pCVCnXq<_5Ue7|EA&^N7snnsDQx>P zjLdHHII^)Agl>!O{{X9|BgO){rs)>p6t_M4*?9h+x~`rmLUIp-NWAT@4;NIEa*3b( z+_J?S8@iC8o!S5lWReY}j@Ns7UQ>DczbQ&}kU&FlM=8NWFk$N-btkL!{ty6)ts6{X1R-B1=?rcp_O@r8;N*9n>vIp3_m#CZc1@|Ga;bsH$dQX@mvehE`=;vp zsQ$X8VLXUc(8n-qF(W=7Z!gu=ku2yL^R^=h<7Q{aT-iP|gh~GZ@gHC4AneS{=!mRX zmonIx_`IS+$3F5ttNtT-z4vj^FUUR!9xaubkVeAbvzvpF&wB^Z&z7b|v+QXV!z`~Q zf=qdslgV^__mZTtGm@J_URPndL;+w1`D zc|BM9RL4aNZ;Xa>3*C^|2}U5}(369w8is3+p&1P`LFiND`|urWZvd zkZyDI_xSYEg-T~=%JEn$HiF)}i^uV{2zGF+Os>vS>=`V!99(f@%S7{k_BLwQ6VT?d zwpb->7DZy$V5%^+yYhR_bxC+t5S0W(EqBloezmydlj`WD+@UfV5I1>LTJCy#%=eO_ zda5{5R|@PdfEyAoim=968V@HQ7VDDbJg?HfR1KO!s@g6HLY9!)B~F{iyH9)S`D(5Z z`&QBGXHIP99fz_kOa3+U9T}*-3ZH zl#)QEMB-d~`<>)(A*{q!FdkM#QHAA{CxqvinSfVHq0cU^f-Q>u^aI@;bD z+-}^uy*Zvb_x!Imva&Z>27s$CNe@kav&a1QPVKsoz zlUaH|>4@Ewo7SsTkE}E|uDuO5^@h5#pyNYoD#}nu6^4k4!YV)_gwdK3DL^%rq9VH0 zAXj{qwyYwg;74U`4Rt7YtrQ}(R0a;UL}$%Y3Z@^L=g&(uVjGG&1s{TnqUsa)=&wRe z{z{QVfy`}w%IuGu@Ygzn&L~BOmbd7r2seY{$L%m_1RLq6j~+RggmWH$9z1CsoDCH$ zQP5jjC~K-afB(?4b!H5?TpvEH4@-OWGOXG*Qp>Ac+>|!dFQb>IQ;n`Pv}D>_l8Ist zBf9cQRL8U|k@>HskDT{>f8!vx8mq9)h1$oM-htWKOCsWnsUur|hd{CjRokJ4J0-QrB=`(f7im#JTCs{Hnmh!Q(sA%C*7|4elaGUFVp6mXQr+uFY zl2aK)Q-A>G*te{c)}#44a>=`8T(U(xLfI?h#UxGjzpAFh1ndL=$fV@a1}du+yuAE4 zd%fTOONkiTV8S~V0NVTmEy}cQzxuup^^hQu%uOAx7a*)h0l!0MXY%{Vw{q`2Wo&a! zNQ%WwYpRZQQh5IAh2mI~l5rCBj`S50AU#^b{1&NRxL;=B_Uod>K452p{e4(^Ir?|7^bIV>&H zYaUTJIWMm9X`Ew}k$_b$LZGvHYLQ9G?<3&lYMv0|f;K{+hGi!SZf+TTljYWWnx0!r z5C9z3`L|#Lxre9kY0I~(6kta3#lbczGV3DZ@773)j%0<6*k-`rKY1Uay7HB-Ke~CY z8At)FK<4vQ#beCdLQTcGHXSOzX;ZrlF)?23SBM+`09W;&_{yPfF>FEUFu7ZLTfUzM zk806iAl;d;1Z!?BJO2P}=pB%gh#?5FU#0BvEo-S3QPJ$u<-I)CfTfRST-fRkzLyUN zkHIFPIo)ocj^1xZjxxvz4Qy@=#-7XL>p18XvZg9-3AtZvcfQ4KsuXkoK{2K^)6dkTxNUFm7g|i}0wae?}j$PC>1Yf6h{No$=ZumStR&Hn(Uo*tIYyCB$O zX?AO~5+W)Bwb=gvy?eQRU#z8=tmb9)jJ?;1kQO!8_m;M z^1nlB>Y@0Ll-K2wtFqVPREt39GQIkQeD zOP0Jh>$ZL8DGc%Fx!OAKbIEg~Yl1e1djxBQYnv;!1S556XNsrPh<06` zX|mZ%N~|uSw01rG(Yuj(dPI;onZg8i0E@+85L-S;{{ZCBzs3l$uu?dZJ{2f>JKP(U z`o9KN4jGk0V%F4xm&c&Bn$&uuV%icoo!N{o0~l0eEy?Ahi5sUc3}16V-w3PTJ*Rmx z-Li?Ss<~qhA6Xn+ru)fM@EKxh5gbD(1CHCpUrP<48c+6_ec%2)Dz?<%k^{L4?U1#I zZaH78PK8qZEi0XNc8p{aDIBgWF2@B}*OYkr+~egw_N)jUJ}}y6S+EjjS1flOOuyBC zD^l~zHXRIq77;9LskBH#Z=@kadMuM{@cMmsS?eiZ3u!Az&RTHUAw#9eRZc;@KV3`P zwuv^x_|3uH0CDK3OD|WCs@fiH&s+WPJ*92g0@u?WZ^i8UbzGE-*$lhMwt zG;(8~d#vE=X>t|WwuRZyjns;xGBn-g@tB?S=sxvmwl9NZVuVgMd5Kp=Zaw0+`@GE$ z)_TvZ`P;SZgi8$hSnpJCfg8>3`CbWs+241%2Flp_%rM=<9X+6(Tt&=EWGr*VZ@W+TehlQB z-bvNg3pnLxb~$7^X5=C$-S1!BqH=VnR6vb}@-xoIO!Zjzp2R#=X6_kSBAGU64S`hS zT}K*xYDem%!aw_z?XP24Hhtzn+EO=$*5OX_F9(v|L*%3PK5C>YVT=%RkVkWhE*TPd zHd+tx;8_uq*hhvsjS?6V4odXrxB4E3xUL}c^ zbmee=FZY|yN}2UW?GfzKY~pBIL`OTQ<73l%@FQo`!@7(0l`PX_a(gy0?%`PUR&Z8E z@$Tw%W_Vc;gpM}N(By1K4p}!R!JpRoKUI!KM@_86utWm+%%ra&7YpuvZuj$Cty=g? zwym6Ou_|Spq~$9EV!an=F#CU*sFQfo8AB@C`6>BFO@8mn(&unwDHvt~*^_=Zw|_o{ z&$IG-sHfg;>XOkUp(V`R28#vE;-e|Iu{Q7exZ~Y=u9NVBP?*K_0FqBXvh=AC8(QeR z!ukrOVRec}cz1g?0PomboAJdg5$_oiZL2T;02^eFlKNn=OQN!n?gj)aV`lP_P0;$7 z{pVXx1fycb)j%T55Z?@a85QZJW5ALwE>n&5Hkj6vnm8HmV^S>N`p4r6XvDIChyLQe?vXL;!NP^D9pBDHK}DEJHBM zAl(iOpOEW1sbu9n?0h77ZSxA&6p)EaQShl%a3-=*Kc~Nr)-su$pj$=F2&Sd=Yyi{4{~4d z<)-ZMp`8mvCCLaG6~z`Yy4hudDEVFQC)IDdu2bkwwlrYJrIrCBGW3bfc|MmZ_g}27 zMP6>$h*Ilgx?nARGH#lEikVfeSdeedJ1N+eM=xh(xhi>@=Z^^!#GX#2GiBupj@3c@ zh}xwOD|IDCo{&B7r$PWje6wDAk*-yr8EbCw7%r-S3~%XA!{gQYvsPlyqMZo0IX;h% z@`;=**M1f4zPGmzcD39SXx9V+*X?hyO5(UcZrsDP+tBg<0F}$JkeGps5ya(eZuk6zn>Hm?7cOg_VW;;g=@!v6##9-QEM}1Kj_y0jIX?YY4#N@-h`H@zC1ySC z%qfn`yZrwE_IgUI8BkQbnB8(YV!amN{{VN{{_FL%9VXcR4HwE#C6MZ5lT{y4v%l}8 zuR)jrh&m_(UT@8RXEub0cfU=$=X+dnfv}2~L8Q$FjksVd+$~1%^oi=p(iH zMmqC(DkYC(iMUgIo0D5oy;hgNU4)LLp$PCVt0F#B|~s{l?uZ2>=6iI(s4 z)gw8ywnwpOjFFBTEyq9obrjoF5m-E_5GR`XKAyyC-akZ#;cb#P3aVJ2Plq1*Wgagj zTTf8S9#${&*j}}sz#CiTItuys_+Dd)j0iFp#cb|Rrz zE%|*M{C=BZ@5{IB(b-82Cq4tEZsXK4d@z|^74Lg61Sl5iVc&n<*Oe4@0iPv-t`*Tv zQets$i^l4kzpvW2U*W8;4eV7Lg#fl^T=`kCP2`^UcShpnYkVf9%kFMlLnD&ykUTM` ztBE~b0owMxszSJi9U!3MmR^?eyqrF=ksf!jTM%_sTiek`j!qBH{{XI#8!YVC;w7$C ziup8rQaNPzQ1XAkw8>mrLcJvlE?jmgJuTyU=czWA%@l$ia>bj`U^9vs^X_m{*ruV_ zF}8|0>^&Y9!}mYhouX7|QUFzL(Yg8Y(1I3Rp6_f&;U35Rt>0 zH_@$E44tFNaZ*`R>7>hO_g6*z1;9Ga6%Y(<`^p{{mvo_MwT~SP!jaJczex6zM z)B2yHgjMRZ{dVBeE<^Jz+y54>b19ph#5vjQ;k;H+)k?HLiuan`f3nuhxNDH7G$~d?_ z^*5W!J@amacLabkW6T4Lm6z$#-I~}F#fiWh8+%-yUs=2GB&TC2@EJ(OwYdi472tOJ z+8zC$;>H0Hz2lLXvY#ls%685iMiALXfd^?Go>Axe+SgSKE&_!&CdGy$f<%#dSLLC3 zJfHnptwUs1j{^)Wa2Lu(2+h1zlYJx&{{StX9_~t95nq^PAvTSz!aLH=nDv!`42ZuN zb-{GGv76J54~?UDBlllj%Wo=_P(UydZfgsl=d;ncv|}h{5d(OV#JS>jev+m&(-Ny_ z+UH&4jXAiojT_ba%EQTCI2_#+6&)bHvuJ)x1d~7ciGJ>Zo%d+!7&swWh|z)NH<5QY z(qq-~)gvH%qyZKw7=2}9Ws#KkWy7IUN}CxHIVIUxh|FiA50r1GPLefH@q>0?2_UVP zBg%2U8_m0|{{ZWMMlg#KsZd8V9JFTFlEIUz-2Ss@{WMftnTx8c3;=E-a#Yp7*>-Pv z`^tEZIO7s9F}oWEVaA{alapU3CTUDjG)Tx5ON~k@o`LT?i@1JE$@O1T=P?waFgH;l z0a3|j3Ep?zua)}Vs+6C2WGyU;M2%z#rcBr69kJ^8IX~9WN|BZlGlfIKGF}*cG@lPI z>b-qGXH`@&#G$5Pe8J*S4^QZR_mt&4=dvN0UDd-+mhMNt3Weg%96EnpL6D<}V>z+c6aWPYit>`a8<_t9 zJyVsICT5ru8yr9YhT3Fr_jp?#4VE8$)^7T^>Pjb!231_yr5*Bk#9b}-f8J7}h01(p zK(b1tg1sd|P8dLo$@Dy@=PNeJq%m$7V=ItnNSK`$lT6)tY$f$lM?cp2!+!{mbX8GM z0Dv*Tvw1jSgLgjDCVsQ3@{)Oo41rABNqo-;Msa>7zx%r3+{3Nh)Z0X)Nb@5GU`I;N zyL3EGne;YpDm4(K42oLJcaLiZYn|MWjC`75)peV%SITp4%7R##GgOS5XB1&s)k|WD zy-)U!Da-jC->l|zjuH-YcbdvxQdU0A?&1jjLTC52hici*%+k9!&iE+0JfHqP9DBbX zJkrcIS%%C4Oc8*GFME`4<;{2n@>$iJm z`}mUi;GeYZv8g;c$(5rQ3j&2L#{N=Y8@f_VXG=vJ85Ziff~gvxE$Qd~04o(VCdsm~ zM3JPGO{6Q$vfeD8)zNTE1XArTOxC&s%Cb6~ zc+v#md40e9jbVT*zCtqE@hpVgjt$g#Y|+gB0C%*1ecbl6G?!(ZylMXX5esC_7DyG}!e^GH)Z@ecvHoS}=-sMU8fck0gl?63Z>|-^;UR@pt^E z-TO9(quK;yn3H_s>{+x?-Ri#GVyx1Fr7nco-wkMxMzAhX4U>DcT(r8{ z7|NcCac3s)E~lyg0PIWKQQnr~+5$C)1W<+nM|CFJ{C^v3%6*^U^3h$wl0dKrR0pNK zey@Av>9WH%N%2#Xp;r?ErL#}>WXJwiDQg5WO1r0Km?hB|3PTMZpBxh@vELt`vrp#M< zwD7V6gs4(bE>!a}`sh8^kCB36Qj*_2SYFf1zHI~?EW|&(_QRk#knp;z?dQ*yH zHW1VlO+agIb+y8>x}-bS6Ie|NH45TuD+!^b4zwBoYeh)VXoW0Msi2O! z)xx%d8tU6tRGmXbLTN!FhKQOHDonLyt81wNBCv`o!YV+ub*QabYKtg^a22+sR{$?m zLOI-e;z!3Jw|HPVQy=D^`-CI_$qwGW@EETnNkY;-!*1mhNNjV{>sq* z0Eg$RM`~^lWnlxdmLJDG74POhHio6P;%Jpka6S!fzIu!h6kjz6`Hd9?)z#J3puhjo zhDT+W4A@u(L>xO%?>N0ahkumyb#GJ-atwf};^sljvUmA8Ivf;W%*$yaw4BP_g5xP> z+9Y3B7b#g2Hs#~Fzz{&W>^Oc$yYO@o@vYT{C;?7zS24*QCGNoS5-hD79tBUUda>Y)Z$44kp%{+qrZ3&EJ=miJ8leb4J>2@t zDA%6Yq;o#_a>=ulaMRh>i*_p|Ef>o3Z5>Km(lZG!BBbbHc zbq%3VZ?6Ge?0;`QPO&9+SxG!W7mK~dX%{N%2Rv-7EPFI*vrKv%aro|<-w9g=nRpil z7`8p;jPZTE{b!WrDnPMcIOEM(jgSIkvv6{rhsFBp+4h-XnN@;H7_b$a&UQ40M{{ZXuuH8iQHgrfNQ7WUW5h@$r%y~Q+JqLB(Rgy^C$#cb| zR_Ot)&OHV7Xq7hF2FVi12mzWkxgZ{vS$ObhF{*57I&ieOB!F`@@@?nV{as$kA~J=Q zKruPeG0?f<==v4rZ7{M)Wz{-L2DSk8YTm_qlJ&N~4b)@VYgGG6FU7;!%x=VuIgalS zpll40+enIB)zUPvng{HptL;CT?S@`Yqw#hdi2pyCtc(zb?VXx=jWhU=;->ZHQ2WQ$OhBi3P(Y%C! zNjR(F(Zh9|o;!6U_*E9w4$wyQ*%4J#05J`IJQF^$Odosql#QEeSa!JPkOovEB6yCR z95s8&Q6mbYcXlfgyGDd^e3WuCf9(AeZIL8cpcw`QiEezoI&9wz%DP>t*`1@4AgTc| z^q;-MpH43;_4E(5fXxvM@*-Nm<)ItK{{Zph-SV48>z9+O`J+i>M#myk9)ae?43gBkN#NoBUysFht-SC4He(V>oKTa+ zVmUM3_|tbKPbpKPp>`Po9j-|>R>%AQ0Pywg?P{sBmW*!qZ-tYETmvUstMqwI-tT|q z_nm~1LlA~^i8q2XFk+zb^g2GT_g+e~aXaBE?7JZ`ykwPfQB}P({{VPb&#T?iwzS#e ziQ+`!bsMrz61DocqV(U^P@yi#7~tNqXJygJ<9@QUZXpubhG>)pkl#Kdo_Xcd@$(jP zNs31^i4r@=I5%W;uffOd$o~MEs^QsSj#Y|IPGU$1bY1nG;od*i?qh~X`pWW?Yn>Q? z*(+U{NETym{{YjkPG)JPNnGCQoSvJ!N0HsPlkxOvts~M$nrAYklFc6MGOF>p8RH+- zY_dbkdde-KvxSyN5ImU)!zg9W%<=7M>ObXqK397!PZ+dO%n1?zSP)3Z7TsWt)ACds zQ@Ghg8wrp02$uog{>&(9x;^ZPJ5zrDiIe zua)R|54x`wSCeGg8VFphYDLFm?>c+L>b|aB-uvEH(0m1s(F>VXS)+E0J%(5j(QhC1 zX2zFnn>wl#3fr>znIBE#clkb__?=}3!dp5tDu^#3xsR}-j#p@rW#}g*>pZPij9VM| zg(q%e@w5j-kfJrY@pj|!;biiaOGTU|Sf)XkDU*r2o9`f=zDkkh#^JA&dRSVv^YPUB~?!jZG#sSNBMT;$^3b&2D%%Gv>mA|RGKLS3@g7IK`dUh|dY_q*CO zpc&d(o@q`ci7Ktde@!CyFhj}Ls~5sBjhk~cR|F9nC4#b^swtDn@2kp*6L_(O3JI}c za>q|jO#K>m*9z?EvXv#9Xe=DR7M)9~CsoQMy=L!y?>!^l%Go21MU(`-K``1u4*4>s z?|mG7R?Qki%U6(H#DNjRaT3gQ&ipU5`w{wx>WF2Pwork}iP#4#rkKXqr=jEU4#%^p zd;=A=^%l7&zsf7|rXD*jttJRTC9Vrz?A+Q9Deo!5@?s%Vf)D`;7`H}AQ+ew!-q)tv z9?u?yccsp3(h|o10A~>Vv=JBI35hTSdPY0`wB4qCOFdeV{~;$R#nhz0^1BC6(8M zki(LWnA;q3L?LonQ~=0_5%hmKRAIe}MP?yV#3-_j2S~#2XG{rMorari4}FPy-RIyL+`#lXgUh5)}tNM#lcc zjX+g_A~i1RMbiAh`hQ}oa0-$@B{_+#iWFewu5NYLpHCh+t*t!mr(uqH427B!M!4CF zU#rSGf^UNMrDaInBukA8S?zxPeGP(~0J4Axpl+4>M{l26cF$|Yhv1E;g^NV1kszon zM{~X8xN4{PI&vw}1cvZAk-L^l8}b zVc2#kzzG$CoW+T{zyAPUyf1||S>e2egU2IiM@Sdr#&q?v?(6APlC~!HHrHNio^n*# zO|ly3HTHOy3V0f?l2O1pjkU1mV7#9dZ^ZU_)JPL|V`8h@Urv4v*tV0rXoPYy>hbVj zPgm6bs=jEG3&|QC5x7a`^n0{~@@z6kW=%!yi*&{5k!+VdZI$%0`To27I&vUjOWlr~ z^+l~=@>qO>*{st89A*)UDxkkI^!4?UKOIvHRu?J&CnjEP7p1*w>_7`{(O!z|@j}TM z8JE<%@eyIX;7+6UHI{+7C=}R+Zex2{Z`MoS!_`x_RxAq~4wn`s{SToPF_bJTD7j+R zI_PXGnMUnO@R<#X9hdYo%g5*?EHWzaDIi#6M{T^0baA;hEsEvl=lWb<)zDld?4SYz z6&D=AJb3K$({9=f4=u&niM54=vip6?lh?W^0i?t@xKZYBS(mJ&wBj4*B82ptgpu9D z(65@4;XSd5A@S#87fzAOALV%MQk~-x{A7`<5=g(HmqM-R$oD7sUuRo{ypB!8VDPWk zJx4FXn`GG)bcsQ@QKqBGwtR9_YKd?$9V>o}sS!I!!sf*og&+Vx;-X!fG-+t{DxV>4 zk-6#dwQ+2eF_e~2LgIC??oH#Xq_pjb;Zcz%w@<6$;PLny1ZyZI^#ESn!j?o!OkP%D z9-}}s=H%XweHex~w|QVmCCM2v_i_0e=DrudJ%{B}*epCzs0=N3JZ#oz`g`r?PcI$f z)1>m5WxRwOOKL#i4tOwU=vJ;et^^xi_r37)Etj`X-c2$)%OkzJm$DpqSCLF(@%a2{ zcDdck-oWXg?E87qz3{wEui*zLDvC77ofFVhu2zzE$7PMBJ+jC!Tm7!BD~xFYxKSgj zh7Ttwb?)H&kLu}iGs=o3c6GJ!Cxi1zic)~aJMQ|{<|_!rTM^}Umg_5VVxs&AIh6-z zdx?C13L!wm1I(*?X>MhzBULK$+pAK_qWE2U96xTWbY>}3h{A%s>o`-0VbfUi_CGl- zJ7@SpU5+^vpbH@;L3X7C2P zfo@4MDv^__t&^uadmdJgC`#f2vfkH1S2h8>r=LdoQ{yK#Hcj~h}+TVaWyV^r!n?Zze6DsZ9w7_s)sn~ zk#mkO?)W^P`Z=71RsaIpY2DSj6cBEbMkO>R#49d~@7+B=r$h*m^2wwKZn6=KYTk@~ zogC<;BSUhs09X=j#2320KaWhIB2o#ocU1)GZig4;)HaStA_Rm{rLkU`sOdb<_V^>B zU62MYR}*1m10nR^(5lkJbU?`5+fN=p2~Tu_Y)!F^$hMl>*`sGfMzb=!j6{Q?so37K zD1B!sLnD~_$5=#U2P?$oZIhi8=*Xc1tFU`cZbBKoSA%6WHD^b>gW8~0fA092+UIS_;@8MG~@C&SCZyXnpTFX9N0Zq`;M z+CuEXjm6_FAL~4&CP`jdV+?Tss;FKZcbuGT46 z?N4bC0dJp*NXm}6JD+nyC7{{1lvNC|k%(b^TNca5`o<=a7yxgQo0IcpRK+SG%iX#>vg=xJ=c$%w}OS>Xu_Rh1Qd;SrJp|?jE`wq-YuNRxxI=- zYs&1NQbupod)l{aa}-SJVCvRZu(;P8zm-!UzR)8k?=q>9W7d(ER;|@VAJoMi5}~%q zyi=3#(kP#Xuc;Zu-RKE7XxoEsGIr9W1xEtlh6G_n+(M zN=sxpj#%7S9C|;~rc@|4b7oka83+v_=f_S@c+2j;c~Z64YiQ+cL_}o~={}h*=@*Up zHr{Wm(2CJOLV#8_Dyn&`yd0-~UWb(b0LR6fYnpAQMRGY4ai>esSIfAWc|G6vo$|DB z+ocXfaiYO{6D7%2Z+q_^nF2_qx6(4JDYRpxW9c?6)^|F%y4n({6orvMV73<-0mSjM zs!5c2$&Zfl?FC?m2$53m2{yi^dRWPvN%wL80GYy@bB&o%St4bRoQ)Nn`8iuZCqo#E zXa${sb+Nb>#R10;j6bKBk)zlkU9M9kEQ7-rjo2DtMt`7HN<|0hZj?zfay_m1e~&Z2_Z{0--=43uIG=eE#QIJlL+ds2v_-6;weU2NR#e?=asR4rnQyo=7mCdBmr0PTCqR?=*fU@#Da zh+R3EB9A`%?&%9W@tJ1cGf8p8$I;+U9)|{Xn=X#kzs2|M3*5-J0522o7rXho^QPKH z&$Khbs|yvuih*%{vP54aEqW;Bb~CjOzCfy6lEem7ZcBL0y4~+ztN#E?mt@)Yc#i2n zY%(msJ0A0^Z*%K>r;AwLS+5!mmuNAnuxJB|l(pJSi*YygTh3C9nC6oKd0~WEM2-qH ziO}WSl#lLei64ElsyG-Z04TbPV;`d4Kb<@qOP46Bu|^qC#~28#FT3N%k9+UB+SbOD z1rTie=923+McTv{c5YH=^7>S%wsM_AIRF5$9#L|TVIJ|LMaiZ%>oA1WemM#rnVf+urCrJJs5Ak;slnz);w$EUdA*NxR3n)z!7G3`ruKkzupX zksG!oD`~v)olL&!u!>I1$jUZz8scFRvum43slAFxBxmE?=gIjxx-q`l1dr2LmBS8w zPCK{sR(;{?56eVrqhv)CQQqQjKmgU3%IM z+S-$ESWR)EH73xLSZgT_(y)r!%L;=zsH;0`ZBiICYf2hZ6Iv0%ihxmpLr%2~sOV}; zwFwk-4N@%g)>fW+%0sIwtEmjGwQY5(L9Vc(BC?X}S6a5V)h2UM2CTJZB9AXPzse{} zd8;j@2-cQ#(%1P-V0(PEmaG)58YCCe^u?3G(A49fO+IQbH|?lctzW^W0%|PTY&q}e csv;zchWSt$h#FWXv{2cV#`FZ*w)7YO*}TJK!~g&Q literal 0 HcmV?d00001 diff --git a/src/java/windows-unittests.cmake b/src/java/windows-unittests.cmake index a9aa84c78..c7306a143 100644 --- a/src/java/windows-unittests.cmake +++ b/src/java/windows-unittests.cmake @@ -8,6 +8,7 @@ FILE(TO_NATIVE_PATH ${JAVA_PACKAGE_LIB_DIR} PACKAGE_LIB_DIR_NATIVE_PATH) execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck + spotlessApply -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -DnativeLibDir=${PACKAGE_LIB_DIR_NATIVE_PATH} -Dorg.gradle.daemon=false diff --git a/src/leakcheck.h b/src/leakcheck.h new file mode 100644 index 000000000..b71161d9f --- /dev/null +++ b/src/leakcheck.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file will track the number of instances of each type that are created and destroyed. This is useful for +// debugging memory leaks. To use this, just add the type to the LeakTypeList in this file. Then have that type +// inherit from LeakChecked<(itself)>. +// +// On process exit, ValidateShutdown() will call LeakTypeList::Dump() and print out any types that have leaked. + +namespace Generators { +struct GeneratorParams; +struct Generator; +struct Model; +struct Search; +struct Tensor; +struct Tokenizer; +struct TokenizerStream; + +template +struct LeakTypeList { + template + static constexpr bool is_tracked = (std::is_same_v || ...); + static bool Dump(); +}; + +using LeakTypes = LeakTypeList; + +template +struct LeakChecked { + static_assert(LeakTypes::is_tracked, "Please add this type to 'TrackedTypes' above"); + + LeakChecked() { ++count_; } + ~LeakChecked() { --count_; } + + static int Count() { return count_; } + + private: + static inline std::atomic count_; +}; + +} // namespace Generators \ No newline at end of file diff --git a/src/logging.cpp b/src/logging.cpp index 2107e7bec..fb3229e66 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -38,6 +38,8 @@ void SetLogBool(std::string_view name, bool value) { g_log.model_logits = value; else if (name == "speculative_decoding") g_log.speculative_decoding = value; + else if (name == "ort_lib") + g_log.ort_lib = value; else throw JSON::unknown_value_error{}; } diff --git a/src/logging.h b/src/logging.h index 428be1b26..a29cfdf25 100644 --- a/src/logging.h +++ b/src/logging.h @@ -43,6 +43,7 @@ struct LogItems { bool model_output_values{}; // After the model runs the output tensor values can be displayed bool model_logits{}; // Same as model_output_values but only for the logits bool speculative_decoding{}; // Log speculative decoding steps. + bool ort_lib{}; // Log the onnxruntime library loading and api calls. }; extern LogItems g_log; diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp index 91ab17c73..84c8bec11 100644 --- a/src/models/captured_graph_pool.cpp +++ b/src/models/captured_graph_pool.cpp @@ -81,11 +81,11 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, auto output_type = session_info_->GetOutputDataType(config_->model.decoder.outputs.logits); - if (output_type == Ort::TypeToTensorType::type) { + if (output_type == Ort::TypeToTensorType) { new_captured_graph->sb_logits32_ = std::make_unique(allocator_device_, max_beam_batch_size); } - if (output_type == Ort::TypeToTensorType::type) { + if (output_type == Ort::TypeToTensorType) { new_captured_graph->sb_logits16_ = std::make_unique(allocator_device_, max_beam_batch_size); } diff --git a/src/models/debugging.cpp b/src/models/debugging.cpp index ba8d1b55c..cdca66869 100644 --- a/src/models/debugging.cpp +++ b/src/models/debugging.cpp @@ -7,32 +7,15 @@ namespace Generators { static constexpr size_t c_value_count = 10; // Dump this many values from the start of a tensor +template +const char* TypeToString(ONNXTensorElementDataType type, Ort::TypeList) { + const char* name = "(please add type to list)"; + ((type == Ort::TypeToTensorType ? name = typeid(Types).name(), true : false) || ...); + return name; +} + const char* TypeToString(ONNXTensorElementDataType type) { - switch (type) { - case Ort::TypeToTensorType::type: - return "bool"; - case Ort::TypeToTensorType::type: - return "int8"; - case Ort::TypeToTensorType::type: - return "uint8"; - case Ort::TypeToTensorType::type: - return "int16"; - case Ort::TypeToTensorType::type: - return "uint16"; - case Ort::TypeToTensorType::type: - return "int32"; - case Ort::TypeToTensorType::type: - return "int64"; - case Ort::TypeToTensorType::type: - return "float16"; - case Ort::TypeToTensorType::type: - return "float32"; - case Ort::TypeToTensorType::type: - return "float64"; - default: - assert(false); - return "(please add type to list)"; - } + return TypeToString(type, Ort::TensorTypes{}); } std::ostream& operator<<(std::ostream& stream, Ort::Float16_t v) { @@ -40,6 +23,11 @@ std::ostream& operator<<(std::ostream& stream, Ort::Float16_t v) { return stream; } +std::ostream& operator<<(std::ostream& stream, Ort::BFloat16_t v) { + stream << "BF16:" << v.value; // TODO: implement conversion when useful + return stream; +} + template void DumpSpan(std::ostream& stream, std::span values) { if (values.size() <= c_value_count) { @@ -66,66 +54,20 @@ template void DumpCudaSpan(std::ostream&, std::span); template void DumpCudaSpan(std::ostream&, std::span); #endif +template +bool DumpSpan(std::ostream& stream, ONNXTensorElementDataType type, const void* p_values_raw, size_t count, Ort::TypeList) { + return ((type == Ort::TypeToTensorType && (DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}), true)) || ...); +} + void DumpValues(std::ostream& stream, ONNXTensorElementDataType type, const void* p_values_raw, size_t count) { if (count == 0) { return; } stream << SGR::Fg_Green << "Values[ " << SGR::Reset; + if (!DumpSpan(stream, type, p_values_raw, count, Ort::TensorTypes{})) + stream << SGR::Fg_Red << "Unhandled data type" << SGR::Reset; - switch (type) { - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span(reinterpret_cast(p_values_raw), count)); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - case Ort::TypeToTensorType::type: - DumpSpan(stream, std::span{reinterpret_cast(p_values_raw), count}); - break; - - default: - stream << SGR::Fg_Red << "Unhandled data type" << SGR::Reset; - break; - } stream << SGR::Fg_Green << "]" << SGR::Reset << std::endl; } diff --git a/src/models/env_utils.cpp b/src/models/env_utils.cpp new file mode 100644 index 000000000..d4d14634c --- /dev/null +++ b/src/models/env_utils.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "env_utils.h" + +#include + +#if _MSC_VER +#include +#endif + +namespace Generators { + +std::string GetEnvironmentVariable(const char* var_name) { +#if _MSC_VER + // Why getenv() should be avoided on Windows: + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv + // Instead use the Win32 API: GetEnvironmentVariableA() + + // Max limit of an environment variable on Windows including the null-terminating character + constexpr DWORD kBufferSize = 32767; + + // Create buffer to hold the result + std::string buffer(kBufferSize, '\0'); + + // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. + // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. + // Therefore, If the function succeeds, kBufferSize should be larger than char_count. + auto char_count = ::GetEnvironmentVariableA(var_name, buffer.data(), kBufferSize); + + if (kBufferSize > char_count) { + buffer.resize(char_count); + return buffer; + } + + return {}; +#else + const char* val = getenv(var_name); + return val == nullptr ? "" : std::string(val); +#endif // _MSC_VER +} + +void GetEnvironmentVariable(const char* var_name, bool& value) { + std::string str_value = GetEnvironmentVariable(var_name); + if (str_value == "1" || str_value == "true") { + value = true; + } else if (str_value == "0" || str_value == "false") { + value = false; + } else if (!str_value.empty()) { + throw std::invalid_argument("Invalid value for environment variable " + std::string(var_name) + ": " + str_value + + ". Expected '1' or 'true' for true, '0' or 'false' for false."); + } + + // Otherwise, value will not be modified. +} + +} // namespace Generators diff --git a/src/models/env_utils.h b/src/models/env_utils.h new file mode 100644 index 000000000..d436bedda --- /dev/null +++ b/src/models/env_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +namespace Generators { + +std::string GetEnvironmentVariable(const char* var_name); + +// This overload is used to get boolean environment variables. +// If the environment variable is set to "1" or "true" (case-sensitive), value will be set to true. +// Otherwise, value will not be modified. +void GetEnvironmentVariable(const char* var_name, bool& value); + +} // namespace Generators diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 32d540937..4620bc1fa 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -13,14 +13,14 @@ InputIDs::InputIDs(const Model& model, State& state) type_ = model_.session_info_->GetInputDataType(name_); // If 64-bit, convert from 32-bit to 64-bit - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); auto* p_data = value_->GetTensorMutableData(); for (auto v : state_.params_->input_ids) { *p_data++ = v; } } else { - if (type_ != Ort::TypeToTensorType::type) + if (type_ != Ort::TypeToTensorType) throw std::runtime_error("InputIDs must be int64 or int32"); value_ = OrtValue::CreateTensor(model.allocator_cpu_.GetInfo(), std::span(const_cast(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_); } @@ -63,7 +63,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { #if USE_DML if (model_.device_type_ == DeviceType::DML) { - value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType::type); + value_int32_ = sb_input_ids_int32_->CreateTensorOnStaticBuffer(shape_, Ort::TypeToTensorType); } #endif } @@ -72,7 +72,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { } // Update input_ids with next tokens, converting from 32-bit to 64-bit - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { switch (model_.device_type_) { #if USE_CUDA case DeviceType::CUDA: { diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index cbac4bade..2c69c3cc9 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -106,7 +106,7 @@ void KV_Cache_Combined::PickPastState(std::span beam_indices, int } void KV_Cache_Combined::PickPastState(std::span beam_indices, int index) { - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { PickPastState(beam_indices, index); @@ -295,7 +295,7 @@ void KV_Cache::PickPastState(std::span beam_indices, int index) { } void KV_Cache::PickPastState(std::span beam_indices, int index) { - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { PickPastState(beam_indices, index); } else { PickPastState(beam_indices, index); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index d189dd534..ba09c1560 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -17,10 +17,10 @@ Logits::Logits(const Model& model, State& state) output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); if (state_.GetCapturedGraphInfo()) { - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { sb_logits32_ = state_.GetCapturedGraphInfo()->sb_logits32_.get(); } - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { sb_logits16_ = state_.GetCapturedGraphInfo()->sb_logits16_.get(); } } @@ -53,9 +53,16 @@ RoamingArray Logits::Get() { // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + +#if USE_DML + if (type_ == Ort::TypeToTensorType) { + logits_of_last_token_fp32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + } +#endif + logits_of_last_token = output_last_tokens_.get(); - size_t element_size = type_ == Ort::TypeToTensorType::type ? 4 : 2; + size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process const auto* input_ids = state_.params_->input_ids.data(); @@ -119,24 +126,27 @@ RoamingArray Logits::Get() { } // Convert from float16 to float32 if necessary - if (type_ == Ort::TypeToTensorType::type) { - std::unique_ptr logits_of_last_token_fp32; + if (type_ == Ort::TypeToTensorType) { #if USE_DML if (model_.device_type_ == DeviceType::DML) { DmlHelpers::DmlCastInputToOutput( model_.GetDmlExecutionContext(), *model_.allocator_device_, *logits_of_last_token, - logits_of_last_token_fp32, + logits_of_last_token_fp32_, model_.GetDmlDevice(), model_.GetOrtDmlApi(), logits_cast_command_list_state_); + + logits_of_last_token = logits_of_last_token_fp32_.get(); } else #endif + { + std::unique_ptr logits_of_last_token_fp32; ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32, model_.device_type_, model_.cuda_stream_); - - output_last_tokens_ = std::move(logits_of_last_token_fp32); // use output_last_tokens_ to hold the fp32 logits - logits_of_last_token = output_last_tokens_.get(); + output_last_tokens_ = std::move(logits_of_last_token_fp32); // use output_last_tokens_ to hold the fp32 logits + logits_of_last_token = output_last_tokens_.get(); + } } #if USE_DML @@ -226,7 +236,7 @@ void Logits::Update() { return; } - StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType::type ? sb_logits16_ : sb_logits32_; + StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); state_.outputs_[output_index_] = output_raw_.get(); diff --git a/src/models/logits.h b/src/models/logits.h index 94e57c355..ed7a281b3 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -49,6 +49,7 @@ struct Logits { #if USE_DML DmlReusedCommandListState logits_cast_command_list_state_{}; + std::unique_ptr logits_of_last_token_fp32_; std::unique_ptr value32_cpu_; #endif }; diff --git a/src/models/model.cpp b/src/models/model.cpp index 245966597..3e17fd009 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -316,6 +316,40 @@ void Model::CreateSessionOptions() { ort_options.EnableProfiling(profile_file_prefix.c_str()); } + if (options.disable_cpu_ep_fallback.has_value()) { + if (options.disable_cpu_ep_fallback.value()) + ort_options.DisableCpuEpFallback(); + else + ort_options.EnableCpuEpFallback(); + } + + if (options.disable_quant_qdq.has_value()) { + if (options.disable_quant_qdq.value()) + ort_options.DisableQuantQdq(); + else + ort_options.EnableQuantQdq(); + } + + if (options.enable_quant_qdq_cleanup.has_value()) { + if (options.enable_quant_qdq_cleanup.value()) + ort_options.EnableQuantQdqCleanup(); + else + ort_options.DisableQuantQdqCleanup(); + } + + if (options.ep_context_enable.has_value()) { + if (options.ep_context_enable.value()) + ort_options.SetEpContextEnable(); + } + + if (options.ep_context_embed_mode.has_value()) { + ort_options.SetEpContextEmbedMode(options.ep_context_embed_mode.value().c_str()); + } + + if (options.ep_context_file_path.has_value()) { + ort_options.SetEpContextFilePath(options.ep_context_file_path.value().c_str()); + } + for (auto& provider_options : options.provider_options) { if (provider_options.name == "cuda") { auto ort_provider_options = OrtCUDAProviderOptionsV2::Create(); @@ -348,11 +382,11 @@ void Model::CreateSessionOptions() { auto current_module_path = CurrentModulePath(); dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); - auto directml_dll = current_module_path + "DirectML.dll"; - wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll.c_str(), nullptr, 0)); + constexpr auto directml_dll = "DirectML.dll"; + wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0)); THROW_LAST_ERROR_IF(!smart_directml_dll); - if (LoadLibraryEx(directml_dll.c_str(), nullptr, 0) == NULL) { + if (LoadLibraryEx(directml_dll, nullptr, 0) == NULL) { throw std::runtime_error("DirectML.dll not found"); } @@ -388,6 +422,13 @@ void Model::CreateSessionOptions() { device_type_ = DeviceType::DML; // We use a DML allocator for input/output caches, but other tensors will use CPU tensors #endif + } else if (provider_options.name == "qnn") { + std::unordered_map opts; + for (auto& option : provider_options.options) { + opts.emplace(option.first, option.second); + } + + ort_options.AppendExecutionProvider("QNN", opts); } else throw std::runtime_error("Unknown provider type: " + provider_options.name); } @@ -406,7 +447,7 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path) { if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); - if (config->model.type == "llama" || config->model.type == "gemma" || config->model.type == "gemma2" || config->model.type == "mistral" || config->model.type == "phi" || config->model.type == "phi3" || config->model.type == "phi3small" || config->model.type == "qwen2") + if (config->model.type == "llama" || config->model.type == "gemma" || config->model.type == "gemma2" || config->model.type == "mistral" || config->model.type == "phi" || config->model.type == "phi3" || config->model.type == "phi3small" || config->model.type == "phimoe" || config->model.type == "qwen2") return std::make_shared(std::move(config), ort_env); if (config->model.type == "whisper") return std::make_shared(std::move(config), ort_env); @@ -428,7 +469,7 @@ std::shared_ptr CreateGeneratorParams() { void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream) { auto shape_info = in.GetTensorTypeAndShapeInfo(); auto shape = shape_info->GetShape(); - assert(shape_info->GetElementType() == Ort::TypeToTensorType::type); + assert(shape_info->GetElementType() == Ort::TypeToTensorType); bool allocate_p_out = p_out == nullptr; if (p_out) { @@ -467,7 +508,7 @@ void ConvertFp32ToFp16(OrtAllocator& allocator, OrtValue& in, std::unique_ptrGetShape(); - assert(shape_info->GetElementType() == Ort::TypeToTensorType::type); + assert(shape_info->GetElementType() == Ort::TypeToTensorType); bool allocate_p_out = p_out == nullptr; if (p_out) { diff --git a/src/models/model.h b/src/models/model.h index 5a6e39947..8bfccb8da 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -51,7 +51,7 @@ struct State { int current_batch_size_{0}; }; -struct TokenizerStream { +struct TokenizerStream : LeakChecked { TokenizerStream(const Tokenizer& tokenizer); const std::string& Decode(int32_t token); @@ -66,7 +66,7 @@ struct TokenizerStream { // Sequence length is vector.size()/count std::vector PadInputs(std::span> sequences, int32_t pad_token_id); -struct Tokenizer : std::enable_shared_from_this { +struct Tokenizer : std::enable_shared_from_this, LeakChecked { Tokenizer(Config& config); std::unique_ptr CreateStream() const; @@ -108,7 +108,7 @@ struct SessionInfo { std::unordered_map inputs_, outputs_; }; -struct Model : std::enable_shared_from_this { +struct Model : std::enable_shared_from_this, LeakChecked { Model(std::unique_ptr config); virtual ~Model(); diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index eebfe0278..ddc617f79 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -73,6 +73,7 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o #include "onnxruntime_c_api.h" #include "../span.h" #include "../logging.h" +#include "env_utils.h" #if defined(__ANDROID__) #include @@ -93,11 +94,14 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o #define PATH_MAX (4096) #endif -#define LOG_DEBUG(...) Generators::Log("debug", __VA_ARGS__) -#define LOG_INFO(...) Generators::Log("info", __VA_ARGS__) -#define LOG_WARN(...) Generators::Log("warning", __VA_ARGS__) -#define LOG_ERROR(...) Generators::Log("error", __VA_ARGS__) -#define LOG_FATAL(...) Generators::Log("fatal", __VA_ARGS__) +#define LOG_WHEN_ENABLED(LOG_FUNC) \ + if (Generators::g_log.enabled && Generators::g_log.ort_lib) LOG_FUNC + +#define LOG_DEBUG(...) LOG_WHEN_ENABLED(Generators::Log("debug", __VA_ARGS__)) +#define LOG_INFO(...) LOG_WHEN_ENABLED(Generators::Log("info", __VA_ARGS__)) +#define LOG_WARN(...) LOG_WHEN_ENABLED(Generators::Log("warning", __VA_ARGS__)) +#define LOG_ERROR(...) LOG_WHEN_ENABLED(Generators::Log("error", __VA_ARGS__)) +#define LOG_FATAL(...) LOG_WHEN_ENABLED(Generators::Log("fatal", __VA_ARGS__)) #endif @@ -112,21 +116,6 @@ using OrtApiBaseFn = const OrtApiBase* (*)(void); inline const OrtApi* api{}; #if defined(__linux__) -inline void* LoadDynamicLibraryIfExists(const std::string& path) { - LOG_INFO("Attempting to dlopen %s", path.c_str()); - void* ort_lib_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); - if (ort_lib_handle == nullptr) { - return nullptr; - } - -#if !defined(__ANDROID__) // RTLD_DI_ORIGIN not available on Android - char pathname[PATH_MAX]; - dlinfo((void*)ort_lib_handle, RTLD_DI_ORIGIN, &pathname); - LOG_INFO("Loaded native library at %s", pathname); -#endif - return ort_lib_handle; -} - inline std::string GetCurrentModuleDir() { Dl_info dl_info; dladdr((void*)GetCurrentModuleDir, &dl_info); @@ -140,6 +129,31 @@ inline std::string GetCurrentModuleDir() { return module_directory; } +inline void* LoadDynamicLibraryIfExists(const std::string& path) { + LOG_INFO("Attempting to dlopen %s", path.c_str()); + void* ort_lib_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (ort_lib_handle == nullptr) { + char* err = dlerror(); + LOG_WARN("Error while dlopen: %s", (err != nullptr ? err : "Unknown")); + // Trying current dir + std::string current_module_dir = GetCurrentModuleDir(); + std::string local_path{current_module_dir + "/" + path}; + LOG_INFO("Attempting to dlopen %s", local_path.c_str()); + ort_lib_handle = dlopen(local_path.c_str(), RTLD_NOW | RTLD_LOCAL); + } + if (ort_lib_handle) { +#if !defined(__ANDROID__) // RTLD_DI_ORIGIN not available on Android + char pathname[PATH_MAX]; + dlinfo((void*)ort_lib_handle, RTLD_DI_ORIGIN, &pathname); + LOG_INFO("Loaded native library at %s", pathname); +#endif + } else { + char* err = dlerror(); + LOG_WARN("Error while dlopen: %s", (err != nullptr ? err : "Unknown")); + } + return ort_lib_handle; +} + inline void InitApiWithDynamicFn(OrtApiBaseFn ort_api_base_fn) { if (ort_api_base_fn == nullptr) { throw std::runtime_error("OrtGetApiBase not found"); @@ -175,6 +189,13 @@ inline void InitApi() { return; } + bool ort_lib = false; + Generators::GetEnvironmentVariable("ORTGENAI_LOG_ORT_LIB", ort_lib); + if (ort_lib) { + Generators::SetLogBool("enabled", true); + Generators::SetLogBool("ort_lib", true); + } + #if defined(__linux__) // If the GenAI library links against the onnxruntime library, it will have a dependency on a specific // version of OrtGetApiBase. @@ -202,38 +223,12 @@ inline void InitApi() { #if !defined(__ANDROID__) if (ort_lib_handle == nullptr) { - const std::array target_libraries = { - std::string("libonnxruntime.so"), - std::string("libonnxruntime.so.1.18.0"), - std::string("libonnxruntime.so.1.19.0"), - std::string("libonnxruntime.so.1.20.0")}; - - // Search parent directory - std::string current_module_dir = GetCurrentModuleDir(); - for (const std::string& lib_name : target_libraries) { - std::string pip_path{current_module_dir + "/" + lib_name}; - ort_lib_handle = LoadDynamicLibraryIfExists(pip_path); - if (ort_lib_handle != nullptr) { - break; - } - } - - if (ort_lib_handle == nullptr) { - // Search for pip installation - for (const std::string& lib_name : target_libraries) { - std::string pip_path{current_module_dir + "/../onnxruntime/capi/" + lib_name}; - ort_lib_handle = LoadDynamicLibraryIfExists(pip_path); - if (ort_lib_handle != nullptr) { - break; - } - } - } + ort_lib_handle = LoadDynamicLibraryIfExists("libonnxruntime.so.1"); } #endif if (ort_lib_handle == nullptr) { - char* err = dlerror(); - throw std::runtime_error(std::string("Failed to load ") + path.c_str() + ": " + (err != nullptr ? err : "Unknown")); + throw std::runtime_error(std::string("Failed to load onnxruntime. Set ORTGENAI_LOG_ORT_LIB envvar to enable detailed logging.")); } OrtApiBaseFn ort_api_base_fn = (OrtApiBaseFn)dlsym(ort_lib_handle, "OrtGetApiBase"); @@ -527,6 +522,19 @@ struct OrtSessionOptions { OrtSessionOptions& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena OrtSessionOptions& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena + OrtSessionOptions& EnableCpuEpFallback(); + OrtSessionOptions& DisableCpuEpFallback(); + + OrtSessionOptions& EnableQuantQdq(); + OrtSessionOptions& DisableQuantQdq(); + + OrtSessionOptions& EnableQuantQdqCleanup(); + OrtSessionOptions& DisableQuantQdqCleanup(); + + OrtSessionOptions& SetEpContextEnable(); + OrtSessionOptions& SetEpContextEmbedMode(const char* mode); + OrtSessionOptions& SetEpContextFilePath(const char* file_path); + OrtSessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath OrtSessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling diff --git a/src/models/onnxruntime_inline.h b/src/models/onnxruntime_inline.h index cd7180e93..bb0ef70b2 100644 --- a/src/models/onnxruntime_inline.h +++ b/src/models/onnxruntime_inline.h @@ -57,61 +57,40 @@ struct StringAllocator : OrtAllocator { std::string string_; }; -// This template converts a C++ type into it's ONNXTensorElementDataType +template +struct TypeList {}; + +using TensorTypes = TypeList; + +// Variable templates to convert a C++ type into it's ONNXTensorElementDataType template -struct TypeToTensorType; +inline constexpr ONNXTensorElementDataType TypeToTensorType = T::Unsupported_Type; // Force a compile error if hit, please add specialized version if type is valid template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; template <> -struct TypeToTensorType { - static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; -}; +inline constexpr ONNXTensorElementDataType TypeToTensorType = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; inline std::vector GetAvailableProviders() { int len; @@ -534,6 +513,42 @@ inline OrtSessionOptions& OrtSessionOptions::DisableCpuMemArena() { return *this; } +inline OrtSessionOptions& OrtSessionOptions::EnableCpuEpFallback() { + return AddConfigEntry("session.disable_cpu_ep_fallback", "0"); +} + +inline OrtSessionOptions& OrtSessionOptions::DisableCpuEpFallback() { + return AddConfigEntry("session.disable_cpu_ep_fallback", "1"); +} + +inline OrtSessionOptions& OrtSessionOptions::EnableQuantQdq() { + return AddConfigEntry("session.disable_quant_qdq", "0"); +} + +inline OrtSessionOptions& OrtSessionOptions::DisableQuantQdq() { + return AddConfigEntry("session.disable_quant_qdq", "1"); +} + +inline OrtSessionOptions& OrtSessionOptions::EnableQuantQdqCleanup() { + return AddConfigEntry("session.enable_quant_qdq_cleanup", "1"); +} + +inline OrtSessionOptions& OrtSessionOptions::DisableQuantQdqCleanup() { + return AddConfigEntry("session.enable_quant_qdq_cleanup", "0"); +} + +inline OrtSessionOptions& OrtSessionOptions::SetEpContextEnable() { + return AddConfigEntry("ep.context_enable", "1"); +} + +inline OrtSessionOptions& OrtSessionOptions::SetEpContextEmbedMode(const char* mode) { + return AddConfigEntry("ep.context_embed_mode", mode); +} + +inline OrtSessionOptions& OrtSessionOptions::SetEpContextFilePath(const char* file_path) { + return AddConfigEntry("ep.context_file_path", file_path); +} + inline OrtSessionOptions& OrtSessionOptions::SetExecutionMode(ExecutionMode execution_mode) { Ort::ThrowOnError(Ort::api->SetSessionExecutionMode(this, execution_mode)); return *this; @@ -1132,7 +1147,7 @@ inline void OrtValue::FillSparseTensorBlockSparse(const OrtMemoryInfo& data_mem_ template inline std::unique_ptr OrtValue::CreateTensor(const OrtMemoryInfo& info, std::span p_data, std::span shape) { - return CreateTensor(info, p_data.data(), p_data.size_bytes(), shape, Ort::TypeToTensorType::type); + return CreateTensor(info, p_data.data(), p_data.size_bytes(), shape, Ort::TypeToTensorType); } inline std::unique_ptr OrtValue::CreateTensor(const OrtMemoryInfo& info, void* p_data, size_t p_data_byte_count, std::span shape, @@ -1144,7 +1159,7 @@ inline std::unique_ptr OrtValue::CreateTensor(const OrtMemoryInfo& inf template inline std::unique_ptr OrtValue::CreateTensor(OrtAllocator& allocator, std::span shape) { - return CreateTensor(allocator, shape, Ort::TypeToTensorType::type); + return CreateTensor(allocator, shape, Ort::TypeToTensorType); } inline std::unique_ptr OrtValue::CreateTensor(OrtAllocator& allocator, std::span shape, ONNXTensorElementDataType type) { @@ -1158,7 +1173,7 @@ inline std::unique_ptr OrtValue::CreateTensor(OrtAllocator& allocator, template inline std::unique_ptr OrtValue::CreateSparseTensor(const OrtMemoryInfo& info, T* p_data, const OrtShape& dense_shape, const OrtShape& values_shape) { - return CreateSparseTensor(info, p_data, dense_shape, values_shape, Ort::TypeToTensorType::type); + return CreateSparseTensor(info, p_data, dense_shape, values_shape, Ort::TypeToTensorType); } inline std::unique_ptr OrtValue::CreateSparseTensor(const OrtMemoryInfo& info, void* p_data, const OrtShape& dense_shape, @@ -1171,7 +1186,7 @@ inline std::unique_ptr OrtValue::CreateSparseTensor(const OrtMemoryInf template inline std::unique_ptr OrtValue::CreateSparseTensor(OrtAllocator* allocator, const OrtShape& dense_shape) { - return CreateSparseTensor(allocator, dense_shape, Ort::TypeToTensorType::type); + return CreateSparseTensor(allocator, dense_shape, Ort::TypeToTensorType); } inline std::unique_ptr OrtValue::CreateSparseTensor(OrtAllocator* allocator, const OrtShape& dense_shape, diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index aaded5817..f0745fe06 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -15,7 +15,7 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArrayHasInput(model_.config_->model.decoder.inputs.attention_mask); has_posid_input_ = model_.session_info_->HasInput(model_.config_->model.decoder.inputs.position_ids); - type_ = Ort::TypeToTensorType::type; + type_ = Ort::TypeToTensorType; if (has_mask_input_) { type_ = model_.session_info_->GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); } @@ -28,7 +28,7 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArrayGetInputDataType(model_.config_->model.decoder.inputs.position_ids); } - if (type_ != Ort::TypeToTensorType::type && type_ != Ort::TypeToTensorType::type) + if (type_ != Ort::TypeToTensorType && type_ != Ort::TypeToTensorType) throw std::runtime_error("position_ids & attention_mask only support int32 or int64 types"); std::array shape{state_.params_->batch_size, state_.params_->sequence_length}; // Only batch_size initially, as we haven't expanded over the beams yet @@ -38,7 +38,7 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArrayBatchBeamSize()); - if (type_ == Ort::TypeToTensorType::type) + if (type_ == Ort::TypeToTensorType) InitializeTensors(shape, sequence_lengths_unk); else InitializeTensors(shape, sequence_lengths_unk); @@ -117,7 +117,7 @@ void PositionInputs::UpdatePositionIDs(int current_length) { #if USE_CUDA position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); assert(model_.device_type_ == DeviceType::CUDA); - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), position_ids_next_->GetTensorData(), sizeof(int32_t) * position_ids_shape_[0], @@ -137,7 +137,7 @@ void PositionInputs::UpdatePositionIDs(int current_length) { ComPtr target_resource; Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { auto source = std::span(position_ids_next_->GetTensorData(), sizeof(int32_t) * position_ids_shape_[0]); model_.GetDmlUploadHeap()->BeginUploadToGpu( @@ -182,7 +182,7 @@ void PositionInputs::UpdatePositionIDs(int current_length) { } break; #endif case DeviceType::CPU: { - if (type_ == Ort::TypeToTensorType::type) + if (type_ == Ort::TypeToTensorType) UpdatePositionIDsImpl(); else UpdatePositionIDsImpl(); @@ -190,7 +190,7 @@ void PositionInputs::UpdatePositionIDs(int current_length) { } #if USE_CUDA case DeviceType::CUDA: - if (type_ == Ort::TypeToTensorType::type) + if (type_ == Ort::TypeToTensorType) cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); else cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); @@ -224,7 +224,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { attention_mask_shape_[1] = state_.params_->search.max_length; attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); if (is_first_mask_update_) { - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), 0, sizeof(int32_t) * attention_mask_shape_[0] * attention_mask_shape_[1], @@ -294,7 +294,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { } #endif case DeviceType::CPU: { - if (type_ == Ort::TypeToTensorType::type) + if (type_ == Ort::TypeToTensorType) UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), current_length); @@ -308,7 +308,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { case DeviceType::CUDA: { int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : current_length; bool update_only = sb_attention_mask_ && !is_first_mask_update_; - if (type_ == Ort::TypeToTensorType::type) { + if (type_ == Ort::TypeToTensorType) { cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), static_cast(attention_mask_shape_[0]), diff --git a/src/models/utils.cpp b/src/models/utils.cpp index 0e4256e79..7f4d43629 100644 --- a/src/models/utils.cpp +++ b/src/models/utils.cpp @@ -6,31 +6,31 @@ namespace Generators { size_t SizeOf(ONNXTensorElementDataType type) { switch (type) { - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(uint8_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(int8_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(uint16_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(int16_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(uint32_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(int32_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(int64_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(int64_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(bool); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(float); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(double); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(Ort::Float16_t); - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return sizeof(Ort::BFloat16_t); default: throw std::runtime_error("Unsupported ONNXTensorElementDataType in GetTypeSize"); diff --git a/src/ort_genai.h b/src/ort_genai.h index d0c1d0c75..4a83b69e2 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -232,6 +232,12 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_GetSequenceData(this, index); } + std::unique_ptr GetOutput(const char* name) { + OgaTensor* out; + OgaCheckResult(OgaGenerator_GetOutput(this, name, &out)); + return std::unique_ptr(out); + } + #if __cplusplus >= 202002L std::span GetSequence(size_t index) const { return {GetSequenceData(index), GetSequenceCount(index)}; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 6f26d2857..f835c2111 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -208,6 +208,50 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + auto* ortvalue_output = generator.state_->GetOutput(name); + auto type_info = ortvalue_output->GetTensorTypeAndShapeInfo(); + std::unique_ptr ortvalue_clone = OrtValue::CreateTensor(generator.model_->allocator_cpu_, + type_info->GetShape(), + type_info->GetElementType()); + // Copy data to ortvalue_clone + auto element_size = Generators::SizeOf(type_info->GetElementType()); + auto data_size = type_info->GetElementCount() * element_size; + if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::CUDA) { +#if USE_CUDA + cudaMemcpy(ortvalue_clone->GetTensorMutableRawData(), ortvalue_output->GetTensorMutableRawData(), data_size, cudaMemcpyDeviceToHost); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU && generator.model_->device_type_ == Generators::DeviceType::DML) { +#if USE_DML + ComPtr gpu_resource; + Ort::ThrowOnError(generator.model_->GetOrtDmlApi()->GetD3D12ResourceFromAllocation( + generator.model_->allocator_device_, + ortvalue_output->GetTensorMutableRawData(), + &gpu_resource)); + auto cpu_tensor = ortvalue_clone->GetTensorMutableRawData(); + generator.model_->GetDmlReadbackHeap()->ReadbackFromGpu( + std::span(reinterpret_cast(cpu_tensor), data_size), + gpu_resource.Get(), + 0, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +#endif + } else if (ortvalue_output->GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_CPU) { + std::copy(static_cast(ortvalue_output->GetTensorMutableRawData()), + static_cast(ortvalue_output->GetTensorMutableRawData()) + data_size, + static_cast(ortvalue_clone->GetTensorMutableRawData())); + } else { + throw std::runtime_error("Unsupported Device type: " + std::to_string(ortvalue_output->GetTensorMemoryInfo().GetDeviceType())); + } + + auto tensor = std::make_shared(std::move(ortvalue_clone)); + tensor->external_owner_ = tensor; + *out = reinterpret_cast(tensor.get()); + return nullptr; + OGA_CATCH +} + size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* oga_generator, size_t index) { auto& generator = *reinterpret_cast(oga_generator); return generator.GetSequence(static_cast(index)).GetCPU().size(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index ec97ce4e5..7b1f084c2 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -224,6 +224,14 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/* + * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor + * and will be released when the OgaTensor is destroyed + * \param[in] generator The generator to run the GetOutput on the name provided and the out pointer to store the output + * \return OgaResult containing the error message if the computation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out); + /* * \brief Returns the number of tokens in the sequence at the given index. * \param[in] generator The generator to get the count of the tokens for the sequence at the given index. diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 189f4135c..6628862a3 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -13,9 +13,6 @@ if(NOT (CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Linu target_link_libraries(python PRIVATE ${ONNXRUNTIME_LIB}) endif() -if(CMAKE_SYSTEM_NAME STREQUAL "Linux") - set_property(TARGET python APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN/../onnxruntime/capi/") -endif() set_target_properties(python PROPERTIES OUTPUT_NAME "onnxruntime_genai") if(CMAKE_GENERATOR_TOOLSET MATCHES "Visual Studio") @@ -27,7 +24,7 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER) cmake_policy(SET CMP0104 OLD) enable_language(CUDA) set_target_properties(python PROPERTIES LINKER_LANGUAGE CUDA) - target_link_libraries(python PRIVATE cublasLt cublas cudnn curand cufft cudart) + target_link_libraries(python PRIVATE cublas curand cudart) endif() # Avoid warning of Calling FetchContent_Populate(Lib) is deprecated temporarily diff --git a/src/python/__init__.py.in b/src/python/__init__.py.in index 63b949a97..3811d48b9 100644 --- a/src/python/__init__.py.in +++ b/src/python/__init__.py.in @@ -6,7 +6,7 @@ from onnxruntime_genai import _dll_directory __version__ = "@VERSION_INFO@" __id__ = "@TARGET_NAME@" -_dll_directory.add_onnxruntime_dependency() +_dll_directory.add_onnxruntime_dependency(__id__) try: from onnxruntime_genai.@PACKAGE_DIR_NAME@ import * diff --git a/src/python/py/_dll_directory.py b/src/python/py/_dll_directory.py index 9479003ff..0e47dbdce 100644 --- a/src/python/py/_dll_directory.py +++ b/src/python/py/_dll_directory.py @@ -8,10 +8,16 @@ def _is_windows(): return sys.platform.startswith("win") -def add_onnxruntime_dependency(): - """Add the onnxruntime DLL directory to the DLL search path. +def _is_linux(): + return sys.platform.startswith("linux") + + +def add_onnxruntime_dependency(package_id: str): + """Add the onnxruntime shared library dependency. - This function is a no-op on non-Windows platforms. + On Windows, this function adds the onnxruntime DLL directory to the DLL search path. + On Linux, this function loads the onnxruntime shared library and its dependencies + so that they can be found by the dynamic linker. """ if _is_windows(): import importlib.util @@ -21,6 +27,35 @@ def add_onnxruntime_dependency(): ort_package_path = ort_package.submodule_search_locations[0] os.add_dll_directory(os.path.join(ort_package_path, "capi")) + if package_id == "onnxruntime-genai-directml": + # Load the DirectML.dll library to avoid loading it again in the native code. + # This avoids needing to know the exact path of the shared library from native code. + dml_path = os.path.join(ort_package_path, "capi", "DirectML.dll") + if not os.path.exists(dml_path): + raise ImportError("Could not find the DirectML.dll library. " + "Please check if the onnxruntime directml package is installed.") + + import ctypes + _ = ctypes.CDLL(dml_path) + + elif _is_linux(): + import importlib.util + import ctypes + import glob + + ort_package = importlib.util.find_spec("onnxruntime") + if not ort_package: + raise ImportError("Could not find the onnxruntime package.") + + # Load the onnxruntime shared library here since we can find the path in python with ease. + # This avoids needing to know the exact path of the shared library from native code. + ort_package_path = ort_package.submodule_search_locations[0] + ort_lib_path = glob.glob(os.path.join(ort_package_path, "capi", "libonnxruntime.so*")) + if not ort_lib_path: + raise ImportError("Could not find the onnxruntime shared library.") + + _ = ctypes.CDLL(ort_lib_path[0]) + def add_cuda_dependency(): """Add the CUDA DLL directory to the DLL search path. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 0f2ab24c1..3e2e00b02 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -32,7 +32,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers self.vocab_size = config.vocab_size - self.activation = config.hidden_activation if hasattr(config, "hidden_activation") else config.hidden_act + self.activation = config.hidden_activation if hasattr(config, "hidden_activation") and config.hidden_activation is not None else config.hidden_act self.model_name_or_path = config._name_or_path self.model_type = config.architectures[0] @@ -1620,11 +1620,11 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size) + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) else: # Load PyTorch model - extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir} - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) + extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, cache_dir=self.cache_dir, use_auth_token=True, trust_remote_code=True, **extra_kwargs) # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index 48c4ec7bd..f15f21cb9 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -83,17 +83,18 @@ def is_empty(self): class QuantizedModel: - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): self.quant_type = quant_type self.embedding = TensorModule() self.final_norm = TensorModule() self.lm_head = TensorModule() - self.layers = [] + self.layers = {} + self.num_layers = num_layers layer_id = 0 for weight_file in os.listdir(input_path): if weight_file.endswith(".safetensors"): - module = QuantizedDecoderLayer(layer_id, bits, group_size) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) weights = load_file(os.path.join(input_path, weight_file)) # Map weights to modules @@ -115,10 +116,9 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in else: curr_layer_id = int(name.split(".")[2]) if curr_layer_id != layer_id: - # Add layer to list of modules - self.layers.append(module) + # Switch layer module used layer_id = curr_layer_id - module = QuantizedDecoderLayer(layer_id, bits, group_size) + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, bits, group_size)) # Map weights and biases of norm, attention, and feed-forward network # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj @@ -288,11 +288,7 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in module.mlp.up_proj.g_idx = tensor else: raise NotImplementedError(f"{name} in your quantized model is not recognized.") - - if not module.is_empty(): - # Append final layer to list of layers - self.layers.append(module) - + # Set LM head weights + biases if not already set if self.lm_head.weight is None: # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) @@ -301,6 +297,7 @@ def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, in self.lm_head.bias = self.embedding.bias # Sort list of layers by layer id + self.layers = list(self.layers.values()) self.layers.sort(key=lambda m: m.layer_id) # Set properties of each layer based on quantization type @@ -521,11 +518,13 @@ def pack_ort_format(self, module, intweight): class AWQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + def __init__(self, quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): + if i >= self.num_layers: + break print(f"Unpacking and repacking layer {i}") # Unpack and repack all `QuantizedTensorModule` classes in attention @@ -586,14 +585,16 @@ def reverse_reorder_tensor(self, tensor, bits): class GPTQModel(QuantizedModel): - def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): - super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + def __init__(self, quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): + super().__init__(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) # Unpack and repack all `QuantizedTensorModule` classes in model for i, layer in enumerate(self.layers): + if i >= self.num_layers: + break print(f"Unpacking and repacking layer {i}") - # Unpack and repack all `QuantizedTensorModule` classes in attention + # Unpack and repack all `QuantizedTensorModule` classes in attention for name, q_tensors in layer.self_attn.__dict__.items(): if isinstance(q_tensors, QuantizedTensorModule) and q_tensors.qweight is not None: self.handle_qzeros(q_tensors) @@ -642,16 +643,16 @@ def __init__(self, module): class QuantModel: @staticmethod - def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size): + def from_pretrained(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers): """ Unpack quantized weights in PyTorch models, store them in a standard format, and repack them into ONNX Runtime's format. Also performs any pre-processing and post-processing when unpacking the quantized weights. """ if quant_type == "awq": - model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size) + model = AWQModel(quant_type, input_path, bits, group_size, q_size, kv_size, intermediate_size, num_layers) elif quant_type == "gptq": - model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size) + model = GPTQModel(quant_type, input_path, bits, group_size, use_g_idx, q_size, kv_size, intermediate_size, num_layers) else: raise NotImplementedError(f"The {quant_type} quantized model is not currently supported.") diff --git a/src/python/python.cpp b/src/python/python.cpp index 81ea20bec..229fb90d3 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -46,29 +46,29 @@ pybind11::array_t ToPython(std::span v) { ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) { switch (type.num()) { case pybind11::detail::npy_api::NPY_BOOL_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_UINT8_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_INT8_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_UINT16_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_INT16_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_UINT32_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_INT32_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_UINT64_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_INT64_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case 23 /*NPY_FLOAT16*/: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_FLOAT_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; case pybind11::detail::npy_api::NPY_DOUBLE_: - return Ort::TypeToTensorType::type; + return Ort::TypeToTensorType; default: throw std::runtime_error("Unsupported numpy type"); } @@ -76,64 +76,46 @@ ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) { int ToNumpyType(ONNXTensorElementDataType type) { switch (type) { - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_BOOL_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_INT8_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_UINT8_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_INT16_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_UINT16_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_INT32_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_UINT32_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_INT64_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_UINT64_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return 23 /*NPY_FLOAT16*/; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_FLOAT_; - case Ort::TypeToTensorType::type: + case Ort::TypeToTensorType: return pybind11::detail::npy_api::NPY_DOUBLE_; default: throw std::runtime_error("Unsupported onnx type"); } } +template +std::string ToFormatDescriptor(ONNXTensorElementDataType type, Ort::TypeList) { + std::string result; + ((type == Ort::TypeToTensorType ? result = pybind11::format_descriptor::format(), true : false) || ...); + if (result.empty()) + throw std::runtime_error("Unsupported onnx type"); + return result; +} + std::string ToFormatDescriptor(ONNXTensorElementDataType type) { - switch (type) { - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - case Ort::TypeToTensorType::type: - return pybind11::format_descriptor::format(); - default: - throw std::runtime_error("Unsupported onnx type"); - } + return ToFormatDescriptor(type, Ort::TensorTypes{}); } std::unique_ptr ToOrtValue(pybind11::array& v) { @@ -519,29 +501,9 @@ PYBIND11_MODULE(onnxruntime_genai, m) { m.def("set_log_options", &SetLogOptions); - m.def("is_cuda_available", []() { -#if USE_CUDA - return true; -#else - return false; -#endif - }); - - m.def("is_dml_available", []() { -#if USE_DML - return true; -#else - return false; -#endif - }); - - m.def("is_rocm_available", []() { -#if USE_ROCM - return true; -#else - return false; -#endif - }); + m.def("is_cuda_available", []() { return USE_CUDA != 0; }); + m.def("is_dml_available", []() { return USE_DML != 0; }); + m.def("is_rocm_available", []() { return USE_ROCM != 0; }); m.def("set_current_gpu_device_id", [](int device_id) { Ort::SetCurrentGpuDeviceId(device_id); }); m.def("get_current_gpu_device_id", []() { return Ort::GetCurrentGpuDeviceId(); }); diff --git a/src/python/setup.py.in b/src/python/setup.py.in index d4e5acf4f..a12fb0ec7 100644 --- a/src/python/setup.py.in +++ b/src/python/setup.py.in @@ -1,6 +1,7 @@ from setuptools import setup, find_packages from setuptools.dist import Distribution import sys +import os from os import path if sys.version_info < (3, 0): @@ -24,18 +25,27 @@ package_name = '@TARGET_NAME@' def _onnxruntime_dependency() -> str: dependency = None + # Use dev version as default since CI tests use nightly version for testing + ort_version = os.environ.get("ONNXRUNTIME_VERSION", "1.19.0.dev20240805002") + is_nightly = True if "dev" in ort_version else False + if package_name == "onnxruntime-genai": - dependency = "onnxruntime" + dependency = "onnxruntime" if not is_nightly else "ort-nightly" + + import platform + # win arm64 whls are only available in onnxruntime-qnn + if platform.machine() == "ARM64" and sys.platform.startswith("win"): + dependency = "onnxruntime-qnn" if not is_nightly else "ort-nightly-qnn" elif package_name == "onnxruntime-genai-cuda": - dependency = "onnxruntime-gpu" + dependency = "onnxruntime-gpu" if not is_nightly else "ort-nightly-gpu" elif package_name == "onnxruntime-genai-directml": - dependency = "onnxruntime-directml" + dependency = "onnxruntime-directml" if not is_nightly else "ort-nightly-directml" elif package_name == "onnxruntime-genai-rocm": - dependency = "onnxruntime-rocm" + dependency = "onnxruntime-rocm" if not is_nightly else "ort-nightly-rocm" else: raise ValueError(f'Unable to determine the onnxruntime dependency for {package_name}.') - return dependency + return dependency if not ort_version else dependency + ">=" + ort_version setup( @@ -48,8 +58,8 @@ setup( include_package_data=True, package_data={'': ['*.pyd', '*.dll', '*.so*'] + extras}, install_requires=[ - 'numpy<2', - # _onnxruntime_dependency(), # Uncomment this when the onnxruntime stable release contains the ort shared lib + 'numpy>=1.21.6', + _onnxruntime_dependency(), ], distclass=BinaryDistribution, author="Microsoft Corporation", diff --git a/src/search.h b/src/search.h index 4d5df105a..730159d03 100644 --- a/src/search.h +++ b/src/search.h @@ -5,7 +5,7 @@ namespace Generators { -struct Search { +struct Search : LeakChecked { Search(const GeneratorParams& params) : params_{params.shared_from_this()} {} virtual ~Search() = default; @@ -104,7 +104,7 @@ struct BeamSearch_Cpu : Search_Cpu { RoamingArray GetSequence(size_t index) override; RoamingArray GetSequence(size_t batch_id, size_t beam_id); - bool IsDone() const; + bool IsDone() const override; void SelectTop() override; diff --git a/src/tensor.h b/src/tensor.h index 6fcde20c9..25b6dd706 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. namespace Generators { -struct Tensor : std::enable_shared_from_this { +struct Tensor : std::enable_shared_from_this, LeakChecked { Tensor() = default; Tensor(std::unique_ptr ort_tensor) : ort_tensor_{std::move(ort_tensor)} {} diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index eba5aff15..8e8cc13cb 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -173,6 +173,65 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { } #endif +TEST(CAPITests, GetOutputCAPI) { + std::vector input_ids_shape{2, 4}; + std::vector input_ids{0, 0, 0, 52, 0, 0, 195, 731}; + + auto input_sequence_length = input_ids_shape[1]; + auto batch_size = input_ids_shape[0]; + int max_length = 10; + + // To generate this file: + // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 + // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) + + auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + + auto params = OgaGeneratorParams::Create(*model); + params->SetSearchOption("max_length", max_length); + params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + + auto generator = OgaGenerator::Create(*model, *params); + + // check prompt + // full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5] + std::vector expected_sampled_logits_prompt{0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + 0.30323744f, 0.0545997f, 0.03894716f, 0.11702324f, 0.0410665f, + -0.12675379f, -0.04443946f, 0.14492269f, 0.03021223f, -0.03212897f, + 0.29694548f, 0.00955007f, 0.0430819f, 0.10063869f, 0.0437237f, + 0.27329233f, 0.00841076f, -0.1060291f, 0.11328877f, 0.13369876f, + -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, + 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; + + generator->ComputeLogits(); + auto prompt_logits_ptr = generator->GetOutput("logits"); + auto prompt_logits = static_cast(prompt_logits_ptr->Data()); + int num_prompt_outputs_to_check = 40; + int sample_size = 200; + float tolerance = 0.001f; + // Verify outputs match expected outputs + for (int i = 0; i < num_prompt_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance); + } + + generator->GenerateNextToken(); + // check for the 1st token generation + // full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5] + std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, + 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; + + generator->ComputeLogits(); + auto token_gen_logits_ptr = generator->GetOutput("logits"); + auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); + int num_token_gen_outputs_to_check = 10; + + for (int i = 0; i < num_token_gen_outputs_to_check; i++) { + EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); + } + generator->GenerateNextToken(); +} + #if TEST_PHI2 struct Phi2Test { diff --git a/test/csharp/Microsoft.ML.OnnxRuntimeGenAI.Tests.csproj b/test/csharp/Microsoft.ML.OnnxRuntimeGenAI.Tests.csproj index 978deb04e..7d6c8ea74 100644 --- a/test/csharp/Microsoft.ML.OnnxRuntimeGenAI.Tests.csproj +++ b/test/csharp/Microsoft.ML.OnnxRuntimeGenAI.Tests.csproj @@ -1,7 +1,7 @@  - net6.0 + net8.0 false AnyCPU true @@ -27,22 +27,42 @@ - + + PreserveNewest false - + PreserveNewest false - + PreserveNewest false - + PreserveNewest false + + + + PreserveNewest + false + + + PreserveNewest + false + + + PreserveNewest + false + + + PreserveNewest + false + + PreserveNewest false diff --git a/test/python/_test_utils.py b/test/python/_test_utils.py index a314454ba..808f8930e 100644 --- a/test/python/_test_utils.py +++ b/test/python/_test_utils.py @@ -52,32 +52,85 @@ def run_subprocess( return completed_process -def download_models(download_path, device): - # python -m onnxruntime_genai.models.builder -m -p int4 -e cpu -o --extra_options num_hidden_layers=1 - model_names = { - "cpu": { - "phi-2": "microsoft/phi-2", - }, - "cuda": { - "phi-2": "microsoft/phi-2", - }, +def get_model_paths(): + hf_paths = { + "phi-2": "microsoft/phi-2", + # "phi-3-mini": "microsoft/Phi-3-mini-128k-instruct", } - for model_name, model_identifier in model_names[device].items(): - model_path = os.path.join(download_path, device, model_name) - if not os.path.exists(model_path): - command = [ - sys.executable, - "-m", - "onnxruntime_genai.models.builder", - "-m", - model_identifier, - "-p", - "int4", - "-e", - device, - "-o", - model_path, - "--extra_options", - "num_hidden_layers=1", - ] - run_subprocess(command).check_returncode() + + ci_data_path = os.path.join("/", "data", "ortgenai_pytorch_models") + if not os.path.exists(ci_data_path): + return {}, hf_paths + + # Note: If a model has over 4B parameters, please add a quantized version + # to `ci_paths` instead of `hf_paths` to reduce file size and testing time. + ci_paths = { + "llama-2": os.path.join(ci_data_path, "Llama-2-7B-Chat-GPTQ"), + "llama-3": os.path.join(ci_data_path, "Meta-Llama-3-8B-AWQ"), + "mistral-v0.2": os.path.join(ci_data_path, "Mistral-7B-Instruct-v0.2-GPTQ"), + # "phi-2": os.path.join(ci_data_path, "phi2"), + # "gemma-2b": os.path.join(ci_data_path, "gemma-1.1-2b-it"), + "gemma-7b": os.path.join(ci_data_path, "gemma-7b-it-awq"), + # "phi-3-mini": os.path.join(ci_data_path, "phi3-mini-128k-instruct"), + } + + return ci_paths, hf_paths + + +def download_model(model_name, input_path, output_path, precision, device, one_layer=True): + command = [ + sys.executable, + "-m", + "onnxruntime_genai.models.builder", + ] + + if model_name is not None: + # If model_name is provided: + # python -m onnxruntime_genai.models.builder -m -o -p -e + command += ["-m", model_name] + elif input_path != "": + # If input_path is provided: + # python -m onnxruntime_genai.models.builder -i -o -p -e + command += ["-i", input_path] + else: + raise Exception("Either `model_name` or `input_path` can be provided for PyTorch models, not both.") + + command += [ + "-o", + output_path, + "-p", + precision, + "-e", + device, + ] + + extra_options = ["--extra_options"] + if device == "cpu" and precision == "int4": + extra_options += ["int4_accuracy_level=4"] + if one_layer: + extra_options += ["num_hidden_layers=1"] + if len(extra_options) > 1: + command += extra_options + + run_subprocess(command).check_returncode() + + +def download_models(download_path, precision, device): + ci_paths, hf_paths = get_model_paths() + output_paths = [] + + # python -m onnxruntime_genai.models.builder -i -o -p -e + for model_name, input_path in ci_paths.items(): + output_path = os.path.join(download_path, model_name, precision, device) + if not os.path.exists(output_path): + download_model(None, input_path, output_path, precision, device) + output_paths.append(output_path) + + # python -m onnxruntime_genai.models.builder -m -o -p -e + for model_name, hf_name in hf_paths.items(): + output_path = os.path.join(download_path, model_name, precision, device) + if not os.path.exists(output_path): + download_model(hf_name, "", output_path, precision, device) + output_paths.append(output_path) + + return output_paths diff --git a/test/python/conftest.py b/test/python/conftest.py index 08498d184..d3a08df69 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -18,41 +18,44 @@ def pytest_addoption(parser): ) -def get_path_for_model_and_device(data_path, model_name, device): - return os.path.join(data_path, device, model_name) +def get_path_for_model(data_path, model_name, precision, device): + return os.path.join(data_path, model_name, precision, device) @pytest.fixture def phi2_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "phi-2", + "int4", ) @pytest.fixture def gemma_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "gemma", + "int4", ) @pytest.fixture def llama_for(request): return functools.partial( - get_path_for_model_and_device, + get_path_for_model, request.config.getoption("--test_models"), "llama", + "int4", ) @pytest.fixture def path_for_model(request): return functools.partial( - get_path_for_model_and_device, request.config.getoption("--test_models") + get_path_for_model, request.config.getoption("--test_models") ) diff --git a/test/python/requirements-cpu.txt b/test/python/requirements-cpu.txt index 3dbd6031f..d656c9ea6 100644 --- a/test/python/requirements-cpu.txt +++ b/test/python/requirements-cpu.txt @@ -1,4 +1,4 @@ -f https://download.pytorch.org/whl/torch_stable.html -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ torch==2.2.1+cpu -ort-nightly==1.19.0.dev20240717002 +ort-nightly==1.20.0.dev20240805004 diff --git a/test/python/requirements-cuda.txt b/test/python/requirements-cuda.txt index d934b5a7f..f0db6a66d 100644 --- a/test/python/requirements-cuda.txt +++ b/test/python/requirements-cuda.txt @@ -1,4 +1,4 @@ -f https://download.pytorch.org/whl/torch_stable.html -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -torch==2.2.1+cu118 -ort-nightly-gpu==1.19.0.dev20240719001 +torch==2.2.1+cu121 +ort-nightly-gpu==1.20.0.dev20240806001 diff --git a/test/python/requirements-directml.txt b/test/python/requirements-directml.txt index 414abdfd6..6e7b86828 100644 --- a/test/python/requirements-directml.txt +++ b/test/python/requirements-directml.txt @@ -1,4 +1,4 @@ -f https://download.pytorch.org/whl/torch_stable.html -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ torch==2.2.1+cpu -ort-nightly-directml==1.19.0.dev20240717002 +ort-nightly-directml==1.20.0.dev20240805004 diff --git a/test/python/test_onnxruntime_genai.py b/test/python/test_onnxruntime_genai.py index 41d615e51..212de1cfd 100644 --- a/test/python/test_onnxruntime_genai.py +++ b/test/python/test_onnxruntime_genai.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - import argparse +import json import logging import os import pathlib import sys import sysconfig -from typing import Union +from typing import Union, List import onnxruntime_genai as og from _test_utils import download_models, run_subprocess @@ -34,17 +34,22 @@ def run_onnxruntime_genai_api_tests( "--test_models", test_models, ] - run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_onnxruntime_genai_e2e_tests( cwd: Union[str, bytes, os.PathLike], log: logging.Logger, + output_paths: List[Union[str, bytes, os.PathLike]], ): log.debug("Running: ONNX Runtime GenAI E2E Tests") - command = [sys.executable, "test_onnxruntime_genai_e2e.py"] + command = [ + sys.executable, + "test_onnxruntime_genai_e2e.py", + "--models", + json.dumps(output_paths), + ] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -74,23 +79,19 @@ def main(): log.info("Running onnxruntime-genai tests pipeline") - if not args.e2e: - if not ( - sysconfig.get_platform().endswith("arm64") or sys.version_info.minor < 8 - ): - download_models(os.path.abspath(args.test_models), "cpu") - if og.is_cuda_available(): - download_models( - os.path.abspath(args.test_models), - "cuda", - ) - - run_onnxruntime_genai_api_tests( - os.path.abspath(args.cwd), log, os.path.abspath(args.test_models) - ) - - else: - run_onnxruntime_genai_e2e_tests(os.path.abspath(args.cwd), log) + # Get INT4 ONNX models + output_paths = [] + if not ( + sysconfig.get_platform().endswith("arm64") or sys.version_info.minor < 8 + ): + output_paths += download_models(os.path.abspath(args.test_models), "int4", "cpu") + if og.is_cuda_available(): + output_paths += download_models(os.path.abspath(args.test_models), "int4", "cuda") + + # Run ONNX Runtime GenAI tests + run_onnxruntime_genai_api_tests(os.path.abspath(args.cwd), log, os.path.abspath(args.test_models)) + if args.e2e: + run_onnxruntime_genai_e2e_tests(os.path.abspath(args.cwd), log, output_paths) return 0 diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index eaac1e087..9939242d2 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -1,37 +1,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from __future__ import annotations +import argparse +import json import os -import sys -import tempfile +import logging import onnxruntime_genai as og -from _test_utils import run_subprocess - -def download_model( - download_path: str | bytes | os.PathLike, device: str, model_identifier: str, precision: str -): - # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cpu -o download_path - # Or with cuda graph enabled: - # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options enable_cuda_graph=1 -o download_path - command = [ - sys.executable, - "-m", - "onnxruntime_genai.models.builder", - "-m", - model_identifier, - "-p", - precision, - "-e", - device, - "-o", - download_path, - ] - if device == "cuda": - command.append("--extra_options") - command.append("enable_cuda_graph=1") - run_subprocess(command).check_returncode() +logging.basicConfig( + format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG +) +log = logging.getLogger("onnxruntime-genai-tests") def run_model(model_path: str | bytes | os.PathLike): @@ -47,7 +28,7 @@ def run_model(model_path: str | bytes | os.PathLike): sequences = tokenizer.encode_batch(prompts) params = og.GeneratorParams(model) params.set_search_options(max_length=200) - params.try_graph_capture_with_max_batch_size(16) + params.try_graph_capture_with_max_batch_size(4) params.input_ids = sequences output_sequences = model.generate(params) @@ -55,10 +36,28 @@ def run_model(model_path: str | bytes | os.PathLike): assert output +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--models", + type=str, + required=True, + help="List of model paths to run. Pass as `json.dumps(model_paths)` to this argument.", + ) + + args = parser.parse_args() + args.models = json.loads(args.models) + return args + + if __name__ == "__main__": - for model_name in ["microsoft/phi-2"]: - for precision in ["int4", "fp32"]: - with tempfile.TemporaryDirectory() as temp_dir: - device = "cuda" if og.is_cuda_available() else "cpu" - download_model(temp_dir, device, model_name, precision) - run_model(temp_dir) + args = get_args() + for model_path in args.models: + try: + log.info(f"Running {model_path}") + run_model(model_path) + except Exception as e: + log.error(e) + log.error(f"Failed to run {model_path}", exc_info=True) diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index ecc8f8a31..5ea725e17 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -74,7 +74,7 @@ def generate_license(line_list): line_list.append('LICENSE') def generate_readme(line_list): - line_list.append('README.md') + line_list.append('PACKAGE.md') def generate_project_url(line_list, project_url): line_list.append("" + project_url + "") @@ -102,6 +102,8 @@ def generate_dependencies(xml_text, package_version, ort_package_name, ort_packa xml_text.append(f'') xml_text.append(f'') xml_text.append(f'') + if ort_package_name.endswith("DirectML"): + xml_text.append(f'') xml_text.append("") xml_text.append("") @@ -110,7 +112,7 @@ def generate_files(lines, args): lines.append('') lines.append(f'') - lines.append(f'') + lines.append(f'') lines.append(f'') def add_native_artifact_if_exists(xml_lines, runtime, artifact): @@ -125,6 +127,7 @@ def add_native_artifact_if_exists(xml_lines, runtime, artifact): if runtime.startswith("win"): add_native_artifact_if_exists(lines, runtime, "onnxruntime-genai.lib") add_native_artifact_if_exists(lines, runtime, "onnxruntime-genai.dll") + add_native_artifact_if_exists(lines, runtime, "d3d12core.dll") if runtime.startswith("linux"): add_native_artifact_if_exists(lines, runtime, "libonnxruntime-genai.so") From 0dd4572efb90a56151390380c80a753cd6118d11 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Wed, 4 Sep 2024 14:11:00 -0700 Subject: [PATCH 03/13] remove unnecessary --- examples/python/model-generate.py | 24 +--- examples/python/model-qa.py | 35 ++---- src/generators.cpp | 198 ------------------------------ src/generators.h | 38 ------ src/logging.cpp | 4 +- src/logging.h | 2 +- src/models/decoder_only.cpp | 13 +- src/models/decoder_only.h | 10 +- src/models/input_ids.cpp | 1 + src/models/kv_cache.cpp | 36 +++++- src/models/kv_cache.h | 2 + src/models/model.h | 1 - src/models/position_inputs.h | 4 +- src/python/py/models/builder.py | 94 +++++--------- src/python/python.cpp | 40 ------ src/search.cpp | 62 ---------- src/search.h | 21 +--- src/sequences.h | 1 + 18 files changed, 101 insertions(+), 485 deletions(-) diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index d78b8ac6f..0a97f25b4 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -2,15 +2,9 @@ import argparse import time - def main(args): if args.verbose: print("Loading model...") model = og.Model(f'{args.model}') - assistant_model = ( - og.Model(f"{args.assistant_model}") - if hasattr(args, "assistant_model") - else None - ) if args.verbose: print("Model loaded") tokenizer = og.Tokenizer(model) if args.verbose: print("Tokenizer created") @@ -21,13 +15,13 @@ def main(args): prompts = ["I like walking my cute dog", "What is the best restaurant in town?", "Hello, how are you today?"] - + if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") exit(1) prompts[:] = [f'{args.chat_template.format(input=text)}' for text in prompts] - + input_tokens = tokenizer.encode_batch(prompts) if args.verbose: print(f'Prompt(s) encoded: {prompts}') @@ -48,10 +42,7 @@ def main(args): if args.verbose: print("Generating tokens ...\n") start_time = time.time() - if assistant_model is None: - output_tokens = model.generate(params) - else: - output_tokens = model.generate_with_assist(assistant_model, params) + output_tokens = model.generate(params) run_time = time.time() - start_time for i in range(len(prompts)): @@ -65,16 +56,9 @@ def main(args): print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}") print() - if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end token generation loop example for gen-ai") parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)') - parser.add_argument( - "-a", - "--assistant_model", - type=str, - help="Assistant onnx model folder path (must contain config.json and model.onnx)", - ) parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from. Provide this parameter multiple times to batch multiple prompts') parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt') parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt') @@ -88,4 +72,4 @@ def main(args): parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 19ebaae07..4532f307a 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -15,16 +15,10 @@ def main(args): if args.verbose: print("Tokenizer created") if args.verbose: print() - assistant_model = None - if hasattr(args, "assistant_model"): - assistant_model = og.Model(args.assistant_model) - if args.verbose: - print("Assistant model loaded") - search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} if args.verbose: print(search_options) - + if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") @@ -49,16 +43,13 @@ def main(args): params = og.GeneratorParams(model) params.set_search_options(**search_options) params.input_ids = input_tokens - if assistant_model is not None: - generator = og.SpeculativeDecodingGenerator(model, assistant_model, params) - else: - generator = og.Generator(model, params) + generator = og.Generator(model, params) if args.verbose: print("Generator created") if args.verbose: print("Running generation loop ...") if args.timings: first = True - generated_tokens = [] + new_tokens = [] print() print("Output: ", end='', flush=True) @@ -72,11 +63,9 @@ def main(args): first_token_timestamp = time.time() first = False - new_tokens = generator.get_next_tokens() - for new_token in new_tokens: - print(tokenizer_stream.decode(new_token), end="", flush=True) - if args.timings: - generated_tokens.extend(new_tokens) + new_token = generator.get_next_tokens()[0] + print(tokenizer_stream.decode(new_token), end='', flush=True) + if args.timings: new_tokens.append(new_token) except KeyboardInterrupt: print(" --control+c pressed, aborting generation--") print() @@ -88,20 +77,12 @@ def main(args): if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp - print( - f"Prompt length: {len(input_tokens)}, New tokens: {len(generated_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(generated_tokens)/run_time:.2f} tps" - ) + print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)') - parser.add_argument( - "-a", - "--assistant_model", - type=str, - help="Assistant onnx model folder path (must contain config.json and model.onnx)", - ) parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt') parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt') parser.add_argument('-ds', '--do_random_sampling', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false') @@ -113,4 +94,4 @@ def main(args): parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/src/generators.cpp b/src/generators.cpp index 8ada93793..c7f9b0c30 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -118,14 +118,6 @@ std::unique_ptr CreateGenerator(const Model& model, const GeneratorPa return std::make_unique(model, params); } -std::unique_ptr CreateAssistantGenerator(const Model& model, const GeneratorParams& params) { - return std::make_unique(model, params); -} - -std::unique_ptr CreateSpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params) { - return std::make_unique(model, assistant_model, params); -} - std::unique_ptr CreateSearch(const GeneratorParams& params) { #if USE_CUDA if (params.device_type == DeviceType::CUDA) { @@ -141,16 +133,6 @@ std::unique_ptr CreateSearch(const GeneratorParams& params) { return std::make_unique(params); } -std::unique_ptr CreateSpeculativeSearch(const GeneratorParams& params) { -#if USE_CUDA - throw std::runtime_error("Speculative decoding is not supported on CUDA"); -#endif - if (params.search.num_beams > 1) { - throw std::runtime_error("Speculative decoding is not supported with beam search"); - } - return std::make_unique(params); -} - Generator::Generator(const Model& model, const GeneratorParams& params) : model_{model.shared_from_this()} { if (params.search.max_length == 0) throw std::runtime_error("search max_length is 0"); @@ -239,171 +221,6 @@ RoamingArray Generator::GetSequence(size_t index) const { return search_->GetSequence(index); } -AssistantGenerator::AssistantGenerator(const Model& model, const GeneratorParams& params) - : Generator(model, params) { - if (params.search.num_beams != 1) - throw std::runtime_error("AssistantGenerator only supports num_beams=1, got " + std::to_string(params.search.num_beams)); - if (params.batch_size != 1) - throw std::runtime_error("AssistantGenerator only supports batch_size=1, got " + std::to_string(params.batch_size)); - if (params.vocab_size < 1) - throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); - if (params.sequence_length >= params.search.max_length) - throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); - - state_ = std::make_unique( - *std::dynamic_pointer_cast(model_), search_->GetSequenceLengths(), params); -} - -void AssistantGenerator::ComputeLogits() { - if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); - - auto sequence_length = search_->GetSequenceLength(); - auto next_token_length = first_run_in_assist_ ? 2 : 1; - auto past_length = sequence_length - next_token_length; - auto logits = state_->Run(search_->GetSequence(0), next_token_length, past_length, 1); - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); - DumpSpan(stream, logits.GetCPU()); - stream << std::endl; - } - search_->SetLogits(logits); - computed_logits_ = true; - - auto& search = search_->params_->search; - search_->ApplyMinLength(search.min_length); - search_->ApplyRepetitionPenalty(search.repetition_penalty); - first_run_in_assist_ = false; -} - -void AssistantGenerator::GenerateNextToken() { - Generator::GenerateNextToken(); - candidate_length_++; -} - -void AssistantGenerator::AcceptCandidateTokens(RoamingArray next_tokens) { - search_->DropLastTokens(candidate_length_); - search_->SetNextTokens(next_tokens); - candidate_length_ = 0; - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); - stream << SGR::Fg_Green << "assistant sequence: " << SGR::Reset << std::endl; - DumpSpan(stream, search_->GetSequence(0).GetCPU()); - stream << std::endl - << "length: " << search_->GetSequenceLength() << std::endl; - } - first_run_in_assist_ = true; -} - -SpeculativeDecodingGenerator::SpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params) - : assistant_generator_{CreateAssistantGenerator(assistant_model, params)}, - model_{model.shared_from_this()} { - if (params.search.max_length == 0) - throw std::runtime_error("search max_length is 0"); - if (params.search.max_length > model.config_->model.context_length) - throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); - if (params.batch_size != 1) - throw std::runtime_error("batch_size must be 1, is " + std::to_string(params.batch_size)); - if (params.vocab_size < 1) - throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); - if (params.sequence_length >= params.search.max_length) - throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); - if (params.input_ids.empty() || params.input_ids.data() == nullptr) - throw std::runtime_error("input_ids not set in GeneratorParams"); - - if (model.config_->model.type != "llama" && - model.config_->model.type != "gemma" && - model.config_->model.type != "gemma2" && - model.config_->model.type != "mistral" && - model.config_->model.type != "phi" && - model.config_->model.type != "phi3" && - model.config_->model.type != "phi3small" && - model.config_->model.type != "qwen2") - throw std::runtime_error("Speculative decoding is not supported for this model type " + model.config_->model.type); - - search_ = CreateSpeculativeSearch(params); - state_ = std::make_unique( - *std::dynamic_pointer_cast(model_), search_->GetSequenceLengths(), params); -} - -void SpeculativeDecodingGenerator::ComputeLogits() { - if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); - - candidate_length_ = 0; - while (!assistant_generator_->IsDone() && candidate_length_ < max_candidate_length_) { - assistant_generator_->ComputeLogits(); - assistant_generator_->GenerateNextToken(); - candidate_length_++; - } - - auto candidate_sequence = assistant_generator_->search_->GetSequence(0); - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); - stream << SGR::Fg_Green << "candidates from assistant model: " << SGR::Reset << std::endl; - stream << SGR::Fg_Green << "candidate count: " << SGR::Reset << candidate_length_ << std::endl; - DumpSpan(stream, candidate_sequence.GetCPU()); - } - - auto logits = state_->Run(candidate_sequence, candidate_length_ + 1, search_->GetSequenceLength() - 1, candidate_length_ + 1); - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); - stream << SGR::Fg_Green << "produced logits from main model: " << SGR::Reset << std::endl; - } - - search_->SetLogits(logits); - computed_logits_ = true; -} - -void SpeculativeDecodingGenerator::GenerateNextToken() { - if (!computed_logits_) - throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); - computed_logits_ = false; - auto& search = search_->params_->search; - - if (g_log.enabled && g_log.generate_next_token) { - auto& stream = Log("generate_next_token"); - stream << SGR::Fg_Green << "do_sample: " << SGR::Reset << search.do_sample << ' ' - << SGR::Fg_Green << "top_k: " << SGR::Reset << search.top_k << ' ' - << SGR::Fg_Green << "top_p: " << SGR::Reset << search.top_p << ' ' - << SGR::Fg_Green << "temperature: " << SGR::Reset << search.temperature << ' ' - << SGR::Fg_Cyan << "sequence length: " << SGR::Reset << search_->GetSequenceLength() - << std::endl; - } - - if (search.do_sample) - throw std::runtime_error("Not implemented"); - if (search.top_k != 1) - throw std::runtime_error("Not implemented"); - if (search.top_p != 1.0f) - throw std::runtime_error("Not implemented"); - if (search.temperature != 1.0f) - throw std::runtime_error("Not implemented"); - - auto candidate_sequence = assistant_generator_->search_->GetSequence(0); - - // Compare with logits one by one to determine the accepted tokens. - // total new token count is accepted token count + 1. - auto next_tokens = search_->CheckCandidates(candidate_sequence, candidate_length_); - // Update sequence to drop tokens of size candidate_length_, - // and append next tokens. - assistant_generator_->AcceptCandidateTokens(next_tokens); - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); - stream << SGR::Fg_Green << "candidate count: " << SGR::Reset << candidate_length_ << std::endl; - stream << SGR::Fg_Green << "next tokens: " << SGR::Reset; - DumpSpan(stream, next_tokens.GetCPU()); - stream << std::endl; - } -} - -bool SpeculativeDecodingGenerator::IsDone() const { - if (computed_logits_) - throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); - - return search_->IsDone(); -} - TokenSequences Generate(const Model& model, const GeneratorParams& params) { auto generator = CreateGenerator(model, params); @@ -423,19 +240,4 @@ TokenSequences Generate(const Model& model, const GeneratorParams& params) { return result; } -TokenSequences Generate(const Model& model, const Model& assistant_model, const GeneratorParams& params) { - auto generator = CreateSpeculativeDecodingGenerator(model, assistant_model, params); - - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); - } - - // Supports only single batch size, single sequence. - TokenSequences result = {{}}; - auto sequence_cpu = generator->search_->GetSequence(0).GetCPU(); - result[0].assign(sequence_cpu.begin(), sequence_cpu.end()); - return result; -} - } // namespace Generators diff --git a/src/generators.h b/src/generators.h index 26d30e3fc..529e01943 100644 --- a/src/generators.h +++ b/src/generators.h @@ -141,42 +141,6 @@ struct Generator : LeakChecked { bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio }; -struct AssistantGenerator : Generator { - AssistantGenerator(const Model& model, const GeneratorParams& params); - - void ComputeLogits() override; - void GenerateNextToken() override; - - void AcceptCandidateTokens(RoamingArray next_tokens); - RoamingArray GetCandidateTokens() const; - - int candidate_length_{}; // Set to the number of generated candiates in ComputeLogits() and number of selected candidates after GenerateNextTokens(). - int max_candidate_length_{5}; // TODO: Move to param config. - - protected: - void ComputeLogits(RoamingArray next_tokens); - - private: - bool first_run_in_assist_{true}; // Set to false in ComputeLogits() and true after AcceptCandidateTokens(). -}; - -// TODO: Inherit from Generator? -struct SpeculativeDecodingGenerator { - SpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params); - - bool IsDone() const; - void ComputeLogits(); - void GenerateNextToken(); - - std::unique_ptr assistant_generator_; - std::shared_ptr model_; - std::unique_ptr state_; - std::unique_ptr search_; - bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio - int candidate_length_{}; // Set to the number of generated candiates in ComputeLogits() and number of selected candidates after GenerateNextTokens(). - int max_candidate_length_{5}; // TODO: Move to param config. -}; - struct OrtGlobals { OrtGlobals(); @@ -198,9 +162,7 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); -std::unique_ptr CreateSpeculativeDecodingGenerator(const Model& model, const Model& assistant_model, const GeneratorParams& params); std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence -std::vector> Generate(const Model& model, const Model& assistant_model, const GeneratorParams& params); float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction void top_k_indices(std::span top_k, std::span inputs); diff --git a/src/logging.cpp b/src/logging.cpp index fb3229e66..be6589ad7 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -36,8 +36,8 @@ void SetLogBool(std::string_view name, bool value) { g_log.model_output_values = value; else if (name == "model_logits") g_log.model_logits = value; - else if (name == "speculative_decoding") - g_log.speculative_decoding = value; + else if (name == "continuous_decoding") + g_log.continuous_decoding = value; else if (name == "ort_lib") g_log.ort_lib = value; else diff --git a/src/logging.h b/src/logging.h index a29cfdf25..ebd09b1bf 100644 --- a/src/logging.h +++ b/src/logging.h @@ -42,7 +42,7 @@ struct LogItems { bool model_output_shapes{}; // Before the model runs there are only the output shapes, no values in them. Useful for pre Session::Run debugging bool model_output_values{}; // After the model runs the output tensor values can be displayed bool model_logits{}; // Same as model_output_values but only for the logits - bool speculative_decoding{}; // Log speculative decoding steps. + bool continuous_decoding{}; // Log continuous decoding steps. bool ort_lib{}; // Log the onnxruntime library loading and api calls. }; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 899593058..53a7fb9e1 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -26,6 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra } RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { + // TODO(aciddelgado): remove first_run if (!first_run_) { UpdateInputsOutputs(next_tokens, next_indices, current_length); } @@ -43,13 +44,14 @@ void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_to logits_.Update(); } -RoamingArray SpeculativeDecodingDecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { +// TODO(aciddelgado): make general +RoamingArray DecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { int batch_size = static_cast(input_ids_.GetShape()[0]); if (batch_size != 1) throw std::runtime_error("Speculative decoding only supports batch size 1, got " + std::to_string(batch_size)); auto total_length = past_length + next_token_length; - auto total_logits = first_run_ ? total_length : next_token_length; + auto total_logits = first_run_ ? total_length : next_token_length; // TODO(aciddelgado): remove first_run // NB(bowenbao): workaround gqa limitation on token phase. // if (next_token_length > 1) { // total_logits = total_length; @@ -60,12 +62,13 @@ RoamingArray SpeculativeDecodingDecoderOnly_State::Run(RoamingArray& sequence, size_t next_token_length, int past_length) { +void DecoderOnly_State::UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length) { auto total_length = past_length + next_token_length; - if (g_log.enabled && g_log.speculative_decoding) { - auto& stream = Log("speculative_decoding"); + if (g_log.enabled && g_log.continuous_decoding) { + auto& stream = Log("continuous_decoding"); stream << "UpdateInputsOutputsFromSequence: past_length=" << past_length << ", next_token_length=" << next_token_length << ", total_length=" << total_length << std::endl; } + // TODO(aciddelgado): remove first_run if (first_run_) { // First run input ids includes prompt tokens. input_ids_.Update(sequence, 0, total_length); diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index b66ecf91b..67241b2eb 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -19,10 +19,12 @@ struct DecoderOnly_Model : Model { struct DecoderOnly_State : State { DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params); RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; + RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; protected: void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length); + void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); // what this does const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; @@ -34,12 +36,4 @@ struct DecoderOnly_State : State { ExtraInputs extra_inputs_{model_, *this}; }; -struct SpeculativeDecodingDecoderOnly_State : DecoderOnly_State { - SpeculativeDecodingDecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params) : DecoderOnly_State{model, sequence_lengths, params} {}; - RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; - - protected: - void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); -}; - } // namespace Generators diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 7b23689c5..9ac973029 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -126,6 +126,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { } } +// TODO(aciddelgado): Is this ok? add cuda support void InputIDs::Update(RoamingArray next_tokens, size_t start, size_t token_count) { switch (model_.device_type_) { case DeviceType::CPU: { diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 2c69c3cc9..b444fc8e0 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -218,7 +218,7 @@ void KV_Cache::UpdatePresent(int current_length) { // This can be later refactored to merge with tensor allocation during initialization. if (shape_[2] == current_length) return; - shape_[2] = current_length; + shape_[2] = current_length; // TODO(aciddelgado): is it ok to set this if past_present_share_buffer_ is true? // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; @@ -261,6 +261,40 @@ void KV_Cache::UpdateAndResize(int current_length, int past_length) { Update({}, current_length); } +// TODO(aciddelgado): RewindTo function +// void KV_Cache::RewindTo(int new_length) { +// // If we're sharing past & present buffers there is nothing to do here, so early exit +// if (past_present_share_buffer_) +// return; +// if (shape_[0] != 1) +// throw std::runtime_error("KV_Cache::RewindTo(int new_length) only supports batch size 1, got " + std::to_string(shape_[0])); +// if (model_.device_type_ != DeviceType::CPU) +// throw std::runtime_error("KV_Cache::RewindTo(int new_length) only supports CPU"); + +// auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); +// auto element_size = SizeOf(element_type); +// auto new_shape = std::array({1, shape_[1], new_length, shape_[3]}); +// if (shape_[2] != new_length) { +// for (int i = 0; i < layer_count_ * 2; i++) { +// auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); +// const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); +// auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); + +// // Copy new_length kv-cache +// for (int j = 0; j < shape_[1]; j++) { +// memcpy( +// new_present_data + j * new_length * shape_[3] * element_size, +// present_data + j * shape_[2] * shape_[3] * element_size, +// new_length * shape_[3] * element_size); +// } + +// presents_[i] = std::move(new_present); +// } +// } + +// shape_[2] = new_length; +// } + // Copy present state to past state reordered by the beam_indices template void KV_Cache::PickPastState(std::span beam_indices, int index) { diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index ae7b57547..7b2abbbc7 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -42,6 +42,8 @@ struct KV_Cache { void UpdatePresent(int current_length); // Resize past to new sequence length, and drop past that is > past_length. void UpdateAndResize(int current_length, int past_length); + // Rewind cache to new_length. + // void RewindTo(int new_length); template void PickPastState(std::span beam_indices, int index); void PickPastState(std::span beam_indices, int index); diff --git a/src/models/model.h b/src/models/model.h index 8bfccb8da..a9a256962 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -33,7 +33,6 @@ struct State { OrtValue* GetOutput(const char* name); - // Used by speculative search virtual RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { throw std::runtime_error("Not implemented"); }; std::shared_ptr params_; diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index dcbc14d7e..10ceaa27d 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -22,7 +22,7 @@ struct PositionInputs { void UpdatePositionIDs(int current_length); void UpdateAttentionMask(int current_length); - // Used by speculative decoding. + // Used by continuous decoding. void UpdatePositionIDs(int current_length, int past_length); void UpdateAttentionMask(int current_length, int past_length); @@ -34,7 +34,7 @@ struct PositionInputs { template void UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length); - // Used by speculative decoding + // Used by continuous decoding template void UpdatePositionIDsImpl(int current_length, int past_length); template diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 3e2e00b02..fb6523b33 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -605,10 +605,10 @@ def make_less(self, name, inputs): self.make_node("Less", inputs=inputs, outputs=[output], name=name) self.make_value_info(output, TensorProto.BOOL, shape=None) - def make_range(self, name, inputs, shape): + def make_range(self, name, inputs): output = f"{name}/output_0" self.make_node("Range", inputs=inputs, outputs=[output], name=name) - self.make_value_info(output, TensorProto.INT64, shape=shape) + self.make_value_info(output, TensorProto.INT64, shape=["unk"]) def make_slice(self, name, inputs, dtype, shape): output = f"{name}/output_0" @@ -635,18 +635,6 @@ def make_tanh(self, name, root_input, dtype, shape): self.make_node("Tanh", inputs=[root_input], outputs=[output], name=name) self.make_value_info(output, dtype, shape=shape) - def make_trilu(self, name, inputs, upper: int, dtype, shape): - output = f"{name}/output_0" - self.make_node( - "Trilu", - inputs=inputs, - outputs=[output], - name=name, - upper=upper, - domain="com.microsoft", - ) - self.make_value_info(output, dtype, shape=shape) - def make_matmul(self, matmul, basename, root_input, **kwargs): if self.onnx_dtype in {"fp16", "fp32"}: return self.make_matmul_fp16_or_fp32(matmul, basename, root_input, **kwargs) @@ -1821,79 +1809,61 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): unsqueeze_6_name = f"{basename}/Unsqueeze_6" # shared unsqueeze for input_ids and attention_mask self.make_unsqueeze(unsqueeze_6_name, unsqueeze_inputs, dtype=TensorProto.INT64, shape=[1]) concat_2_name = f"{basename}/Concat_2" - concat_inputs = [f"{unsqueeze_4_name}/output_0", f"{unsqueeze_3_name}/output_0"] + concat_inputs = [f"{unsqueeze_4_name}/output_0", f"{unsqueeze_5_name}/output_0"] self.make_concat(concat_2_name, concat_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) constant_shape_name = f"{basename}/ConstantOfShape_2" constant_shape_numpy_dtype = self.to_numpy_dtype[self.io_dtype] constant_shape_value = numpy_helper.from_array(np.array([np.finfo(constant_shape_numpy_dtype).min], dtype=constant_shape_numpy_dtype)) - self.make_constant_of_shape( - constant_shape_name, - f"{concat_2_name}/output_0", - value=constant_shape_value, - dtype=self.io_dtype, - shape=["sequence_length", "total_sequence_length"], - ) + self.make_constant_of_shape(constant_shape_name, f"{concat_2_name}/output_0", value=constant_shape_value, dtype=self.io_dtype, shape=['unk', 'unk']) # Top path + shape_4_name = f"{basename}/Shape_4" + self.make_shape(shape_4_name, f"{constant_shape_name}/output_0", shape=[2]) + slice_1_name = f"{basename}/Slice_1" + slice_1_inputs = [f"{shape_4_name}/output_0", "/model/constants/TensorProto.INT64/1D/-1", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/0"] + self.make_slice(slice_1_name, slice_1_inputs, dtype=TensorProto.INT64, shape=[1]) + squeeze_1_name = f"{basename}/Squeeze_1" + squeeze_1_inputs = [f"{slice_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_squeeze(squeeze_1_name, squeeze_1_inputs) + unsqueeze_7_name = f"{basename}/output_0" + unsqueeze_7_inputs = [f"{squeeze_1_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_unsqueeze(unsqueeze_7_name, unsqueeze_7_inputs, dtype=TensorProto.INT64, shape=[1]) concat_3_name = f"{basename}/Concat_3" - concat_3_inputs = [ - f"{unsqueeze_4_name}/output_0", - "/model/constants/TensorProto.INT64/1D/1", - ] + concat_3_inputs = [f"{unsqueeze_7_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] self.make_concat(concat_3_name, concat_3_inputs, dtype=TensorProto.INT64, shape=[2], axis=0) # Bottom path + shape_5_name = f"{basename}/Shape_5" + self.make_shape(shape_5_name, f"{constant_shape_name}/output_0", shape=[2]) + slice_2_name = f"{basename}/Slice_2" + slice_2_inputs = [f"{shape_5_name}/output_0", "/model/constants/TensorProto.INT64/1D/-1", f"/model/constants/TensorProto.INT64/1D/{np.iinfo(np.int64).max}", "/model/constants/TensorProto.INT64/1D/0"] + self.make_slice(slice_2_name, slice_2_inputs, dtype=TensorProto.INT64, shape=[1]) + squeeze_2_name = f"{basename}/Squeeze_2" + squeeze_2_inputs = [f"{slice_2_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] + self.make_squeeze(squeeze_2_name, squeeze_2_inputs) range_name = f"{basename}/Range" - range_inputs = [ - "/model/constants/TensorProto.INT64/0D/0", - f"{basename}/Gather_2/output_0", - "/model/constants/TensorProto.INT64/0D/1", - ] - self.make_range(range_name, range_inputs, shape=["sequence_length"]) + range_inputs = ["/model/constants/TensorProto.INT64/0D/0", f"{squeeze_2_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_range(range_name, range_inputs) add_2_name = f"{basename}/Add_2" - add_inputs = [f"{range_name}/output_0", f"{past_key_gather_name}/output_0"] - self.make_add( - add_2_name, add_inputs, dtype=TensorProto.INT64, shape=["sequence_length"] - ) - range_2_name = f"{basename}/Range_2" - range_2_inputs = [ - "/model/constants/TensorProto.INT64/0D/0", - f"{shared_add_name}/output_0", - "/model/constants/TensorProto.INT64/0D/1", - ] - self.make_range(range_2_name, range_2_inputs, shape=["total_sequence_length"]) + add_inputs = [f"{range_name}/output_0", "/model/constants/TensorProto.INT64/0D/1"] + self.make_add(add_2_name, add_inputs, dtype=TensorProto.INT64, shape=["unk"]) # Merged path reshape_name = f"{basename}/Reshape" reshape_inputs = [f"{add_2_name}/output_0", f"{concat_3_name}/output_0"] self.make_reshape(reshape_name, reshape_inputs, dtype=TensorProto.INT64, shape=None) less_name = f"{basename}/Less" - less_inputs = [f"{reshape_name}/output_0", f"{range_2_name}/output_0"] + less_inputs = [f"{range_name}/output_0", f"{reshape_name}/output_0"] self.make_less(less_name, less_inputs) where_2_name = f"{basename}/Where_2" - where_2_inputs = [ - f"{less_name}/output_0", - f"{constant_shape_name}/output_0", - f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/0", - ] + where_2_inputs = [f"{less_name}/output_0", f"/model/constants/{self.to_str_dtype[self.io_dtype]}/0D/0", f"{constant_shape_name}/output_0"] self.make_where(where_2_name, where_2_inputs, dtype=self.io_dtype, shape=None) - unsqueeze_8_name = f"{basename}/Unsqueeze_8" unsqueeze_8_inputs = [f"{where_2_name}/output_0", "/model/constants/TensorProto.INT64/1D/0"] - self.make_unsqueeze( - unsqueeze_8_name, - unsqueeze_8_inputs, - dtype=self.io_dtype, - shape=[1, "sequence_length", "total_sequence_length"], - ) + self.make_unsqueeze(unsqueeze_8_name, unsqueeze_8_inputs, dtype=self.io_dtype, shape=None) unsqueeze_9_name = f"{basename}/Unsqueeze_9" unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] - self.make_unsqueeze( - unsqueeze_9_name, - unsqueeze_9_inputs, - dtype=self.io_dtype, - shape=[1, 1, "sequence_length", "total_sequence_length"], - ) + self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) return unsqueeze_6_name, expand_name diff --git a/src/python/python.cpp b/src/python/python.cpp index 229fb90d3..1340ef16c 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -327,38 +327,6 @@ struct PyGenerator { PyRoamingArray py_sequencelengths_; }; -// TODO(bowenbao): merge with PyGenerator? -struct PySpeculativeDecodingGenerator { - PySpeculativeDecodingGenerator(Model& model, Model& assistant_model, PyGeneratorParams& params) { - params.Prepare(); - generator_ = CreateSpeculativeDecodingGenerator(model, assistant_model, params); - } - - pybind11::array_t GetNextTokens() { - py_tokens_.Assign(generator_->search_->GetNextTokens()); - return ToPython(py_tokens_.GetCPU()); - } - - void ComputeLogits() { - generator_->ComputeLogits(); - } - - void GenerateNextToken() { - generator_->GenerateNextToken(); - } - - bool IsDone() const { - return generator_->IsDone(); - } - - private: - std::unique_ptr generator_; - PyRoamingArray py_tokens_; - PyRoamingArray py_indices_; - PyRoamingArray py_sequence_; - PyRoamingArray py_sequencelengths_; -}; - void SetLogOptions(const pybind11::kwargs& dict) { for (auto& entry : dict) { auto name = entry.first.cast(); @@ -447,7 +415,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { return CreateModel(GetOrtEnv(), config_path.c_str()); })) .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) - .def("generate_with_assist", [](Model& model, const Model& assistant_model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, assistant_model, params); }) .def_property_readonly( "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); @@ -461,13 +428,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); - pybind11::class_(m, "SpeculativeDecodingGenerator") - .def(pybind11::init()) - .def("is_done", &PySpeculativeDecodingGenerator::IsDone) - .def("compute_logits", &PySpeculativeDecodingGenerator::ComputeLogits) - .def("generate_next_token", &PySpeculativeDecodingGenerator::GenerateNextToken) - .def("get_next_tokens", &PySpeculativeDecodingGenerator::GetNextTokens); - pybind11::class_(m, "Images") .def_static("open", [](pybind11::args image_paths) { if (image_paths.empty()) diff --git a/src/search.cpp b/src/search.cpp index 382159e1d..d4754e57e 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -49,10 +49,6 @@ RoamingArray GreedySearch_Cpu::GetNextTokens() { return next_tokens_; } -RoamingArray SpeculativeGreedySearch_Cpu::GetNextTokens() { - return next_accepted_tokens_; -} - RoamingArray BeamSearch_Cpu::GetNextTokens() { return beam_scorer_->GetNextTokens(); } @@ -288,64 +284,6 @@ void GreedySearch_Cpu::DropLastTokens(size_t num_tokens) { sequences_.DropLastTokens({num_tokens}); } -RoamingArray SpeculativeGreedySearch_Cpu::CheckCandidates(RoamingArray sequence, int candidate_length) { - if (params_->batch_size != 1) - throw std::runtime_error("Speculative search only supports batch size 1"); - auto sequence_cpu = sequence.GetCPU(); - auto prev_sequence_length = sequence_cpu.size() - candidate_length; - auto candidate_tokens_cpu = sequence.GetCPU().subspan(prev_sequence_length, candidate_length); - int logit_index = 0; - for (; logit_index < candidate_length + 1; logit_index++) { - ApplyMinLength(params_->search.min_length, logit_index); - ApplyRepetitionPenalty(params_->search.repetition_penalty, logit_index); - std::span const scores = next_token_scores_.subspan(logit_index * params_->vocab_size, params_->vocab_size); - - if (g_log.enabled && g_log.model_logits) { - auto& stream = Log("speculative_decoding"); - stream << "model_logits of logit_index=" << logit_index << std::endl; - DumpSpan(stream, scores); - stream << std::endl; - } - - auto const token = static_cast(std::distance(scores.begin(), std::max_element(scores.begin(), scores.end()))); - SetNextToken(0, token); - AppendNextTokensToSequences(); - if (done_ || logit_index == candidate_length || candidate_tokens_cpu[logit_index] != token) { - break; - } - } - auto next_tokens = sequences_.GetSequence(0).subspan(prev_sequence_length, logit_index + 1); - next_accepted_tokens_ = cpu_span{next_tokens.data(), next_tokens.size()}; - return next_accepted_tokens_; -} - -void SpeculativeGreedySearch_Cpu::ApplyMinLength(int min_length, size_t token_idx) { - if (sequences_.GetSequenceLength() >= min_length) { - return; - } - - std::span const scores = next_token_scores_.subspan(token_idx * params_->vocab_size, params_->vocab_size); - scores[params_->eos_token_id] = std::numeric_limits::lowest(); -} - -void SpeculativeGreedySearch_Cpu::ApplyRepetitionPenalty(float penalty, size_t token_idx) { - if (penalty == 1.0f) - return; - - std::span const scores = next_token_scores_.subspan(token_idx * params_->vocab_size, params_->vocab_size); - std::span const sequence = sequences_.GetSequence(token_idx); - - std::unordered_set unique_word_ids; - for (const auto& word_id : sequence) { - unique_word_ids.insert(word_id); - } - - for (const int32_t word_id : unique_word_ids) { - float const score = scores[word_id]; - scores[word_id] = (score < 0 ? score * penalty : score / penalty); - } -} - bool BeamSearch_Cpu::IsDone() const { if (beam_scorer_->IsDone()) { return true; diff --git a/src/search.h b/src/search.h index 730159d03..368dd417d 100644 --- a/src/search.h +++ b/src/search.h @@ -27,10 +27,9 @@ struct Search : LeakChecked { virtual void ApplyMinLength(int min_length) = 0; virtual void ApplyRepetitionPenalty(float penalty) = 0; - // Used by Speculative search + // Used by Continuous Decoding virtual void DropLastTokens(size_t num_tokens) { assert(false); }; virtual void SetNextTokens(RoamingArray next_tokens) { assert(false); }; - virtual RoamingArray CheckCandidates(RoamingArray sequence, int candidate_length) { assert(false); }; std::shared_ptr params_; }; @@ -56,7 +55,7 @@ struct Search_Cpu : Search { cpu_span next_tokens_; // shape (beam_size*batch_size) - std::span next_token_scores_; // shape (beam_size*batch_size, vocab_size) or shape(candidate_tokens_count, vocab_size) for speculative search + std::span next_token_scores_; // shape (beam_size*batch_size, vocab_size) Sequences sequences_; bool done_{}; @@ -73,7 +72,7 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopP(float p, float temperature) override; void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; - // Used by Speculative search. + // Used by continuous decoding search. void SetNextTokens(RoamingArray next_tokens) override; void DropLastTokens(size_t num_tokens) override; @@ -117,18 +116,4 @@ struct BeamSearch_Cpu : Search_Cpu { std::unique_ptr beam_scorer_; }; -struct SpeculativeGreedySearch_Cpu : GreedySearch_Cpu { - SpeculativeGreedySearch_Cpu(const GeneratorParams& params) : GreedySearch_Cpu(params) {}; - RoamingArray CheckCandidates(RoamingArray sequence, int candidate_length); - - RoamingArray GetNextTokens() override; - - protected: - void ApplyMinLength(int min_length, size_t token_idx); - void ApplyRepetitionPenalty(float penalty, size_t token_idx); - - private: - cpu_span next_accepted_tokens_; // shape(accepted_token_counts) for speculative search -}; - } // namespace Generators \ No newline at end of file diff --git a/src/sequences.h b/src/sequences.h index 4b45ecbf3..bdf13caa5 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -19,6 +19,7 @@ struct Sequences { // Used by Greedy search: void AppendNextTokenToSequences(std::span next_tokens); + // TODO(aciddelgado): Rewind sequences function // Used by Speculative search: void DropLastTokens(size_t num_tokens); From 102933317e888426dc1b1b8a42bde509c3e26c60 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 17 Sep 2024 14:36:21 -0700 Subject: [PATCH 04/13] so ryan can see changes --- src/beam_search_scorer.cpp | 5 +- src/beam_search_scorer_cuda.cpp | 4 +- src/generators.cpp | 59 ++- src/generators.h | 11 +- src/logging.h | 1 + src/models/decoder_only.cpp | 104 ++--- src/models/decoder_only.h | 17 +- src/models/embeddings.cpp | 3 +- src/models/gpt.cpp | 11 +- src/models/gpt.h | 2 +- src/models/input_ids.cpp | 141 ++++--- src/models/input_ids.h | 4 +- src/models/kernels.cu | 53 +++ src/models/kernels.h | 4 + src/models/kv_cache.cpp | 140 +++---- src/models/kv_cache.h | 8 +- src/models/logits.cpp | 260 ++++++++++--- src/models/logits.h | 5 +- src/models/model.h | 11 +- src/models/multi_modal_vision_model.cpp | 14 +- src/models/multi_modal_vision_model.h | 2 +- src/models/position_inputs.cpp | 486 ++++++++++++++++++++---- src/models/position_inputs.h | 29 +- src/models/whisper.cpp | 5 +- src/models/whisper.h | 1 + src/ort_genai.h | 17 +- src/ort_genai_c.cpp | 87 +++-- src/ort_genai_c.h | 16 +- src/search.cpp | 2 +- src/search_cuda.cpp | 2 +- src/sequences.cpp | 16 + src/sequences.h | 1 + src/sequences_cuda.cpp | 13 +- src/sequences_cuda.h | 2 +- test/c_api_tests.cpp | 99 ++--- 35 files changed, 1187 insertions(+), 448 deletions(-) diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp index ec9760377..a10a496e1 100644 --- a/src/beam_search_scorer.cpp +++ b/src/beam_search_scorer.cpp @@ -65,7 +65,10 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) next_beam_indices_ptr_ = AllocateArray(batch_beam_size, &next_beam_indices_); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + // TODO(aciddelgado): Initialize in first update function type thing. + // size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2; + hypothesis_buffer_ptr_ = AllocateArray(batch_beam_size * per_beam, &hypothesis_buffer_); memset(next_beam_scores_.data(), 0, next_beam_scores_.size_bytes()); diff --git a/src/beam_search_scorer_cuda.cpp b/src/beam_search_scorer_cuda.cpp index 4c48ed82a..695dda7cf 100644 --- a/src/beam_search_scorer_cuda.cpp +++ b/src/beam_search_scorer_cuda.cpp @@ -37,7 +37,9 @@ BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters) cuda::LaunchInitScoresKernel(next_beam_scores_.data(), parameters.batch_size, parameters.search.num_beams, stream_); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + // TODO(aciddelgado): Initialize in first update function type thing. + // size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; + size_t per_beam = (state_cpu_->max_length_ * (state_cpu_->max_length_ + 1)) / 2; hypothesis_buffer_ptr_ = CudaMallocArray(batch_beam_size * per_beam, &hypothesis_buffer_); } diff --git a/src/generators.cpp b/src/generators.cpp index c7f9b0c30..554f4bba9 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -101,10 +101,10 @@ void GeneratorParams::TryGraphCapture(int max_bs) { void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { for (const auto& [name, tensor] : named_tensors) { if (name == Config::Defaults::InputIdsName) { - input_ids = std::span(tensor->ort_tensor_->GetTensorMutableData(), - tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetElementCount()); - batch_size = static_cast(tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]); - sequence_length = static_cast(input_ids.size()) / batch_size; + // input_ids = std::span(tensor->ort_tensor_->GetTensorMutableData(), + // tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetElementCount()); + // batch_size = static_cast(tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]); + // sequence_length = static_cast(input_ids.size()) / batch_size; } else { // If the nominal name is found in the map, use the graph name. // Else, use the nominal name as the graph name. @@ -142,20 +142,39 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size)); if (params.vocab_size < 1) throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.vocab_size)); - if (params.sequence_length >= params.search.max_length) - throw std::runtime_error("input sequence_length (" + std::to_string(params.sequence_length) + ") is >= max_length (" + std::to_string(params.search.max_length) + ")"); - if (params.input_ids.empty() || params.input_ids.data() == nullptr) - throw std::runtime_error("input_ids not set in GeneratorParams"); search_ = CreateSearch(params); - state_ = model.CreateState(search_->GetSequenceLengths(), params); + state_ = model.CreateState(search_->GetSequenceLengths(), params); // Search sequence lengths set when creating state } -void Generator::ComputeLogits() { +// void Generator::AddInput(const std::string& name, const std::shared_ptr& tensor) { +// search_->AddInput(name, tensor); +// } + +void Generator::AddTokens(cpu_span input_ids) { + // TODO(aciddelgado): check for first call after reset + search_->SetNextTokens(input_ids); + // state_->AddInputTokens(input_ids); // Do this in Run instead + + if (g_log.enabled && g_log.add_tokens) { + auto& stream = Log("add_tokens"); + stream << "input_ids: "; + for (auto token : input_ids) { + stream << token << ' '; + } + stream << std::endl; + } + + computed_logits_ = false; + ComputeLogits(input_ids); +} + +void Generator::ComputeLogits(const RoamingArray& next_tokens) { if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); + throw std::runtime_error("ComputeLogits called again without calling AddTokens or GenerateNextToken first"); - auto logits = state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()); + // auto logits = state_->Run(candidate_sequence, candidate_length_ + 1, search_->GetSequenceLength() - 1, candidate_length_ + 1); + auto logits = state_->Run(search_->GetSequenceLength(), next_tokens, search_->GetNextIndices()); if (g_log.enabled && g_log.model_logits) { auto& stream = Log("model_logits"); DumpSpan(stream, logits.GetCPU()); @@ -164,9 +183,9 @@ void Generator::ComputeLogits() { search_->SetLogits(logits); computed_logits_ = true; - auto& search = search_->params_->search; - search_->ApplyMinLength(search.min_length); - search_->ApplyRepetitionPenalty(search.repetition_penalty); + // auto& search = search_->params_->search; + // search_->ApplyMinLength(search.min_length); + // search_->ApplyRepetitionPenalty(search.repetition_penalty); } bool Generator::IsDone() const { @@ -177,10 +196,14 @@ bool Generator::IsDone() const { } void Generator::GenerateNextToken() { - if (!computed_logits_) - throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); + // TODO(aciddelgado): check that AddTokens has been called + if (!computed_logits_) { + ComputeLogits(search_->GetNextTokens()); + } computed_logits_ = false; auto& search = search_->params_->search; + search_->ApplyMinLength(search.min_length); + search_->ApplyRepetitionPenalty(search.repetition_penalty); if (g_log.enabled && g_log.generate_next_token) { auto& stream = Log("generate_next_token"); @@ -223,9 +246,9 @@ RoamingArray Generator::GetSequence(size_t index) const { TokenSequences Generate(const Model& model, const GeneratorParams& params) { auto generator = CreateGenerator(model, params); + // generator->AddTokens(params.search.input_ids); while (!generator->IsDone()) { - generator->ComputeLogits(); generator->GenerateNextToken(); } diff --git a/src/generators.h b/src/generators.h index 529e01943..f46c39d04 100644 --- a/src/generators.h +++ b/src/generators.h @@ -71,7 +71,7 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; - int sequence_length{}; + // int sequence_length{}; int hidden_size{}; int BatchBeamSize() const { return search.num_beams * batch_size; } @@ -96,7 +96,7 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec #endif // TODO: Move this to a separate GPT struct - std::span input_ids; // Array of [batchsize][sequence_length] + // std::span input_ids; // Array of [batchsize][sequence_length] struct Whisper { std::shared_ptr input_features; // float32 [batch_size, number_of_mels, something that is 3000] @@ -130,7 +130,9 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - virtual void ComputeLogits(); + // virtual void ComputeLogits(); + // TODO(aciddelgado): Make this function work with batched inputs + virtual void AddTokens(cpu_span input_ids); // Add tokens to the input_ids virtual void GenerateNextToken(); RoamingArray GetSequence(size_t index) const; @@ -139,6 +141,9 @@ struct Generator : LeakChecked { std::unique_ptr state_; std::unique_ptr search_; bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio + + private: + void ComputeLogits(const RoamingArray& next_tokens); }; struct OrtGlobals { diff --git a/src/logging.h b/src/logging.h index ebd09b1bf..e5f894131 100644 --- a/src/logging.h +++ b/src/logging.h @@ -43,6 +43,7 @@ struct LogItems { bool model_output_values{}; // After the model runs the output tensor values can be displayed bool model_logits{}; // Same as model_output_values but only for the logits bool continuous_decoding{}; // Log continuous decoding steps. + bool add_tokens{}; // Log the addition of tokens to the input. bool ort_lib{}; // Log the onnxruntime library loading and api calls. }; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 53a7fb9e1..b18944dcf 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -9,8 +9,8 @@ DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort InitDeviceAllocator(*session_decoder_); } -std::unique_ptr DecoderOnly_Model::CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const { - return std::make_unique(*this, sequence_lengths, params); +std::unique_ptr DecoderOnly_Model::CreateState(RoamingArray sequence_lengths_unk, const GeneratorParams& params) const { + return std::make_unique(*this, sequence_lengths_unk, params); } DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params) @@ -25,63 +25,71 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra extra_inputs_.Add(); } -RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - // TODO(aciddelgado): remove first_run - if (!first_run_) { - UpdateInputsOutputs(next_tokens, next_indices, current_length); - } +// void DecoderOnly_State::AddInputTokens(const RoamingArray& tokens) { +// input_ids_.AddInputTokens(tokens, reset_input_); +// reset_input_ = false; +// } + +RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) { + // if (!first_run_) { + UpdateInputsOutputs(next_tokens, next_indices, total_length); + // } int batch_size = static_cast(input_ids_.GetShape()[0]); State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); - + reset_input_ = true; + return logits_.Get(); } -void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, RoamingArray beam_indices, int current_length) { +void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { input_ids_.Update(next_tokens_unk); - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(new_length); } -// TODO(aciddelgado): make general -RoamingArray DecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { - int batch_size = static_cast(input_ids_.GetShape()[0]); - if (batch_size != 1) - throw std::runtime_error("Speculative decoding only supports batch size 1, got " + std::to_string(batch_size)); +// TODO(aciddelgado): Transition into a new paradigm +// RoamingArray DecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { +// int batch_size = static_cast(input_ids_.GetShape()[0]); +// if (batch_size != 1) +// throw std::runtime_error("Speculative decoding only supports batch size 1, got " + std::to_string(batch_size)); - auto total_length = past_length + next_token_length; - auto total_logits = first_run_ ? total_length : next_token_length; // TODO(aciddelgado): remove first_run - // NB(bowenbao): workaround gqa limitation on token phase. - // if (next_token_length > 1) { - // total_logits = total_length; - // } - UpdateInputsOutputsFromSequence(sequence, next_token_length, past_length); - State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); +// auto total_length = past_length + next_token_length; +// auto total_logits = first_run_ ? total_length : next_token_length; // TODO(aciddelgado): remove first_run +// // NB(bowenbao): workaround gqa limitation on token phase. +// // if (next_token_length > 1) { +// // total_logits = total_length; +// // } +// UpdateInputsOutputsFromSequence(sequence, next_token_length, past_length); +// State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); +// reset_input_ = true; - return logits_.Get(total_logits - return_last_logit_count, return_last_logit_count); -} +// return logits_.Get(total_logits - return_last_logit_count, return_last_logit_count); +// } -void DecoderOnly_State::UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length) { - auto total_length = past_length + next_token_length; - if (g_log.enabled && g_log.continuous_decoding) { - auto& stream = Log("continuous_decoding"); - stream << "UpdateInputsOutputsFromSequence: past_length=" << past_length << ", next_token_length=" << next_token_length << ", total_length=" << total_length << std::endl; - } - // TODO(aciddelgado): remove first_run - if (first_run_) { - // First run input ids includes prompt tokens. - input_ids_.Update(sequence, 0, total_length); - position_inputs_.Update(total_length, 0); - kv_cache_.UpdatePresent(total_length); - logits_.Update(total_length); - } else { - // Subsequent runs input ids only include candidate tokens. - input_ids_.Update(sequence, past_length, next_token_length); - position_inputs_.Update(total_length, past_length); - kv_cache_.UpdateAndResize(total_length, past_length); - logits_.Update(next_token_length); - } -} +// TODO(aciddelgado): update should append, not replace for input_ids. ensure correct next and past lengths +// void DecoderOnly_State::UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length) { +// auto total_length = past_length + next_token_length; +// if (g_log.enabled && g_log.continuous_decoding) { +// auto& stream = Log("continuous_decoding"); +// stream << "UpdateInputsOutputsFromSequence: past_length=" << past_length << ", next_token_length=" << next_token_length << ", total_length=" << total_length << std::endl; +// } +// // TODO(aciddelgado): remove first_run +// if (first_run_) { +// // First run input ids includes prompt tokens. +// input_ids_.Update(sequence, 0, total_length); +// position_inputs_.Update(total_length, 0); +// kv_cache_.UpdatePresent(total_length); +// logits_.Update(); +// } else { +// // Subsequent runs input ids only include candidate tokens. +// input_ids_.Update(sequence, past_length, next_token_length); +// position_inputs_.Update(total_length, past_length); +// kv_cache_.UpdateAndResize(total_length, past_length); +// logits_.Update(); +// } +// } } // namespace Generators diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 67241b2eb..43bfb3449 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -1,6 +1,6 @@ #pragma once #include "model.h" -#include "input_ids.h" +// #include "input_ids.h" #include "logits.h" #include "kv_cache.h" #include "position_inputs.h" @@ -11,15 +11,18 @@ namespace Generators { struct DecoderOnly_Model : Model { DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env); - std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const override; + std::unique_ptr CreateState(RoamingArray sequence_lengths_unk, const GeneratorParams& params) const override; + // std::unique_ptr CreateState(const GeneratorParams& params) const override; std::unique_ptr session_decoder_; }; struct DecoderOnly_State : State { - DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths, const GeneratorParams& params); - RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; - RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; + DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params); + // DecoderOnly_State(const DecoderOnly_Model& model, const GeneratorParams& params); + RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) override; + // RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; + // void AddInputTokens(const RoamingArray& tokens) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; protected: @@ -29,11 +32,13 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - InputIDs input_ids_{model_, *this}; + // InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache kv_cache_{model_, *this}; PositionInputs position_inputs_; ExtraInputs extra_inputs_{model_, *this}; + + bool reset_input_{true}; }; } // namespace Generators diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index 10508b89f..a25f2e115 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -7,11 +7,12 @@ namespace Generators { +// TODO(aciddelgado): get this right what is this Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, const std::string& name) : model_{model}, state_{state}, shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, - state_.params_->sequence_length, state_.params_->hidden_size}, + 0, state_.params_->hidden_size}, type_{mode == Embeddings::Mode::Input ? model_.session_info_->GetInputDataType(name) : model_.session_info_->GetOutputDataType(name)}, diff --git a/src/models/gpt.cpp b/src/models/gpt.cpp index 275708038..e2ea18249 100644 --- a/src/models/gpt.cpp +++ b/src/models/gpt.cpp @@ -35,11 +35,12 @@ RoamingArray Gpt_State::Run(int current_length, RoamingArray nex return logits_.Get(); } -void Gpt_State::UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { - input_ids_.Update(next_tokens); - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void Gpt_State::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { + input_ids_.Update(next_tokens_unk); + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(new_length); } } // namespace Generators diff --git a/src/models/gpt.h b/src/models/gpt.h index 1363e952f..265d379ea 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -25,7 +25,7 @@ struct Gpt_State : State { const Gpt_Model& model_; - InputIDs input_ids_{model_, *this}; + // InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache_Combined kv_cache_{model_, *this}; PositionInputs position_inputs_; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 9ac973029..fd72df177 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -5,28 +5,31 @@ namespace Generators { +// NOW IS 0-INITIALIZED InputIDs::InputIDs(const Model& model, State& state) : model_{model}, state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - shape_ = {state_.params_->batch_size, state_.params_->sequence_length}; + shape_ = {state_.params_->search.num_beams * state_.params_->batch_size, 0}; type_ = model_.session_info_->GetInputDataType(name_); // If 64-bit, convert from 32-bit to 64-bit - if (type_ == Ort::TypeToTensorType) { - value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); - auto* p_data = value_->GetTensorMutableData(); - for (auto v : state_.params_->input_ids) { - *p_data++ = v; - } - } else { - if (type_ != Ort::TypeToTensorType) - throw std::runtime_error("InputIDs must be int64 or int32"); - value_ = OrtValue::CreateTensor(model.allocator_cpu_.GetInfo(), std::span(const_cast(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_); - } - - value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); - shape_[0] *= state_.params_->search.num_beams; + // if (type_ == Ort::TypeToTensorType) { + // value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); + // auto* p_data = value_->GetTensorMutableData(); + // for (auto v : state_.params_->input_ids) { + // *p_data++ = v; + // } + // } else { + // if (type_ != Ort::TypeToTensorType) + // throw std::runtime_error("InputIDs must be int64 or int32"); + // value_ = OrtValue::CreateTensor(model.allocator_cpu_.GetInfo(), std::span(const_cast(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_); + // } + + // value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); // TODO(aciddelgado): 0 initializing tensors allowed? + + // value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); + // shape_[0] *= state_.params_->search.num_beams; if (state_.GetCapturedGraphInfo()) { sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); @@ -46,10 +49,17 @@ void InputIDs::Add() { state_.input_names_.push_back(name_); } -void InputIDs::Update(RoamingArray next_tokens_unk) { - // Resize input_ids shape once if it doesn't match the decoder shape - if (shape_[1] != 1) { - shape_[1] = 1; +void InputIDs::Update(RoamingArray new_tokens) { + // // Resize input_ids shape once if it doesn't match the decoder shape + // if (shape_[1] != 1) { + // shape_[1] = 1; + // if (!sb_input_ids_) { + // value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + + // Resize input_ids shape to sequence_length of new_tokens + size_t sequence_length = static_cast(new_tokens.GetCPU().size()) / shape_[0]; + if (shape_[1] != sequence_length) { + shape_[1] = sequence_length; if (!sb_input_ids_) { value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); @@ -77,7 +87,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { #if USE_CUDA case DeviceType::CUDA: { auto* data = value_->GetTensorMutableData(); - auto next_tokens = next_tokens_unk.GetGPU(); + auto next_tokens = new_tokens.GetGPU(); cuda::LaunchInt32ToInt64(next_tokens.data(), data, static_cast(next_tokens.size()), model_.cuda_stream_); } break; #endif @@ -88,8 +98,8 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, value_int32_->GetTensorMutableRawData(), &source_resource)); auto source = std::span( - reinterpret_cast(next_tokens_unk.GetCPU().data()), - next_tokens_unk.GetCPU().size_bytes()); + reinterpret_cast(new_tokens.GetCPU().data()), + new_tokens.GetCPU().size_bytes()); model_.GetDmlUploadHeap()->BeginUploadToGpu( source_resource.Get(), @@ -109,9 +119,11 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { #endif case DeviceType::CPU: { auto* data = value_->GetTensorMutableData(); - auto next_tokens = next_tokens_unk.GetCPU(); - for (int i = 0; i < shape_[0]; i++) { - data[i] = next_tokens[i]; + auto next_tokens = new_tokens.GetCPU(); + for (int b = 0; b < shape_[0]; b++) { + for (int i = 0; i < shape_[1]; i++) { + data[b * shape_[1] + i] = next_tokens[b * shape_[1] + i]; + } } } } @@ -119,44 +131,55 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { auto* data = value_->GetTensorMutableData(); #if USE_CUDA if (model_.device_type_ == DeviceType::CUDA) - cudaMemcpyAsync(data, next_tokens_unk.GetGPU().data(), shape_[0] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_); + cudaMemcpyAsync(data, new_tokens.GetGPU().data(), shape_[0] * shape_[1] * sizeof(int32_t), cudaMemcpyDeviceToDevice, model_.cuda_stream_); else #endif - memcpy(data, next_tokens_unk.GetCPU().data(), shape_[0] * sizeof(int32_t)); + memcpy(data, new_tokens.GetCPU().data(), shape_[0] * shape_[1] * sizeof(int32_t)); } } -// TODO(aciddelgado): Is this ok? add cuda support -void InputIDs::Update(RoamingArray next_tokens, size_t start, size_t token_count) { - switch (model_.device_type_) { - case DeviceType::CPU: { - break; - } - default: - throw std::runtime_error("Update with token count not supported for device type " + to_string(model_.device_type_)); - } - if (shape_[0] != 1) { - throw std::runtime_error("Update with token count only supported for batch size 1, got " + std::to_string(shape_[0])); - } - shape_[1] = token_count; - - if (!sb_input_ids_) { - value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - } else { - value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); - } - state_.inputs_[input_index_] = value_.get(); - if (type_ == Ort::TypeToTensorType) { - auto* data = value_->GetTensorMutableData(); - auto next_tokens_cpu = next_tokens.GetCPU(); - assert(next_tokens_cpu.size() >= start + token_count); - for (int i = 0; i < token_count; i++) { - data[i] = next_tokens_cpu[start + i]; - } - } else { - auto* data = value_->GetTensorMutableData() + start; - memcpy(data, next_tokens.GetCPU().data(), shape_[0] * token_count * sizeof(int32_t)); - } -} +// Add tokens to the end of input ids tensor +// void InputIDs::AddInputTokens(RoamingArray tokens, bool is_first_tokens) { +// switch (model_.device_type_) { +// case DeviceType::CPU: { +// break; +// } +// default: +// throw std::runtime_error("Add Tokens not supported for device type " + to_string(model_.device_type_)); +// } +// if (shape_[0] != 1) { +// throw std::runtime_error("Add Tokens only supported for batch size 1, got " + std::to_string(shape_[0])); +// } +// auto tokens_cpu = tokens.GetCPU(); +// int start = is_first_tokens ? 0 : shape_[1]; +// int token_count = tokens_cpu.size(); +// shape_[1] = start + token_count; + +// std::unique_ptr temp_value; +// if (!sb_input_ids_) { +// temp_value = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); +// } else { +// temp_value = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); +// } +// if (type_ == Ort::TypeToTensorType) { +// auto* data = temp_value->GetTensorMutableData(); +// auto next_tokens_cpu = next_tokens.GetCPU(); +// for (int i = 0; i < start; i++) { +// data[i] = value_->GetTensorData()[i]; +// } +// for (int i = 0; i < token_count; i++) { +// data[start + i] = tokens_cpu[i]; +// } +// } else { +// auto* data = temp_value->GetTensorMutableData(); +// if (is_first_tokens) { +// memcpy(data, value_->GetTensorData(), start * sizeof(int32_t)); +// data += start; +// } +// memcpy(data, tokens.GetCPU().data(), token_count * sizeof(int32_t)); +// } +// value_ = std::move(temp_value); +// state_.inputs_[input_index_] = value_.get(); +// } } // namespace Generators diff --git a/src/models/input_ids.h b/src/models/input_ids.h index a3c0cc32a..1f093de75 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -15,8 +15,8 @@ struct InputIDs { // Resize input_ids to [1], update value with next_tokens. // next_tokens is assumed to have length 1. void Update(RoamingArray next_tokens); - // Resize input_ids to [token_count], update value with next_tokens[start:start + token_count]. - void Update(RoamingArray next_tokens, size_t start, size_t token_count); + // Add tokens to the end of input ids tensor + // void InputIDs::AddInputTokens(RoamingArray tokens, bool is_first_tokens); auto& GetShape() const { return shape_; } const char* name_; diff --git a/src/models/kernels.cu b/src/models/kernels.cu index f90be8347..ef60193de 100644 --- a/src/models/kernels.cu +++ b/src/models/kernels.cu @@ -23,6 +23,24 @@ void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t st template void Launch_UpdatePositionIds(int32_t* positions, int batch_beam_size, cudaStream_t stream); template void Launch_UpdatePositionIds(int64_t* positions, int batch_beam_size, cudaStream_t stream); +template +__global__ void UpdatePositionIds(T* positions, int total_length, int new_kv_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < new_kv_length) { + positions[i] = total_length + i; + } +} + +template +void Launch_UpdatePositionIds(T* positions, int total_length, int new_kv_length, cudaStream_t stream) { + int threads = std::min(256, new_kv_length); + int blocks = (new_kv_length + threads - 1) / threads; + UpdatePositionIds<<>>(positions, total_length, new_kv_length); +} + +template void Launch_UpdatePositionIds(int32_t* positions, int total_length, int new_kv_length, cudaStream_t stream); +template void Launch_UpdatePositionIds(int64_t* positions, int total_length, int new_kv_length, cudaStream_t stream); + template __global__ void CopyAndUpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length) { @@ -63,6 +81,41 @@ template void Launch_UpdateAttentionMask(int32_t* mask_data, const int32_t* old_ template void Launch_UpdateAttentionMask(int64_t* mask_data, const int64_t* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); +template +__global__ void UpdateAttentionMaskStatic(T* mask_data, int new_kv_length, int total_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int past_length = total_length - new_kv_length; + if (i < new_kv_length) { + mask_data[past_length + i] = 1; + } +} + +template +__global__ void UpdateAttentionMask(T* mask_data, int total_length) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < total_length) { + mask_data[i] = 1; + } +} + +template +void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream) { + // LEFT OFF ABOUT THE UPDATE THING AND HOW SOMETIMES WE'LL JUST WANT TO UPDATE IN PLACE AND HAVE ACTUAL 0'S AND OTHER TIMES IT'S JUST 1'S ALL THE WAY THROUGH ON A NEW TENSOR SO WE DON'T NEEDT HE OLD ONE + + if (update_static) { + int threads = std::min(256, new_kv_length); + int blocks = (new_kv_length + threads - 1) / threads; + UpdateAttentionMaskStatic<<>>(mask_data, new_kv_length, total_length); + } else { + int threads = std::min(256, total_length); + int blocks = (total_length + threads - 1) / threads; + UpdateAttentionMask<<>>(mask_data, total_length); + } +} + +template void Launch_UpdateAttentionMask(int32_t* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); +template void Launch_UpdateAttentionMask(int64_t* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); + __global__ void HandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index >= batch_beam_size) diff --git a/src/models/kernels.h b/src/models/kernels.h index b778d95a8..6b3868855 100644 --- a/src/models/kernels.h +++ b/src/models/kernels.h @@ -8,8 +8,12 @@ namespace cuda { template void Launch_UpdatePositionIds(T* positions, int batch_beam_size, cudaStream_t stream); template +void Launch_UpdatePositionIds(T* positions, int total_length, int new_kv_length, cudaStream_t stream); +template void Launch_UpdateAttentionMask(T* mask_data, const T* old_mask_data, int batch_beam_size, int current_length, int max_length, bool update_only, cudaStream_t stream); +template +void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream); void LaunchHandleEOSArray(float* batch_logits, int batch_beam_size, int vocab_size, const int32_t* eos_token_ids, int eos_token_ids_count, cudaStream_t stream); diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index b444fc8e0..cbdc1fae4 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -4,6 +4,7 @@ namespace Generators { +// TODO(aciddelgado): fix alternative kv cache implementations KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state) : model_{model}, state_{state}, @@ -25,7 +26,8 @@ KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state) type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - shape_[3] = state_.params_->sequence_length; + // shape_[3] = state_.params_->sequence_length; + shape_[3] = 0; for (int i = 0; i < layer_count_; ++i) { presents_.push_back(OrtValue::CreateTensor(*model.allocator_device_, shape_, type_)); @@ -144,23 +146,24 @@ KV_Cache::KV_Cache(const Model& model, State& state) empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); // Set the size after empty_past_ has been created with 0 for this field - if (past_present_share_buffer_) + if (past_present_share_buffer_) { shape_[2] = state_.params_->search.max_length; - else - shape_[2] = state_.params_->sequence_length; + // else + // shape_[2] = state_.params_->sequence_length; - if (state_.GetCapturedGraphInfo()) { - assert(past_present_share_buffer_); - sb_kv_caches_.reserve(layer_count_ * 2); - for (int i = 0; i < layer_count_ * 2; ++i) { - sb_kv_caches_.push_back(state_.GetCapturedGraphInfo()->sb_kv_caches_[i].get()); + if (state_.GetCapturedGraphInfo()) { + sb_kv_caches_.reserve(layer_count_ * 2); + for (int i = 0; i < layer_count_ * 2; ++i) { + sb_kv_caches_.push_back(state_.GetCapturedGraphInfo()->sb_kv_caches_[i].get()); + } } - } - for (int i = 0; i < layer_count_ * 2; ++i) { - presents_.push_back( - sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); + // THIS USED TO BE DONE EVEN WITHOUT PAST_PRESENT_SHARE_BUFFER, MEANING DO IT ON FIRST UPDATE + for (int i = 0; i < layer_count_ * 2; ++i) { + presents_.push_back( + sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) + : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); + } } } @@ -192,74 +195,79 @@ void KV_Cache::Add() { } } -void KV_Cache::Update(std::span beam_indices, int current_length) { +// TODO(aciddelgado): consider 0-initializing pasts somewhere +void KV_Cache::Update(std::span beam_indices, int total_length) { // If we're sharing past & present buffers there is nothing to do here, so early exit if (past_present_share_buffer_) return; - for (int i = 0; i < layer_count_ * 2; i++) { - if (beam_indices.empty()) { - pasts_[i] = std::move(presents_[i]); - } else { - PickPastState(beam_indices, i); + if (!is_first_update_) { + for (int i = 0; i < layer_count_ * 2; i++) { + if (beam_indices.empty()) { + pasts_[i] = std::move(presents_[i]); + } else { + PickPastState(beam_indices, i); + } + state_.inputs_[input_index_ + i] = pasts_[i].get(); } - state_.inputs_[input_index_ + i] = pasts_[i].get(); } - shape_[2] = current_length; + shape_[2] = total_length; for (int i = 0; i < layer_count_ * 2; i++) { presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); state_.outputs_[output_index_ + i] = presents_[i].get(); } -} -void KV_Cache::UpdatePresent(int current_length) { - // Used for speculative decoding main generator. - // This can be later refactored to merge with tensor allocation during initialization. - if (shape_[2] == current_length) - return; - shape_[2] = current_length; // TODO(aciddelgado): is it ok to set this if past_present_share_buffer_ is true? - // If we're sharing past & present buffers there is nothing to do here, so early exit - if (past_present_share_buffer_) - return; - for (int i = 0; i < layer_count_ * 2; i++) { - presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - state_.outputs_[output_index_ + i] = presents_[i].get(); - } + is_first_update_ = false; } -void KV_Cache::UpdateAndResize(int current_length, int past_length) { - // If we're sharing past & present buffers there is nothing to do here, so early exit - if (past_present_share_buffer_) - return; - if (shape_[0] != 1) - throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports batch size 1, got " + std::to_string(shape_[0])); - if (model_.device_type_ != DeviceType::CPU) - throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports CPU"); - - auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); - auto element_size = SizeOf(element_type); - auto new_shape = std::array({1, shape_[1], past_length, shape_[3]}); - if (shape_[2] != past_length) { - for (int i = 0; i < layer_count_ * 2; i++) { - auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); - const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); - auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); - - // Copy past_length kv-cache - for (int j = 0; j < shape_[1]; j++) { - memcpy( - new_present_data + j * past_length * shape_[3] * element_size, - present_data + j * shape_[2] * shape_[3] * element_size, - past_length * shape_[3] * element_size); - } +// void KV_Cache::UpdatePresent(int current_length) { +// // Used for speculative decoding main generator. +// // This can be later refactored to merge with tensor allocation during initialization. +// if (shape_[2] == current_length) +// return; +// shape_[2] = current_length; // TODO(aciddelgado): is it ok to set this if past_present_share_buffer_ is true? +// // If we're sharing past & present buffers there is nothing to do here, so early exit +// if (past_present_share_buffer_) +// return; +// for (int i = 0; i < layer_count_ * 2; i++) { +// presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); +// state_.outputs_[output_index_ + i] = presents_[i].get(); +// } +// } - presents_[i] = std::move(new_present); - } - } +// void KV_Cache::UpdateAndResize(int current_length, int past_length) { +// // If we're sharing past & present buffers there is nothing to do here, so early exit +// if (past_present_share_buffer_) +// return; +// if (shape_[0] != 1) +// throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports batch size 1, got " + std::to_string(shape_[0])); +// if (model_.device_type_ != DeviceType::CPU) +// throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports CPU"); - Update({}, current_length); -} +// auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); +// auto element_size = SizeOf(element_type); +// auto new_shape = std::array({1, shape_[1], past_length, shape_[3]}); +// if (shape_[2] != past_length) { +// for (int i = 0; i < layer_count_ * 2; i++) { +// auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); +// const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); +// auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); + +// // Copy past_length kv-cache +// for (int j = 0; j < shape_[1]; j++) { +// memcpy( +// new_present_data + j * past_length * shape_[3] * element_size, +// present_data + j * shape_[2] * shape_[3] * element_size, +// past_length * shape_[3] * element_size); +// } + +// presents_[i] = std::move(new_present); +// } +// } + +// Update({}, current_length); +// } // TODO(aciddelgado): RewindTo function // void KV_Cache::RewindTo(int new_length) { diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 7b2abbbc7..1ba434811 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -36,12 +36,12 @@ struct KV_Cache { // Called only once during initialization of state. void Add(); // Move present to past. Prepare present output for next generation iteration. - void Update(std::span beam_indices, int current_length); + void Update(std::span beam_indices, int total_length); // Used by speculative decoding // Resize present to new sequence length. - void UpdatePresent(int current_length); + // void UpdatePresent(int current_length); // Resize past to new sequence length, and drop past that is > past_length. - void UpdateAndResize(int current_length, int past_length); + // void UpdateAndResize(int current_length, int past_length); // Rewind cache to new_length. // void RewindTo(int new_length); template @@ -55,6 +55,8 @@ struct KV_Cache { size_t input_index_{~0U}, output_index_{~0U}; bool past_present_share_buffer_; // True if model.decoder.past_present_share_buffer is set to true, and we're using cuda, and not beam search + bool is_first_update_{true}; + std::array shape_; ONNXTensorElementDataType type_; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index b90e08507..86df70ccb 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -12,7 +12,7 @@ namespace Generators { Logits::Logits(const Model& model, State& state) : model_{model}, state_{state}, - shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, state_.params_->sequence_length, state_.params_->vocab_size}, + shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, 0, state_.params_->vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); @@ -37,22 +37,183 @@ Logits::Logits(const Model& model, State& state) #pragma warning(push) #pragma warning(disable : 4189) // local variable is initialized but not referenced +// RoamingArray Logits::Get() { +// size_t element_count = shape_[0] * shape_[1] * shape_[2]; + +// // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor +// // The model's output logits are {batch_size*num_beams, input_seq_len, vocab_size} +// OrtValue* logits_of_last_token = output_raw_.get(); +// if (shape_[1] != 1) { +// const size_t seq_length = shape_[1]; +// const size_t vocab_size = shape_[2]; +// const size_t num_beams = state_.params_->search.num_beams; +// const size_t element_count_last_token = shape_[0] * shape_[2]; + +// shape_[1] = 1; + +// // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it +// output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + +// #if USE_DML +// if (type_ == Ort::TypeToTensorType) { +// logits_of_last_token_fp32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); +// } +// #endif + +// logits_of_last_token = output_last_tokens_.get(); + +// size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; +// size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process + +// const auto* input_ids = state_.params_->input_ids.data(); +// for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { +// // Find the first non pad token from the end +// size_t token_index = seq_length; +// while (token_index-- > 0) { +// if (input_ids[token_index] != state_.params_->pad_token_id) +// break; +// } + +// for (int beam_index = 0; beam_index < num_beams; beam_index++) { +// switch (model_.device_type_) { +// #if USE_DML +// case DeviceType::DML: { +// ComPtr source_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource)); + +// ComPtr target_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, logits_of_last_token->GetTensorMutableRawData(), &target_resource)); + +// uint64_t source_offset = (vocab_index * seq_length + token_index * vocab_size) * element_size; +// uint64_t target_offset = vocab_index * element_size; +// uint64_t size_in_bytes = vocab_size * element_size; + +// model_.GetDmlExecutionContext()->CopyBufferRegion( +// target_resource.Get(), +// target_offset, +// D3D12_RESOURCE_STATE_UNORDERED_ACCESS, +// source_resource.Get(), +// source_offset, +// D3D12_RESOURCE_STATE_UNORDERED_ACCESS, +// size_in_bytes); +// } break; +// #endif + +// case DeviceType::CPU: +// case DeviceType::CUDA: { +// auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; +// auto logits_last_tokens = std::span{logits_of_last_token->GetTensorMutableData(), element_count_last_token * element_size}; +// auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size); +// auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size); +// if (model_.device_type_ == DeviceType::CUDA) +// #if USE_CUDA +// CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream); +// #else +// throw std::runtime_error("Unexpected CUDA device usage"); +// #endif +// else +// copy(source, target); +// } break; +// } + +// vocab_index += vocab_size; +// } + +// input_ids += seq_length; +// } + +// element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated +// } + +// // Convert from float16 to float32 if necessary +// if (type_ == Ort::TypeToTensorType) { +// #if USE_DML +// if (model_.device_type_ == DeviceType::DML) { +// DmlHelpers::DmlCastInputToOutput( +// model_.GetDmlExecutionContext(), +// *model_.allocator_device_, +// *logits_of_last_token, +// logits_of_last_token_fp32_, +// model_.GetDmlDevice(), +// model_.GetOrtDmlApi(), +// logits_cast_command_list_state_); + +// logits_of_last_token = logits_of_last_token_fp32_.get(); +// } else +// #endif +// { +// std::unique_ptr logits_of_last_token_fp32; +// ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32, model_.device_type_, model_.cuda_stream_); +// output_last_tokens_ = std::move(logits_of_last_token_fp32); // use output_last_tokens_ to hold the fp32 logits +// logits_of_last_token = output_last_tokens_.get(); +// } +// } + +// #if USE_DML +// // DML doesn't support on-device scoring yet, so we need to download some data to the CPU +// if (model_.device_type_ == DeviceType::DML) { +// value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); +// } +// #endif + +// assert(shape_[1] == 1); + +// #if USE_CUDA +// if (model_.device_type_ == DeviceType::CUDA) { +// auto batched_logits_gpu = gpu_span{logits_of_last_token->GetTensorMutableData(), element_count}; +// if (cuda_eos_token_ids_ptr_) +// cuda::LaunchHandleEOSArray( +// batched_logits_gpu.data(), +// static_cast(shape_[0]) /* batch_beam_size*/, +// static_cast(shape_[2]) /* vocab_size */, +// cuda_eos_token_ids_.data(), +// static_cast(cuda_eos_token_ids_.size()), +// model_.cuda_stream_); +// return batched_logits_gpu; +// } +// #elif USE_DML +// if (model_.device_type_ == DeviceType::DML) { +// // DML doesn't support on-device scoring yet, so we transfer the data to the CPU +// ComPtr gpu_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( +// model_.allocator_device_, +// logits_of_last_token->GetTensorMutableData(), +// &gpu_resource)); +// auto cpu_tensor = value32_cpu_->GetTensorMutableData(); + +// model_.GetDmlReadbackHeap()->ReadbackFromGpu( +// std::span(reinterpret_cast(cpu_tensor), element_count * sizeof(float)), +// gpu_resource.Get(), +// 0, +// D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + +// auto batched_logits_cpu = cpu_span{cpu_tensor, element_count}; +// HandleEOSArray(batched_logits_cpu); +// return batched_logits_cpu; +// } +// #endif + +// auto batched_logits_cpu = cpu_span{logits_of_last_token->GetTensorMutableData(), element_count}; +// HandleEOSArray(batched_logits_cpu); +// return batched_logits_cpu; +// } + RoamingArray Logits::Get() { size_t element_count = shape_[0] * shape_[1] * shape_[2]; - // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor // The model's output logits are {batch_size*num_beams, input_seq_len, vocab_size} OrtValue* logits_of_last_token = output_raw_.get(); + std::array shape_last{shape_[0], 1, shape_[2]}; if (shape_[1] != 1) { const size_t seq_length = shape_[1]; const size_t vocab_size = shape_[2]; const size_t num_beams = state_.params_->search.num_beams; const size_t element_count_last_token = shape_[0] * shape_[2]; - shape_[1] = 1; + // shape_[1] = 1; // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it - output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_); #if USE_DML if (type_ == Ort::TypeToTensorType) { @@ -65,7 +226,8 @@ RoamingArray Logits::Get() { size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process - const auto* input_ids = state_.params_->input_ids.data(); + // const auto* input_ids = state_.params_->input_ids.data(); + const auto* input_ids = state_.input_ids_.Get()->GetTensorData(); // TODO(aciddelgado): make sure on CPU for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { // Find the first non pad token from the end size_t token_index = seq_length; @@ -152,7 +314,7 @@ RoamingArray Logits::Get() { #if USE_DML // DML doesn't support on-device scoring yet, so we need to download some data to the CPU if (model_.device_type_ == DeviceType::DML) { - value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); + value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_last); } #endif @@ -200,59 +362,59 @@ RoamingArray Logits::Get() { #pragma warning(pop) -RoamingArray Logits::Get(size_t start, size_t size) { - const size_t num_beams = state_.params_->search.num_beams; - if (num_beams != 1) - throw std::runtime_error("Get with start and size not supported for num_beams != 1, got " + std::to_string(num_beams)); - if (shape_[0] != 1) - throw std::runtime_error("Get with start and size not supported for batch size != 1, got " + std::to_string(shape_[0])); - - size_t element_count = shape_[1] * shape_[2]; - size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; - size_t selected_element_count = size * shape_[2]; - - output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array({1, static_cast(size), shape_[2]}), type_); - OrtValue* logits_of_selected_tokens = output_last_tokens_.get(); - - auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; - auto logits_of_selected_tokens_raw = std::span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count * element_size}; - auto source = logits_raw.subspan(start * shape_[2] * element_size, selected_element_count * element_size); - copy(source, logits_of_selected_tokens_raw); - - if (type_ == Ort::TypeToTensorType) { - std::unique_ptr logits_of_selected_tokens_fp32; - ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_selected_tokens, logits_of_selected_tokens_fp32, model_.device_type_, model_.cuda_stream_); - output_last_tokens_ = std::move(logits_of_selected_tokens_fp32); - logits_of_selected_tokens = output_last_tokens_.get(); - } - - auto batched_logits_cpu = cpu_span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count}; - HandleEOSArray(batched_logits_cpu); - return batched_logits_cpu; -} - -void Logits::Update() { - if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == 1) { +// RoamingArray Logits::Get(size_t start, size_t size) { +// const size_t num_beams = state_.params_->search.num_beams; +// if (num_beams != 1) +// throw std::runtime_error("Get with start and size not supported for num_beams != 1, got " + std::to_string(num_beams)); +// if (shape_[0] != 1) +// throw std::runtime_error("Get with start and size not supported for batch size != 1, got " + std::to_string(shape_[0])); + +// size_t element_count = shape_[1] * shape_[2]; +// size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; +// size_t selected_element_count = size * shape_[2]; + +// output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array({1, static_cast(size), shape_[2]}), type_); +// OrtValue* logits_of_selected_tokens = output_last_tokens_.get(); + +// auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; +// auto logits_of_selected_tokens_raw = std::span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count * element_size}; +// auto source = logits_raw.subspan(start * shape_[2] * element_size, selected_element_count * element_size); +// copy(source, logits_of_selected_tokens_raw); + +// if (type_ == Ort::TypeToTensorType) { +// std::unique_ptr logits_of_selected_tokens_fp32; +// ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_selected_tokens, logits_of_selected_tokens_fp32, model_.device_type_, model_.cuda_stream_); +// output_last_tokens_ = std::move(logits_of_selected_tokens_fp32); +// logits_of_selected_tokens = output_last_tokens_.get(); +// } + +// auto batched_logits_cpu = cpu_span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count}; +// HandleEOSArray(batched_logits_cpu); +// return batched_logits_cpu; +// } + +void Logits::Update(int new_kv_length) { + if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == new_kv_length) { return; } + shape_[1] = new_kv_length; StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); state_.outputs_[output_index_] = output_raw_.get(); } -void Logits::Update(size_t token_count) { - if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == token_count) { - return; - } +// void Logits::Update() { +// if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == 1) { +// return; +// } - shape_[1] = token_count; - StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; - output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); - state_.outputs_[output_index_] = output_raw_.get(); -} +// StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; +// output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) +// : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); +// state_.outputs_[output_index_] = output_raw_.get(); +// } void Logits::HandleEOSArray(cpu_span batched_logits) { if (model_.config_->model.eos_token_ids.empty()) diff --git a/src/models/logits.h b/src/models/logits.h index ed7a281b3..cd0fc25cc 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -17,9 +17,10 @@ struct Logits { // Retrieves logits[:, start:start + size, :]. RoamingArray Get(size_t start, size_t size); // batch_size x size x vocab_size - void Update(); + // void Update(); + void Update(int new_kv_length); // Resize logits to [bz, token_count, vocab_size]. - void Update(size_t token_count); + // void Update(size_t token_count); private: void HandleEOSArray(cpu_span logits); diff --git a/src/models/model.h b/src/models/model.h index a9a256962..f68e96420 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -5,6 +5,7 @@ #include "captured_graph_pool.h" #include "utils.h" #include "prompt_image_processor.h" +#include "input_ids.h" #if USE_DML #include "dml_provider_factory.h" @@ -29,20 +30,25 @@ struct State { virtual ~State() = default; virtual RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; + // Used by continuous decoding + virtual RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { throw std::runtime_error("Not implemented"); }; + // virtual void AddInputTokens(const RoamingArray& tokens) { throw std::runtime_error("Not implemented"); }; + virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; } OrtValue* GetOutput(const char* name); - virtual RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { throw std::runtime_error("Not implemented"); }; - std::shared_ptr params_; std::vector input_names_, output_names_; std::vector inputs_, outputs_; + InputIDs input_ids_{model_, *this}; // TODO(aciddelgado): is this ok here? + protected: void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run void ClearIO(); // Clear all inputs/outputs + bool first_run_{true}; private: @@ -116,6 +122,7 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::shared_ptr CreateMultiModalProcessor() const; virtual std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const = 0; + // virtual std::unique_ptr CreateState(const GeneratorParams& params) const = 0; std::unique_ptr ExpandInputs(std::unique_ptr& input, int num_beams) const; diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index ecfc01e6d..5270aac22 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -222,10 +222,11 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray return logits_.Get(); } -void DecoderState::UpdateInputsOutputs(int current_length, RoamingArray beam_indices) { - position_inputs_.Update(current_length); - kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); +void DecoderState::UpdateInputsOutputs(int total_length, RoamingArray beam_indices) { + size_t new_length = input_ids_.GetShape()[1]; + position_inputs_.Update(total_length, new_length); + kv_cache_.Update(beam_indices.GetCPU(), total_length); + logits_.Update(new_length); } MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model, @@ -256,7 +257,10 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra vision_state_->Run(current_length, next_tokens, next_indices); // Run the select logic - Select(model_, params_->input_ids, embedding_state_->inputs_embeds_.Get(), + // TODO(aciddelgado): this may not work logically, done to get it to compile for decoder_only + const auto* input_ids = decoder_state_->input_ids_.Get()->GetTensorData(); + auto input_ids_span = std::span(input_ids, decoder_state_->input_ids_.GetShape()[1]); + Select(model_, input_ids_span, embedding_state_->inputs_embeds_.Get(), vision_state_->visual_features_.get(), vision_state_->num_image_tokens_, params_->hidden_size, params_->device_type, params_->cuda_stream); } diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 9b3e62646..b00462b4d 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,7 +42,7 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; - InputIDs input_ids_{model_, *this}; // Model input + // InputIDs input_ids_{model_, *this}; // Model input Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 6fa770f06..bd2f72b48 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -9,6 +9,7 @@ namespace Generators { +// TODO(aciddelgado): WE HAVE REMOVED THE INITIALIZATION WITH SEQUENCE LENGTH HERE PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray& sequence_lengths_unk) : model_{model}, state_{state} { @@ -31,22 +32,23 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray && type_ != Ort::TypeToTensorType) throw std::runtime_error("position_ids & attention_mask only support int32 or int64 types"); - std::array shape{state_.params_->batch_size, state_.params_->sequence_length}; // Only batch_size initially, as we haven't expanded over the beams yet - position_ids_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - position_ids_next_ = OrtValue::CreateTensor(model.allocator_cpu_, std::array{shape[0], 1}, type_); - attention_mask_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); + std::array shape{state_.params_->batch_size, 0}; // Only batch_size initially, as we haven't expanded over the beams yet + // position_ids_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); + // position_ids_next_ = OrtValue::CreateTensor(model.allocator_cpu_, std::array{shape[0], 1}, type_); + // attention_mask_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - initial_sequence_lengths_.resize(state_.params_->BatchBeamSize()); + // initial_sequence_lengths_.resize(state_.params_->BatchBeamSize()); if (type_ == Ort::TypeToTensorType) - InitializeTensors(shape, sequence_lengths_unk); + InitializeSequenceLengths(shape, sequence_lengths_unk); else - InitializeTensors(shape, sequence_lengths_unk); + InitializeSequenceLengths(shape, sequence_lengths_unk); - position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); - position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); - attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); - shape[0] *= state_.params_->search.num_beams; + // TODO(aciddelgado): what is this? does it break with 0 length? + // position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); + // position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); + // attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); + // shape[0] *= state_.params_->search.num_beams; position_ids_shape_ = shape; attention_mask_shape_ = shape; @@ -75,22 +77,51 @@ void PositionInputs::Add() { } } -void PositionInputs::Update(int current_length) { - if (has_posid_input_) { - UpdatePositionIDs(current_length); - } - if (has_mask_input_) { - UpdateAttentionMask(current_length); - } -} +// void PositionInputs::Update(int current_length) { +// if (has_posid_input_) { +// UpdatePositionIDs(current_length); +// } +// if (has_mask_input_) { +// UpdateAttentionMask(current_length); +// } +// } -void PositionInputs::Update(int current_length, int past_length) { +void PositionInputs::Update(int total_length, int new_length) { if (has_posid_input_) { - UpdatePositionIDs(current_length, past_length); + // Initialize on first update + if (is_first_update_) { + position_ids_shape_[1] = new_length; + if (type_ == Ort::TypeToTensorType) + CreateAndInitializePositionIDs(position_ids_shape_); + else + CreateAndInitializePositionIDs(position_ids_shape_); + } else { + // Batch size > 1 case + if (position_ids_shape_[0] > 1) + UpdatePositionIDs(); + // Batch size = 1 case (continuous decoding) + else + UpdatePositionIDs(total_length, new_length); + } } if (has_mask_input_) { - UpdateAttentionMask(current_length, past_length); + // Initialize on first update + if (is_first_update_) { + attention_mask_shape_[1] = new_length; + if (type_ == Ort::TypeToTensorType) + CreateAndInitializeAttentionMask(attention_mask_shape_); + else + CreateAndInitializeAttentionMask(attention_mask_shape_); + } else { + // Batch size > 1 case + if (attention_mask_shape_[0] > 1) + UpdateAttentionMask(total_length); + // Batch size = 1 case + else + UpdateAttentionMask(total_length, new_length); + } } + is_first_update_ = false; } void PositionInputs::AddAttentionMask() { @@ -107,7 +138,7 @@ void PositionInputs::AddPositionIDs() { state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); } -void PositionInputs::UpdatePositionIDs(int current_length) { +void PositionInputs::UpdatePositionIDs() { // Reallocate position_ids for the 2nd and onward shape if (is_first_posid_update_) { position_ids_shape_[1] = 1; @@ -202,19 +233,246 @@ void PositionInputs::UpdatePositionIDs(int current_length) { } } -void PositionInputs::UpdatePositionIDs(int current_length, int past_length) { - if (model_.device_type_ != DeviceType::CPU) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported on CPU."); - if (position_ids_shape_[0] != 1) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported for batch_size=1."); - assert(current_length > past_length); - position_ids_shape_[1] = current_length - past_length; - position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_); - if (type_ == Ort::TypeToTensorType) - UpdatePositionIDsImpl(current_length, past_length); - else - UpdatePositionIDsImpl(current_length, past_length); - state_.inputs_[posid_input_index_] = position_ids_.get(); +void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { + // Support batch_size == 1 only with current length > 0 and new kv length > 1 + if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + // Don't support DML + if (model_.device_type_ == DeviceType::DML) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); + // Reallocate position_ids when new_kv_length changes + if (position_ids_shape_[1] != new_kv_length) { + position_ids_shape_[1] = new_kv_length; + if (!sb_position_ids_) { + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, type_); + } else { +#if USE_CUDA + position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); + assert(model_.device_type_ == DeviceType::CUDA); +// #elif USE_DML +// position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); +// assert(model_.device_type_ == DeviceType::DML); +#endif + } + state_.inputs_[posid_input_index_] = position_ids_.get(); + } + is_first_posid_update_ = false; + // Just incrementing existing position IDs + switch (model_.device_type_) { +// #if USE_DML +// case DeviceType::DML: { +// ComPtr target_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); + +// // Lazily create the kernel only the first time it's needed +// if (!dml_update_position_ids_kernel_) { +// dml_update_position_ids_kernel_ = DmlIncrementValuesKernel( +// model_.GetD3D12Device(), +// model_.GetDmlExecutionContext(), +// static_cast(position_ids_shape_[0]), +// type_, +// target_resource.Get()); +// } + +// // Execute the cached command list +// ComPtr fence; +// uint64_t completion_value; +// model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value); +// } break; +// #endif + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdatePositionIDsImpl(total_length, new_kv_length); + else + UpdatePositionIDsImpl(total_length, new_kv_length); + break; + } +#if USE_CUDA + case DeviceType::CUDA: + if (type_ == Ort::TypeToTensorType) + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), total_length, new_kv_length, model_.cuda_stream_); + else + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), total_length, new_kv_length, model_.cuda_stream_); + break; +#endif + default: + throw std::runtime_error("PositionIDs::Update - Unsupported device type"); + } +} + +// void PositionInputs::UpdatePositionIDs(int current_length, int new_length) { +// // if (model_.device_type_ != DeviceType::CPU) +// // throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported on CPU."); +// // if (position_ids_shape_[0] != 1) +// // throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported for batch_size=1."); +// position_ids_shape_[1] = new_length; +// position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_); +// if (type_ == Ort::TypeToTensorType) +// UpdatePositionIDsImpl(current_length, past_length); +// else +// UpdatePositionIDsImpl(current_length, past_length); +// state_.inputs_[posid_input_index_] = position_ids_.get(); +// } + +void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { + // Support batch_size == 1 only with current length > 0 and new kv length > 1 + if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + // Don't support DML + if (model_.device_type_ == DeviceType::DML) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); + // Update attention mask +// if (sb_attention_mask_) { +// #if USE_CUDA +// attention_mask_shape_[1] = state_.params_->search.max_length; +// attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); +// if (is_first_mask_update_) { +// if (type_ == Ort::TypeToTensorType) { +// cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), +// 0, +// sizeof(int32_t) * attention_mask_shape_[0] * attention_mask_shape_[1], +// model_.cuda_stream_); +// } else { +// cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), +// 0, +// sizeof(int64_t) * attention_mask_shape_[0] * attention_mask_shape_[1], +// model_.cuda_stream_); +// } +// } +// // #elif USE_DML +// // attention_mask_shape_[1] = state_.params_->search.max_length; +// // attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); +// // attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); +// #endif +// } else { +// attention_mask_shape_[1] = total_length; + +// // #if USE_DML +// // if (model_.device_type_ == DeviceType::DML) { +// // attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); +// // } +// // #endif + +// attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); +// } + + if (sb_attention_mask_ && is_first_mask_update_) { +#if USE_CUDA + attention_mask_shape_[1] = state_.params_->search.max_length; + attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); + if (is_first_mask_update_) { + int past_length = total_length - new_kv_length; + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (total_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (total_length - past_length), + model_.cuda_stream_); + } + } +// #elif USE_DML +// attention_mask_shape_[1] = state_.params_->search.max_length; +// attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); +// attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); +#endif + } else { + attention_mask_shape_[1] = total_length; + attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); + } + + switch (model_.device_type_) { +// #if USE_DML +// case DeviceType::DML: { +// ComPtr attention_mask_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); + +// ComPtr attention_mask_next_resource; +// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource)); + +// if (is_first_mask_update_) { +// dml_update_mask_kernel_ = DmlUpdateMaskKernel( +// model_.GetD3D12Device(), +// model_.GetDmlExecutionContext(), +// static_cast(attention_mask_shape_[0]), +// static_cast(attention_mask_shape_[1]), +// type_, +// current_length, +// attention_mask_resource.Get(), +// attention_mask_next_resource.Get()); +// is_second_mask_update_ = true; +// } else if (is_second_mask_update_) { +// dml_update_mask_kernel_ = DmlUpdateMaskKernel( +// model_.GetD3D12Device(), +// model_.GetDmlExecutionContext(), +// static_cast(attention_mask_shape_[0]), +// static_cast(attention_mask_shape_[1]), +// type_, +// 1, +// attention_mask_resource.Get(), +// attention_mask_next_resource.Get()); +// is_second_mask_update_ = false; +// } + +// ComPtr fence; +// uint64_t completion_value; +// model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value); +// break; +// } +// #endif + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + else + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + break; + } +#if USE_CUDA + case DeviceType::CUDA: { + // int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : total_length; + bool update_static = sb_attention_mask_; + if (type_ == Ort::TypeToTensorType) { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } else { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } + break; + } +#endif + default: + throw std::runtime_error("PositionInputs::Update - Unsupported device type"); + } + +// #if USE_DML +// if (model_.device_type_ != DeviceType::DML) { +// attention_mask_ = std::move(attention_mask_next_); +// } +// #else + // attention_mask_ = std::move(attention_mask_next_); +// #endif + + state_.inputs_[mask_input_index_] = attention_mask_.get(); + + is_first_mask_update_ = false; } void PositionInputs::UpdateAttentionMask(int current_length) { @@ -250,7 +508,6 @@ void PositionInputs::UpdateAttentionMask(int current_length) { attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } #endif - attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } @@ -345,49 +602,143 @@ void PositionInputs::UpdateAttentionMask(int current_length) { is_first_mask_update_ = false; } -void PositionInputs::UpdateAttentionMask(int current_length, int past_length) { - if (model_.device_type_ != DeviceType::CPU) - throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported on CPU."); - if (attention_mask_shape_[0] != 1) - throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported for batch_size=1."); - attention_mask_shape_[1] = current_length; - attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); - if (type_ == Ort::TypeToTensorType) - UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); - else - UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); - attention_mask_ = std::move(attention_mask_next_); - state_.inputs_[mask_input_index_] = attention_mask_.get(); - is_first_mask_update_ = false; -} +// void PositionInputs::UpdateAttentionMask(int current_length, int new_length) { +// if (model_.device_type_ != DeviceType::CPU) +// throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported on CPU."); +// if (attention_mask_shape_[0] != 1) +// throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported for batch_size=1."); +// attention_mask_shape_[1] = current_length; +// attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); +// if (type_ == Ort::TypeToTensorType) +// UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); +// else +// UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); +// attention_mask_ = std::move(attention_mask_next_); +// state_.inputs_[mask_input_index_] = attention_mask_.get(); +// is_first_mask_update_ = false; +// } template -void PositionInputs::InitializeTensors(std::array shape, cpu_span sequence_lengths) { +void PositionInputs::CreateAndInitializePositionIDs(std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens - auto* mask_data = attention_mask_->GetTensorMutableData(); + position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + position_ids_next_ = OrtValue::CreateTensor(model_.allocator_cpu_, std::array{shape[0], 1}, type_); auto* position_data = position_ids_->GetTensorMutableData(); auto* position_data_next = position_ids_next_->GetTensorMutableData(); - const auto* word_id = state_.params_->input_ids.data(); - auto* mask = mask_data; + const auto* word_id = state_.input_ids_.Get()->GetTensorData(); auto* position = position_data; for (int i = 0; i < shape[0]; i++) { T abs_position = 0; - for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { + for (int j = 0; j < shape[1]; j++, word_id++, position++) { if (*word_id == state_.params_->pad_token_id) { - *mask = 0; *position = 0; } else { - *mask = 1; *position = abs_position++; } } position_data_next[i] = abs_position; - for (int k = 0; k < state_.params_->search.num_beams; k++) { - sequence_lengths[i * state_.params_->search.num_beams + k] = static_cast(abs_position); - initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); + // initial_sequence_lengths_[i] = static_cast(abs_position); + } + + // Move tensors to appropriate device and expand by num_beams + model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); + model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); + position_ids_shape_[0] *= state_.params_->search.num_beams; +} + +template +void PositionInputs::CreateAndInitializeAttentionMask(std::array shape) { + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens + attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + auto* mask_data = attention_mask_->GetTensorMutableData(); + const auto* word_id = state_.input_ids_.Get()->GetTensorData(); + auto* mask = mask_data; + for (int i = 0; i < shape[0]; i++) { + T abs_position = 0; + for (int j = 0; j < shape[1]; j++, word_id++, mask++) { + if (*word_id == state_.params_->pad_token_id) { + *mask = 0; + } else { + *mask = 1; + } } + + // initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); + } + + // Move tensors to appropriate device and expand by num_beams + model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); + attention_mask_shape_[0] *= state_.params_->search.num_beams; +} + +// template +// void PositionInputs::InitializeTensors(std::array shape/*, cpu_span sequence_lengths_unk*/) { +// // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. +// // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens +// auto* mask_data = attention_mask_->GetTensorMutableData(); +// auto* position_data = position_ids_->GetTensorMutableData(); +// auto* position_data_next = position_ids_next_->GetTensorMutableData(); +// const auto* word_id = state_.params_->input_ids.data(); +// auto* mask = mask_data; +// auto* position = position_data; +// for (int i = 0; i < shape[0]; i++) { +// T abs_position = 0; +// for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { +// if (*word_id == state_.params_->pad_token_id) { +// *mask = 0; +// *position = 0; +// } else { +// *mask = 1; +// *position = abs_position++; +// } +// } + +// position_data_next[i] = abs_position; +// for (int k = 0; k < state_.params_->search.num_beams; k++) { +// // sequence_lengths_unk[i * state_.params_->search.num_beams + k] = static_cast(abs_position); +// initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); +// } +// } +// } + +// template +// void PositionInputs::InitializeTensors(std::array shape/*, cpu_span sequence_lengths_unk*/) { +// // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. +// // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens +// auto* mask_data = attention_mask_->GetTensorMutableData(); +// auto* position_data = position_ids_->GetTensorMutableData(); +// auto* position_data_next = position_ids_next_->GetTensorMutableData(); +// const auto* word_id = state_.params_->input_ids.data(); +// auto* mask = mask_data; +// auto* position = position_data; +// for (int i = 0; i < shape[0]; i++) { +// T abs_position = 0; +// for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { +// if (*word_id == state_.params_->pad_token_id) { +// *mask = 0; +// *position = 0; +// } else { +// *mask = 1; +// *position = abs_position++; +// } +// } + +// position_data_next[i] = abs_position; +// for (int k = 0; k < state_.params_->search.num_beams; k++) { +// // sequence_lengths_unk[i * state_.params_->search.num_beams + k] = static_cast(abs_position); +// initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); +// } +// } +// } + +template +void PositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { + for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { + sequence_lengths_unk[i] = 0; + // initial_sequence_lengths_[i] = 0; } } @@ -401,11 +752,10 @@ void PositionInputs::UpdatePositionIDsImpl() { }; template -void PositionInputs::UpdatePositionIDsImpl(int current_length, int past_length) { +void PositionInputs::UpdatePositionIDsImpl(int current_length, int new_kv_length) { auto* data = position_ids_->GetTensorMutableData(); - for (int i = 0; i < current_length - past_length; i++) { - data[i] = i + past_length; - } + for (int i = 0; i < new_kv_length; i++) + data[i] = i + current_length + new_kv_length; }; template @@ -419,8 +769,8 @@ void PositionInputs::UpdateAttentionMaskImpl(T* data, const T* old_data, int cur }; template -void PositionInputs::UpdateAttentionMaskImpl(T* data, int current_length, int past_length) { - for (int i = 0; i < current_length; i++) { +void PositionInputs::UpdateAttentionMaskImpl(T* data, int total_length) { + for (int i = 0; i < total_length; i++) { data[i] = 1; } }; diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 10ceaa27d..de8875835 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -11,23 +11,31 @@ namespace Generators { struct PositionInputs { PositionInputs(const Model& model, State& state, RoamingArray& sequence_lengths); + PositionInputs(const Model& model, State& state); void Add(); - void Update(int current_length); - void Update(int current_length, int past_length); + // void Update(int current_length); + void Update(int total_length, int new_length); private: void AddAttentionMask(); void AddPositionIDs(); - void UpdatePositionIDs(int current_length); - void UpdateAttentionMask(int current_length); + // Batch size > 1 case + void UpdatePositionIDs(); + void UpdateAttentionMask(int total_length); // Used by continuous decoding. - void UpdatePositionIDs(int current_length, int past_length); - void UpdateAttentionMask(int current_length, int past_length); + void UpdatePositionIDs(int total_length, int new_length); + void UpdateAttentionMask(int total_length, int new_length); + // template + // void InitializeTensors(std::array shape/*, cpu_span sequence_lengths*/); template - void InitializeTensors(std::array shape, cpu_span sequence_lengths); + void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); + template + void CreateAndInitializePositionIDs(std::array shape); + template + void CreateAndInitializeAttentionMask(std::array shape); template void UpdatePositionIDsImpl(); @@ -36,9 +44,9 @@ struct PositionInputs { // Used by continuous decoding template - void UpdatePositionIDsImpl(int current_length, int past_length); + void UpdatePositionIDsImpl(int total_length, int new_kv_length); template - void UpdateAttentionMaskImpl(T* data, int current_length, int past_length); + void UpdateAttentionMaskImpl(T* data, int total_length); const Model& model_; State& state_; @@ -58,7 +66,7 @@ struct PositionInputs { std::unique_ptr position_ids_next_; // Replaces position_ids_ after the first Run() call std::unique_ptr attention_mask_next_; // Replaces attention_mask_ after the first Run() call - std::vector initial_sequence_lengths_; + // std::vector initial_sequence_lengths_; // Used for decoding runs with cuda graphs. StaticBuffer* sb_position_ids_{}; @@ -66,6 +74,7 @@ struct PositionInputs { bool is_first_posid_update_{true}; bool is_first_mask_update_{true}; + bool is_first_update_{true}; #if USE_DML std::optional dml_update_mask_kernel_; diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index b42c458d8..0eae1b196 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -31,7 +31,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray s auto sequence_lengths = sequence_lengths_unk.GetCPU(); for (int i = 0; i < decoder_input_ids_.GetShape()[0]; i++) { - sequence_lengths[i] = static_cast(params_->sequence_length); + sequence_lengths[i] = 0; // TODO(aciddelgado): what? static_cast(params_->sequence_length); } input_names_.push_back("encoder_input_ids"); @@ -80,7 +80,8 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { decoder_input_ids_.Update(next_tokens); kv_cache_.Update(beam_indices.GetCPU(), current_length); - logits_.Update(); + size_t new_length = input_ids_.GetShape()[1]; + logits_.Update(new_length); } } // namespace Generators diff --git a/src/models/whisper.h b/src/models/whisper.h index 5f3872e4c..aaf3a3fb4 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -33,6 +33,7 @@ struct Whisper_State : State { Decoder, } run_state_{RunState::Encoder_Decoder_Init}; + // TODO(aciddelgado): does decoder_input_ids behave differentely than input_ids_? InputIDs decoder_input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache kv_cache_{model_, *this}; diff --git a/src/ort_genai.h b/src/ort_genai.h index 19c658e1d..cab2b55d8 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -191,13 +191,14 @@ struct OgaGeneratorParams : OgaAbstract { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); } - void SetInputIDs(const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - OgaCheckResult(OgaGeneratorParamsSetInputIDs(this, input_ids, input_ids_count, sequence_length, batch_size)); - } + // void SetInputIDs(const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { + // OgaCheckResult(OgaGeneratorParamsSetInputIDs(this, input_ids, input_ids_count, sequence_length, batch_size)); + // } - void SetInputSequences(const OgaSequences& sequences) { - OgaCheckResult(OgaGeneratorParamsSetInputSequences(this, &sequences)); - } + // void SetInputSequences(const OgaSequences& sequences) { + // OgaCheckResult(OgaGeneratorParamsSetInputSequences(this, &sequences)); + // } + void SetModelInput(const char* name, OgaTensor& tensor) { OgaCheckResult(OgaGeneratorParamsSetModelInput(this, name, &tensor)); @@ -225,6 +226,10 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } + void AddInputTokens(const OgaSequences& sequences) { + OgaCheckResult(OgaGenerator_AddInputTokens(this, &sequences)); + } + void ComputeLogits() { OgaCheckResult(OgaGenerator_ComputeLogits(this)); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index a70d355f7..a42b4e6ed 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -129,35 +129,35 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGen OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* oga_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - OGA_TRY - auto& params = *reinterpret_cast(oga_params); - params.input_ids = std::span(input_ids, input_ids_count); - params.sequence_length = static_cast(sequence_length); - params.batch_size = static_cast(batch_size); - if (params.sequence_length * params.batch_size != input_ids_count) - throw std::runtime_error("sequence length * batch size is not equal to input_ids_count"); - return nullptr; - OGA_CATCH -} - -OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* oga_params, const OgaSequences* p_sequences) { - OGA_TRY - auto& params = *reinterpret_cast(oga_params); - auto& sequences = *reinterpret_cast(p_sequences); - - std::vector> span_sequences; - for (size_t i = 0; i < sequences.size(); i++) { - span_sequences.emplace_back(sequences[i]); - } - - params.input_ids_owner = Generators::PadInputs(span_sequences, params.pad_token_id); - params.batch_size = static_cast(sequences.size()); - params.sequence_length = static_cast(params.input_ids_owner.size() / params.batch_size); - params.input_ids = params.input_ids_owner; - return nullptr; - OGA_CATCH -} +// OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* oga_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { +// OGA_TRY +// auto& params = *reinterpret_cast(oga_params); +// params.input_ids = std::span(input_ids, input_ids_count); +// params.sequence_length = static_cast(sequence_length); +// params.batch_size = static_cast(batch_size); +// if (params.sequence_length * params.batch_size != input_ids_count) +// throw std::runtime_error("sequence length * batch size is not equal to input_ids_count"); +// return nullptr; +// OGA_CATCH +// } + +// OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* oga_params, const OgaSequences* p_sequences) { +// OGA_TRY +// auto& params = *reinterpret_cast(oga_params); +// auto& sequences = *reinterpret_cast(p_sequences); + +// std::vector> span_sequences; +// for (size_t i = 0; i < sequences.size(); i++) { +// span_sequences.emplace_back(sequences[i]); +// } + +// params.input_ids_owner = Generators::PadInputs(span_sequences, params.pad_token_id); +// params.batch_size = static_cast(sequences.size()); +// params.sequence_length = static_cast(params.input_ids_owner.size() / params.batch_size); +// params.input_ids = params.input_ids_owner; +// return nullptr; +// OGA_CATCH +// } OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* oga_params, const OgaNamedTensors* p_named_tensors) { OGA_TRY @@ -206,13 +206,38 @@ bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) { return reinterpret_cast(generator)->IsDone(); } -OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { +OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const OgaSequences* p_sequences) { OGA_TRY - reinterpret_cast(generator)->ComputeLogits(); + auto& generator = *reinterpret_cast(oga_generator); + auto& params = *generator.state_->params_; + auto& sequences = *reinterpret_cast(p_sequences); + + std::vector> span_sequences; + for (size_t i = 0; i < sequences.size(); i++) { + span_sequences.emplace_back(sequences[i]); + } + + auto input_ids = Generators::PadInputs(span_sequences, params.pad_token_id); + generator.AddTokens(input_ids); return nullptr; OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + generator.AddTokens(std::span(input_ids, input_ids_count)); + return nullptr; + OGA_CATCH +} + +// OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { +// OGA_TRY +// reinterpret_cast(generator)->ComputeLogits(); +// return nullptr; +// OGA_CATCH +// } + OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { OGA_TRY reinterpret_cast(generator)->GenerateNextToken(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 0c703405e..a8498cb5d 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -179,8 +179,8 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatch * \param[in] batch_size The batch size of the input ids. * \return OgaResult containing the error message if the setting of the input ids failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, - size_t input_ids_count, size_t sequence_length, size_t batch_size); +// OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, +// size_t input_ids_count, size_t sequence_length, size_t batch_size); /* * \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation. @@ -188,7 +188,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorPar * \param[in] sequences The input id sequences. * \return OgaResult containing the error message if the setting of the input id sequences failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); +// OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* generator_params, const OgaNamedTensors* named_tensors); @@ -225,6 +225,16 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); +/* + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_params The generator params to get the pad token id. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] p_sequences The input id sequences. + * \return OgaResult containing the error message if the setting of the input ids failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const OgaSequences* p_sequences); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count); + /* * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. * \param[in] generator The generator to compute the logits for. diff --git a/src/search.cpp b/src/search.cpp index d4754e57e..ca51f41e8 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -9,7 +9,7 @@ namespace Generators { Search_Cpu::Search_Cpu(const GeneratorParams& params) : Search{params}, - sequences_{params.input_ids, params.batch_size, params.search.num_beams, params_->search.max_length} { + sequences_{/*params.input_ids,*/ params.batch_size, params.search.num_beams, params_->search.max_length} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = AllocateArray(batch_beam_size, &sequence_lengths_); } diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index 160d102b1..b4f4b2b2d 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -17,7 +17,7 @@ void OnCudaError(cudaError_t error) { Search_Cuda::Search_Cuda(const GeneratorParams& params) : Search{params}, - sequences_{params.input_ids, params.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { + sequences_{/*params.input_ids,*/ params.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = std::make_unique(batch_beam_size); sequence_lengths_ = cpu_span(sequence_lengths_buffer_.get(), batch_beam_size); diff --git a/src/sequences.cpp b/src/sequences.cpp index e5d1fff85..ad1ca2d43 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -33,6 +33,22 @@ Sequences::Sequences(std::span input_sequences, int batch_size, i } } +Sequences::Sequences(int batch_size, int beam_size, int max_length) + : batch_beam_size_{batch_size * beam_size}, + max_length_{max_length}, + current_length_{0} { + const size_t sequences_size = static_cast(batch_beam_size_) * max_length; + + if (beam_size == 1) { + sequences_buffer_ = std::make_unique(sequences_size); + sequences_ = cpu_span(sequences_buffer_.get(), sequences_size); + } else { + sequences_buffer_ = std::make_unique(2 * sequences_size); + sequences_ = cpu_span(sequences_buffer_.get(), sequences_size); + sequences_next_ = cpu_span(sequences_buffer_.get() + sequences_size, sequences_size); + } +} + cpu_span Sequences::GetSequence(size_t batch_beam_index) { auto span = sequences_.subspan(batch_beam_index * max_length_, current_length_); return cpu_span{span.data(), span.size()}; diff --git a/src/sequences.h b/src/sequences.h index bdf13caa5..538358200 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -4,6 +4,7 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences { Sequences(std::span input_sequence, int batch_size, int beam_size, int max_length); + Sequences(int batch_size, int beam_size, int max_length); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). cpu_span GetSequence(size_t batch_beam_index); diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index 8f7a643e5..1e6a2cfe6 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -10,12 +10,13 @@ void Launch_ExpandInputSequences(std::span input_sequences, std:: void Launch_AppendNextTokenToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); } // namespace cuda -Sequences_Cuda::Sequences_Cuda(std::span input_sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream) +// TODO(aciddelgado): make cuda sequences functional +Sequences_Cuda::Sequences_Cuda(/*std::span input_sequences,*/ int batch_size, int beam_size, int max_length, cudaStream_t stream) : stream_{stream}, batch_beam_size_{batch_size * beam_size}, max_length_{max_length}, - current_length_{static_cast(input_sequences.size()) / batch_size} { - assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly + current_length_{0} { + // assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly size_t sequences_size = batch_beam_size_ * max_length; if (beam_size == 1) { @@ -30,10 +31,10 @@ Sequences_Cuda::Sequences_Cuda(std::span input_sequences, int bat // TODO: input_sequences will be in cuda memory in the future, for now make a temp copy gpu_span input_sequences_gpu; - auto input_sequences_temp = CudaMallocArray(input_sequences.size(), &input_sequences_gpu); - cudaMemcpyAsync(input_sequences_gpu.data(), input_sequences.data(), input_sequences.size_bytes(), cudaMemcpyHostToDevice, stream); + // auto input_sequences_temp = CudaMallocArray(input_sequences.size(), &input_sequences_gpu); + // cudaMemcpyAsync(input_sequences_gpu.data(), input_sequences.data(), input_sequences.size_bytes(), cudaMemcpyHostToDevice, stream); - cuda::Launch_ExpandInputSequences(input_sequences_gpu, sequences_, batch_size, beam_size, current_length_, max_length, stream_); + // cuda::Launch_ExpandInputSequences(input_sequences_gpu, sequences_, batch_size, beam_size, current_length_, max_length, stream_); cudaStreamSynchronize(stream); // Until we remove the todo above, wait for this to complete as input_sequences_gpu is on the stack } diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index 8dc9038c3..b787b6e9a 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -3,7 +3,7 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences_Cuda { - Sequences_Cuda(std::span input_sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream); + Sequences_Cuda(/*std::span input_sequences,*/ int batch_size, int beam_size, int max_length, cudaStream_t stream); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). RoamingArray GetSequence(size_t batch_beam_index); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index e6d9c2a65..35849f8f7 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -102,34 +102,35 @@ TEST(CAPITests, AppendTokensToSequence) { #endif } -TEST(CAPITests, EndToEndPhiBatch) { -#if TEST_PHI2 - auto model = OgaModel::Create(MODEL_PATH "phi-2"); - auto tokenizer = OgaTokenizer::Create(*model); - - const char* input_strings[] = { - "This is a test.", - "Rats are awesome pets!", - "The quick brown fox jumps over the lazy dog.", - }; - - auto input_sequences = OgaSequences::Create(); - for (auto& string : input_strings) - tokenizer->Encode(string, *input_sequences); - - auto params = OgaGeneratorParams::Create(*model); - params->SetSearchOption("max_length", 20); - params->SetInputSequences(*input_sequences); - - auto output_sequences = model->Generate(*params); - - // Decode The Batch - for (size_t i = 0; i < output_sequences->Count(); i++) { - auto out_string = tokenizer->Decode(output_sequences->Get(i)); - std::cout << "Decoded string:" << out_string << std::endl; - } -#endif -} +// TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now +// TEST(CAPITests, EndToEndPhiBatch) { +// #if TEST_PHI2 +// auto model = OgaModel::Create(MODEL_PATH "phi-2"); +// auto tokenizer = OgaTokenizer::Create(*model); + +// const char* input_strings[] = { +// "This is a test.", +// "Rats are awesome pets!", +// "The quick brown fox jumps over the lazy dog.", +// }; + +// auto input_sequences = OgaSequences::Create(); +// for (auto& string : input_strings) +// tokenizer->Encode(string, *input_sequences); + +// auto params = OgaGeneratorParams::Create(*model); +// params->SetSearchOption("max_length", 20); +// params->SetInputSequences(*input_sequences); + +// auto output_sequences = model->Generate(*params); + +// // Decode The Batch +// for (size_t i = 0; i < output_sequences->Count(); i++) { +// auto out_string = tokenizer->Decode(output_sequences->Get(i)); +// std::cout << "Decoded string:" << out_string << std::endl; +// } +// #endif +// } TEST(CAPITests, Tensor_And_AddExtraInput) { // Create a [3 4] shaped tensor @@ -179,12 +180,14 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); - params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + params->SetSearchOption("batch_size", batch_size); + // params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); + generator->AddInputTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -199,19 +202,20 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } - // Test high level API - auto sequences = model->Generate(*params); + // TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now + // // Test high level API + // auto sequences = model->Generate(*params); - // Verify outputs match expected outputs - for (int i = 0; i < batch_size; i++) { - const auto sequence_length = sequences->SequenceCount(i); - const auto* sequence_data = sequences->SequenceData(i); + // // Verify outputs match expected outputs + // for (int i = 0; i < batch_size; i++) { + // const auto sequence_length = sequences->SequenceCount(i); + // const auto* sequence_data = sequences->SequenceData(i); - ASSERT_LE(sequence_length, max_length); + // ASSERT_LE(sequence_length, max_length); - const auto* expected_output_start = &expected_output[i * max_length]; - EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); - } + // const auto* expected_output_start = &expected_output[i * max_length]; + // EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); + // } } #endif @@ -231,9 +235,10 @@ TEST(CAPITests, GetOutputCAPI) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); - params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); + // params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); + generator->AddInputTokens(input_ids.data(), input_ids.size()); // check prompt // full logits has shape [2, 4, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 4, 5] @@ -246,7 +251,7 @@ TEST(CAPITests, GetOutputCAPI) { -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; - generator->ComputeLogits(); + // generator->ComputeLogits(); auto prompt_logits_ptr = generator->GetOutput("logits"); auto prompt_logits = static_cast(prompt_logits_ptr->Data()); int num_prompt_outputs_to_check = 40; @@ -257,13 +262,14 @@ TEST(CAPITests, GetOutputCAPI) { EXPECT_NEAR(expected_sampled_logits_prompt[i], prompt_logits[i*sample_size], tolerance); } + generator->GenerateNextToken(); generator->GenerateNextToken(); // check for the 1st token generation // full logits has shape [2, 1, 1000]. Sample 1 for every 200 tokens and the expected sampled logits has shape [2, 1, 5] std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; - generator->ComputeLogits(); + // generator->ComputeLogits(); auto token_gen_logits_ptr = generator->GetOutput("logits"); auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); int num_token_gen_outputs_to_check = 10; @@ -271,7 +277,7 @@ TEST(CAPITests, GetOutputCAPI) { for (int i = 0; i < num_token_gen_outputs_to_check; i++) { EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); } - generator->GenerateNextToken(); + // generator->GenerateNextToken(); } #if TEST_PHI2 @@ -293,7 +299,7 @@ struct Phi2Test { tokenizer_->Encode(string, *input_sequences_); params_ = OgaGeneratorParams::Create(*model_); - params_->SetInputSequences(*input_sequences_); + // params_->SetInputSequences(*input_sequences_); params_->SetSearchOption("max_length", 40); } @@ -301,9 +307,10 @@ struct Phi2Test { // Low level loop { auto generator = OgaGenerator::Create(*model_, *params_); + generator->AddInputTokens(input_sequences_); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); } From 3e3d56a2cb00825198aa4e048b48da5000230e9e Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 19 Sep 2024 13:56:52 -0700 Subject: [PATCH 05/13] decoder only cpu greedy works --- benchmark/c/main.cpp | 39 ++++++++----- examples/c/src/main.cpp | 10 ++-- examples/c/src/phi3v.cpp | 4 +- examples/python/model-qa.py | 19 +++--- src/config.cpp | 4 +- src/config.h | 1 + src/generators.cpp | 10 +++- src/models/decoder_only.cpp | 1 + src/models/input_ids.cpp | 4 +- src/models/model.cpp | 4 +- src/models/model.h | 3 +- src/models/position_inputs.cpp | 9 ++- src/ort_genai.h | 8 ++- src/ort_genai_c.cpp | 7 ++- src/ort_genai_c.h | 4 +- src/python/python.cpp | 52 ++++++++++------- src/search.cpp | 7 +++ src/search.h | 1 + test/c_api_tests.cpp | 39 ++++++++----- test/model_tests.cpp | 102 +++++++++++++++++++-------------- test/sampling_benchmark.cpp | 28 ++++----- test/sampling_tests.cpp | 52 ++++++++--------- 22 files changed, 246 insertions(+), 162 deletions(-) diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index e924a26f8..f5b6d154b 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -121,11 +121,17 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons auto params = OgaGeneratorParams::Create(model); params->SetSearchOption("max_length", static_cast(num_prompt_tokens)); params->SetSearchOption("min_length", static_cast(num_prompt_tokens)); - params->SetInputSequences(*base_prompt_sequences); + // params->SetInputSequences(*base_prompt_sequences); - auto output_sequences = model.Generate(*params); - const auto output_sequence_length = output_sequences->SequenceCount(0); - const auto* output_sequence_data = output_sequences->SequenceData(0); + // auto output_sequences = model.Generate(*params); + auto generator = OgaGenerator::Create(model, *params); + generator->AddInputSequences(*base_prompt_sequences); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } + + const auto output_sequence_length = generator->GetSequenceCount(0); + const auto* output_sequence_data = generator->GetSequenceData(0); return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)}; } @@ -151,7 +157,7 @@ void RunBenchmark(const benchmark::Options& opts) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", static_cast(num_tokens)); params->SetSearchOption("min_length", static_cast(num_tokens)); - params->SetInputSequences(*prompt_sequences); + // params->SetInputSequences(*prompt_sequences); return params; }; @@ -160,13 +166,18 @@ void RunBenchmark(const benchmark::Options& opts) { // warmup if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n"; for (size_t i = 0; i < opts.num_warmup_iterations; ++i) { - auto output_sequences = model->Generate(*generator_params); + // auto output_sequences = model->Generate(*generator_params); + auto generator = OgaGenerator::Create(*model, *generator_params); + generator->AddInputSequences(*prompt_sequences); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + } if (opts.verbose && i == 0) { // show prompt and output on first iteration std::cout << "Prompt:\n\t" << prompt << "\n"; - const auto output_sequence_length = output_sequences->SequenceCount(0); - const auto* output_sequence_data = output_sequences->SequenceData(0); + const auto output_sequence_length = generator->GetSequenceCount(0); + const auto* output_sequence_data = generator->GetSequenceData(0); const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length); std::cout << "Output:\n\t" << output << "\n"; } @@ -188,7 +199,7 @@ void RunBenchmark(const benchmark::Options& opts) { { Timing prompt_processing_timing{prompt_processing_times}; - generator->ComputeLogits(); + generator->AddInputSequences(*prompt_sequences); } { @@ -199,13 +210,13 @@ void RunBenchmark(const benchmark::Options& opts) { while (!generator->IsDone()) { { Timing token_gen_timing{token_gen_times}; - generator->ComputeLogits(); - } - - { - Timing sampling_timing{sampling_times}; generator->GenerateNextToken(); } + + // { + // Timing sampling_timing{sampling_times}; + // generator->GenerateNextToken(); + // } } } } diff --git a/examples/c/src/main.cpp b/examples/c/src/main.cpp index 5017108e6..7c94b241b 100644 --- a/examples/c/src/main.cpp +++ b/examples/c/src/main.cpp @@ -27,12 +27,13 @@ void CXX_API(const char* model_path) { std::cout << "Generating response..." << std::endl; auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", 1024); - params->SetInputSequences(*sequences); + // params->SetInputSequences(*sequences); auto generator = OgaGenerator::Create(*model, *params); + generator->AddInputSequences(*sequences); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); @@ -82,13 +83,14 @@ void C_API(const char* model_path) { OgaGeneratorParams* params; CheckResult(OgaCreateGeneratorParams(model, ¶ms)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 1024)); - CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences)); + // CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences)); OgaGenerator* generator; CheckResult(OgaCreateGenerator(model, params, &generator)); + CheckResult(OgaGenerator_AddInputSequences(generator, sequences)); while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); + // CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); diff --git a/examples/c/src/phi3v.cpp b/examples/c/src/phi3v.cpp index 3ba7fc93a..e6287e975 100644 --- a/examples/c/src/phi3v.cpp +++ b/examples/c/src/phi3v.cpp @@ -56,7 +56,7 @@ void CXX_API(const char* model_path) { auto generator = OgaGenerator::Create(*model, *params); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); @@ -128,7 +128,7 @@ void C_API(const char* model_path) { CheckResult(OgaCreateGenerator(model, params, &generator)); while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); + // CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 4532f307a..e10e359af 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -16,6 +16,7 @@ def main(args): if args.verbose: print() search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} + search_options['batch_size'] = 1 if args.verbose: print(search_options) @@ -24,6 +25,10 @@ def main(args): print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") exit(1) + params = og.GeneratorParams(model) + params.set_search_options(**search_options) + generator = og.Generator(model, params) + # Keep asking for input prompts in a loop while True: text = input("Input: ") @@ -39,11 +44,8 @@ def main(args): prompt = f'{args.chat_template.format(input=text)}' input_tokens = tokenizer.encode(prompt) - - params = og.GeneratorParams(model) - params.set_search_options(**search_options) - params.input_ids = input_tokens - generator = og.Generator(model, params) + + generator.add_input_tokens(input_tokens) if args.verbose: print("Generator created") if args.verbose: print("Running generation loop ...") @@ -56,7 +58,7 @@ def main(args): try: while not generator.is_done(): - generator.compute_logits() + # generator.compute_logits() generator.generate_next_token() if args.timings: if first: @@ -64,6 +66,7 @@ def main(args): first = False new_token = generator.get_next_tokens()[0] + # print(new_token, end=' ') print(tokenizer_stream.decode(new_token), end='', flush=True) if args.timings: new_tokens.append(new_token) except KeyboardInterrupt: @@ -71,8 +74,10 @@ def main(args): print() print() + # print(generator.get_sequence(0)) + # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - del generator + # del generator if args.timings: prompt_time = first_token_timestamp - started_timestamp diff --git a/src/config.cpp b/src/config.cpp index 00708faae..c4c9f1dcc 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -431,6 +431,8 @@ struct Search_Element : JSON::Element { v_.min_length = static_cast(value); } else if (name == "max_length") { v_.max_length = static_cast(value); + } else if (name == "batch_size") { + v_.batch_size = static_cast(value); } else if (name == "num_beams") { v_.num_beams = static_cast(value); } else if (name == "num_return_sequences") { @@ -555,7 +557,7 @@ void ParseConfig(const fs::path& filename, Config& config) { Config::Config(const fs::path& path) : config_path{path} { ParseConfig(path / "genai_config.json", *this); - + if (model.context_length == 0) throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0"); diff --git a/src/config.h b/src/config.h index 7263dbda3..5179dc90b 100644 --- a/src/config.h +++ b/src/config.h @@ -118,6 +118,7 @@ struct Config { bool do_sample{}; // True to do randomized sampling through top_k and top_p, if false, the top logit score is chosen int min_length{}; int max_length{}; // If omitted or 0 in json file, will be set to model.context_length on load + int batch_size{1}; int num_beams{1}; // 1 means no beam search. int num_return_sequences{1}; float repetition_penalty{1.0f}; // 1.0 means no penalty. diff --git a/src/generators.cpp b/src/generators.cpp index 554f4bba9..51ff84890 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -10,6 +10,8 @@ #include "search_cuda.h" #endif +#include + namespace Generators { static bool _ = (Ort::InitApi(), false); @@ -98,6 +100,7 @@ void GeneratorParams::TryGraphCapture(int max_bs) { } } +// TODO(aciddelgado): Almost certainly broken at this point but who knows void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { for (const auto& [name, tensor] : named_tensors) { if (name == Config::Defaults::InputIdsName) { @@ -189,8 +192,11 @@ void Generator::ComputeLogits(const RoamingArray& next_tokens) { } bool Generator::IsDone() const { - if (computed_logits_) - throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); + // TODO(aciddelgado): how do we deal with this now that it's addtokens and computelogits isn't in api + if (computed_logits_) { + return false; + } + // throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); return search_->IsDone(); } diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index b18944dcf..4e927f78b 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -1,5 +1,6 @@ #include "../generators.h" #include "decoder_only.h" +#include namespace Generators { DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env) diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index fd72df177..fcd5468c4 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -11,7 +11,9 @@ InputIDs::InputIDs(const Model& model, State& state) state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); shape_ = {state_.params_->search.num_beams * state_.params_->batch_size, 0}; - type_ = model_.session_info_->GetInputDataType(name_); + auto session_info = model_.session_info_.get(); + type_ = session_info->GetInputDataType(name_); + // type_ = model_.session_info_->GetInputDataType(name_); // If 64-bit, convert from 32-bit to 64-bit // if (type_ == Ort::TypeToTensorType) { diff --git a/src/models/model.cpp b/src/models/model.cpp index 3e17fd009..2cb7c5eab 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include #include +#include #include "../generators.h" #include "../search.h" @@ -38,7 +39,8 @@ namespace Generators { State::State(const GeneratorParams& params, const Model& model) : params_{params.shared_from_this()}, - model_{model} {} + model_{model}, + input_ids_{model, *this} {} void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) { if (first_run_) { diff --git a/src/models/model.h b/src/models/model.h index f68e96420..4e8a78f5a 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -43,7 +43,8 @@ struct State { std::vector input_names_, output_names_; std::vector inputs_, outputs_; - InputIDs input_ids_{model_, *this}; // TODO(aciddelgado): is this ok here? + // InputIDs input_ids_{model_, *this}; // TODO(aciddelgado): is this ok here? + InputIDs input_ids_; // TODO(aciddelgado): is this ok here? protected: void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index bd2f72b48..d4f9f027f 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -470,6 +470,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { // attention_mask_ = std::move(attention_mask_next_); // #endif + // LEFT OFF: state_.inputs_[mask_input_index_] = attention_mask_.get(); not working state_.inputs_[mask_input_index_] = attention_mask_.get(); is_first_mask_update_ = false; @@ -643,9 +644,10 @@ void PositionInputs::CreateAndInitializePositionIDs(std::array shape } // Move tensors to appropriate device and expand by num_beams - model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); - model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); + position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); + position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); position_ids_shape_[0] *= state_.params_->search.num_beams; + state_.inputs_[posid_input_index_] = position_ids_.get(); } template @@ -670,8 +672,9 @@ void PositionInputs::CreateAndInitializeAttentionMask(std::array sha } // Move tensors to appropriate device and expand by num_beams - model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); + attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); attention_mask_shape_[0] *= state_.params_->search.num_beams; + state_.inputs_[mask_input_index_] = attention_mask_.get(); } // template diff --git a/src/ort_genai.h b/src/ort_genai.h index cab2b55d8..8c853eff7 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -226,8 +226,12 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } - void AddInputTokens(const OgaSequences& sequences) { - OgaCheckResult(OgaGenerator_AddInputTokens(this, &sequences)); + void AddInputSequences(const OgaSequences& sequences) { + OgaCheckResult(OgaGenerator_AddInputSequences(this, &sequences)); + } + + void AddInputTokens(int32_t* input_ids, size_t input_ids_count) { + OgaCheckResult(OgaGenerator_AddInputTokens(this, input_ids, input_ids_count)); } void ComputeLogits() { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index a42b4e6ed..dc5c2dbe4 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -9,6 +9,7 @@ #include "generators.h" #include "models/model.h" #include "search.h" +#include "smartptrs.h" namespace Generators { @@ -206,7 +207,7 @@ bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) { return reinterpret_cast(generator)->IsDone(); } -OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const OgaSequences* p_sequences) { +OgaResult* OGA_API_CALL OgaGenerator_AddInputSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); auto& params = *generator.state_->params_; @@ -223,10 +224,10 @@ OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count) { +OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); - generator.AddTokens(std::span(input_ids, input_ids_count)); + generator.AddTokens(Generators::cpu_span(input_ids, input_ids_count)); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index a8498cb5d..a953cd7a7 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -232,8 +232,8 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); * \param[in] p_sequences The input id sequences. * \return OgaResult containing the error message if the setting of the input ids failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const OgaSequences* p_sequences); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, const int32_t* input_ids, size_t input_ids_count); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count); /* * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. diff --git a/src/python/python.cpp b/src/python/python.cpp index 1340ef16c..d8bf74e52 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -9,6 +9,7 @@ #include "../search.h" #include "../models/model.h" #include "../logging.h" +#include "../smartptrs.h" using namespace pybind11::literals; @@ -31,7 +32,7 @@ struct npy_format_descriptor { } // namespace pybind11 template -std::span ToSpan(pybind11::array_t v) { +Generators::cpu_span ToSpan(pybind11::array_t v) { if constexpr (std::is_const_v) return {v.data(), static_cast(v.size())}; else @@ -222,19 +223,19 @@ struct PyGeneratorParams { // Turn the python py_input_ids_ into the low level parameters void Prepare() { // TODO: This will switch to using the variant vs being ifs - if (py_input_ids_.size() != 0) { - if (py_input_ids_.ndim() == 1) { // Just a 1D array - params_->batch_size = 1; - params_->sequence_length = static_cast(py_input_ids_.shape(0)); - } else { - if (py_input_ids_.ndim() != 2) - throw std::runtime_error("Input IDs can only be 1 or 2 dimensional"); - - params_->batch_size = static_cast(py_input_ids_.shape(0)); - params_->sequence_length = static_cast(py_input_ids_.shape(1)); - } - params_->input_ids = ToSpan(py_input_ids_); - } + // if (py_input_ids_.size() != 0) { + // if (py_input_ids_.ndim() == 1) { // Just a 1D array + // params_->batch_size = 1; + // params_->sequence_length = static_cast(py_input_ids_.shape(0)); + // } else { + // if (py_input_ids_.ndim() != 2) + // throw std::runtime_error("Input IDs can only be 1 or 2 dimensional"); + + // params_->batch_size = static_cast(py_input_ids_.shape(0)); + // params_->sequence_length = static_cast(py_input_ids_.shape(1)); + // } + // params_->input_ids = ToSpan(py_input_ids_); + // } if (py_whisper_input_features_.size() != 0) { GeneratorParams::Whisper& whisper = params_->inputs.emplace(); @@ -274,7 +275,7 @@ struct PyGeneratorParams { params_->TryGraphCapture(max_batch_size.cast()); } - pybind11::array_t py_input_ids_; + // pybind11::array_t py_input_ids_; pybind11::array_t py_whisper_input_features_; std::vector refs_; // References to data we want to ensure doesn't get garbage collected @@ -289,7 +290,7 @@ struct PyNamedTensors { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { - params.Prepare(); + // params.Prepare(); generator_ = CreateGenerator(model, params); } @@ -303,14 +304,19 @@ struct PyGenerator { return ToPython(py_sequence_.GetCPU()); } - void ComputeLogits() { - generator_->ComputeLogits(); - } + // void ComputeLogits() { + // generator_->ComputeLogits(); + // } pybind11::array GetOutput(const std::string& name) { return ToNumpy(generator_->state_->GetOutput(name.c_str()), *(generator_->model_)); } + // TODO(aciddelgado): Does this work with batch size > 1? + void AddTokens(pybind11::array_t tokens) { + generator_->AddTokens(ToSpan(tokens)); + } + void GenerateNextToken() { generator_->GenerateNextToken(); } @@ -374,7 +380,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->pad_token_id; }) .def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->eos_token_id; }) .def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->vocab_size; }) - .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) + // .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) { if (!named_tensors || !named_tensors->named_tensors_) @@ -412,9 +418,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_>(m, "Model") .def(pybind11::init([](const std::string& config_path) { + std::cout << "Loading model from: " << config_path << std::endl; return CreateModel(GetOrtEnv(), config_path.c_str()); })) - .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) + // .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) .def_property_readonly( "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); @@ -422,8 +429,9 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "Generator") .def(pybind11::init()) .def("is_done", &PyGenerator::IsDone) - .def("compute_logits", &PyGenerator::ComputeLogits) + // .def("compute_logits", &PyGenerator::ComputeLogits) .def("get_output", &PyGenerator::GetOutput) + .def("add_input_tokens", &PyGenerator::AddTokens) .def("generate_next_token", &PyGenerator::GenerateNextToken) .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); diff --git a/src/search.cpp b/src/search.cpp index ca51f41e8..1770e5b9f 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -26,6 +26,7 @@ GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params) gen_.seed(seq); } + // TODO(aciddelgado): the reason we don't use next tokens for user input is that we'd have to allocate a new buffer for different input sizes and it would be a useless copy. next_tokens_buffer_ = AllocateArray(params.batch_size, &next_tokens_); memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); @@ -253,6 +254,12 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { } void GreedySearch_Cpu::SetNextTokens(RoamingArray next_tokens) { + // Reset done count/state + done_ = false; + not_done_count_ = params_->batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + + // Set user-defined next tokens auto next_tokens_cpu = next_tokens.GetCPU(); auto batch_size = params_->batch_size; auto tokens_count_per_batch = next_tokens_cpu.size() / batch_size; diff --git a/src/search.h b/src/search.h index 368dd417d..019f81cab 100644 --- a/src/search.h +++ b/src/search.h @@ -14,6 +14,7 @@ struct Search : LeakChecked { virtual RoamingArray GetSequenceLengths() = 0; virtual int GetSequenceLength() const = 0; virtual RoamingArray GetSequence(size_t index) = 0; + // TODO(aciddelgado): do we want a GetSequences() API? virtual void SetLogits(RoamingArray logits) = 0; virtual bool IsDone() const = 0; diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 35849f8f7..91eac8817 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -4,6 +4,7 @@ #include #include #include +#define MODEL_PATH "/home/aciddelgado/ort-genai-source/test/test_models/" #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -176,7 +177,9 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) - auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); + std::cout << "Loading model..." << std::endl; + std::cout << "Model path: " << MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32" << std::endl; + auto model = OgaModel::Create("/home/aciddelgado/ort-genai-source/test/test_models/hf-internal-testing/tiny-random-gpt2-fp32"); auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); @@ -184,8 +187,9 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); + std::cout << "Adding input tokens..." << std::endl; generator->AddInputTokens(input_ids.data(), input_ids.size()); - + std::cout << "Generating..." << std::endl; while (!generator->IsDone()) { // generator->ComputeLogits(); generator->GenerateNextToken(); @@ -196,6 +200,14 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { const auto sequence_length = generator->GetSequenceCount(i); const auto* sequence_data = generator->GetSequenceData(i); + std::cout << "Sequence length: " << sequence_length << std::endl; + std::cout << "Max length: " << max_length << std::endl; + std::cout << "Output sequence: "; + for (int j = 0; j < sequence_length; j++) { + std::cout << sequence_data[j] << " "; + } + std::cout << std::endl; + ASSERT_LE(sequence_length, max_length); const auto* expected_output_start = &expected_output[i * max_length]; @@ -307,7 +319,7 @@ struct Phi2Test { // Low level loop { auto generator = OgaGenerator::Create(*model_, *params_); - generator->AddInputTokens(input_sequences_); + generator->AddInputSequences(input_sequences_); while (!generator->IsDone()) { // generator->ComputeLogits(); @@ -321,16 +333,17 @@ struct Phi2Test { } } - // High level - { - auto output_sequences = model_->Generate(*params_); - - // Decode The Batch - for (size_t i = 0; i < output_sequences->Count(); i++) { - auto out_string = tokenizer_->Decode(output_sequences->Get(i)); - std::cout << "Decoded string:" << out_string << std::endl; - } - } + // TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now + // // High level + // { + // auto output_sequences = model_->Generate(*params_); + + // // Decode The Batch + // for (size_t i = 0; i < output_sequences->Count(); i++) { + // auto out_string = tokenizer_->Decode(output_sequences->Get(i)); + // std::cout << "Decoded string:" << out_string << std::endl; + // } + // } } std::unique_ptr model_; diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 6766fb892..e167081b9 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -38,13 +39,14 @@ TEST(ModelTests, GreedySearchGptFp32) { auto params = Generators::CreateGeneratorParams(*model); params->search.max_length = 10; params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + // params->sequence_length = static_cast(input_ids_shape[1]); + // params->input_ids = input_ids; auto generator = Generators::CreateGenerator(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -77,19 +79,24 @@ TEST(ModelTests, BeamSearchGptFp32) { auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + // params->sequence_length = static_cast(input_ids_shape[1]); + // params->input_ids = input_ids; params->search.max_length = 20; params->search.length_penalty = 1.0f; params->search.num_beams = 4; auto generator = Generators::CreateGenerator(*model, *params); - auto result = Generators::Generate(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); + // auto result = Generators::Generate(*model, *params); + while (!generator->IsDone()) { + // generator->ComputeLogits(); + generator->GenerateNextToken(); + } // Verify outputs match expected outputs for (int i = 0; i < params->batch_size; i++) { - auto sequence = std::span(result[i].data(), params->search.max_length); + auto sequence = generator->GetSequence(i).GetCPU(); auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } @@ -110,14 +117,15 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); + // params->sequence_length = static_cast(input_ids_shape[1]); params->search.max_length = 10; - params->input_ids = input_ids; + // params->input_ids = input_ids; auto generator = Generators::CreateGenerator(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -155,19 +163,24 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - params->sequence_length = static_cast(input_ids_shape[1]); - params->input_ids = input_ids; + // params->sequence_length = static_cast(input_ids_shape[1]); + // params->input_ids = input_ids; params->search.max_length = 20; params->search.num_beams = 4; params->search.length_penalty = 1.0f; auto generator = Generators::CreateGenerator(*model, *params); - auto result = Generators::Generate(*model, *params); + generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); + // auto result = Generators::Generate(*model, *params); + while (!generator->IsDone()) { + // generator->ComputeLogits(); + generator->GenerateNextToken(); + } // Verify outputs match expected outputs for (int i = 0; i < params->batch_size; i++) { - auto sequence = std::span(result[i].data(), params->search.max_length); + auto sequence = generator->GetSequence(i).GetCPU(); auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } @@ -196,14 +209,15 @@ Print all primes between 1 and n auto params = Generators::CreateGeneratorParams(*model); params->batch_size = 1; - params->sequence_length = static_cast(tokens.size()); - params->input_ids = tokens; + // params->sequence_length = static_cast(tokens.size()); + // params->input_ids = tokens; params->search.max_length = 128; // Generator version auto generator = Generators::CreateGenerator(*model, *params); + generator->AddInputTokens(Generators::cpu_span(tokens.data(), tokens.size())); while (!generator->IsDone()) { - generator->ComputeLogits(); + // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -213,32 +227,32 @@ Print all primes between 1 and n #endif } -TEST(ModelTests, TestHighLevelApiCuda) { -#if TEST_PHI2 - auto prompt = R"( -def print_prime(n): -''' -Print all primes between 1 and n -''' -)"; - - std::cout << "With prompt:" << prompt << "\r\n"; - - auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "phi-2"); - auto tokenizer = model->CreateTokenizer(); - auto tokens = tokenizer->Encode(prompt); - - auto params = Generators::CreateGeneratorParams(*model); - params->batch_size = 1; - params->sequence_length = static_cast(tokens.size()); - params->input_ids = tokens; - params->search.max_length = 128; - - // High level version - auto result = Generators::Generate(*model, *params); - - std::cout << tokenizer->Decode(result[0]) << "\r\n"; -#endif -} +// TEST(ModelTests, TestHighLevelApiCuda) { +// #if TEST_PHI2 +// auto prompt = R"( +// def print_prime(n): +// ''' +// Print all primes between 1 and n +// ''' +// )"; + +// std::cout << "With prompt:" << prompt << "\r\n"; + +// auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "phi-2"); +// auto tokenizer = model->CreateTokenizer(); +// auto tokens = tokenizer->Encode(prompt); + +// auto params = Generators::CreateGeneratorParams(*model); +// params->batch_size = 1; +// params->sequence_length = static_cast(tokens.size()); +// params->input_ids = tokens; +// params->search.max_length = 128; + +// // High level version +// auto result = Generators::Generate(*model, *params); + +// std::cout << tokenizer->Decode(result[0]) << "\r\n"; +// #endif +// } #endif \ No newline at end of file diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index e614b2b20..fceb9765b 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -25,9 +25,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -60,9 +60,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -98,9 +98,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -137,9 +137,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; std::vector cpu_logits(vocab_size * batch_size); std::random_device rd; @@ -181,9 +181,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); @@ -222,9 +222,9 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); @@ -265,9 +265,9 @@ TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 85c151d2e..861b953a3 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -29,9 +29,9 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { params->search.do_sample = true; params->search.top_p = 0.25f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto generator = Generators::CreateGenerator(*model, *params); auto logits_span = Generators::cpu_span(logits_cpu); @@ -57,9 +57,9 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -91,9 +91,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -141,9 +141,9 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -179,9 +179,9 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -219,9 +219,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -266,9 +266,9 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { params->search.do_sample = true; params->search.top_p = 0.25f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -296,9 +296,9 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -331,9 +331,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -360,9 +360,9 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -402,9 +402,9 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -446,9 +446,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -485,9 +485,9 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - params->sequence_length = 1; + // params->sequence_length = 1; params->vocab_size = vocab_size; - params->input_ids = input_ids; + // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); From 9f7d0e04382b6091af1ad8e6f17b49cac04cf42f Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 20 Sep 2024 11:02:16 -0700 Subject: [PATCH 06/13] clean up comments --- benchmark/c/main.cpp | 9 - examples/c/src/main.cpp | 4 - examples/c/src/phi3v.cpp | 2 - examples/python/model-qa.py | 7 - src/config.cpp | 2 +- src/generators.cpp | 47 +-- src/generators.h | 10 +- src/models/decoder_only.cpp | 50 --- src/models/decoder_only.h | 6 - src/models/embeddings.cpp | 2 +- src/models/gpt.h | 1 - src/models/input_ids.cpp | 72 +---- src/models/input_ids.h | 6 +- src/models/kv_cache.cpp | 88 +----- src/models/kv_cache.h | 7 - src/models/logits.cpp | 208 +------------ src/models/logits.h | 7 +- src/models/model.cpp | 1 - src/models/model.h | 11 +- src/models/multi_modal_vision_model.cpp | 5 +- src/models/multi_modal_vision_model.h | 1 - src/models/position_inputs.cpp | 394 +++++------------------- src/models/position_inputs.h | 11 +- src/models/whisper.cpp | 4 +- src/models/whisper.h | 1 - src/ort_genai.h | 11 +- src/ort_genai_c.cpp | 51 +-- src/ort_genai_c.h | 29 +- src/python/python.cpp | 27 -- src/search.cpp | 3 +- src/search.h | 6 +- src/search_cuda.cpp | 2 +- src/sequences.cpp | 27 -- src/sequences.h | 4 +- src/sequences_cuda.cpp | 9 +- src/sequences_cuda.h | 2 +- test/c_api_tests.cpp | 81 +---- test/model_tests.cpp | 45 --- test/sampling_benchmark.cpp | 14 - test/sampling_tests.cpp | 26 -- 40 files changed, 136 insertions(+), 1157 deletions(-) diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index f5b6d154b..b9b35728f 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -121,9 +121,7 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons auto params = OgaGeneratorParams::Create(model); params->SetSearchOption("max_length", static_cast(num_prompt_tokens)); params->SetSearchOption("min_length", static_cast(num_prompt_tokens)); - // params->SetInputSequences(*base_prompt_sequences); - // auto output_sequences = model.Generate(*params); auto generator = OgaGenerator::Create(model, *params); generator->AddInputSequences(*base_prompt_sequences); while (!generator->IsDone()) { @@ -157,7 +155,6 @@ void RunBenchmark(const benchmark::Options& opts) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", static_cast(num_tokens)); params->SetSearchOption("min_length", static_cast(num_tokens)); - // params->SetInputSequences(*prompt_sequences); return params; }; @@ -166,7 +163,6 @@ void RunBenchmark(const benchmark::Options& opts) { // warmup if (opts.verbose) std::cout << "Running warmup iterations (" << opts.num_warmup_iterations << ")...\n"; for (size_t i = 0; i < opts.num_warmup_iterations; ++i) { - // auto output_sequences = model->Generate(*generator_params); auto generator = OgaGenerator::Create(*model, *generator_params); generator->AddInputSequences(*prompt_sequences); while (!generator->IsDone()) { @@ -212,11 +208,6 @@ void RunBenchmark(const benchmark::Options& opts) { Timing token_gen_timing{token_gen_times}; generator->GenerateNextToken(); } - - // { - // Timing sampling_timing{sampling_times}; - // generator->GenerateNextToken(); - // } } } } diff --git a/examples/c/src/main.cpp b/examples/c/src/main.cpp index 7c94b241b..e925da654 100644 --- a/examples/c/src/main.cpp +++ b/examples/c/src/main.cpp @@ -27,13 +27,11 @@ void CXX_API(const char* model_path) { std::cout << "Generating response..." << std::endl; auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", 1024); - // params->SetInputSequences(*sequences); auto generator = OgaGenerator::Create(*model, *params); generator->AddInputSequences(*sequences); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); @@ -83,14 +81,12 @@ void C_API(const char* model_path) { OgaGeneratorParams* params; CheckResult(OgaCreateGeneratorParams(model, ¶ms)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 1024)); - // CheckResult(OgaGeneratorParamsSetInputSequences(params, sequences)); OgaGenerator* generator; CheckResult(OgaCreateGenerator(model, params, &generator)); CheckResult(OgaGenerator_AddInputSequences(generator, sequences)); while (!OgaGenerator_IsDone(generator)) { - // CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); diff --git a/examples/c/src/phi3v.cpp b/examples/c/src/phi3v.cpp index e6287e975..890bff76b 100644 --- a/examples/c/src/phi3v.cpp +++ b/examples/c/src/phi3v.cpp @@ -56,7 +56,6 @@ void CXX_API(const char* model_path) { auto generator = OgaGenerator::Create(*model, *params); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); @@ -128,7 +127,6 @@ void C_API(const char* model_path) { CheckResult(OgaCreateGenerator(model, params, &generator)); while (!OgaGenerator_IsDone(generator)) { - // CheckResult(OgaGenerator_ComputeLogits(generator)); CheckResult(OgaGenerator_GenerateNextToken(generator)); const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, 0); diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index e10e359af..059ed4d2e 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -58,7 +58,6 @@ def main(args): try: while not generator.is_done(): - # generator.compute_logits() generator.generate_next_token() if args.timings: if first: @@ -66,7 +65,6 @@ def main(args): first = False new_token = generator.get_next_tokens()[0] - # print(new_token, end=' ') print(tokenizer_stream.decode(new_token), end='', flush=True) if args.timings: new_tokens.append(new_token) except KeyboardInterrupt: @@ -74,11 +72,6 @@ def main(args): print() print() - # print(generator.get_sequence(0)) - - # Delete the generator to free the captured graph for the next generator, if graph capture is enabled - # del generator - if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp diff --git a/src/config.cpp b/src/config.cpp index c4c9f1dcc..cd3f4a495 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -557,7 +557,7 @@ void ParseConfig(const fs::path& filename, Config& config) { Config::Config(const fs::path& path) : config_path{path} { ParseConfig(path / "genai_config.json", *this); - + if (model.context_length == 0) throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0"); diff --git a/src/generators.cpp b/src/generators.cpp index 51ff84890..665d8aa0e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -10,8 +10,6 @@ #include "search_cuda.h" #endif -#include - namespace Generators { static bool _ = (Ort::InitApi(), false); @@ -100,15 +98,10 @@ void GeneratorParams::TryGraphCapture(int max_bs) { } } -// TODO(aciddelgado): Almost certainly broken at this point but who knows +// TODO(aciddelgado): Does this work? void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { for (const auto& [name, tensor] : named_tensors) { - if (name == Config::Defaults::InputIdsName) { - // input_ids = std::span(tensor->ort_tensor_->GetTensorMutableData(), - // tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetElementCount()); - // batch_size = static_cast(tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]); - // sequence_length = static_cast(input_ids.size()) / batch_size; - } else { + if (name != Config::Defaults::InputIdsName) { // If the nominal name is found in the map, use the graph name. // Else, use the nominal name as the graph name. [[maybe_unused]] const auto [graph_name, found] = config_->GetGraphName(name); @@ -150,14 +143,9 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ state_ = model.CreateState(search_->GetSequenceLengths(), params); // Search sequence lengths set when creating state } -// void Generator::AddInput(const std::string& name, const std::shared_ptr& tensor) { -// search_->AddInput(name, tensor); -// } - void Generator::AddTokens(cpu_span input_ids) { - // TODO(aciddelgado): check for first call after reset + // TODO(aciddelgado): batch_size > 1 requires full rewind search_->SetNextTokens(input_ids); - // state_->AddInputTokens(input_ids); // Do this in Run instead if (g_log.enabled && g_log.add_tokens) { auto& stream = Log("add_tokens"); @@ -176,7 +164,6 @@ void Generator::ComputeLogits(const RoamingArray& next_tokens) { if (computed_logits_) throw std::runtime_error("ComputeLogits called again without calling AddTokens or GenerateNextToken first"); - // auto logits = state_->Run(candidate_sequence, candidate_length_ + 1, search_->GetSequenceLength() - 1, candidate_length_ + 1); auto logits = state_->Run(search_->GetSequenceLength(), next_tokens, search_->GetNextIndices()); if (g_log.enabled && g_log.model_logits) { auto& stream = Log("model_logits"); @@ -185,24 +172,19 @@ void Generator::ComputeLogits(const RoamingArray& next_tokens) { } search_->SetLogits(logits); computed_logits_ = true; - - // auto& search = search_->params_->search; - // search_->ApplyMinLength(search.min_length); - // search_->ApplyRepetitionPenalty(search.repetition_penalty); } bool Generator::IsDone() const { - // TODO(aciddelgado): how do we deal with this now that it's addtokens and computelogits isn't in api + // TODO(aciddelgado): Is this the correct approach to handling computed_logits_ now? if (computed_logits_) { return false; } - // throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); return search_->IsDone(); } void Generator::GenerateNextToken() { - // TODO(aciddelgado): check that AddTokens has been called + // TODO(aciddelgado): check that AddTokens has been called at least once if (!computed_logits_) { ComputeLogits(search_->GetNextTokens()); } @@ -250,23 +232,4 @@ RoamingArray Generator::GetSequence(size_t index) const { return search_->GetSequence(index); } -TokenSequences Generate(const Model& model, const GeneratorParams& params) { - auto generator = CreateGenerator(model, params); - // generator->AddTokens(params.search.input_ids); - - while (!generator->IsDone()) { - generator->GenerateNextToken(); - } - - TokenSequences result; - for (int i = 0; i < params.batch_size * params.search.num_return_sequences; i++) { - auto sequence = generator->search_->GetSequence(i); - auto sequence_cpu = sequence.GetCPU(); - - auto& v = result.emplace_back(); - v.assign(sequence_cpu.begin(), sequence_cpu.end()); - } - return result; -} - } // namespace Generators diff --git a/src/generators.h b/src/generators.h index f46c39d04..3e34f6cdf 100644 --- a/src/generators.h +++ b/src/generators.h @@ -71,7 +71,6 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; - // int sequence_length{}; int hidden_size{}; int BatchBeamSize() const { return search.num_beams * batch_size; } @@ -95,17 +94,12 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec #endif - // TODO: Move this to a separate GPT struct - // std::span input_ids; // Array of [batchsize][sequence_length] - struct Whisper { std::shared_ptr input_features; // float32 [batch_size, number_of_mels, something that is 3000] }; std::variant inputs; - std::vector input_ids_owner; // Backing memory of input_ids in some cases - std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime struct Input { @@ -130,9 +124,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; - // virtual void ComputeLogits(); - // TODO(aciddelgado): Make this function work with batched inputs - virtual void AddTokens(cpu_span input_ids); // Add tokens to the input_ids + virtual void AddTokens(cpu_span input_ids); virtual void GenerateNextToken(); RoamingArray GetSequence(size_t index) const; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 4e927f78b..41d7274e2 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -1,6 +1,5 @@ #include "../generators.h" #include "decoder_only.h" -#include namespace Generators { DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env) @@ -26,15 +25,8 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra extra_inputs_.Add(); } -// void DecoderOnly_State::AddInputTokens(const RoamingArray& tokens) { -// input_ids_.AddInputTokens(tokens, reset_input_); -// reset_input_ = false; -// } - RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) { - // if (!first_run_) { UpdateInputsOutputs(next_tokens, next_indices, total_length); - // } int batch_size = static_cast(input_ids_.GetShape()[0]); State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); @@ -51,46 +43,4 @@ void DecoderOnly_State::UpdateInputsOutputs(const RoamingArray& next_to logits_.Update(new_length); } -// TODO(aciddelgado): Transition into a new paradigm -// RoamingArray DecoderOnly_State::Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { -// int batch_size = static_cast(input_ids_.GetShape()[0]); -// if (batch_size != 1) -// throw std::runtime_error("Speculative decoding only supports batch size 1, got " + std::to_string(batch_size)); - -// auto total_length = past_length + next_token_length; -// auto total_logits = first_run_ ? total_length : next_token_length; // TODO(aciddelgado): remove first_run -// // NB(bowenbao): workaround gqa limitation on token phase. -// // if (next_token_length > 1) { -// // total_logits = total_length; -// // } -// UpdateInputsOutputsFromSequence(sequence, next_token_length, past_length); -// State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); -// reset_input_ = true; - -// return logits_.Get(total_logits - return_last_logit_count, return_last_logit_count); -// } - -// TODO(aciddelgado): update should append, not replace for input_ids. ensure correct next and past lengths -// void DecoderOnly_State::UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length) { -// auto total_length = past_length + next_token_length; -// if (g_log.enabled && g_log.continuous_decoding) { -// auto& stream = Log("continuous_decoding"); -// stream << "UpdateInputsOutputsFromSequence: past_length=" << past_length << ", next_token_length=" << next_token_length << ", total_length=" << total_length << std::endl; -// } -// // TODO(aciddelgado): remove first_run -// if (first_run_) { -// // First run input ids includes prompt tokens. -// input_ids_.Update(sequence, 0, total_length); -// position_inputs_.Update(total_length, 0); -// kv_cache_.UpdatePresent(total_length); -// logits_.Update(); -// } else { -// // Subsequent runs input ids only include candidate tokens. -// input_ids_.Update(sequence, past_length, next_token_length); -// position_inputs_.Update(total_length, past_length); -// kv_cache_.UpdateAndResize(total_length, past_length); -// logits_.Update(); -// } -// } - } // namespace Generators diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 43bfb3449..073e99497 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -1,6 +1,5 @@ #pragma once #include "model.h" -// #include "input_ids.h" #include "logits.h" #include "kv_cache.h" #include "position_inputs.h" @@ -12,17 +11,13 @@ struct DecoderOnly_Model : Model { DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env); std::unique_ptr CreateState(RoamingArray sequence_lengths_unk, const GeneratorParams& params) const override; - // std::unique_ptr CreateState(const GeneratorParams& params) const override; std::unique_ptr session_decoder_; }; struct DecoderOnly_State : State { DecoderOnly_State(const DecoderOnly_Model& model, RoamingArray sequence_lengths_unk, const GeneratorParams& params); - // DecoderOnly_State(const DecoderOnly_Model& model, const GeneratorParams& params); RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) override; - // RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) override; - // void AddInputTokens(const RoamingArray& tokens) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; protected: @@ -32,7 +27,6 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; - // InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache kv_cache_{model_, *this}; PositionInputs position_inputs_; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index a25f2e115..fe11b3b4a 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -7,7 +7,7 @@ namespace Generators { -// TODO(aciddelgado): get this right what is this +// TODO(aciddelgado): initialize after addtokens is called Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, const std::string& name) : model_{model}, state_{state}, diff --git a/src/models/gpt.h b/src/models/gpt.h index 265d379ea..50607ed64 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -25,7 +25,6 @@ struct Gpt_State : State { const Gpt_Model& model_; - // InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache_Combined kv_cache_{model_, *this}; PositionInputs position_inputs_; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index fcd5468c4..67f6b8e89 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -5,7 +5,6 @@ namespace Generators { -// NOW IS 0-INITIALIZED InputIDs::InputIDs(const Model& model, State& state) : model_{model}, state_{state} { @@ -13,25 +12,6 @@ InputIDs::InputIDs(const Model& model, State& state) shape_ = {state_.params_->search.num_beams * state_.params_->batch_size, 0}; auto session_info = model_.session_info_.get(); type_ = session_info->GetInputDataType(name_); - // type_ = model_.session_info_->GetInputDataType(name_); - - // If 64-bit, convert from 32-bit to 64-bit - // if (type_ == Ort::TypeToTensorType) { - // value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); - // auto* p_data = value_->GetTensorMutableData(); - // for (auto v : state_.params_->input_ids) { - // *p_data++ = v; - // } - // } else { - // if (type_ != Ort::TypeToTensorType) - // throw std::runtime_error("InputIDs must be int64 or int32"); - // value_ = OrtValue::CreateTensor(model.allocator_cpu_.GetInfo(), std::span(const_cast(state_.params_->input_ids.data()), shape_[0] * shape_[1]), shape_); - // } - - // value_ = OrtValue::CreateTensor(model.allocator_cpu_, shape_, type_); // TODO(aciddelgado): 0 initializing tensors allowed? - - // value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); - // shape_[0] *= state_.params_->search.num_beams; if (state_.GetCapturedGraphInfo()) { sb_input_ids_ = state_.GetCapturedGraphInfo()->sb_input_ids_.get(); @@ -52,13 +32,7 @@ void InputIDs::Add() { } void InputIDs::Update(RoamingArray new_tokens) { - // // Resize input_ids shape once if it doesn't match the decoder shape - // if (shape_[1] != 1) { - // shape_[1] = 1; - // if (!sb_input_ids_) { - // value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - - // Resize input_ids shape to sequence_length of new_tokens + // Resize input_ids shape based on new_tokens size_t sequence_length = static_cast(new_tokens.GetCPU().size()) / shape_[0]; if (shape_[1] != sequence_length) { shape_[1] = sequence_length; @@ -140,48 +114,4 @@ void InputIDs::Update(RoamingArray new_tokens) { } } -// Add tokens to the end of input ids tensor -// void InputIDs::AddInputTokens(RoamingArray tokens, bool is_first_tokens) { -// switch (model_.device_type_) { -// case DeviceType::CPU: { -// break; -// } -// default: -// throw std::runtime_error("Add Tokens not supported for device type " + to_string(model_.device_type_)); -// } -// if (shape_[0] != 1) { -// throw std::runtime_error("Add Tokens only supported for batch size 1, got " + std::to_string(shape_[0])); -// } -// auto tokens_cpu = tokens.GetCPU(); -// int start = is_first_tokens ? 0 : shape_[1]; -// int token_count = tokens_cpu.size(); -// shape_[1] = start + token_count; - -// std::unique_ptr temp_value; -// if (!sb_input_ids_) { -// temp_value = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); -// } else { -// temp_value = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); -// } -// if (type_ == Ort::TypeToTensorType) { -// auto* data = temp_value->GetTensorMutableData(); -// auto next_tokens_cpu = next_tokens.GetCPU(); -// for (int i = 0; i < start; i++) { -// data[i] = value_->GetTensorData()[i]; -// } -// for (int i = 0; i < token_count; i++) { -// data[start + i] = tokens_cpu[i]; -// } -// } else { -// auto* data = temp_value->GetTensorMutableData(); -// if (is_first_tokens) { -// memcpy(data, value_->GetTensorData(), start * sizeof(int32_t)); -// data += start; -// } -// memcpy(data, tokens.GetCPU().data(), token_count * sizeof(int32_t)); -// } -// value_ = std::move(temp_value); -// state_.inputs_[input_index_] = value_.get(); -// } - } // namespace Generators diff --git a/src/models/input_ids.h b/src/models/input_ids.h index 1f093de75..33dc029fe 100644 --- a/src/models/input_ids.h +++ b/src/models/input_ids.h @@ -12,11 +12,9 @@ struct InputIDs { // Register input_ids as ORT session input. // Called only once during initialization of state. void Add(); - // Resize input_ids to [1], update value with next_tokens. - // next_tokens is assumed to have length 1. + // Resize input_ids based on size of next_tokens. + // Update value with next_tokens. void Update(RoamingArray next_tokens); - // Add tokens to the end of input ids tensor - // void InputIDs::AddInputTokens(RoamingArray tokens, bool is_first_tokens); auto& GetShape() const { return shape_; } const char* name_; diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index cbdc1fae4..44cf25ff7 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -4,7 +4,7 @@ namespace Generators { -// TODO(aciddelgado): fix alternative kv cache implementations +// TODO(aciddelgado): check alternative kv cache implementations KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state) : model_{model}, state_{state}, @@ -26,7 +26,6 @@ KV_Cache_Combined::KV_Cache_Combined(const Model& model, State& state) type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - // shape_[3] = state_.params_->sequence_length; shape_[3] = 0; for (int i = 0; i < layer_count_; ++i) { @@ -148,8 +147,6 @@ KV_Cache::KV_Cache(const Model& model, State& state) // Set the size after empty_past_ has been created with 0 for this field if (past_present_share_buffer_) { shape_[2] = state_.params_->search.max_length; - // else - // shape_[2] = state_.params_->sequence_length; if (state_.GetCapturedGraphInfo()) { sb_kv_caches_.reserve(layer_count_ * 2); @@ -158,7 +155,6 @@ KV_Cache::KV_Cache(const Model& model, State& state) } } - // THIS USED TO BE DONE EVEN WITHOUT PAST_PRESENT_SHARE_BUFFER, MEANING DO IT ON FIRST UPDATE for (int i = 0; i < layer_count_ * 2; ++i) { presents_.push_back( sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) @@ -221,88 +217,6 @@ void KV_Cache::Update(std::span beam_indices, int total_length) { is_first_update_ = false; } -// void KV_Cache::UpdatePresent(int current_length) { -// // Used for speculative decoding main generator. -// // This can be later refactored to merge with tensor allocation during initialization. -// if (shape_[2] == current_length) -// return; -// shape_[2] = current_length; // TODO(aciddelgado): is it ok to set this if past_present_share_buffer_ is true? -// // If we're sharing past & present buffers there is nothing to do here, so early exit -// if (past_present_share_buffer_) -// return; -// for (int i = 0; i < layer_count_ * 2; i++) { -// presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); -// state_.outputs_[output_index_ + i] = presents_[i].get(); -// } -// } - -// void KV_Cache::UpdateAndResize(int current_length, int past_length) { -// // If we're sharing past & present buffers there is nothing to do here, so early exit -// if (past_present_share_buffer_) -// return; -// if (shape_[0] != 1) -// throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports batch size 1, got " + std::to_string(shape_[0])); -// if (model_.device_type_ != DeviceType::CPU) -// throw std::runtime_error("KV_Cache::Update(int current_length, int past_length) only supports CPU"); - -// auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); -// auto element_size = SizeOf(element_type); -// auto new_shape = std::array({1, shape_[1], past_length, shape_[3]}); -// if (shape_[2] != past_length) { -// for (int i = 0; i < layer_count_ * 2; i++) { -// auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); -// const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); -// auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); - -// // Copy past_length kv-cache -// for (int j = 0; j < shape_[1]; j++) { -// memcpy( -// new_present_data + j * past_length * shape_[3] * element_size, -// present_data + j * shape_[2] * shape_[3] * element_size, -// past_length * shape_[3] * element_size); -// } - -// presents_[i] = std::move(new_present); -// } -// } - -// Update({}, current_length); -// } - -// TODO(aciddelgado): RewindTo function -// void KV_Cache::RewindTo(int new_length) { -// // If we're sharing past & present buffers there is nothing to do here, so early exit -// if (past_present_share_buffer_) -// return; -// if (shape_[0] != 1) -// throw std::runtime_error("KV_Cache::RewindTo(int new_length) only supports batch size 1, got " + std::to_string(shape_[0])); -// if (model_.device_type_ != DeviceType::CPU) -// throw std::runtime_error("KV_Cache::RewindTo(int new_length) only supports CPU"); - -// auto element_type = presents_[0]->GetTensorTypeAndShapeInfo()->GetElementType(); -// auto element_size = SizeOf(element_type); -// auto new_shape = std::array({1, shape_[1], new_length, shape_[3]}); -// if (shape_[2] != new_length) { -// for (int i = 0; i < layer_count_ * 2; i++) { -// auto new_present = OrtValue::CreateTensor(*model_.allocator_device_, new_shape, type_); -// const auto* present_data = reinterpret_cast(presents_[i]->GetTensorRawData()); -// auto* new_present_data = reinterpret_cast(new_present->GetTensorMutableRawData()); - -// // Copy new_length kv-cache -// for (int j = 0; j < shape_[1]; j++) { -// memcpy( -// new_present_data + j * new_length * shape_[3] * element_size, -// present_data + j * shape_[2] * shape_[3] * element_size, -// new_length * shape_[3] * element_size); -// } - -// presents_[i] = std::move(new_present); -// } -// } - -// shape_[2] = new_length; -// } - // Copy present state to past state reordered by the beam_indices template void KV_Cache::PickPastState(std::span beam_indices, int index) { diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index 1ba434811..c167831f3 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -37,13 +37,6 @@ struct KV_Cache { void Add(); // Move present to past. Prepare present output for next generation iteration. void Update(std::span beam_indices, int total_length); - // Used by speculative decoding - // Resize present to new sequence length. - // void UpdatePresent(int current_length); - // Resize past to new sequence length, and drop past that is > past_length. - // void UpdateAndResize(int current_length, int past_length); - // Rewind cache to new_length. - // void RewindTo(int new_length); template void PickPastState(std::span beam_indices, int index); void PickPastState(std::span beam_indices, int index); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 86df70ccb..105c7f632 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -37,167 +37,6 @@ Logits::Logits(const Model& model, State& state) #pragma warning(push) #pragma warning(disable : 4189) // local variable is initialized but not referenced -// RoamingArray Logits::Get() { -// size_t element_count = shape_[0] * shape_[1] * shape_[2]; - -// // First iteration? Then copy the logits over to a {batch_beams, 1, vocab_size} tensor -// // The model's output logits are {batch_size*num_beams, input_seq_len, vocab_size} -// OrtValue* logits_of_last_token = output_raw_.get(); -// if (shape_[1] != 1) { -// const size_t seq_length = shape_[1]; -// const size_t vocab_size = shape_[2]; -// const size_t num_beams = state_.params_->search.num_beams; -// const size_t element_count_last_token = shape_[0] * shape_[2]; - -// shape_[1] = 1; - -// // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it -// output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); - -// #if USE_DML -// if (type_ == Ort::TypeToTensorType) { -// logits_of_last_token_fp32_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); -// } -// #endif - -// logits_of_last_token = output_last_tokens_.get(); - -// size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; -// size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process - -// const auto* input_ids = state_.params_->input_ids.data(); -// for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { -// // Find the first non pad token from the end -// size_t token_index = seq_length; -// while (token_index-- > 0) { -// if (input_ids[token_index] != state_.params_->pad_token_id) -// break; -// } - -// for (int beam_index = 0; beam_index < num_beams; beam_index++) { -// switch (model_.device_type_) { -// #if USE_DML -// case DeviceType::DML: { -// ComPtr source_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, output_raw_->GetTensorMutableRawData(), &source_resource)); - -// ComPtr target_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, logits_of_last_token->GetTensorMutableRawData(), &target_resource)); - -// uint64_t source_offset = (vocab_index * seq_length + token_index * vocab_size) * element_size; -// uint64_t target_offset = vocab_index * element_size; -// uint64_t size_in_bytes = vocab_size * element_size; - -// model_.GetDmlExecutionContext()->CopyBufferRegion( -// target_resource.Get(), -// target_offset, -// D3D12_RESOURCE_STATE_UNORDERED_ACCESS, -// source_resource.Get(), -// source_offset, -// D3D12_RESOURCE_STATE_UNORDERED_ACCESS, -// size_in_bytes); -// } break; -// #endif - -// case DeviceType::CPU: -// case DeviceType::CUDA: { -// auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; -// auto logits_last_tokens = std::span{logits_of_last_token->GetTensorMutableData(), element_count_last_token * element_size}; -// auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size); -// auto source = logits_raw.subspan((vocab_index * seq_length + token_index * vocab_size) * element_size, vocab_size * element_size); -// if (model_.device_type_ == DeviceType::CUDA) -// #if USE_CUDA -// CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), cudaMemcpyDeviceToDevice, state_.params_->cuda_stream); -// #else -// throw std::runtime_error("Unexpected CUDA device usage"); -// #endif -// else -// copy(source, target); -// } break; -// } - -// vocab_index += vocab_size; -// } - -// input_ids += seq_length; -// } - -// element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated -// } - -// // Convert from float16 to float32 if necessary -// if (type_ == Ort::TypeToTensorType) { -// #if USE_DML -// if (model_.device_type_ == DeviceType::DML) { -// DmlHelpers::DmlCastInputToOutput( -// model_.GetDmlExecutionContext(), -// *model_.allocator_device_, -// *logits_of_last_token, -// logits_of_last_token_fp32_, -// model_.GetDmlDevice(), -// model_.GetOrtDmlApi(), -// logits_cast_command_list_state_); - -// logits_of_last_token = logits_of_last_token_fp32_.get(); -// } else -// #endif -// { -// std::unique_ptr logits_of_last_token_fp32; -// ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_last_token, logits_of_last_token_fp32, model_.device_type_, model_.cuda_stream_); -// output_last_tokens_ = std::move(logits_of_last_token_fp32); // use output_last_tokens_ to hold the fp32 logits -// logits_of_last_token = output_last_tokens_.get(); -// } -// } - -// #if USE_DML -// // DML doesn't support on-device scoring yet, so we need to download some data to the CPU -// if (model_.device_type_ == DeviceType::DML) { -// value32_cpu_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape_); -// } -// #endif - -// assert(shape_[1] == 1); - -// #if USE_CUDA -// if (model_.device_type_ == DeviceType::CUDA) { -// auto batched_logits_gpu = gpu_span{logits_of_last_token->GetTensorMutableData(), element_count}; -// if (cuda_eos_token_ids_ptr_) -// cuda::LaunchHandleEOSArray( -// batched_logits_gpu.data(), -// static_cast(shape_[0]) /* batch_beam_size*/, -// static_cast(shape_[2]) /* vocab_size */, -// cuda_eos_token_ids_.data(), -// static_cast(cuda_eos_token_ids_.size()), -// model_.cuda_stream_); -// return batched_logits_gpu; -// } -// #elif USE_DML -// if (model_.device_type_ == DeviceType::DML) { -// // DML doesn't support on-device scoring yet, so we transfer the data to the CPU -// ComPtr gpu_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation( -// model_.allocator_device_, -// logits_of_last_token->GetTensorMutableData(), -// &gpu_resource)); -// auto cpu_tensor = value32_cpu_->GetTensorMutableData(); - -// model_.GetDmlReadbackHeap()->ReadbackFromGpu( -// std::span(reinterpret_cast(cpu_tensor), element_count * sizeof(float)), -// gpu_resource.Get(), -// 0, -// D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - -// auto batched_logits_cpu = cpu_span{cpu_tensor, element_count}; -// HandleEOSArray(batched_logits_cpu); -// return batched_logits_cpu; -// } -// #endif - -// auto batched_logits_cpu = cpu_span{logits_of_last_token->GetTensorMutableData(), element_count}; -// HandleEOSArray(batched_logits_cpu); -// return batched_logits_cpu; -// } - RoamingArray Logits::Get() { size_t element_count = shape_[0] * shape_[1] * shape_[2]; @@ -210,8 +49,6 @@ RoamingArray Logits::Get() { const size_t num_beams = state_.params_->search.num_beams; const size_t element_count_last_token = shape_[0] * shape_[2]; - // shape_[1] = 1; - // create new OrtValue for logits_of_last_token and use output_last_tokens_ to hold it output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_last, type_); @@ -226,8 +63,7 @@ RoamingArray Logits::Get() { size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process - // const auto* input_ids = state_.params_->input_ids.data(); - const auto* input_ids = state_.input_ids_.Get()->GetTensorData(); // TODO(aciddelgado): make sure on CPU + const auto* input_ids = state_.input_ids_.Get()->GetTensorData(); for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { // Find the first non pad token from the end size_t token_index = seq_length; @@ -362,37 +198,6 @@ RoamingArray Logits::Get() { #pragma warning(pop) -// RoamingArray Logits::Get(size_t start, size_t size) { -// const size_t num_beams = state_.params_->search.num_beams; -// if (num_beams != 1) -// throw std::runtime_error("Get with start and size not supported for num_beams != 1, got " + std::to_string(num_beams)); -// if (shape_[0] != 1) -// throw std::runtime_error("Get with start and size not supported for batch size != 1, got " + std::to_string(shape_[0])); - -// size_t element_count = shape_[1] * shape_[2]; -// size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; -// size_t selected_element_count = size * shape_[2]; - -// output_last_tokens_ = OrtValue::CreateTensor(*model_.allocator_device_, std::array({1, static_cast(size), shape_[2]}), type_); -// OrtValue* logits_of_selected_tokens = output_last_tokens_.get(); - -// auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; -// auto logits_of_selected_tokens_raw = std::span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count * element_size}; -// auto source = logits_raw.subspan(start * shape_[2] * element_size, selected_element_count * element_size); -// copy(source, logits_of_selected_tokens_raw); - -// if (type_ == Ort::TypeToTensorType) { -// std::unique_ptr logits_of_selected_tokens_fp32; -// ConvertFp16ToFp32(*model_.allocator_device_, *logits_of_selected_tokens, logits_of_selected_tokens_fp32, model_.device_type_, model_.cuda_stream_); -// output_last_tokens_ = std::move(logits_of_selected_tokens_fp32); -// logits_of_selected_tokens = output_last_tokens_.get(); -// } - -// auto batched_logits_cpu = cpu_span{logits_of_selected_tokens->GetTensorMutableData(), selected_element_count}; -// HandleEOSArray(batched_logits_cpu); -// return batched_logits_cpu; -// } - void Logits::Update(int new_kv_length) { if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == new_kv_length) { return; @@ -405,17 +210,6 @@ void Logits::Update(int new_kv_length) { state_.outputs_[output_index_] = output_raw_.get(); } -// void Logits::Update() { -// if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == 1) { -// return; -// } - -// StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; -// output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) -// : sb_logits->CreateTensorOnStaticBuffer(shape_, type_); -// state_.outputs_[output_index_] = output_raw_.get(); -// } - void Logits::HandleEOSArray(cpu_span batched_logits) { if (model_.config_->model.eos_token_ids.empty()) return; diff --git a/src/models/logits.h b/src/models/logits.h index cd0fc25cc..6307c1ec7 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -12,15 +12,10 @@ struct Logits { // Register input_ids as ORT session input. void Add(); // For first iteration, find last token of each beam and store it in output_last_tokens_. - // Also resizes logits to [bz, 1, vocab_size] for subsequent calls. RoamingArray Get(); - // Retrieves logits[:, start:start + size, :]. - RoamingArray Get(size_t start, size_t size); // batch_size x size x vocab_size - // void Update(); + // Resize logits to [bz, token_count, vocab_size] if necessary. void Update(int new_kv_length); - // Resize logits to [bz, token_count, vocab_size]. - // void Update(size_t token_count); private: void HandleEOSArray(cpu_span logits); diff --git a/src/models/model.cpp b/src/models/model.cpp index 2cb7c5eab..dac18d4e2 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include #include -#include #include "../generators.h" #include "../search.h" diff --git a/src/models/model.h b/src/models/model.h index 4e8a78f5a..b57ef7725 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -29,11 +29,7 @@ struct State { State(const GeneratorParams& params, const Model& model_); virtual ~State() = default; - virtual RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; - // Used by continuous decoding - virtual RoamingArray Run(RoamingArray sequence, int next_token_length, int past_length, int return_last_logit_count) { throw std::runtime_error("Not implemented"); }; - // virtual void AddInputTokens(const RoamingArray& tokens) { throw std::runtime_error("Not implemented"); }; - + virtual RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices = {}) = 0; virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; } OrtValue* GetOutput(const char* name); @@ -43,13 +39,11 @@ struct State { std::vector input_names_, output_names_; std::vector inputs_, outputs_; - // InputIDs input_ids_{model_, *this}; // TODO(aciddelgado): is this ok here? - InputIDs input_ids_; // TODO(aciddelgado): is this ok here? + InputIDs input_ids_; protected: void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run void ClearIO(); // Clear all inputs/outputs - bool first_run_{true}; private: @@ -123,7 +117,6 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::shared_ptr CreateMultiModalProcessor() const; virtual std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const = 0; - // virtual std::unique_ptr CreateState(const GeneratorParams& params) const = 0; std::unique_ptr ExpandInputs(std::unique_ptr& input, int num_beams) const; diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 5270aac22..1ee305755 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -4,6 +4,8 @@ #include "../generators.h" #include "multi_modal_vision_model.h" +// TODO(aciddelgado): update to use new input logic + namespace Generators { namespace { @@ -223,7 +225,7 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray } void DecoderState::UpdateInputsOutputs(int total_length, RoamingArray beam_indices) { - size_t new_length = input_ids_.GetShape()[1]; + size_t new_length = input_ids_.GetShape()[1]; // TODO(aciddelgado): looks like this input_ids_ is not updated by add_tokens position_inputs_.Update(total_length, new_length); kv_cache_.Update(beam_indices.GetCPU(), total_length); logits_.Update(new_length); @@ -257,7 +259,6 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra vision_state_->Run(current_length, next_tokens, next_indices); // Run the select logic - // TODO(aciddelgado): this may not work logically, done to get it to compile for decoder_only const auto* input_ids = decoder_state_->input_ids_.Get()->GetTensorData(); auto input_ids_span = std::span(input_ids, decoder_state_->input_ids_.GetShape()[1]); Select(model_, input_ids_span, embedding_state_->inputs_embeds_.Get(), diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index b00462b4d..2cb70d92e 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,7 +42,6 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; - // InputIDs input_ids_{model_, *this}; // Model input Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index d4f9f027f..1df8443ba 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -9,7 +9,6 @@ namespace Generators { -// TODO(aciddelgado): WE HAVE REMOVED THE INITIALIZATION WITH SEQUENCE LENGTH HERE PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray& sequence_lengths_unk) : model_{model}, state_{state} { @@ -33,22 +32,12 @@ PositionInputs::PositionInputs(const Model& model, State& state, RoamingArray shape{state_.params_->batch_size, 0}; // Only batch_size initially, as we haven't expanded over the beams yet - // position_ids_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - // position_ids_next_ = OrtValue::CreateTensor(model.allocator_cpu_, std::array{shape[0], 1}, type_); - // attention_mask_ = OrtValue::CreateTensor(model.allocator_cpu_, shape, type_); - - // initial_sequence_lengths_.resize(state_.params_->BatchBeamSize()); if (type_ == Ort::TypeToTensorType) InitializeSequenceLengths(shape, sequence_lengths_unk); else InitializeSequenceLengths(shape, sequence_lengths_unk); - // TODO(aciddelgado): what is this? does it break with 0 length? - // position_ids_ = model_.ExpandInputs(position_ids_, state_.params_->search.num_beams); - // position_ids_next_ = model_.ExpandInputs(position_ids_next_, state_.params_->search.num_beams); - // attention_mask_ = model_.ExpandInputs(attention_mask_, state_.params_->search.num_beams); - // shape[0] *= state_.params_->search.num_beams; position_ids_shape_ = shape; attention_mask_shape_ = shape; @@ -77,15 +66,6 @@ void PositionInputs::Add() { } } -// void PositionInputs::Update(int current_length) { -// if (has_posid_input_) { -// UpdatePositionIDs(current_length); -// } -// if (has_mask_input_) { -// UpdateAttentionMask(current_length); -// } -// } - void PositionInputs::Update(int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update @@ -237,7 +217,7 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { // Support batch_size == 1 only with current length > 0 and new kv length > 1 if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); - // Don't support DML + // Doesn't support DML at the moment if (model_.device_type_ == DeviceType::DML) throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); // Reallocate position_ids when new_kv_length changes @@ -249,9 +229,6 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { #if USE_CUDA position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); assert(model_.device_type_ == DeviceType::CUDA); -// #elif USE_DML -// position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); -// assert(model_.device_type_ == DeviceType::DML); #endif } state_.inputs_[posid_input_index_] = position_ids_.get(); @@ -259,27 +236,6 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { is_first_posid_update_ = false; // Just incrementing existing position IDs switch (model_.device_type_) { -// #if USE_DML -// case DeviceType::DML: { -// ComPtr target_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, position_ids_->GetTensorMutableRawData(), &target_resource)); - -// // Lazily create the kernel only the first time it's needed -// if (!dml_update_position_ids_kernel_) { -// dml_update_position_ids_kernel_ = DmlIncrementValuesKernel( -// model_.GetD3D12Device(), -// model_.GetDmlExecutionContext(), -// static_cast(position_ids_shape_[0]), -// type_, -// target_resource.Get()); -// } - -// // Execute the cached command list -// ComPtr fence; -// uint64_t completion_value; -// model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value); -// } break; -// #endif case DeviceType::CPU: { if (type_ == Ort::TypeToTensorType) UpdatePositionIDsImpl(total_length, new_kv_length); @@ -300,183 +256,7 @@ void PositionInputs::UpdatePositionIDs(int total_length, int new_kv_length) { } } -// void PositionInputs::UpdatePositionIDs(int current_length, int new_length) { -// // if (model_.device_type_ != DeviceType::CPU) -// // throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported on CPU."); -// // if (position_ids_shape_[0] != 1) -// // throw std::runtime_error("PositionInputs::UpdatePositionIDs - past_length only supported for batch_size=1."); -// position_ids_shape_[1] = new_length; -// position_ids_ = OrtValue::CreateTensor(*model_.allocator_device_, position_ids_shape_, type_); -// if (type_ == Ort::TypeToTensorType) -// UpdatePositionIDsImpl(current_length, past_length); -// else -// UpdatePositionIDsImpl(current_length, past_length); -// state_.inputs_[posid_input_index_] = position_ids_.get(); -// } - -void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { - // Support batch_size == 1 only with current length > 0 and new kv length > 1 - if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); - // Don't support DML - if (model_.device_type_ == DeviceType::DML) - throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); - // Update attention mask -// if (sb_attention_mask_) { -// #if USE_CUDA -// attention_mask_shape_[1] = state_.params_->search.max_length; -// attention_mask_next_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -// if (is_first_mask_update_) { -// if (type_ == Ort::TypeToTensorType) { -// cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), -// 0, -// sizeof(int32_t) * attention_mask_shape_[0] * attention_mask_shape_[1], -// model_.cuda_stream_); -// } else { -// cudaMemsetAsync(attention_mask_next_->GetTensorMutableRawData(), -// 0, -// sizeof(int64_t) * attention_mask_shape_[0] * attention_mask_shape_[1], -// model_.cuda_stream_); -// } -// } -// // #elif USE_DML -// // attention_mask_shape_[1] = state_.params_->search.max_length; -// // attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -// // attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -// #endif -// } else { -// attention_mask_shape_[1] = total_length; - -// // #if USE_DML -// // if (model_.device_type_ == DeviceType::DML) { -// // attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); -// // } -// // #endif - -// attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); -// } - - if (sb_attention_mask_ && is_first_mask_update_) { -#if USE_CUDA - attention_mask_shape_[1] = state_.params_->search.max_length; - attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); - if (is_first_mask_update_) { - int past_length = total_length - new_kv_length; - if (type_ == Ort::TypeToTensorType) { - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), - 1, - sizeof(int32_t) * past_length, - model_.cuda_stream_); - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, - 0, - sizeof(int32_t) * (total_length - past_length), - model_.cuda_stream_); - } else { - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), - 1, - sizeof(int64_t) * past_length, - model_.cuda_stream_); - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, - 0, - sizeof(int64_t) * (total_length - past_length), - model_.cuda_stream_); - } - } -// #elif USE_DML -// attention_mask_shape_[1] = state_.params_->search.max_length; -// attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -// attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); -#endif - } else { - attention_mask_shape_[1] = total_length; - attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); - } - - switch (model_.device_type_) { -// #if USE_DML -// case DeviceType::DML: { -// ComPtr attention_mask_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_->GetTensorMutableRawData(), &attention_mask_resource)); - -// ComPtr attention_mask_next_resource; -// Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, attention_mask_next_->GetTensorMutableRawData(), &attention_mask_next_resource)); - -// if (is_first_mask_update_) { -// dml_update_mask_kernel_ = DmlUpdateMaskKernel( -// model_.GetD3D12Device(), -// model_.GetDmlExecutionContext(), -// static_cast(attention_mask_shape_[0]), -// static_cast(attention_mask_shape_[1]), -// type_, -// current_length, -// attention_mask_resource.Get(), -// attention_mask_next_resource.Get()); -// is_second_mask_update_ = true; -// } else if (is_second_mask_update_) { -// dml_update_mask_kernel_ = DmlUpdateMaskKernel( -// model_.GetD3D12Device(), -// model_.GetDmlExecutionContext(), -// static_cast(attention_mask_shape_[0]), -// static_cast(attention_mask_shape_[1]), -// type_, -// 1, -// attention_mask_resource.Get(), -// attention_mask_next_resource.Get()); -// is_second_mask_update_ = false; -// } - -// ComPtr fence; -// uint64_t completion_value; -// model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_mask_kernel_->GetCommandList(), &fence, &completion_value); -// break; -// } -// #endif - case DeviceType::CPU: { - if (type_ == Ort::TypeToTensorType) - UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); - else - UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); - break; - } -#if USE_CUDA - case DeviceType::CUDA: { - // int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : total_length; - bool update_static = sb_attention_mask_; - if (type_ == Ort::TypeToTensorType) { - cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), - new_kv_length, - total_length, - update_static, - model_.cuda_stream_); - } else { - cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), - new_kv_length, - total_length, - update_static, - model_.cuda_stream_); - } - break; - } -#endif - default: - throw std::runtime_error("PositionInputs::Update - Unsupported device type"); - } - -// #if USE_DML -// if (model_.device_type_ != DeviceType::DML) { -// attention_mask_ = std::move(attention_mask_next_); -// } -// #else - // attention_mask_ = std::move(attention_mask_next_); -// #endif - - // LEFT OFF: state_.inputs_[mask_input_index_] = attention_mask_.get(); not working - state_.inputs_[mask_input_index_] = attention_mask_.get(); - - is_first_mask_update_ = false; -} - -void PositionInputs::UpdateAttentionMask(int current_length) { +void PositionInputs::UpdateAttentionMask(int total_length) { // Update attention mask if (sb_attention_mask_) { #if USE_CUDA @@ -501,8 +281,8 @@ void PositionInputs::UpdateAttentionMask(int current_length) { attention_mask_next_ = sb_attention_mask_next_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); #endif } else { - assert(attention_mask_shape_[1] == current_length - 1); // We should always be growing by 1 - attention_mask_shape_[1] = current_length; + assert(attention_mask_shape_[1] == total_length - 1); // We should always be growing by 1 + attention_mask_shape_[1] = total_length; #if USE_DML if (model_.device_type_ == DeviceType::DML) { @@ -528,7 +308,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { static_cast(attention_mask_shape_[0]), static_cast(attention_mask_shape_[1]), type_, - current_length, + total_length, attention_mask_resource.Get(), attention_mask_next_resource.Get()); is_second_mask_update_ = true; @@ -555,22 +335,22 @@ void PositionInputs::UpdateAttentionMask(int current_length) { if (type_ == Ort::TypeToTensorType) UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), - current_length); + total_length); else UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), - current_length); + total_length); break; } #if USE_CUDA case DeviceType::CUDA: { - int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : current_length; + int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : total_length; bool update_only = sb_attention_mask_ && !is_first_mask_update_; if (type_ == Ort::TypeToTensorType) { cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), static_cast(attention_mask_shape_[0]), - current_length, + total_length, max_seq_len, update_only, model_.cuda_stream_); @@ -578,7 +358,7 @@ void PositionInputs::UpdateAttentionMask(int current_length) { cuda::Launch_UpdateAttentionMask(attention_mask_next_->GetTensorMutableData(), attention_mask_->GetTensorData(), static_cast(attention_mask_shape_[0]), - current_length, + total_length, max_seq_len, update_only, model_.cuda_stream_); @@ -599,25 +379,83 @@ void PositionInputs::UpdateAttentionMask(int current_length) { #endif state_.inputs_[mask_input_index_] = attention_mask_.get(); - is_first_mask_update_ = false; } -// void PositionInputs::UpdateAttentionMask(int current_length, int new_length) { -// if (model_.device_type_ != DeviceType::CPU) -// throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported on CPU."); -// if (attention_mask_shape_[0] != 1) -// throw std::runtime_error("PositionInputs::UpdateAttentionMask - past_length only supported for batch_size=1."); -// attention_mask_shape_[1] = current_length; -// attention_mask_next_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); -// if (type_ == Ort::TypeToTensorType) -// UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); -// else -// UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), current_length, past_length); -// attention_mask_ = std::move(attention_mask_next_); -// state_.inputs_[mask_input_index_] = attention_mask_.get(); -// is_first_mask_update_ = false; -// } +void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { + // Support batch_size == 1 only with current length > 0 and new kv length > 1 + if (position_ids_shape_[0] != 1 && !(total_length == 0 || new_kv_length == 1)) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - batch_size must be 1 for continuous decoding."); + // Doesn't support DML at the moment + if (model_.device_type_ == DeviceType::DML) + throw std::runtime_error("PositionInputs::UpdatePositionIDs - DML not supported for continuous decoding."); + // Update attention mask + if (sb_attention_mask_ && is_first_mask_update_) { +#if USE_CUDA + attention_mask_shape_[1] = state_.params_->search.max_length; + attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); + if (is_first_mask_update_) { + int past_length = total_length - new_kv_length; + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (total_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (total_length - past_length), + model_.cuda_stream_); + } + } +#endif + } else { + attention_mask_shape_[1] = total_length; + attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); + } + + switch (model_.device_type_) { + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + else + UpdateAttentionMaskImpl(attention_mask_->GetTensorMutableData(), total_length); + break; + } +#if USE_CUDA + case DeviceType::CUDA: { + bool update_static = sb_attention_mask_; + if (type_ == Ort::TypeToTensorType) { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } else { + cuda::Launch_UpdateAttentionMask(attention_mask_->GetTensorMutableData(), + new_kv_length, + total_length, + update_static, + model_.cuda_stream_); + } + break; + } +#endif + default: + throw std::runtime_error("PositionInputs::Update - Unsupported device type"); + } + + state_.inputs_[mask_input_index_] = attention_mask_.get(); + is_first_mask_update_ = false; +} template void PositionInputs::CreateAndInitializePositionIDs(std::array shape) { @@ -640,7 +478,6 @@ void PositionInputs::CreateAndInitializePositionIDs(std::array shape } position_data_next[i] = abs_position; - // initial_sequence_lengths_[i] = static_cast(abs_position); } // Move tensors to appropriate device and expand by num_beams @@ -667,8 +504,6 @@ void PositionInputs::CreateAndInitializeAttentionMask(std::array sha *mask = 1; } } - - // initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); } // Move tensors to appropriate device and expand by num_beams @@ -677,71 +512,10 @@ void PositionInputs::CreateAndInitializeAttentionMask(std::array sha state_.inputs_[mask_input_index_] = attention_mask_.get(); } -// template -// void PositionInputs::InitializeTensors(std::array shape/*, cpu_span sequence_lengths_unk*/) { -// // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. -// // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens -// auto* mask_data = attention_mask_->GetTensorMutableData(); -// auto* position_data = position_ids_->GetTensorMutableData(); -// auto* position_data_next = position_ids_next_->GetTensorMutableData(); -// const auto* word_id = state_.params_->input_ids.data(); -// auto* mask = mask_data; -// auto* position = position_data; -// for (int i = 0; i < shape[0]; i++) { -// T abs_position = 0; -// for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { -// if (*word_id == state_.params_->pad_token_id) { -// *mask = 0; -// *position = 0; -// } else { -// *mask = 1; -// *position = abs_position++; -// } -// } - -// position_data_next[i] = abs_position; -// for (int k = 0; k < state_.params_->search.num_beams; k++) { -// // sequence_lengths_unk[i * state_.params_->search.num_beams + k] = static_cast(abs_position); -// initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); -// } -// } -// } - -// template -// void PositionInputs::InitializeTensors(std::array shape/*, cpu_span sequence_lengths_unk*/) { -// // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. -// // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens -// auto* mask_data = attention_mask_->GetTensorMutableData(); -// auto* position_data = position_ids_->GetTensorMutableData(); -// auto* position_data_next = position_ids_next_->GetTensorMutableData(); -// const auto* word_id = state_.params_->input_ids.data(); -// auto* mask = mask_data; -// auto* position = position_data; -// for (int i = 0; i < shape[0]; i++) { -// T abs_position = 0; -// for (int j = 0; j < shape[1]; j++, word_id++, mask++, position++) { -// if (*word_id == state_.params_->pad_token_id) { -// *mask = 0; -// *position = 0; -// } else { -// *mask = 1; -// *position = abs_position++; -// } -// } - -// position_data_next[i] = abs_position; -// for (int k = 0; k < state_.params_->search.num_beams; k++) { -// // sequence_lengths_unk[i * state_.params_->search.num_beams + k] = static_cast(abs_position); -// initial_sequence_lengths_[i * state_.params_->search.num_beams + k] = static_cast(abs_position); -// } -// } -// } - template void PositionInputs::InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk) { for (int i = 0; i < shape[0] * state_.params_->search.num_beams; i++) { sequence_lengths_unk[i] = 0; - // initial_sequence_lengths_[i] = 0; } } diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index de8875835..45321ef38 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -14,7 +14,6 @@ struct PositionInputs { PositionInputs(const Model& model, State& state); void Add(); - // void Update(int current_length); void Update(int total_length, int new_length); private: @@ -24,12 +23,10 @@ struct PositionInputs { // Batch size > 1 case void UpdatePositionIDs(); void UpdateAttentionMask(int total_length); - // Used by continuous decoding. + // Batch size == 1 case. void UpdatePositionIDs(int total_length, int new_length); void UpdateAttentionMask(int total_length, int new_length); - // template - // void InitializeTensors(std::array shape/*, cpu_span sequence_lengths*/); template void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); template @@ -42,7 +39,6 @@ struct PositionInputs { template void UpdateAttentionMaskImpl(T* data, const T* old_data, int current_length); - // Used by continuous decoding template void UpdatePositionIDsImpl(int total_length, int new_kv_length); template @@ -59,14 +55,13 @@ struct PositionInputs { bool has_mask_input_{false}; bool has_posid_input_{false}; - std::array position_ids_shape_{}; // {params.batch_size*params.beam_size, params.sequence_length} + std::array position_ids_shape_{}; std::unique_ptr position_ids_; - std::array attention_mask_shape_{}; // {params.batch_size*params.beam_size, params.sequence_length} + std::array attention_mask_shape_{}; std::unique_ptr attention_mask_; std::unique_ptr position_ids_next_; // Replaces position_ids_ after the first Run() call std::unique_ptr attention_mask_next_; // Replaces attention_mask_ after the first Run() call - // std::vector initial_sequence_lengths_; // Used for decoding runs with cuda graphs. StaticBuffer* sb_position_ids_{}; diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 0eae1b196..d5e34bd4a 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -3,6 +3,8 @@ #include "../generators.h" #include "whisper.h" +// TODO(aciddelgado): update whisper to new paradigm + namespace Generators { Whisper_Model::Whisper_Model(std::unique_ptr config, OrtEnv& ort_env) @@ -31,7 +33,7 @@ Whisper_State::Whisper_State(const Whisper_Model& model, RoamingArray s auto sequence_lengths = sequence_lengths_unk.GetCPU(); for (int i = 0; i < decoder_input_ids_.GetShape()[0]; i++) { - sequence_lengths[i] = 0; // TODO(aciddelgado): what? static_cast(params_->sequence_length); + sequence_lengths[i] = 0; } input_names_.push_back("encoder_input_ids"); diff --git a/src/models/whisper.h b/src/models/whisper.h index aaf3a3fb4..5f3872e4c 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -33,7 +33,6 @@ struct Whisper_State : State { Decoder, } run_state_{RunState::Encoder_Decoder_Init}; - // TODO(aciddelgado): does decoder_input_ids behave differentely than input_ids_? InputIDs decoder_input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache kv_cache_{model_, *this}; diff --git a/src/ort_genai.h b/src/ort_genai.h index 8c853eff7..e53617fa5 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -189,16 +189,7 @@ struct OgaGeneratorParams : OgaAbstract { void SetSearchOptionBool(const char* name, bool value) { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); - } - - // void SetInputIDs(const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { - // OgaCheckResult(OgaGeneratorParamsSetInputIDs(this, input_ids, input_ids_count, sequence_length, batch_size)); - // } - - // void SetInputSequences(const OgaSequences& sequences) { - // OgaCheckResult(OgaGeneratorParamsSetInputSequences(this, &sequences)); - // } - + } void SetModelInput(const char* name, OgaTensor& tensor) { OgaCheckResult(OgaGeneratorParamsSetModelInput(this, name, &tensor)); diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index dc5c2dbe4..49c1b9552 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -130,36 +130,6 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGen OGA_CATCH } -// OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* oga_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) { -// OGA_TRY -// auto& params = *reinterpret_cast(oga_params); -// params.input_ids = std::span(input_ids, input_ids_count); -// params.sequence_length = static_cast(sequence_length); -// params.batch_size = static_cast(batch_size); -// if (params.sequence_length * params.batch_size != input_ids_count) -// throw std::runtime_error("sequence length * batch size is not equal to input_ids_count"); -// return nullptr; -// OGA_CATCH -// } - -// OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* oga_params, const OgaSequences* p_sequences) { -// OGA_TRY -// auto& params = *reinterpret_cast(oga_params); -// auto& sequences = *reinterpret_cast(p_sequences); - -// std::vector> span_sequences; -// for (size_t i = 0; i < sequences.size(); i++) { -// span_sequences.emplace_back(sequences[i]); -// } - -// params.input_ids_owner = Generators::PadInputs(span_sequences, params.pad_token_id); -// params.batch_size = static_cast(sequences.size()); -// params.sequence_length = static_cast(params.input_ids_owner.size() / params.batch_size); -// params.input_ids = params.input_ids_owner; -// return nullptr; -// OGA_CATCH -// } - OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* oga_params, const OgaNamedTensors* p_named_tensors) { OGA_TRY auto& params = *reinterpret_cast(oga_params); @@ -188,13 +158,13 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorPa OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { - OGA_TRY - auto result = Generators::Generate(*reinterpret_cast(model), *reinterpret_cast(generator_params)); - *out = reinterpret_cast(std::make_unique(std::move(result)).release()); - return nullptr; - OGA_CATCH -} +// OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { +// OGA_TRY +// auto result = Generators::Generate(*reinterpret_cast(model), *reinterpret_cast(generator_params)); +// *out = reinterpret_cast(std::make_unique(std::move(result)).release()); +// return nullptr; +// OGA_CATCH +// } OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY @@ -232,13 +202,6 @@ OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, OGA_CATCH } -// OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { -// OGA_TRY -// reinterpret_cast(generator)->ComputeLogits(); -// return nullptr; -// OGA_CATCH -// } - OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { OGA_TRY reinterpret_cast(generator)->GenerateNextToken(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index a953cd7a7..33427f3e2 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -170,26 +170,6 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGenerato OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value); OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size); -/* - * \brief Sets the input ids for the generator params. The input ids are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] input_ids The input ids array of size input_ids_count = batch_size * sequence_length. - * \param[in] input_ids_count The total number of input ids. - * \param[in] sequence_length The sequence length of the input ids. - * \param[in] batch_size The batch size of the input ids. - * \return OgaResult containing the error message if the setting of the input ids failed. - */ -// OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, -// size_t input_ids_count, size_t sequence_length, size_t batch_size); - -/* - * \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation. - * \param[in] generator_params The generator params to set the input ids on. - * \param[in] sequences The input id sequences. - * \return OgaResult containing the error message if the setting of the input id sequences failed. - */ -// OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences); - OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputs(OgaGeneratorParams* generator_params, const OgaNamedTensors* named_tensors); /* @@ -227,12 +207,19 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); /* * \brief Adds the input ids to the generator. The input ids are used to seed the generation. - * \param[in] oga_params The generator params to get the pad token id. * \param[in] oga_generator The generator to add the input ids to. * \param[in] p_sequences The input id sequences. * \return OgaResult containing the error message if the setting of the input ids failed. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputSequences(OgaGenerator* oga_generator, const OgaSequences* p_sequences); + +/* + * \brief Adds the input ids to the generator. The input ids are used to seed the generation. + * \param[in] oga_generator The generator to add the input ids to. + * \param[in] input_ids The input ids to add. + * \param[in] input_ids_count The number of input ids to add (batch_size * sequence_length). + * \return OgaResult containing the error message if the setting of the input ids failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga_generator, int32_t* input_ids, size_t input_ids_count); /* diff --git a/src/python/python.cpp b/src/python/python.cpp index d8bf74e52..4768b830d 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -220,23 +220,7 @@ struct PyGeneratorParams { std::shared_ptr params_; - // Turn the python py_input_ids_ into the low level parameters void Prepare() { - // TODO: This will switch to using the variant vs being ifs - // if (py_input_ids_.size() != 0) { - // if (py_input_ids_.ndim() == 1) { // Just a 1D array - // params_->batch_size = 1; - // params_->sequence_length = static_cast(py_input_ids_.shape(0)); - // } else { - // if (py_input_ids_.ndim() != 2) - // throw std::runtime_error("Input IDs can only be 1 or 2 dimensional"); - - // params_->batch_size = static_cast(py_input_ids_.shape(0)); - // params_->sequence_length = static_cast(py_input_ids_.shape(1)); - // } - // params_->input_ids = ToSpan(py_input_ids_); - // } - if (py_whisper_input_features_.size() != 0) { GeneratorParams::Whisper& whisper = params_->inputs.emplace(); whisper.input_features = std::make_shared(ToOrtValue(py_whisper_input_features_)); @@ -275,7 +259,6 @@ struct PyGeneratorParams { params_->TryGraphCapture(max_batch_size.cast()); } - // pybind11::array_t py_input_ids_; pybind11::array_t py_whisper_input_features_; std::vector refs_; // References to data we want to ensure doesn't get garbage collected @@ -290,7 +273,6 @@ struct PyNamedTensors { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { - // params.Prepare(); generator_ = CreateGenerator(model, params); } @@ -304,15 +286,10 @@ struct PyGenerator { return ToPython(py_sequence_.GetCPU()); } - // void ComputeLogits() { - // generator_->ComputeLogits(); - // } - pybind11::array GetOutput(const std::string& name) { return ToNumpy(generator_->state_->GetOutput(name.c_str()), *(generator_->model_)); } - // TODO(aciddelgado): Does this work with batch size > 1? void AddTokens(pybind11::array_t tokens) { generator_->AddTokens(ToSpan(tokens)); } @@ -380,7 +357,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_property_readonly("pad_token_id", [](const PyGeneratorParams& v) { return v.params_->pad_token_id; }) .def_property_readonly("eos_token_id", [](const PyGeneratorParams& v) { return v.params_->eos_token_id; }) .def_property_readonly("vocab_size", [](const PyGeneratorParams& v) { return v.params_->vocab_size; }) - // .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) { if (!named_tensors || !named_tensors->named_tensors_) @@ -418,10 +394,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_>(m, "Model") .def(pybind11::init([](const std::string& config_path) { - std::cout << "Loading model from: " << config_path << std::endl; return CreateModel(GetOrtEnv(), config_path.c_str()); })) - // .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) .def_property_readonly( "device_type", [](const Model& model) { return to_string(model.device_type_); }, "The device type the model is running on") .def("create_multimodal_processor", [](const Model& model) { return model.CreateMultiModalProcessor(); }); @@ -429,7 +403,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "Generator") .def(pybind11::init()) .def("is_done", &PyGenerator::IsDone) - // .def("compute_logits", &PyGenerator::ComputeLogits) .def("get_output", &PyGenerator::GetOutput) .def("add_input_tokens", &PyGenerator::AddTokens) .def("generate_next_token", &PyGenerator::GenerateNextToken) diff --git a/src/search.cpp b/src/search.cpp index 1770e5b9f..3442c34a4 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -9,7 +9,7 @@ namespace Generators { Search_Cpu::Search_Cpu(const GeneratorParams& params) : Search{params}, - sequences_{/*params.input_ids,*/ params.batch_size, params.search.num_beams, params_->search.max_length} { + sequences_{params.batch_size, params.search.num_beams, params_->search.max_length} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = AllocateArray(batch_beam_size, &sequence_lengths_); } @@ -26,7 +26,6 @@ GreedySearch_Cpu::GreedySearch_Cpu(const GeneratorParams& params) gen_.seed(seq); } - // TODO(aciddelgado): the reason we don't use next tokens for user input is that we'd have to allocate a new buffer for different input sizes and it would be a useless copy. next_tokens_buffer_ = AllocateArray(params.batch_size, &next_tokens_); memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); diff --git a/src/search.h b/src/search.h index 019f81cab..ee48320c7 100644 --- a/src/search.h +++ b/src/search.h @@ -14,7 +14,6 @@ struct Search : LeakChecked { virtual RoamingArray GetSequenceLengths() = 0; virtual int GetSequenceLength() const = 0; virtual RoamingArray GetSequence(size_t index) = 0; - // TODO(aciddelgado): do we want a GetSequences() API? virtual void SetLogits(RoamingArray logits) = 0; virtual bool IsDone() const = 0; @@ -28,9 +27,10 @@ struct Search : LeakChecked { virtual void ApplyMinLength(int min_length) = 0; virtual void ApplyRepetitionPenalty(float penalty) = 0; - // Used by Continuous Decoding - virtual void DropLastTokens(size_t num_tokens) { assert(false); }; + // Set user input tokens virtual void SetNextTokens(RoamingArray next_tokens) { assert(false); }; + // To be used for rewind + virtual void DropLastTokens(size_t num_tokens) { assert(false); }; std::shared_ptr params_; }; diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index b4f4b2b2d..ba0918f1e 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -17,7 +17,7 @@ void OnCudaError(cudaError_t error) { Search_Cuda::Search_Cuda(const GeneratorParams& params) : Search{params}, - sequences_{/*params.input_ids,*/ params.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { + sequences_{params.batch_size, params.search.num_beams, params_->search.max_length, params_->cuda_stream} { auto batch_beam_size = params.BatchBeamSize(); sequence_lengths_buffer_ = std::make_unique(batch_beam_size); sequence_lengths_ = cpu_span(sequence_lengths_buffer_.get(), batch_beam_size); diff --git a/src/sequences.cpp b/src/sequences.cpp index ad1ca2d43..878688e9f 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -6,33 +6,6 @@ namespace Generators { -Sequences::Sequences(std::span input_sequences, int batch_size, int beam_size, int max_length) - : batch_beam_size_{batch_size * beam_size}, - max_length_{max_length}, - current_length_{static_cast(input_sequences.size()) / batch_size} { - assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly - const size_t sequences_size = static_cast(batch_beam_size_) * max_length; - - if (beam_size == 1) { - sequences_buffer_ = std::make_unique(sequences_size); - sequences_ = cpu_span(sequences_buffer_.get(), sequences_size); - } else { - sequences_buffer_ = std::make_unique(2 * sequences_size); - sequences_ = cpu_span(sequences_buffer_.get(), sequences_size); - sequences_next_ = cpu_span(sequences_buffer_.get() + sequences_size, sequences_size); - } - - // The original inputs are not expanded, this expands them in place into the sequences - for (size_t batch = 0; batch < batch_size; batch++) { - for (size_t beam = 0; beam < beam_size; beam++) { - for (int j = 0; j < current_length_; j++) { - sequences_[(batch * beam_size + beam) * max_length + j] = - static_cast(input_sequences[batch * current_length_ + j]); - } - } - } -} - Sequences::Sequences(int batch_size, int beam_size, int max_length) : batch_beam_size_{batch_size * beam_size}, max_length_{max_length}, diff --git a/src/sequences.h b/src/sequences.h index 538358200..dbe5f770a 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -3,7 +3,6 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences { - Sequences(std::span input_sequence, int batch_size, int beam_size, int max_length); Sequences(int batch_size, int beam_size, int max_length); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). @@ -20,8 +19,7 @@ struct Sequences { // Used by Greedy search: void AppendNextTokenToSequences(std::span next_tokens); - // TODO(aciddelgado): Rewind sequences function - // Used by Speculative search: + // TODO(aciddelgado): To be used for rewind void DropLastTokens(size_t num_tokens); private: diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index 1e6a2cfe6..33b74d96b 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -10,13 +10,13 @@ void Launch_ExpandInputSequences(std::span input_sequences, std:: void Launch_AppendNextTokenToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); } // namespace cuda -// TODO(aciddelgado): make cuda sequences functional -Sequences_Cuda::Sequences_Cuda(/*std::span input_sequences,*/ int batch_size, int beam_size, int max_length, cudaStream_t stream) +// TODO(aciddelgado): update cuda sequences to new paradigm + +Sequences_Cuda::Sequences_Cuda(int batch_size, int beam_size, int max_length, cudaStream_t stream) : stream_{stream}, batch_beam_size_{batch_size * beam_size}, max_length_{max_length}, current_length_{0} { - // assert(current_length_ * batch_size == input_sequences.size()); // Ensure size divided perfectly size_t sequences_size = batch_beam_size_ * max_length; if (beam_size == 1) { @@ -31,10 +31,7 @@ Sequences_Cuda::Sequences_Cuda(/*std::span input_sequences,*/ int // TODO: input_sequences will be in cuda memory in the future, for now make a temp copy gpu_span input_sequences_gpu; - // auto input_sequences_temp = CudaMallocArray(input_sequences.size(), &input_sequences_gpu); - // cudaMemcpyAsync(input_sequences_gpu.data(), input_sequences.data(), input_sequences.size_bytes(), cudaMemcpyHostToDevice, stream); - // cuda::Launch_ExpandInputSequences(input_sequences_gpu, sequences_, batch_size, beam_size, current_length_, max_length, stream_); cudaStreamSynchronize(stream); // Until we remove the todo above, wait for this to complete as input_sequences_gpu is on the stack } diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index b787b6e9a..c1371bed9 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -3,7 +3,7 @@ namespace Generators { // This class keeps track of sequences generated. struct Sequences_Cuda { - Sequences_Cuda(/*std::span input_sequences,*/ int batch_size, int beam_size, int max_length, cudaStream_t stream); + Sequences_Cuda(int batch_size, int beam_size, int max_length, cudaStream_t stream); // Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size). RoamingArray GetSequence(size_t batch_beam_index); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 91eac8817..7d19a6912 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -4,7 +4,6 @@ #include #include #include -#define MODEL_PATH "/home/aciddelgado/ort-genai-source/test/test_models/" #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -103,36 +102,6 @@ TEST(CAPITests, AppendTokensToSequence) { #endif } -// TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now -// TEST(CAPITests, EndToEndPhiBatch) { -// #if TEST_PHI2 -// auto model = OgaModel::Create(MODEL_PATH "phi-2"); -// auto tokenizer = OgaTokenizer::Create(*model); - -// const char* input_strings[] = { -// "This is a test.", -// "Rats are awesome pets!", -// "The quick brown fox jumps over the lazy dog.", -// }; - -// auto input_sequences = OgaSequences::Create(); -// for (auto& string : input_strings) -// tokenizer->Encode(string, *input_sequences); - -// auto params = OgaGeneratorParams::Create(*model); -// params->SetSearchOption("max_length", 20); -// params->SetInputSequences(*input_sequences); - -// auto output_sequences = model->Generate(*params); - -// // Decode The Batch -// for (size_t i = 0; i < output_sequences->Count(); i++) { -// auto out_string = tokenizer->Decode(output_sequences->Get(i)); -// std::cout << "Decoded string:" << out_string << std::endl; -// } -// #endif -// } - TEST(CAPITests, Tensor_And_AddExtraInput) { // Create a [3 4] shaped tensor std::array data{0, 1, 2, 3, @@ -177,21 +146,14 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 // And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2) - std::cout << "Loading model..." << std::endl; - std::cout << "Model path: " << MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32" << std::endl; - auto model = OgaModel::Create("/home/aciddelgado/ort-genai-source/test/test_models/hf-internal-testing/tiny-random-gpt2-fp32"); - + auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32"); auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); params->SetSearchOption("batch_size", batch_size); - // params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); - std::cout << "Adding input tokens..." << std::endl; generator->AddInputTokens(input_ids.data(), input_ids.size()); - std::cout << "Generating..." << std::endl; while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -200,34 +162,11 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { const auto sequence_length = generator->GetSequenceCount(i); const auto* sequence_data = generator->GetSequenceData(i); - std::cout << "Sequence length: " << sequence_length << std::endl; - std::cout << "Max length: " << max_length << std::endl; - std::cout << "Output sequence: "; - for (int j = 0; j < sequence_length; j++) { - std::cout << sequence_data[j] << " "; - } - std::cout << std::endl; - ASSERT_LE(sequence_length, max_length); const auto* expected_output_start = &expected_output[i * max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } - - // TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now - // // Test high level API - // auto sequences = model->Generate(*params); - - // // Verify outputs match expected outputs - // for (int i = 0; i < batch_size; i++) { - // const auto sequence_length = sequences->SequenceCount(i); - // const auto* sequence_data = sequences->SequenceData(i); - - // ASSERT_LE(sequence_length, max_length); - - // const auto* expected_output_start = &expected_output[i * max_length]; - // EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); - // } } #endif @@ -247,7 +186,6 @@ TEST(CAPITests, GetOutputCAPI) { auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", max_length); - // params->SetInputIDs(input_ids.data(), input_ids.size(), input_sequence_length, batch_size); auto generator = OgaGenerator::Create(*model, *params); generator->AddInputTokens(input_ids.data(), input_ids.size()); @@ -263,7 +201,6 @@ TEST(CAPITests, GetOutputCAPI) { -0.04699047f, 0.17915794f, 0.20838135f, 0.10888482f, -0.00277808f, 0.2938929f, -0.10538938f, -0.00226692f, 0.12050669f, -0.10622668f}; - // generator->ComputeLogits(); auto prompt_logits_ptr = generator->GetOutput("logits"); auto prompt_logits = static_cast(prompt_logits_ptr->Data()); int num_prompt_outputs_to_check = 40; @@ -281,7 +218,6 @@ TEST(CAPITests, GetOutputCAPI) { std::vector expected_sampled_logits_token_gen{0.03742531f, -0.05752287f, 0.14159015f, 0.04210977f, -0.1484456f, 0.3041716f, -0.08701379f, -0.03778192f, 0.07471392f, -0.02049096f}; - // generator->ComputeLogits(); auto token_gen_logits_ptr = generator->GetOutput("logits"); auto token_gen_logits = static_cast(token_gen_logits_ptr->Data()); int num_token_gen_outputs_to_check = 10; @@ -289,7 +225,6 @@ TEST(CAPITests, GetOutputCAPI) { for (int i = 0; i < num_token_gen_outputs_to_check; i++) { EXPECT_NEAR(expected_sampled_logits_token_gen[i], token_gen_logits[i*sample_size], tolerance); } - // generator->GenerateNextToken(); } #if TEST_PHI2 @@ -311,7 +246,6 @@ struct Phi2Test { tokenizer_->Encode(string, *input_sequences_); params_ = OgaGeneratorParams::Create(*model_); - // params_->SetInputSequences(*input_sequences_); params_->SetSearchOption("max_length", 40); } @@ -322,7 +256,6 @@ struct Phi2Test { generator->AddInputSequences(input_sequences_); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -332,18 +265,6 @@ struct Phi2Test { std::cout << "Decoded string:" << out_string << std::endl; } } - - // TODO(aciddelgado): E2E API may be removed + we add tokens to generator directly now - // // High level - // { - // auto output_sequences = model_->Generate(*params_); - - // // Decode The Batch - // for (size_t i = 0; i < output_sequences->Count(); i++) { - // auto out_string = tokenizer_->Decode(output_sequences->Get(i)); - // std::cout << "Decoded string:" << out_string << std::endl; - // } - // } } std::unique_ptr model_; diff --git a/test/model_tests.cpp b/test/model_tests.cpp index e167081b9..6c1d1ba4c 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -39,14 +39,11 @@ TEST(ModelTests, GreedySearchGptFp32) { auto params = Generators::CreateGeneratorParams(*model); params->search.max_length = 10; params->batch_size = static_cast(input_ids_shape[0]); - // params->sequence_length = static_cast(input_ids_shape[1]); - // params->input_ids = input_ids; auto generator = Generators::CreateGenerator(*model, *params); generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -79,17 +76,13 @@ TEST(ModelTests, BeamSearchGptFp32) { auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - // params->sequence_length = static_cast(input_ids_shape[1]); - // params->input_ids = input_ids; params->search.max_length = 20; params->search.length_penalty = 1.0f; params->search.num_beams = 4; auto generator = Generators::CreateGenerator(*model, *params); generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); - // auto result = Generators::Generate(*model, *params); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -117,15 +110,12 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - // params->sequence_length = static_cast(input_ids_shape[1]); params->search.max_length = 10; - // params->input_ids = input_ids; auto generator = Generators::CreateGenerator(*model, *params); generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -163,17 +153,13 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { auto params = Generators::CreateGeneratorParams(*model); params->batch_size = static_cast(input_ids_shape[0]); - // params->sequence_length = static_cast(input_ids_shape[1]); - // params->input_ids = input_ids; params->search.max_length = 20; params->search.num_beams = 4; params->search.length_penalty = 1.0f; auto generator = Generators::CreateGenerator(*model, *params); generator->AddTokens(Generators::cpu_span(input_ids.data(), input_ids.size())); - // auto result = Generators::Generate(*model, *params); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -209,15 +195,12 @@ Print all primes between 1 and n auto params = Generators::CreateGeneratorParams(*model); params->batch_size = 1; - // params->sequence_length = static_cast(tokens.size()); - // params->input_ids = tokens; params->search.max_length = 128; // Generator version auto generator = Generators::CreateGenerator(*model, *params); generator->AddInputTokens(Generators::cpu_span(tokens.data(), tokens.size())); while (!generator->IsDone()) { - // generator->ComputeLogits(); generator->GenerateNextToken(); } @@ -227,32 +210,4 @@ Print all primes between 1 and n #endif } -// TEST(ModelTests, TestHighLevelApiCuda) { -// #if TEST_PHI2 -// auto prompt = R"( -// def print_prime(n): -// ''' -// Print all primes between 1 and n -// ''' -// )"; - -// std::cout << "With prompt:" << prompt << "\r\n"; - -// auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "phi-2"); -// auto tokenizer = model->CreateTokenizer(); -// auto tokens = tokenizer->Encode(prompt); - -// auto params = Generators::CreateGeneratorParams(*model); -// params->batch_size = 1; -// params->sequence_length = static_cast(tokens.size()); -// params->input_ids = tokens; -// params->search.max_length = 128; - -// // High level version -// auto result = Generators::Generate(*model, *params); - -// std::cout << tokenizer->Decode(result[0]) << "\r\n"; -// #endif -// } - #endif \ No newline at end of file diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index fceb9765b..5c1a64b5c 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -25,9 +25,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -60,9 +58,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -98,9 +94,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCpu) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -137,9 +131,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; std::vector cpu_logits(vocab_size * batch_size); std::random_device rd; @@ -181,9 +173,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); @@ -222,9 +212,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); @@ -265,9 +253,7 @@ TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocArray(vocab_size * batch_size); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 861b953a3..e5a06e427 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -29,9 +29,7 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { params->search.do_sample = true; params->search.top_p = 0.25f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto generator = Generators::CreateGenerator(*model, *params); auto logits_span = Generators::cpu_span(logits_cpu); @@ -57,9 +55,7 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -91,9 +87,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; auto generator = Generators::CreateGenerator(*model, *params); auto logits_copy = logits_cpu; @@ -141,9 +135,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -179,9 +171,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -219,9 +209,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CPU; std::vector logits_cpu(vocab_size * batch_size); std::random_device rd; @@ -266,9 +254,7 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { params->search.do_sample = true; params->search.top_p = 0.25f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -296,9 +282,7 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { params->search.do_sample = true; params->search.top_k = 2; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -331,9 +315,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { params->search.top_k = 2; params->search.top_p = 0.25f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; cudaMemcpyAsync(logits_gpu.get(), logits_cpu.data(), logits_cpu.size() * sizeof(float), cudaMemcpyHostToDevice, params->cuda_stream); cudaStreamSynchronize(params->cuda_stream); @@ -360,9 +342,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { params->search.do_sample = true; params->search.top_p = 0.95f; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -402,9 +382,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { params->search.do_sample = true; params->search.top_k = k; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -446,9 +424,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { params->search.top_k = k; params->search.top_p = p; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); @@ -485,9 +461,7 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; params->batch_size = batch_size; - // params->sequence_length = 1; params->vocab_size = vocab_size; - // params->input_ids = input_ids; params->device_type = Generators::DeviceType::CUDA; auto logits_gpu = Generators::CudaMallocArray(vocab_size * batch_size); auto indices_buffer = Generators::CudaMallocHostArray(vocab_size * batch_size); From b80878efc449fba839ad77c7eb0c37a2611840da Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Wed, 25 Sep 2024 14:51:29 -0700 Subject: [PATCH 07/13] cuda working i think --- src/generators.cpp | 11 +---------- src/logging.cpp | 2 -- src/logging.h | 2 -- src/models/decoder_only.cpp | 8 ++++---- src/models/decoder_only.h | 2 +- src/models/gpt.cpp | 6 +++--- src/models/gpt.h | 2 +- src/models/logits.cpp | 26 +++++++++++++++---------- src/models/logits.h | 4 +++- src/models/multi_modal_vision_model.cpp | 8 ++++---- src/models/multi_modal_vision_model.h | 2 +- src/models/position_inputs.cpp | 18 ++++++++--------- src/models/position_inputs.h | 6 +++--- src/models/whisper.cpp | 4 ++-- src/models/whisper.h | 2 +- src/search.cpp | 2 +- src/search.h | 4 ++-- src/search_cuda.cpp | 20 ++++++++++++++++++- src/search_cuda.cu | 6 +++--- src/search_cuda.cuh | 2 +- src/search_cuda.h | 2 ++ src/sequences_cuda.cpp | 13 +++++++++++++ src/sequences_cuda.cu | 14 +++++++++++++ src/sequences_cuda.h | 2 ++ 24 files changed, 106 insertions(+), 62 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 665d8aa0e..7004f33d8 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -145,16 +145,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ void Generator::AddTokens(cpu_span input_ids) { // TODO(aciddelgado): batch_size > 1 requires full rewind - search_->SetNextTokens(input_ids); - - if (g_log.enabled && g_log.add_tokens) { - auto& stream = Log("add_tokens"); - stream << "input_ids: "; - for (auto token : input_ids) { - stream << token << ' '; - } - stream << std::endl; - } + search_->SetUserTokens(input_ids); computed_logits_ = false; ComputeLogits(input_ids); diff --git a/src/logging.cpp b/src/logging.cpp index be6589ad7..26cb64033 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -36,8 +36,6 @@ void SetLogBool(std::string_view name, bool value) { g_log.model_output_values = value; else if (name == "model_logits") g_log.model_logits = value; - else if (name == "continuous_decoding") - g_log.continuous_decoding = value; else if (name == "ort_lib") g_log.ort_lib = value; else diff --git a/src/logging.h b/src/logging.h index e5f894131..9d16c57f9 100644 --- a/src/logging.h +++ b/src/logging.h @@ -42,8 +42,6 @@ struct LogItems { bool model_output_shapes{}; // Before the model runs there are only the output shapes, no values in them. Useful for pre Session::Run debugging bool model_output_values{}; // After the model runs the output tensor values can be displayed bool model_logits{}; // Same as model_output_values but only for the logits - bool continuous_decoding{}; // Log continuous decoding steps. - bool add_tokens{}; // Log the addition of tokens to the input. bool ort_lib{}; // Log the onnxruntime library loading and api calls. }; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 41d7274e2..20843b22f 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -35,12 +35,12 @@ RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { - input_ids_.Update(next_tokens_unk); +void DecoderOnly_State::UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int total_length) { + input_ids_.Update(next_tokens); size_t new_length = input_ids_.GetShape()[1]; - position_inputs_.Update(total_length, new_length); + position_inputs_.Update(next_tokens, total_length, new_length); kv_cache_.Update(beam_indices.GetCPU(), total_length); - logits_.Update(new_length); + logits_.Update(next_tokens, new_length); } } // namespace Generators diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 073e99497..9fbaea13d 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -21,7 +21,7 @@ struct DecoderOnly_State : State { const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; protected: - void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length); + void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray next_indices, int current_length); void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); // what this does const DecoderOnly_Model& model_; diff --git a/src/models/gpt.cpp b/src/models/gpt.cpp index e2ea18249..1c2211ca9 100644 --- a/src/models/gpt.cpp +++ b/src/models/gpt.cpp @@ -35,12 +35,12 @@ RoamingArray Gpt_State::Run(int current_length, RoamingArray nex return logits_.Get(); } -void Gpt_State::UpdateInputsOutputs(const RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { +void Gpt_State::UpdateInputsOutputs(RoamingArray& next_tokens_unk, RoamingArray beam_indices, int total_length) { input_ids_.Update(next_tokens_unk); size_t new_length = input_ids_.GetShape()[1]; - position_inputs_.Update(total_length, new_length); + position_inputs_.Update(next_tokens_unk, total_length, new_length); kv_cache_.Update(beam_indices.GetCPU(), total_length); - logits_.Update(new_length); + logits_.Update(next_tokens_unk, new_length); } } // namespace Generators diff --git a/src/models/gpt.h b/src/models/gpt.h index 50607ed64..4dad06ad7 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -21,7 +21,7 @@ struct Gpt_State : State { RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; private: - void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length); + void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int current_length); const Gpt_Model& model_; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 105c7f632..3b84e876d 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -32,6 +32,8 @@ Logits::Logits(const Model& model, State& state) cudaMemcpyAsync(cuda_eos_token_ids_.data(), cpu_ids.data(), cpu_ids.size() * sizeof(int32_t), ::cudaMemcpyHostToDevice, model_.cuda_stream_); } #endif + + input_sequence_lengths.resize(state_.params_->batch_size); } #pragma warning(push) @@ -63,15 +65,9 @@ RoamingArray Logits::Get() { size_t element_size = type_ == Ort::TypeToTensorType ? 4 : 2; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process - const auto* input_ids = state_.input_ids_.Get()->GetTensorData(); for (int batch_index = 0; batch_index < state_.params_->batch_size; batch_index++) { // Find the first non pad token from the end - size_t token_index = seq_length; - while (token_index-- > 0) { - if (input_ids[token_index] != state_.params_->pad_token_id) - break; - } - + size_t token_index = input_sequence_lengths[batch_index] - 1; for (int beam_index = 0; beam_index < num_beams; beam_index++) { switch (model_.device_type_) { #if USE_DML @@ -116,8 +112,6 @@ RoamingArray Logits::Get() { vocab_index += vocab_size; } - - input_ids += seq_length; } element_count = shape_[0] * shape_[2]; // shape_[1] is now 1, so the element count must be updated @@ -198,11 +192,23 @@ RoamingArray Logits::Get() { #pragma warning(pop) -void Logits::Update(int new_kv_length) { +void Logits::Update(RoamingArray& next_tokens, int new_kv_length) { if (output_raw_.get()->GetTensorTypeAndShapeInfo()->GetShape()[1] == new_kv_length) { return; } + // Store length of input sequence for each batch for the get step + for (int b = 0; b < state_.params_->batch_size; b++) { + // Find the first non pad token from the end + size_t token_index = new_kv_length; + while (token_index-- > 0) { + auto next_token = next_tokens.GetCPU()[b * new_kv_length + token_index]; + if (next_token != state_.params_->pad_token_id) + break; + } + input_sequence_lengths[b] = token_index + 1; + } + shape_[1] = new_kv_length; StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) diff --git a/src/models/logits.h b/src/models/logits.h index 6307c1ec7..3d0d040af 100644 --- a/src/models/logits.h +++ b/src/models/logits.h @@ -15,7 +15,7 @@ struct Logits { RoamingArray Get(); // Resize logits to [bz, token_count, vocab_size] if necessary. - void Update(int new_kv_length); + void Update(RoamingArray& next_tokens_unk, int new_kv_length); private: void HandleEOSArray(cpu_span logits); @@ -34,6 +34,8 @@ struct Logits { std::unique_ptr output_raw_; // Raw logits output from model + std::vector input_sequence_lengths; + // Used for decoding runs with cuda graphs. StaticBuffer* sb_logits32_{}; StaticBuffer* sb_logits16_{}; diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 1ee305755..4b659d856 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -224,11 +224,11 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray return logits_.Get(); } -void DecoderState::UpdateInputsOutputs(int total_length, RoamingArray beam_indices) { +void DecoderState::UpdateInputsOutputs(RoamingArray next_tokens, int total_length, RoamingArray beam_indices) { size_t new_length = input_ids_.GetShape()[1]; // TODO(aciddelgado): looks like this input_ids_ is not updated by add_tokens - position_inputs_.Update(total_length, new_length); + position_inputs_.Update(next_tokens, total_length, new_length); kv_cache_.Update(beam_indices.GetCPU(), total_length); - logits_.Update(new_length); + logits_.Update(next_tokens, new_length); } MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model, @@ -276,7 +276,7 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra } embedding_state_->UpdateInputsAndOutputs(next_tokens); - decoder_state_->UpdateInputsOutputs(current_length, next_indices); + decoder_state_->UpdateInputsOutputs(next_tokens, current_length, next_indices); embedding_state_->Run(current_length, next_tokens, next_indices); decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 2cb70d92e..68fa220ee 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -77,7 +77,7 @@ struct DecoderState : State { private: friend struct MultiModalPipelineState; - void UpdateInputsOutputs(int current_length, RoamingArray beam_indices); + void UpdateInputsOutputs(RoamingArray next_tokens, int current_length, RoamingArray beam_indices); const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 1df8443ba..72bff2a40 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -66,15 +66,15 @@ void PositionInputs::Add() { } } -void PositionInputs::Update(int total_length, int new_length) { +void PositionInputs::Update(RoamingArray& next_tokens_unk, int total_length, int new_length) { if (has_posid_input_) { // Initialize on first update if (is_first_update_) { position_ids_shape_[1] = new_length; if (type_ == Ort::TypeToTensorType) - CreateAndInitializePositionIDs(position_ids_shape_); + CreateAndInitializePositionIDs(next_tokens_unk, position_ids_shape_); else - CreateAndInitializePositionIDs(position_ids_shape_); + CreateAndInitializePositionIDs(next_tokens_unk, position_ids_shape_); } else { // Batch size > 1 case if (position_ids_shape_[0] > 1) @@ -89,9 +89,9 @@ void PositionInputs::Update(int total_length, int new_length) { if (is_first_update_) { attention_mask_shape_[1] = new_length; if (type_ == Ort::TypeToTensorType) - CreateAndInitializeAttentionMask(attention_mask_shape_); + CreateAndInitializeAttentionMask(next_tokens_unk, attention_mask_shape_); else - CreateAndInitializeAttentionMask(attention_mask_shape_); + CreateAndInitializeAttentionMask(next_tokens_unk, attention_mask_shape_); } else { // Batch size > 1 case if (attention_mask_shape_[0] > 1) @@ -458,14 +458,14 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { } template -void PositionInputs::CreateAndInitializePositionIDs(std::array shape) { +void PositionInputs::CreateAndInitializePositionIDs(RoamingArray& next_tokens_unk, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens position_ids_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); position_ids_next_ = OrtValue::CreateTensor(model_.allocator_cpu_, std::array{shape[0], 1}, type_); auto* position_data = position_ids_->GetTensorMutableData(); auto* position_data_next = position_ids_next_->GetTensorMutableData(); - const auto* word_id = state_.input_ids_.Get()->GetTensorData(); + const auto* word_id = next_tokens_unk.GetCPU().data(); auto* position = position_data; for (int i = 0; i < shape[0]; i++) { T abs_position = 0; @@ -488,12 +488,12 @@ void PositionInputs::CreateAndInitializePositionIDs(std::array shape } template -void PositionInputs::CreateAndInitializeAttentionMask(std::array shape) { +void PositionInputs::CreateAndInitializeAttentionMask(RoamingArray& next_tokens_unk, std::array shape) { // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens attention_mask_ = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); auto* mask_data = attention_mask_->GetTensorMutableData(); - const auto* word_id = state_.input_ids_.Get()->GetTensorData(); + const auto* word_id = next_tokens_unk.GetCPU().data(); auto* mask = mask_data; for (int i = 0; i < shape[0]; i++) { T abs_position = 0; diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 45321ef38..36478aef1 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -14,7 +14,7 @@ struct PositionInputs { PositionInputs(const Model& model, State& state); void Add(); - void Update(int total_length, int new_length); + void Update(RoamingArray& next_tokens_unk, int total_length, int new_length); private: void AddAttentionMask(); @@ -30,9 +30,9 @@ struct PositionInputs { template void InitializeSequenceLengths(std::array shape, cpu_span sequence_lengths_unk); template - void CreateAndInitializePositionIDs(std::array shape); + void CreateAndInitializePositionIDs(RoamingArray& next_tokens_unk, std::array shape); template - void CreateAndInitializeAttentionMask(std::array shape); + void CreateAndInitializeAttentionMask(RoamingArray& next_tokens_unk, std::array shape); template void UpdatePositionIDsImpl(); diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index d5e34bd4a..747fa4843 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -79,11 +79,11 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray return logits_.Get(); } -void Whisper_State::UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { +void Whisper_State::UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { decoder_input_ids_.Update(next_tokens); kv_cache_.Update(beam_indices.GetCPU(), current_length); size_t new_length = input_ids_.GetShape()[1]; - logits_.Update(new_length); + logits_.Update(next_tokens, new_length); } } // namespace Generators diff --git a/src/models/whisper.h b/src/models/whisper.h index 5f3872e4c..1b09a73bf 100644 --- a/src/models/whisper.h +++ b/src/models/whisper.h @@ -24,7 +24,7 @@ struct Whisper_State : State { RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) override; private: - void UpdateInputsOutputs(const RoamingArray& next_tokens, RoamingArray next_indices, int current_length); + void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray next_indices, int current_length); const Whisper_Model& model_; enum struct RunState { diff --git a/src/search.cpp b/src/search.cpp index 3442c34a4..42c9241c1 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -252,7 +252,7 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { } } -void GreedySearch_Cpu::SetNextTokens(RoamingArray next_tokens) { +void GreedySearch_Cpu::SetUserTokens(RoamingArray next_tokens) { // Reset done count/state done_ = false; not_done_count_ = params_->batch_size; diff --git a/src/search.h b/src/search.h index ee48320c7..4695119c1 100644 --- a/src/search.h +++ b/src/search.h @@ -28,7 +28,7 @@ struct Search : LeakChecked { virtual void ApplyRepetitionPenalty(float penalty) = 0; // Set user input tokens - virtual void SetNextTokens(RoamingArray next_tokens) { assert(false); }; + virtual void SetUserTokens(RoamingArray next_tokens) { assert(false); }; // To be used for rewind virtual void DropLastTokens(size_t num_tokens) { assert(false); }; @@ -74,7 +74,7 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; // Used by continuous decoding search. - void SetNextTokens(RoamingArray next_tokens) override; + void SetUserTokens(RoamingArray next_tokens) override; void DropLastTokens(size_t num_tokens) override; protected: diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index ba0918f1e..a8d044c2d 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -178,7 +178,8 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { void GreedySearch_Cuda::CheckForEOS() { assert(next_tokens_.size() == eos_meet_.size()); - cuda::Launch_CheckForEOS(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->eos_token_id, params_->pad_token_id, done_cpu_.get(), params_->cuda_stream); + // Don't replace EOS with pad for batch_size == 1 for continuous decoding mode + cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_meet_.data(), params_->eos_token_id, params_->batch_size > 1 ? params_->pad_token_id : params_->eos_token_id, done_cpu_.get(), params_->cuda_stream); } void GreedySearch_Cuda::AppendNextTokensToSequences() { @@ -253,6 +254,23 @@ std::span Search_Cuda::GetScores() { return next_token_scores_; } +// Set user input tokens (batch_beam_size, sequence_length) +void GreedySearch_Cuda::SetUserTokens(RoamingArray next_tokens) { + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + *done_cpu_ = false; + + auto next_tokens_gpu = next_tokens.GetGPU(); + auto batch_size = params_->batch_size; + auto tokens_count_per_batch = next_tokens_gpu.size() / batch_size; + sequences_.AppendUserTokensToSequences(next_tokens_gpu); + + if (sequences_.GetSequenceLength() == params_->search.max_length) { + if (g_log.enabled && g_log.hit_max_length) + Log("hit_max_length", "greedy cuda hit"); + *done_cpu_ = true; + } +} + void Search_Cuda::ApplyMinLength(int min_length) { if (sequences_.GetSequenceLength() >= min_length) return; diff --git a/src/search_cuda.cu b/src/search_cuda.cu index 24fbaf914..0ea68c369 100644 --- a/src/search_cuda.cu +++ b/src/search_cuda.cu @@ -38,7 +38,7 @@ struct ArgMaxDataImpl : ArgMaxData { cuda_unique_ptr> argmaxen_owner_; }; -__global__ void CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu) { +__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu) { // Look for EOS tokens, if seen set EOS flag and replace with pad token for (size_t batch_id = 0; batch_id < next_tokens_count; ++batch_id) { if (next_tokens[batch_id] == eos_token_id || eos_meet[batch_id] == true) { @@ -64,8 +64,8 @@ __global__ void CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* e } } -void Launch_CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream) { - CheckForEOS<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_meet, eos_token_id, pad_token_id, done_cpu); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream) { + CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_meet, eos_token_id, pad_token_id, done_cpu); } __global__ void AddProbsKernel(float* log_probs, diff --git a/src/search_cuda.cuh b/src/search_cuda.cuh index d3662ff17..a6db5e500 100644 --- a/src/search_cuda.cuh +++ b/src/search_cuda.cuh @@ -6,7 +6,7 @@ struct ArgMaxData { virtual ~ArgMaxData() = default; }; -void Launch_CheckForEOS(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_meet, int eos_token_id, int pad_token_id, bool* done_cpu, cudaStream_t stream); void LaunchAddProbsKernel(float* log_probs, float* cum_log_probs, const int batch_size, const int num_beams, const int vocab_size, cudaStream_t stream); void LaunchSetScoreProcessor(float* next_token_scores, int batch_beam_size, int vocab_size, int token, float score, cudaStream_t stream); void LaunchRepetitionPenaltyProcessor(const int32_t* sequences, float* next_token_scores, int batch_size, int num_beams, int vocab_size, int max_sequence_length, int current_sequence_length, float repetition_penalty, cudaStream_t stream); diff --git a/src/search_cuda.h b/src/search_cuda.h index 8a699b880..b70c00653 100644 --- a/src/search_cuda.h +++ b/src/search_cuda.h @@ -52,6 +52,8 @@ struct GreedySearch_Cuda : Search_Cuda { void SampleTopK(int k, float t) override; void SampleTopP(float p, float t) override; void SampleTopKTopP(int k, float p, float t) override; + void SetUserTokens(RoamingArray next_tokens) override; // shape (batch_size, sequence_length) + private: void CheckForEOS(); diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index 33b74d96b..8e2583c50 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -8,6 +8,7 @@ namespace Generators { namespace cuda { void Launch_ExpandInputSequences(std::span input_sequences, std::span sequences, int batch_size, int beam_size, int current_length, int max_length, cudaStream_t stream); void Launch_AppendNextTokenToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); +void Launch_AppendUserTokensToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int past_length, int new_length, int max_length, cudaStream_t stream); } // namespace cuda // TODO(aciddelgado): update cuda sequences to new paradigm @@ -55,6 +56,18 @@ void Sequences_Cuda::AppendNextTokenToSequences(std::span next_to ++current_length_; } +void Sequences_Cuda::AppendUserTokensToSequences(gpu_span user_tokens) { + // if (g_log.enabled && g_log.set_next_tokens) { + // auto& stream = Log("set_next_tokens"); + // DumpCudaSpan(stream, next_tokens_span); + // stream << std::endl; + // } + size_t new_length = user_tokens.size() / batch_beam_size_; + size_t past_length = current_length_; + cuda::Launch_AppendUserTokensToSequences(user_tokens, sequences_, batch_beam_size_, past_length, new_length, max_length_, stream_); + current_length_ += new_length; +} + void Sequences_Cuda::AfterDeviceAppendedNextToken() { ++current_length_; diff --git a/src/sequences_cuda.cu b/src/sequences_cuda.cu index dfbc8aae1..d376a4c5f 100644 --- a/src/sequences_cuda.cu +++ b/src/sequences_cuda.cu @@ -32,5 +32,19 @@ void Launch_AppendNextTokenToSequences(std::span next_tokens, std AppendNextTokenToSequences<<<1, 1, 0, stream>>>(next_tokens.data(), sequences.data(), batch_beam_size, current_length, max_length); } +// TODO(aciddelgado): parallelize this kernel. +__global__ void AppendUserTokensToSequences(const int32_t* user_tokens, int32_t* sequences, int batch_beam_size, int past_length, int new_length, int max_length) { + // Append user tokens to each sequence. + for (int i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < new_length; j++) { + sequences[i * max_length + past_length + j] = user_tokens[i * new_length + j]; + } + } +} + +void Launch_AppendUserTokensToSequences(std::span user_tokens, std::span sequences, int batch_beam_size, int past_length, int new_length, int max_length, cudaStream_t stream) { + AppendUserTokensToSequences<<<1, 1, 0, stream>>>(user_tokens.data(), sequences.data(), batch_beam_size, past_length, new_length, max_length); +} + } // namespace cuda } // namespace Generators diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index c1371bed9..ef1bb1eb4 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -11,6 +11,8 @@ struct Sequences_Cuda { gpu_span GetNextSequences() { return sequences_next_; } void AppendNextTokenToSequences(std::span next_tokens); + void AppendUserTokensToSequences(gpu_span user_tokens); + void SetNextTokens(gpu_span next_tokens_span); // Returns current sequence length. int GetSequenceLength() const; From 9e03c4fde854b38cce317b20ba9631ade26fa41b Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Wed, 25 Sep 2024 16:34:04 -0700 Subject: [PATCH 08/13] move input ids back to where they were --- src/models/decoder_only.h | 2 ++ src/models/gpt.h | 1 + src/models/model.cpp | 3 +-- src/models/model.h | 3 --- src/models/multi_modal_vision_model.cpp | 7 +++---- src/models/multi_modal_vision_model.h | 1 + src/models/whisper.cpp | 2 +- 7 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 9fbaea13d..455614ff9 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -1,5 +1,6 @@ #pragma once #include "model.h" +#include "input_ids.h" #include "logits.h" #include "kv_cache.h" #include "position_inputs.h" @@ -27,6 +28,7 @@ struct DecoderOnly_State : State { const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; + InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache kv_cache_{model_, *this}; PositionInputs position_inputs_; diff --git a/src/models/gpt.h b/src/models/gpt.h index 4dad06ad7..0a692b8b9 100644 --- a/src/models/gpt.h +++ b/src/models/gpt.h @@ -25,6 +25,7 @@ struct Gpt_State : State { const Gpt_Model& model_; + InputIDs input_ids_{model_, *this}; Logits logits_{model_, *this}; KV_Cache_Combined kv_cache_{model_, *this}; PositionInputs position_inputs_; diff --git a/src/models/model.cpp b/src/models/model.cpp index dac18d4e2..3e17fd009 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -38,8 +38,7 @@ namespace Generators { State::State(const GeneratorParams& params, const Model& model) : params_{params.shared_from_this()}, - model_{model}, - input_ids_{model, *this} {} + model_{model} {} void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) { if (first_run_) { diff --git a/src/models/model.h b/src/models/model.h index b57ef7725..2733dbf66 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -5,7 +5,6 @@ #include "captured_graph_pool.h" #include "utils.h" #include "prompt_image_processor.h" -#include "input_ids.h" #if USE_DML #include "dml_provider_factory.h" @@ -39,8 +38,6 @@ struct State { std::vector input_names_, output_names_; std::vector inputs_, outputs_; - InputIDs input_ids_; - protected: void Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size); // Uses the inputs below to run void ClearIO(); // Clear all inputs/outputs diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 4b659d856..eb38c769f 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -225,7 +225,8 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray } void DecoderState::UpdateInputsOutputs(RoamingArray next_tokens, int total_length, RoamingArray beam_indices) { - size_t new_length = input_ids_.GetShape()[1]; // TODO(aciddelgado): looks like this input_ids_ is not updated by add_tokens + int batch_size = static_cast(inputs_embeds_.GetShape()[0]); + size_t new_length = next_tokens.GetCPU().size() / batch_size; position_inputs_.Update(next_tokens, total_length, new_length); kv_cache_.Update(beam_indices.GetCPU(), total_length); logits_.Update(next_tokens, new_length); @@ -259,9 +260,7 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra vision_state_->Run(current_length, next_tokens, next_indices); // Run the select logic - const auto* input_ids = decoder_state_->input_ids_.Get()->GetTensorData(); - auto input_ids_span = std::span(input_ids, decoder_state_->input_ids_.GetShape()[1]); - Select(model_, input_ids_span, embedding_state_->inputs_embeds_.Get(), + Select(model_, next_tokens.GetCPU(), embedding_state_->inputs_embeds_.Get(), vision_state_->visual_features_.get(), vision_state_->num_image_tokens_, params_->hidden_size, params_->device_type, params_->cuda_stream); } diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 68fa220ee..e7928d04c 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -42,6 +42,7 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; + InputIDs input_ids_{model_, *this}; Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; }; diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 747fa4843..fd407d227 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -82,7 +82,7 @@ RoamingArray Whisper_State::Run(int current_length, RoamingArray void Whisper_State::UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray beam_indices, int current_length) { decoder_input_ids_.Update(next_tokens); kv_cache_.Update(beam_indices.GetCPU(), current_length); - size_t new_length = input_ids_.GetShape()[1]; + size_t new_length = decoder_input_ids_.GetShape()[1]; logits_.Update(next_tokens, new_length); } From fc3a0d38a6875b678b59026193a3f9fc8fcdf16c Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Fri, 27 Sep 2024 10:21:23 -0700 Subject: [PATCH 09/13] fix batch_size > 1 --- examples/python/model-generate.py | 21 ++++++++++++++------- src/generators.cpp | 5 +++-- src/generators.h | 6 +++--- src/models/input_ids.cpp | 2 +- src/models/logits.cpp | 2 +- src/ort_genai.h | 2 +- src/ort_genai_c.cpp | 4 ++-- src/ort_genai_c.h | 2 +- src/python/python.cpp | 2 +- src/search.cpp | 6 +++--- src/search.h | 2 +- 11 files changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index 0a97f25b4..3e7cd2769 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -12,9 +12,9 @@ def main(args): if hasattr(args, 'prompts'): prompts = args.prompts else: - prompts = ["I like walking my cute dog", - "What is the best restaurant in town?", - "Hello, how are you today?"] + prompts = ["The first 4 digits of pi are", + "The square root of 2 is", + "The first 6 numbers of the Fibonacci sequence are",] if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: @@ -28,6 +28,7 @@ def main(args): params = og.GeneratorParams(model) search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args} + search_options['batch_size'] = 3 if (args.verbose): print(f'Args: {args}') if (args.verbose): print(f'Search options: {search_options}') @@ -37,22 +38,28 @@ def main(args): params.try_graph_capture_with_max_batch_size(len(prompts)) if args.batch_size_for_cuda_graph: params.try_graph_capture_with_max_batch_size(args.batch_size_for_cuda_graph) - params.input_ids = input_tokens if args.verbose: print("GeneratorParams created") + generator = og.Generator(model, params) + if args.verbose: print("Generator created") + + generator.add_input_tokens(input_tokens) + if args.verbose: print("Input tokens added") + if args.verbose: print("Generating tokens ...\n") start_time = time.time() - output_tokens = model.generate(params) + while not generator.is_done(): + generator.generate_next_token() run_time = time.time() - start_time for i in range(len(prompts)): print(f'Prompt #{i}: {prompts[i]}') print() - print(tokenizer.decode(output_tokens[i])) + print(tokenizer.decode(generator.get_sequence(i))) print() print() - total_tokens = sum(len(x) for x in output_tokens) + total_tokens = sum(len(generator.get_sequence(i)) for i in range(len(prompts))) print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}") print() diff --git a/src/generators.cpp b/src/generators.cpp index 7004f33d8..b7a9029e1 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -110,7 +110,7 @@ void GeneratorParams::SetInputs(const NamedTensors& named_tensors) { } } -std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) { +std::unique_ptr CreateGenerator(const Model& model, GeneratorParams& params) { return std::make_unique(model, params); } @@ -129,11 +129,12 @@ std::unique_ptr CreateSearch(const GeneratorParams& params) { return std::make_unique(params); } -Generator::Generator(const Model& model, const GeneratorParams& params) : model_{model.shared_from_this()} { +Generator::Generator(const Model& model, GeneratorParams& params) : model_{model.shared_from_this()} { if (params.search.max_length == 0) throw std::runtime_error("search max_length is 0"); if (params.search.max_length > model.config_->model.context_length) throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")"); + params.batch_size = params.search.batch_size; // TEMP: bad overlap between search and generator params if (params.batch_size < 1) throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.batch_size)); if (params.vocab_size < 1) diff --git a/src/generators.h b/src/generators.h index 3e34f6cdf..9bf8fe022 100644 --- a/src/generators.h +++ b/src/generators.h @@ -121,7 +121,7 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec }; struct Generator : LeakChecked { - Generator(const Model& model, const GeneratorParams& params); + Generator(const Model& model, GeneratorParams& params); bool IsDone() const; virtual void AddTokens(cpu_span input_ids); @@ -158,8 +158,8 @@ OrtEnv& GetOrtEnv(); std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(); // For benchmarking purposes only -std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); -std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence +std::unique_ptr CreateGenerator(const Model& model, GeneratorParams& params); +// std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction void top_k_indices(std::span top_k, std::span inputs); diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 67f6b8e89..b31f35fb9 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -9,7 +9,7 @@ InputIDs::InputIDs(const Model& model, State& state) : model_{model}, state_{state} { name_ = model_.config_->model.decoder.inputs.input_ids.c_str(); - shape_ = {state_.params_->search.num_beams * state_.params_->batch_size, 0}; + shape_ = {state_.params_->BatchBeamSize(), 0}; auto session_info = model_.session_info_.get(); type_ = session_info->GetInputDataType(name_); diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 3b84e876d..d89a1ca19 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -12,7 +12,7 @@ namespace Generators { Logits::Logits(const Model& model, State& state) : model_{model}, state_{state}, - shape_{static_cast(state_.params_->batch_size) * state_.params_->search.num_beams, 0, state_.params_->vocab_size}, + shape_{state_.params_->BatchBeamSize(), 0, state_.params_->vocab_size}, type_{model_.session_info_->GetOutputDataType(model_.config_->model.decoder.outputs.logits)} { output_raw_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); diff --git a/src/ort_genai.h b/src/ort_genai.h index e53617fa5..b8930729b 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -207,7 +207,7 @@ struct OgaGeneratorParams : OgaAbstract { }; struct OgaGenerator : OgaAbstract { - static std::unique_ptr Create(const OgaModel& model, const OgaGeneratorParams& params) { + static std::unique_ptr Create(const OgaModel& model, OgaGeneratorParams& params) { OgaGenerator* p; OgaCheckResult(OgaCreateGenerator(&model, ¶ms, &p)); return std::unique_ptr(p); diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 49c1b9552..3e87676cc 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -166,9 +166,9 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorPa // OGA_CATCH // } -OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { +OgaResult* OgaCreateGenerator(const OgaModel* model, OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY - *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); + *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 33427f3e2..417db677f 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -190,7 +190,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(Oga * \param[out] out The created generator. * \return OgaResult containing the error message if the generator creation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, OgaGeneratorParams* params, OgaGenerator** out); /* * \brief Destroys the given generator. diff --git a/src/python/python.cpp b/src/python/python.cpp index 4768b830d..20486ac24 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -273,7 +273,7 @@ struct PyNamedTensors { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { - generator_ = CreateGenerator(model, params); + generator_ = CreateGenerator(model, *params.params_); } pybind11::array_t GetNextTokens() { diff --git a/src/search.cpp b/src/search.cpp index 42c9241c1..4d8e93144 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -230,9 +230,9 @@ bool GreedySearch_Cpu::PadIfAlreadyEOS(size_t batch_id) { return true; } -void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) { +void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token, bool check_eos) { next_tokens_[batch_id] = token; - if (token == params_->eos_token_id) { + if (check_eos && token == params_->eos_token_id) { eos_seen_[batch_id] = true; if (g_log.enabled && g_log.hit_eos) Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id)); @@ -264,7 +264,7 @@ void GreedySearch_Cpu::SetUserTokens(RoamingArray next_tokens) { auto tokens_count_per_batch = next_tokens_cpu.size() / batch_size; for (size_t j = 0; j < tokens_count_per_batch; j++) { for (size_t i = 0; i < batch_size; i++) { - SetNextToken(i, next_tokens_cpu[i * tokens_count_per_batch + j]); + SetNextToken(i, next_tokens_cpu[i * tokens_count_per_batch + j], false); } AppendNextTokensToSequences(); } diff --git a/src/search.h b/src/search.h index 4695119c1..6aced4076 100644 --- a/src/search.h +++ b/src/search.h @@ -78,7 +78,7 @@ struct GreedySearch_Cpu : Search_Cpu { void DropLastTokens(size_t num_tokens) override; protected: - void SetNextToken(size_t batch_id, int32_t token); + void SetNextToken(size_t batch_id, int32_t token, bool check_eos = true); void AppendNextTokensToSequences(); private: From 0d971fd487ea5f880bfd82d516d2bd6d33ccf2e8 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Tue, 1 Oct 2024 09:59:18 -0700 Subject: [PATCH 10/13] working on rewind --- src/generators.cpp | 13 +++++ src/generators.h | 1 + src/models/decoder_only.cpp | 6 +++ src/models/decoder_only.h | 2 + src/models/kv_cache.cpp | 38 +++++++++++++++ src/models/kv_cache.h | 3 ++ src/models/model.h | 2 + src/models/position_inputs.cpp | 86 +++++++++++++++++++++++++--------- src/models/position_inputs.h | 6 +++ src/search.cpp | 12 +++++ src/search.h | 3 ++ src/sequences.cpp | 13 +++++ src/sequences.h | 4 ++ 13 files changed, 167 insertions(+), 22 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index b7a9029e1..f983922f9 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -220,6 +220,19 @@ void Generator::GenerateNextToken() { } } +void Generator::RewindToLength(size_t new_length) { + if (new_length > search_->GetSequenceLength()) + throw std::runtime_error("Cannot rewind to a length greater than the current sequence length"); + if (new_length == search_->GetSequenceLength()) + return; + size_t batch_size = search_->params_->search.batch_size; + if (batch_size > 1 && new_length != 0) + throw std::runtime_error("RewindToLength must be called with new_length=0 when batch_size > 1"); + search_->RewindTo(new_length); + state_->RewindTo(new_length); + computed_logits_ = false; +} + RoamingArray Generator::GetSequence(size_t index) const { return search_->GetSequence(index); } diff --git a/src/generators.h b/src/generators.h index 9bf8fe022..91bc89a33 100644 --- a/src/generators.h +++ b/src/generators.h @@ -126,6 +126,7 @@ struct Generator : LeakChecked { bool IsDone() const; virtual void AddTokens(cpu_span input_ids); virtual void GenerateNextToken(); + virtual void RewindToLength(size_t new_length); // Rewind state to new_length RoamingArray GetSequence(size_t index) const; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 20843b22f..b5952f836 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -35,6 +35,12 @@ RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray& next_tokens, RoamingArray beam_indices, int total_length) { input_ids_.Update(next_tokens); size_t new_length = input_ids_.GetShape()[1]; diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 455614ff9..2da346d0f 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -21,6 +21,8 @@ struct DecoderOnly_State : State { RoamingArray Run(int total_length, RoamingArray next_tokens, RoamingArray next_indices) override; const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_.get(); }; + void RewindTo(size_t index) override; + protected: void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray next_indices, int current_length); void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); // what this does diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 44cf25ff7..49ba2fdf9 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -217,6 +217,44 @@ void KV_Cache::Update(std::span beam_indices, int total_length) { is_first_update_ = false; } +void KV_Cache::RewindTo(size_t index) { + if (past_present_share_buffer_) { + return; + } + + is_first_update_ = true; + if (index == 0) { + for (int i = 0; i < layer_count_ * 2; i++) { + pasts_[i] = nullptr; + } + } else { + RewindPastTensorsTo(index); + } +} + +void KV_Cache::RewindPastTensorsTo(size_t index) { + assert(index > 0 && !past_present_share_buffer_); + auto new_shape = shape_; + new_shape[2] = static_cast(index); + auto batch_x_num_heads = new_shape[0] * new_shape[1]; + auto length_x_head_size = new_shape[2] * new_shape[3]; + for (int i = 0; i < layer_count_ * 2; i++) { + OrtValue& present_value = *presents_[i]; + std::unique_ptr past_value = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + for (int j = 0; j < batch_x_num_heads; j++) { +#if USE_CUDA + if (model.device_type == DeviceType::CUDA) { + cudaMemcpyAsync(past_value->GetTensorMutableData() , present_value.GetTensorData(), length_x_head_size, cudaMemcpyDeviceToDevice, model.cuda_stream_); + } else +#endif + { + copy(present_value, *past_value); + } + } + state_.inputs_[input_index_ + i] = pasts_[i].get(); + } +} + // Copy present state to past state reordered by the beam_indices template void KV_Cache::PickPastState(std::span beam_indices, int index) { diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index c167831f3..c7e6482a8 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -37,11 +37,14 @@ struct KV_Cache { void Add(); // Move present to past. Prepare present output for next generation iteration. void Update(std::span beam_indices, int total_length); + void RewindTo(size_t index); template void PickPastState(std::span beam_indices, int index); void PickPastState(std::span beam_indices, int index); private: + void RewindPastTensorsTo(size_t index); + const Model& model_; State& state_; int layer_count_; diff --git a/src/models/model.h b/src/models/model.h index 2733dbf66..62f9facd7 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -33,6 +33,8 @@ struct State { OrtValue* GetOutput(const char* name); + virtual void RewindTo(size_t index); + std::shared_ptr params_; std::vector input_names_, output_names_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 72bff2a40..1349d97ff 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -104,6 +104,23 @@ void PositionInputs::Update(RoamingArray& next_tokens_unk, int total_le is_first_update_ = false; } +void PositionInputs::RewindTo(size_t index) { + // Reset the state of the position inputs + if (index == 0) { + is_first_update_ = true; + is_first_posid_update_ = true; + is_first_mask_update_ = true; + // Rewind the mask input to a previous state + } else if (has_mask_input_) { + if (attention_mask_shape_[0] == 1) +#if USE_CUDA + RewindMask(index); + else +#endif + throw std::runtime_error("PositionInputs::RewindTo - Unsupported batch size"); + } +} + void PositionInputs::AddAttentionMask() { mask_input_index_ = state_.inputs_.size(); @@ -394,30 +411,28 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { #if USE_CUDA attention_mask_shape_[1] = state_.params_->search.max_length; attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); - if (is_first_mask_update_) { - int past_length = total_length - new_kv_length; - if (type_ == Ort::TypeToTensorType) { - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), - 1, - sizeof(int32_t) * past_length, - model_.cuda_stream_); - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, - 0, - sizeof(int32_t) * (total_length - past_length), - model_.cuda_stream_); - } else { - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), - 1, - sizeof(int64_t) * past_length, - model_.cuda_stream_); - cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, - 0, - sizeof(int64_t) * (total_length - past_length), - model_.cuda_stream_); - } + int past_length = total_length - new_kv_length; + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (total_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (total_length - past_length), + model_.cuda_stream_); } #endif - } else { + } else if (!sb_attention_mask_) { attention_mask_shape_[1] = total_length; attention_mask_ = OrtValue::CreateTensor(*model_.allocator_device_, attention_mask_shape_, type_); } @@ -552,4 +567,31 @@ void PositionInputs::UpdateAttentionMaskImpl(T* data, int total_length) { } }; +#if USE_CUDA +void PositionInputs::RewindMask(size_t index) { + if (sb_attention_mask_ && !is_first_mask_update_) { + int past_length = static_cast(index); + if (type_ == Ort::TypeToTensorType) { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int32_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int32_t) * (total_length - past_length), + model_.cuda_stream_); + } else { + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), + 1, + sizeof(int64_t) * past_length, + model_.cuda_stream_); + cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, + 0, + sizeof(int64_t) * (total_length - past_length), + model_.cuda_stream_); + } + } +} +#endif + } // namespace Generators diff --git a/src/models/position_inputs.h b/src/models/position_inputs.h index 36478aef1..92ed84224 100644 --- a/src/models/position_inputs.h +++ b/src/models/position_inputs.h @@ -16,6 +16,8 @@ struct PositionInputs { void Add(); void Update(RoamingArray& next_tokens_unk, int total_length, int new_length); + void RewindTo(size_t index); + private: void AddAttentionMask(); void AddPositionIDs(); @@ -44,6 +46,10 @@ struct PositionInputs { template void UpdateAttentionMaskImpl(T* data, int total_length); +#if USE_CUDA + void RewindMask(size_t index); +#endif + const Model& model_; State& state_; diff --git a/src/search.cpp b/src/search.cpp index 4d8e93144..aaee79a4e 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -270,6 +270,18 @@ void GreedySearch_Cpu::SetUserTokens(RoamingArray next_tokens) { } } +void GreedySearch_Cpu::RewindTo(size_t index) { + sequences_.RewindTo(index); + done_ = false; + not_done_count_ = params_->batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + // Set next tokens to the last tokens in the sequence + if (index > 0) + next_tokens_ = sequences_.GetLastTokens(); + else + memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); +} + void GreedySearch_Cpu::DropLastTokens(size_t num_tokens) { auto sequences_cpu = sequences_.GetSequences(); auto new_sequence_length = sequences_.GetSequenceLength() - num_tokens; diff --git a/src/search.h b/src/search.h index 6aced4076..b6ce0f7a7 100644 --- a/src/search.h +++ b/src/search.h @@ -30,6 +30,8 @@ struct Search : LeakChecked { // Set user input tokens virtual void SetUserTokens(RoamingArray next_tokens) { assert(false); }; // To be used for rewind + virtual void RewindTo(size_t index) { assert(false); }; + // To be used for rewind virtual void DropLastTokens(size_t num_tokens) { assert(false); }; std::shared_ptr params_; @@ -75,6 +77,7 @@ struct GreedySearch_Cpu : Search_Cpu { // Used by continuous decoding search. void SetUserTokens(RoamingArray next_tokens) override; + void RewindTo(size_t index) override; void DropLastTokens(size_t num_tokens) override; protected: diff --git a/src/sequences.cpp b/src/sequences.cpp index 878688e9f..ac95501cb 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -62,6 +62,19 @@ void Sequences::AppendNextTokenToSequences(std::span next_tokens) ++current_length_; } +cpu_span Sequences::GetLastTokens() { + std::vector last_tokens(batch_beam_size_); + for (int i = 0; i < batch_beam_size_; i++) { + last_tokens[i] = sequences_[i * max_length_ + current_length_ - 1]; + } + return cpu_span{last_tokens.data(), last_tokens.size()}; +} + +void Sequences::RewindTo(size_t index) { + current_length_ = static_cast(index); + assert(current_length_ >= 0); +} + void Sequences::DropLastTokens(size_t num_tokens) { current_length_ -= static_cast(num_tokens); assert(current_length_ >= 0); diff --git a/src/sequences.h b/src/sequences.h index dbe5f770a..11caa5952 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -19,6 +19,10 @@ struct Sequences { // Used by Greedy search: void AppendNextTokenToSequences(std::span next_tokens); + // Return Token IDs of last token in each sequence + cpu_span GetLastTokens(); + // Rewind sequences to ith token + void RewindTo(size_t index); // TODO(aciddelgado): To be used for rewind void DropLastTokens(size_t num_tokens); From 2fed10c1ff7aa4eafcb0bda0c1c1c2c6501190ed Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Wed, 2 Oct 2024 15:06:38 -0700 Subject: [PATCH 11/13] b size 1 cpu reverse working --- examples/python/model-qa.py | 10 +++++++++- src/models/decoder_only.cpp | 2 -- src/models/decoder_only.h | 2 -- src/models/kv_cache.cpp | 32 +++++++++++++++++++++++--------- src/models/kv_cache.h | 1 + src/models/logits.cpp | 2 ++ src/models/model.h | 2 +- src/models/position_inputs.cpp | 14 ++++++++------ src/ort_genai.h | 8 ++++---- src/ort_genai_c.cpp | 7 +++++++ src/ort_genai_c.h | 10 +++++++++- src/python/python.cpp | 5 +++++ src/search.cpp | 8 ++++++-- 13 files changed, 75 insertions(+), 28 deletions(-) diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 059ed4d2e..f01185691 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -29,6 +29,12 @@ def main(args): params.set_search_options(**search_options) generator = og.Generator(model, params) + # Set system prompt + system_prompt = "You are a helpful assistant. You are friendly, courteous, and professional. All your responses must end with an exclamation point!" + system_tokens = tokenizer.encode(system_prompt) + generator.add_input_tokens(system_tokens) + system_prompt_length = len(system_tokens) + # Keep asking for input prompts in a loop while True: text = input("Input: ") @@ -76,7 +82,9 @@ def main(args): prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") - + + # Rewind the generator to the system prompt + generator.rewind_to_length(system_prompt_length) if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index b5952f836..c82f81eb1 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -30,13 +30,11 @@ RoamingArray DecoderOnly_State::Run(int total_length, RoamingArray(input_ids_.GetShape()[0]); State::Run(*model_.session_decoder_, *model_.run_options_, batch_size); - reset_input_ = true; return logits_.Get(); } void DecoderOnly_State::RewindTo(size_t index) { - reset_input_ = true; position_inputs_.RewindTo(index); kv_cache_.RewindTo(index); } diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index 2da346d0f..c6d32855f 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -35,8 +35,6 @@ struct DecoderOnly_State : State { KV_Cache kv_cache_{model_, *this}; PositionInputs position_inputs_; ExtraInputs extra_inputs_{model_, *this}; - - bool reset_input_{true}; }; } // namespace Generators diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 49ba2fdf9..7b7bb5b7a 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -220,6 +220,8 @@ void KV_Cache::Update(std::span beam_indices, int total_length) { void KV_Cache::RewindTo(size_t index) { if (past_present_share_buffer_) { return; + } else if (shape_[2] <= static_cast(index)) { + throw std::runtime_error("Requested length of rewind is greater than the current length."); } is_first_update_ = true; @@ -227,30 +229,42 @@ void KV_Cache::RewindTo(size_t index) { for (int i = 0; i < layer_count_ * 2; i++) { pasts_[i] = nullptr; } + } else if (type_ == Ort::TypeToTensorType) { + RewindPastTensorsTo(index); } else { - RewindPastTensorsTo(index); + RewindPastTensorsTo(index); } } +// LEFT OFF: +// check if we even have a present tensor +// use spans and size bytes to ensure correct sizes of data +// make sure we're copying the data correctly +template void KV_Cache::RewindPastTensorsTo(size_t index) { - assert(index > 0 && !past_present_share_buffer_); - auto new_shape = shape_; + assert(index > 0 && shape_[2] >= index && !past_present_share_buffer_); + std::array new_shape = shape_; new_shape[2] = static_cast(index); auto batch_x_num_heads = new_shape[0] * new_shape[1]; - auto length_x_head_size = new_shape[2] * new_shape[3]; + auto new_length_x_head_size = new_shape[2] * new_shape[3]; + auto old_length_x_head_size = shape_[2] * new_shape[3]; + for (int i = 0; i < layer_count_ * 2; i++) { - OrtValue& present_value = *presents_[i]; - std::unique_ptr past_value = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + OrtValue& present = *presents_[i]; + std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); for (int j = 0; j < batch_x_num_heads; j++) { + auto present_data = present.GetTensorData() + j * old_length_x_head_size; + auto past_data = past->GetTensorMutableData() + j * new_length_x_head_size; #if USE_CUDA - if (model.device_type == DeviceType::CUDA) { - cudaMemcpyAsync(past_value->GetTensorMutableData() , present_value.GetTensorData(), length_x_head_size, cudaMemcpyDeviceToDevice, model.cuda_stream_); + if (model_.device_type_ == DeviceType::CUDA) { + cudaMemcpyAsync(past_data, present_data, new_length_x_head_size * sizeof(T), cudaMemcpyDeviceToDevice, model_.cuda_stream_); } else #endif { - copy(present_value, *past_value); + copy(std::span(present_data, new_length_x_head_size), std::span(past_data, new_length_x_head_size)); } } + pasts_[i] = std::move(past); state_.inputs_[input_index_ + i] = pasts_[i].get(); } } diff --git a/src/models/kv_cache.h b/src/models/kv_cache.h index c7e6482a8..a708a4b8d 100644 --- a/src/models/kv_cache.h +++ b/src/models/kv_cache.h @@ -43,6 +43,7 @@ struct KV_Cache { void PickPastState(std::span beam_indices, int index); private: + template void RewindPastTensorsTo(size_t index); const Model& model_; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index d89a1ca19..89cfd5e79 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -209,6 +209,8 @@ void Logits::Update(RoamingArray& next_tokens, int new_kv_length) { input_sequence_lengths[b] = token_index + 1; } + std::cout << "new_kv_length: " << new_kv_length << std::endl; + shape_[1] = new_kv_length; StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) diff --git a/src/models/model.h b/src/models/model.h index 62f9facd7..5c998df08 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -33,7 +33,7 @@ struct State { OrtValue* GetOutput(const char* name); - virtual void RewindTo(size_t index); + virtual void RewindTo(size_t index) { (void)index; }; std::shared_ptr params_; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 1349d97ff..8201a1f8e 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -409,9 +409,10 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { // Update attention mask if (sb_attention_mask_ && is_first_mask_update_) { #if USE_CUDA - attention_mask_shape_[1] = state_.params_->search.max_length; - attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); int past_length = total_length - new_kv_length; + int max_length = state_.params_->search.max_length; + attention_mask_shape_[1] = max_length; + attention_mask_ = sb_attention_mask_->CreateTensorOnStaticBuffer(attention_mask_shape_, type_); if (type_ == Ort::TypeToTensorType) { cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), 1, @@ -419,7 +420,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { model_.cuda_stream_); cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, 0, - sizeof(int32_t) * (total_length - past_length), + sizeof(int32_t) * (max_length - past_length), model_.cuda_stream_); } else { cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), @@ -428,7 +429,7 @@ void PositionInputs::UpdateAttentionMask(int total_length, int new_kv_length) { model_.cuda_stream_); cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, 0, - sizeof(int64_t) * (total_length - past_length), + sizeof(int64_t) * (max_length - past_length), model_.cuda_stream_); } #endif @@ -571,6 +572,7 @@ void PositionInputs::UpdateAttentionMaskImpl(T* data, int total_length) { void PositionInputs::RewindMask(size_t index) { if (sb_attention_mask_ && !is_first_mask_update_) { int past_length = static_cast(index); + int max_length = static_cast(state_.params_->search.max_length); if (type_ == Ort::TypeToTensorType) { cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), 1, @@ -578,7 +580,7 @@ void PositionInputs::RewindMask(size_t index) { model_.cuda_stream_); cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, 0, - sizeof(int32_t) * (total_length - past_length), + sizeof(int32_t) * (max_length - past_length), model_.cuda_stream_); } else { cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(), @@ -587,7 +589,7 @@ void PositionInputs::RewindMask(size_t index) { model_.cuda_stream_); cudaMemsetAsync(attention_mask_->GetTensorMutableRawData() + past_length, 0, - sizeof(int64_t) * (total_length - past_length), + sizeof(int64_t) * (max_length - past_length), model_.cuda_stream_); } } diff --git a/src/ort_genai.h b/src/ort_genai.h index b8930729b..bc9ec4181 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -225,14 +225,14 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_AddInputTokens(this, input_ids, input_ids_count)); } - void ComputeLogits() { - OgaCheckResult(OgaGenerator_ComputeLogits(this)); - } - void GenerateNextToken() { OgaCheckResult(OgaGenerator_GenerateNextToken(this)); } + void RewindToLength(size_t length) { + OgaCheckResult(OgaGenerator_RewindToLength(this, length)); + } + size_t GetSequenceCount(size_t index) const { return OgaGenerator_GetSequenceCount(this, index); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 3e87676cc..7176adff9 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -209,6 +209,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_RewindToLength(OgaGenerator* generator, size_t new_length) { + OGA_TRY + reinterpret_cast(generator)->RewindToLength(new_length); + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 417db677f..b83f79a3d 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -227,9 +227,17 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AddInputTokens(OgaGenerator* oga * \param[in] generator The generator to compute the logits for. * \return OgaResult containing the error message if the computation of the logits failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +/* + * \brief Rewinds the generator to the given length. This is useful when the user wants to rewind the generator to a specific length + * and continue generating from that point. + * \param[in] generator The generator to rewind to the given length. + * \param[in] new_length The new length to rewind the generator to. + * \return OgaResult containing the error message if the rewinding failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_RewindToLength(OgaGenerator* generator, size_t new_length); + /* * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor * and will be released when the OgaTensor is destroyed diff --git a/src/python/python.cpp b/src/python/python.cpp index 20486ac24..7274e1428 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -298,6 +298,10 @@ struct PyGenerator { generator_->GenerateNextToken(); } + void RewindToLength(size_t new_length) { + generator_->RewindToLength(new_length); + } + bool IsDone() const { return generator_->IsDone(); } @@ -406,6 +410,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("get_output", &PyGenerator::GetOutput) .def("add_input_tokens", &PyGenerator::AddTokens) .def("generate_next_token", &PyGenerator::GenerateNextToken) + .def("rewind_to_length", &PyGenerator::RewindToLength) .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); diff --git a/src/search.cpp b/src/search.cpp index aaee79a4e..63d6ed9f9 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -276,8 +276,12 @@ void GreedySearch_Cpu::RewindTo(size_t index) { not_done_count_ = params_->batch_size; memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); // Set next tokens to the last tokens in the sequence - if (index > 0) - next_tokens_ = sequences_.GetLastTokens(); + if (index > 0) { + auto last_tokens = sequences_.GetLastTokens(); + for (size_t i = 0; i < params_->batch_size; i++) { + SetNextToken(i, last_tokens[i]); + } + } else memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); } From 1c86984c5662c428bf81493f995d9cc4abd1489a Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 3 Oct 2024 11:13:52 -0700 Subject: [PATCH 12/13] rewind working on cuda --- examples/python/model-qa.py | 7 +++++-- src/models/kv_cache.cpp | 1 + src/models/logits.cpp | 2 -- src/search.cpp | 5 +---- src/search_cuda.cpp | 11 +++++++++++ src/search_cuda.h | 2 +- src/sequences.cpp | 4 +--- src/sequences.h | 2 +- src/sequences_cuda.cpp | 11 +++++++++++ src/sequences_cuda.cu | 13 +++++++++++++ src/sequences_cuda.h | 5 ++++- 11 files changed, 49 insertions(+), 14 deletions(-) diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index f01185691..23ce32391 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -30,7 +30,7 @@ def main(args): generator = og.Generator(model, params) # Set system prompt - system_prompt = "You are a helpful assistant. You are friendly, courteous, and professional. All your responses must end with an exclamation point!" + system_prompt = args.system_prompt system_tokens = tokenizer.encode(system_prompt) generator.add_input_tokens(system_tokens) system_prompt_length = len(system_tokens) @@ -84,7 +84,8 @@ def main(args): print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") # Rewind the generator to the system prompt - generator.rewind_to_length(system_prompt_length) + if args.rewind: + generator.rewind_to_length(system_prompt_length) if __name__ == "__main__": parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai") @@ -99,5 +100,7 @@ def main(args): parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false') parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false') parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}') + parser.add_argument('-s', '--system_prompt', type=str, default='You are a helpful assistant. You are friendly, courteous, and professional. All your responses must end with an exclamation point!', help='System prompt to use for the prompt.') + parser.add_argument('-re', '--rewind', action='store_true', default=False, help='Rewind to the system prompt after each generation. Defaults to false') args = parser.parse_args() main(args) \ No newline at end of file diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 7b7bb5b7a..d3ecacb84 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -217,6 +217,7 @@ void KV_Cache::Update(std::span beam_indices, int total_length) { is_first_update_ = false; } +// TODO(aciddelgado): test with past_present_share_buffer_ = false void KV_Cache::RewindTo(size_t index) { if (past_present_share_buffer_) { return; diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 89cfd5e79..d89a1ca19 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -209,8 +209,6 @@ void Logits::Update(RoamingArray& next_tokens, int new_kv_length) { input_sequence_lengths[b] = token_index + 1; } - std::cout << "new_kv_length: " << new_kv_length << std::endl; - shape_[1] = new_kv_length; StaticBuffer* sb_logits = type_ == Ort::TypeToTensorType ? sb_logits16_ : sb_logits32_; output_raw_ = !sb_logits ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) diff --git a/src/search.cpp b/src/search.cpp index 63d6ed9f9..89e039063 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -277,10 +277,7 @@ void GreedySearch_Cpu::RewindTo(size_t index) { memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); // Set next tokens to the last tokens in the sequence if (index > 0) { - auto last_tokens = sequences_.GetLastTokens(); - for (size_t i = 0; i < params_->batch_size; i++) { - SetNextToken(i, last_tokens[i]); - } + sequences_.GetLastTokens(next_tokens_); } else memset(next_tokens_.data(), 0, next_tokens_.size_bytes()); diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index a8d044c2d..0f321eb97 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -271,6 +271,17 @@ void GreedySearch_Cuda::SetUserTokens(RoamingArray next_tokens) { } } +void GreedySearch_Cuda::RewindTo(size_t index) { + sequences_.RewindTo(index); + cudaMemsetAsync(eos_meet_.data(), 0, eos_meet_.size_bytes(), params_->cuda_stream); + *done_cpu_ = false; + if (index > 0) { + sequences_.GetLastTokens(next_tokens_); + } + else + cudaMemsetAsync(next_tokens_.data(), 0, params_->batch_size * sizeof(int32_t), params_->cuda_stream); +} + void Search_Cuda::ApplyMinLength(int min_length) { if (sequences_.GetSequenceLength() >= min_length) return; diff --git a/src/search_cuda.h b/src/search_cuda.h index b70c00653..bcccb14ca 100644 --- a/src/search_cuda.h +++ b/src/search_cuda.h @@ -53,7 +53,7 @@ struct GreedySearch_Cuda : Search_Cuda { void SampleTopP(float p, float t) override; void SampleTopKTopP(int k, float p, float t) override; void SetUserTokens(RoamingArray next_tokens) override; // shape (batch_size, sequence_length) - + void RewindTo(size_t index) override; private: void CheckForEOS(); diff --git a/src/sequences.cpp b/src/sequences.cpp index ac95501cb..b436bc3c6 100644 --- a/src/sequences.cpp +++ b/src/sequences.cpp @@ -62,12 +62,10 @@ void Sequences::AppendNextTokenToSequences(std::span next_tokens) ++current_length_; } -cpu_span Sequences::GetLastTokens() { - std::vector last_tokens(batch_beam_size_); +void Sequences::GetLastTokens(cpu_span& last_tokens) { for (int i = 0; i < batch_beam_size_; i++) { last_tokens[i] = sequences_[i * max_length_ + current_length_ - 1]; } - return cpu_span{last_tokens.data(), last_tokens.size()}; } void Sequences::RewindTo(size_t index) { diff --git a/src/sequences.h b/src/sequences.h index 11caa5952..86ca444ba 100644 --- a/src/sequences.h +++ b/src/sequences.h @@ -20,7 +20,7 @@ struct Sequences { void AppendNextTokenToSequences(std::span next_tokens); // Return Token IDs of last token in each sequence - cpu_span GetLastTokens(); + void GetLastTokens(cpu_span& last_tokens); // Rewind sequences to ith token void RewindTo(size_t index); // TODO(aciddelgado): To be used for rewind diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index 8e2583c50..a35864c4a 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -9,6 +9,7 @@ namespace cuda { void Launch_ExpandInputSequences(std::span input_sequences, std::span sequences, int batch_size, int beam_size, int current_length, int max_length, cudaStream_t stream); void Launch_AppendNextTokenToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); void Launch_AppendUserTokensToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int past_length, int new_length, int max_length, cudaStream_t stream); +void Launch_GetLastTokens(std::span sequences, std::span last_tokens, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); } // namespace cuda // TODO(aciddelgado): update cuda sequences to new paradigm @@ -68,6 +69,16 @@ void Sequences_Cuda::AppendUserTokensToSequences(gpu_span user_tokens) current_length_ += new_length; } +void Sequences_Cuda::RewindTo(size_t index) { + current_length_ = index; + assert(current_length_ >= 0); +} + +void Sequences_Cuda::GetLastTokens(gpu_span& last_tokens) { + // TODO(aciddelgado): throw error when no last tokens + cuda::Launch_GetLastTokens(sequences_, last_tokens, batch_beam_size_, current_length_, max_length_, stream_); +} + void Sequences_Cuda::AfterDeviceAppendedNextToken() { ++current_length_; diff --git a/src/sequences_cuda.cu b/src/sequences_cuda.cu index d376a4c5f..65e2d6e4b 100644 --- a/src/sequences_cuda.cu +++ b/src/sequences_cuda.cu @@ -46,5 +46,18 @@ void Launch_AppendUserTokensToSequences(std::span user_tokens, st AppendUserTokensToSequences<<<1, 1, 0, stream>>>(user_tokens.data(), sequences.data(), batch_beam_size, past_length, new_length, max_length); } +// TODO(aciddelgado): parallelize this kernel. +__global__ void GetLastTokens(const int32_t* sequences, int32_t* last_tokens, int batch_beam_size, int current_length, int max_length) { + // Get the last token of each sequence. + for (int i = 0; i < batch_beam_size; i++) { + last_tokens[i] = sequences[i * max_length + current_length - 1]; + } +} + +void Launch_GetLastTokens(std::span sequences, std::span last_tokens, int batch_beam_size, int current_length, int max_length, cudaStream_t stream) { + // Get the last token of each sequence. + GetLastTokens<<<1, 1, 0, stream>>>(sequences.data(), last_tokens.data(), batch_beam_size, current_length, max_length); +} + } // namespace cuda } // namespace Generators diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index ef1bb1eb4..1c3be27dd 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -12,7 +12,10 @@ struct Sequences_Cuda { void AppendNextTokenToSequences(std::span next_tokens); void AppendUserTokensToSequences(gpu_span user_tokens); - void SetNextTokens(gpu_span next_tokens_span); + // void SetNextTokens(gpu_span next_tokens_span); + + void GetLastTokens(gpu_span& last_tokens); + void RewindTo(size_t index); // Returns current sequence length. int GetSequenceLength() const; From 60c42c61a9547d5ab416f6a9831826355660279e Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Thu, 3 Oct 2024 12:58:30 -0700 Subject: [PATCH 13/13] small stuff --- src/generators.cpp | 2 +- src/generators.h | 1 - src/models/decoder_only.h | 1 - src/models/kernels.cu | 4 +--- src/models/kv_cache.cpp | 4 ---- src/ort_genai_c.cpp | 8 -------- src/sequences_cuda.cpp | 7 ------- src/sequences_cuda.h | 1 - 8 files changed, 2 insertions(+), 26 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index f983922f9..b1c713d24 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -145,7 +145,7 @@ Generator::Generator(const Model& model, GeneratorParams& params) : model_{model } void Generator::AddTokens(cpu_span input_ids) { - // TODO(aciddelgado): batch_size > 1 requires full rewind + // TODO(aciddelgado): check for batch_size > 1 requires full rewind search_->SetUserTokens(input_ids); computed_logits_ = false; diff --git a/src/generators.h b/src/generators.h index 91bc89a33..4f246a83d 100644 --- a/src/generators.h +++ b/src/generators.h @@ -160,7 +160,6 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, GeneratorParams& params); -// std::vector> Generate(const Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction void top_k_indices(std::span top_k, std::span inputs); diff --git a/src/models/decoder_only.h b/src/models/decoder_only.h index c6d32855f..1b81ce065 100644 --- a/src/models/decoder_only.h +++ b/src/models/decoder_only.h @@ -25,7 +25,6 @@ struct DecoderOnly_State : State { protected: void UpdateInputsOutputs(RoamingArray& next_tokens, RoamingArray next_indices, int current_length); - void UpdateInputsOutputsFromSequence(const RoamingArray& sequence, size_t next_token_length, int past_length); // what this does const DecoderOnly_Model& model_; CapturedGraphInfoPtr captured_graph_info_; diff --git a/src/models/kernels.cu b/src/models/kernels.cu index ef60193de..a697de637 100644 --- a/src/models/kernels.cu +++ b/src/models/kernels.cu @@ -99,9 +99,7 @@ __global__ void UpdateAttentionMask(T* mask_data, int total_length) { } template -void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream) { - // LEFT OFF ABOUT THE UPDATE THING AND HOW SOMETIMES WE'LL JUST WANT TO UPDATE IN PLACE AND HAVE ACTUAL 0'S AND OTHER TIMES IT'S JUST 1'S ALL THE WAY THROUGH ON A NEW TENSOR SO WE DON'T NEEDT HE OLD ONE - +void Launch_UpdateAttentionMask(T* mask_data, int new_kv_length , int total_length, bool update_static, cudaStream_t stream) { if (update_static) { int threads = std::min(256, new_kv_length); int blocks = (new_kv_length + threads - 1) / threads; diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index d3ecacb84..4cc3a3e6e 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -237,10 +237,6 @@ void KV_Cache::RewindTo(size_t index) { } } -// LEFT OFF: -// check if we even have a present tensor -// use spans and size bytes to ensure correct sizes of data -// make sure we're copying the data correctly template void KV_Cache::RewindPastTensorsTo(size_t index) { assert(index > 0 && shape_[2] >= index && !past_present_share_buffer_); diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 7176adff9..dd1cf0a7d 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -158,14 +158,6 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorPa OGA_CATCH } -// OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { -// OGA_TRY -// auto result = Generators::Generate(*reinterpret_cast(model), *reinterpret_cast(generator_params)); -// *out = reinterpret_cast(std::make_unique(std::move(result)).release()); -// return nullptr; -// OGA_CATCH -// } - OgaResult* OgaCreateGenerator(const OgaModel* model, OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); diff --git a/src/sequences_cuda.cpp b/src/sequences_cuda.cpp index a35864c4a..ef00da7b4 100644 --- a/src/sequences_cuda.cpp +++ b/src/sequences_cuda.cpp @@ -12,8 +12,6 @@ void Launch_AppendUserTokensToSequences(std::span next_tokens, st void Launch_GetLastTokens(std::span sequences, std::span last_tokens, int batch_beam_size, int current_length, int max_length, cudaStream_t stream); } // namespace cuda -// TODO(aciddelgado): update cuda sequences to new paradigm - Sequences_Cuda::Sequences_Cuda(int batch_size, int beam_size, int max_length, cudaStream_t stream) : stream_{stream}, batch_beam_size_{batch_size * beam_size}, @@ -58,11 +56,6 @@ void Sequences_Cuda::AppendNextTokenToSequences(std::span next_to } void Sequences_Cuda::AppendUserTokensToSequences(gpu_span user_tokens) { - // if (g_log.enabled && g_log.set_next_tokens) { - // auto& stream = Log("set_next_tokens"); - // DumpCudaSpan(stream, next_tokens_span); - // stream << std::endl; - // } size_t new_length = user_tokens.size() / batch_beam_size_; size_t past_length = current_length_; cuda::Launch_AppendUserTokensToSequences(user_tokens, sequences_, batch_beam_size_, past_length, new_length, max_length_, stream_); diff --git a/src/sequences_cuda.h b/src/sequences_cuda.h index 1c3be27dd..dd1d14da9 100644 --- a/src/sequences_cuda.h +++ b/src/sequences_cuda.h @@ -12,7 +12,6 @@ struct Sequences_Cuda { void AppendNextTokenToSequences(std::span next_tokens); void AppendUserTokensToSequences(gpu_span user_tokens); - // void SetNextTokens(gpu_span next_tokens_span); void GetLastTokens(gpu_span& last_tokens); void RewindTo(size_t index);