diff --git a/src/ops/add_bias_residual_layer_norm.cc b/src/ops/add_bias_residual_layer_norm.cc index 88a34b7eb5..a8a9e05e3d 100644 --- a/src/ops/add_bias_residual_layer_norm.cc +++ b/src/ops/add_bias_residual_layer_norm.cc @@ -618,12 +618,15 @@ void AddBiasResidualLayerNorm::inference_task( assert(task->regions.size() == regions.size()); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); - if (bc->num_tokens == 0) { - return; - } + AddBiasResidualLayerNormMeta *m = *((AddBiasResidualLayerNormMeta **)task->local_args); + std::string op_name_without_uid = AddBiasResidualLayerNorm::get_op_name_without_uid(m); + std::cout << "INF " << op_name_without_uid << std::endl; + if (bc->num_tokens == 0) { + return; + } assert(regions.size() == 5 + (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); @@ -1003,6 +1006,8 @@ void AddBiasResidualLayerNorm::peft_bwd_task( ctx, runtime); } + std::string op_name_without_uid = AddBiasResidualLayerNorm::get_op_name_without_uid(m); + std::cout << "BWD " << op_name_without_uid << " reset_in_grad[0]: " << m->reset_input_grads[0] << " reset_in_grad[1]: " << m->reset_input_grads[1] << std::endl; AddBiasResidualLayerNorm::peft_bwd_kernel_wrapper( m, output_grad, input_grad, residual_grad, gamma); diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index dd0e2bb822..cabb8b204f 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -392,11 +392,6 @@ InferenceResult GenericTensorAccessorW parent; int batch_size = bc->num_active_infr_tokens(); ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); - // Note that we free activation allocator here since argmax is the - // last operator in forward - if (m->handle.peft_activation_allocator != nullptr) { - m->handle.peft_activation_allocator->free_all(); - } InferenceResult ir; if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); diff --git a/src/ops/fused.cc b/src/ops/fused.cc index 8afd61aece..4f277f2a41 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -487,6 +487,11 @@ FutureMap FusedOp::inference(FFModel const &ff, // so we transfer the maximum of them // size_t batch_config_size = // std::max(sizeof(TreeVerifyBatchConfig), sizeof(BeamSearchBatchConfig)); + printf("FUSED! INFERENCE! %i ops\n", numOperators); + for (int i=0; iop_type << " " << oppp->name << std::endl; + } IndexLauncher launcher(FUSEDOP_INF_TASK_ID, parallel_is, TaskArgument(nullptr, 0), diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 5d52034575..2491634a76 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -818,12 +818,16 @@ void IncMultiHeadSelfAttention::inference_task( log_inc_mha.debug("BatchConfig, num_tokens: %d, num_requests: %d", bc->num_tokens, bc->num_active_requests()); - if (bc->num_tokens == 0) { - return; - } + IncMultiHeadSelfAttentionMeta *m = *((IncMultiHeadSelfAttentionMeta **)task->local_args); + std::string op_name_without_uid = IncMultiHeadSelfAttention::get_op_name_without_uid(m); + std::cout << "INF " << op_name_without_uid << std::endl; + + if (bc->num_tokens == 0) { + return; + } assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 : regions.size() == 3)); @@ -876,6 +880,37 @@ void IncMultiHeadSelfAttention::inference_task( } } +template +void load_tensor_from_file(DT *ptr, size_t size, std::string filepath) { + std::ifstream in(filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector
host_array(size); + size_t loaded_data_size = sizeof(DT) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load weight data error " << in_get_size << ", " + << loaded_data_size << ", " << sizeof(DT) << std::endl; + assert(false); + } + assert(size == host_array.size()); + + copy_tensor_host_to_dev(ptr, host_array.data(), size); + + // // normal + // long data_index = 0; + // for (auto v : host_array) { + // ptr[data_index++] = v; + // } + in.close(); +} + FutureMap IncMultiHeadSelfAttention::peft_bwd( FFModel const &ff, BatchConfigFuture const &bc, @@ -992,6 +1027,17 @@ void IncMultiHeadSelfAttention::peft_bwd_task( assert(task->index_point.get_dim() == 1); + std::string op_name_without_uid = IncMultiHeadSelfAttention::get_op_name_without_uid(m); + std::cout << "BWD " << op_name_without_uid << std::endl; + + if (op_name_without_uid == "layers_11_attention") { + load_tensor_from_file( + output_grad.get_float_ptr(), + (output_grad.domain.get_volume()/128)*24, + "/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.go_0.flexflow" + ); + } + IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( m, bc, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 452a8c09f6..5d6f2bc186 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -641,6 +641,8 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, m->hidden_size); } if (*m->apply_rotary_embedding) { + printf("ROTARY EMBEDDING: num_tokens: %i, q_array_size: %i, m->hidden_size: %i\n", + num_tokens, q_array_size, m->hidden_size); /*q&k*/ parallelism = num_tokens * m->hidden_size; apply_rotary_embedding_hf<<op_name); + size_t last_underscore = op_name_without_uid.length() - 1; + for (int i = op_name_without_uid.length() - 1; i > 0; i--) { + if (!(std::isdigit(m->op_name[i]) || m->op_name[i] == '_')) { + break; + } else if (m->op_name[i] == '_') { + last_underscore = i; + } + } + op_name_without_uid.erase(last_underscore); + + std::string base_filepath = + "./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) + + "_bwd-step_" + std::to_string(m->bwd_step) + + "_layer-num_" + std::to_string(m->layer_guid.transformer_layer_id) + + "_layer-name_" + op_name_without_uid + "_shard-id_" + + std::to_string(shard_id); + + + for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; @@ -995,6 +1017,12 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + // save result to file for checking + std::string filename = base_filepath + "_o_proj_in_grad"; + std::cout << "FILENAME: " << filename << std::endl; + save_tensor(C, m_*n_, filename.c_str()); + } } // Step 2: compute gradients w.r.t. value { @@ -1046,6 +1074,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // save result to file for checking + if (m->inference_debugging) { + std::string filename = base_filepath + "_v_proj_in_grad"; + std::cout << "FILENAME: " << filename << std::endl; + save_tensor(C, m_*n_*m->num_q_heads, filename.c_str()); + std::string filename2 = base_filepath + "_qk_prods_softmax"; + std::cout << "FILENAME: " << filename2 << std::endl; + save_tensor(A, m_*k_*m->num_q_heads, filename2.c_str()); + } } // Step 3: compute gradients w.r.t. the qk_prods_softmax tensor { @@ -1094,6 +1131,14 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename4 = base_filepath + "_qk_prods_softmax_grad"; + std::cout << "FILENAME: " << filename4 << std::endl; + save_tensor(C, num_tokens * num_tokens * m->num_q_heads, filename4.c_str()); + std::string filename5 = base_filepath + "_vcache"; + std::cout << "FILENAME: " << filename5 << std::endl; + save_tensor(B, m->vProjSize * m->num_q_heads * num_tokens, filename5.c_str()); + } } // Step 4: softmax backpropagation { @@ -1120,6 +1165,14 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, &beta, m->qk_tensor, m->qk_prods)); + + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename6 = base_filepath + "_qk_prods_softmax_grad_in"; + std::cout << "FILENAME: " << filename6 << std::endl; + save_tensor(C, num_tokens * num_tokens * m->num_q_heads, filename6.c_str()); + } + // TODO: fill all elements above diagonal to force causal attention size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; if (entries_above_diagonal > 0) { @@ -1135,6 +1188,12 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, entries_above_diagonal, DT(0.0f)); } + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename7 = base_filepath + "_qk_prods_softmax_grad_in_masked"; + std::cout << "FILENAME: " << filename7 << std::endl; + save_tensor(C, num_tokens * num_tokens * m->num_q_heads, filename7.c_str()); + } } // Step 5: compute gradients w.r.t. key { @@ -1189,6 +1248,14 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename8 = base_filepath + "_query_activation"; + std::cout << "FILENAME: " << filename8 << std::endl; + save_tensor(B, m->qProjSize * m->num_q_heads *num_tokens, filename8.c_str()); + std::string filename9 = base_filepath + "_devkproj_pre"; + std::cout << "FILENAME: " << filename9 << std::endl; + save_tensor(C, num_tokens * (m->qProjSize * m->num_q_heads), filename9.c_str()); + } } // Step 6: compute gradients w.r.t query { @@ -1208,7 +1275,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, // after transposition & striding int m_ = num_tokens; // num_new_tokens int n_ = m->qProjSize; - int k_ = num_tokens; + int k_ = num_tokens; // before transposition and striding int lda = num_tokens; // num_new_tokens int ldb = m->qProjSize * m->num_q_heads; @@ -1239,6 +1306,47 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename3 = base_filepath + "_devQKVPRojArray_pre"; + std::cout << "FILENAME: " << filename3 << std::endl; + save_tensor(C, num_tokens * m->qProjSize * m->num_q_heads * 3, filename3.c_str()); + } + } + + // Compute rotary embeddings bwd + { + if (*m->apply_rotary_embedding) { + assert(m->hidden_size == m->qProjSize * m->num_q_heads); + assert(m->qProjSize == m->kProjSize); + printf("ROTARY EMBEDDING bwd: num_tokens: %i, m->hidden_size: %i\n", num_tokens, m->hidden_size); + /*q&k*/ + int parallelism = num_tokens * m->hidden_size; + DT *A = static_cast
(m->devQKVProjArray); + apply_rotary_embedding_bwd<<>>(A, + m->complex_input, + m->token_infos, + m->qProjSize, + num_tokens, + m->hidden_size); + DT *C = static_cast
(m->devQKVProjArray); + if (m->inference_debugging) { + std::string filename3 = base_filepath + "_devQKVPRojArray"; + std::cout << "FILENAME: " << filename3 << std::endl; + save_tensor(C, num_tokens * m->qProjSize * m->num_q_heads * 3, filename3.c_str()); + } + } + + // matrix C: gradients for key (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = static_cast
(m->devQKVProjArray) + num_tokens * (m->qProjSize * m->num_q_heads); // skip over regions reserved for Q gradients + if (m->inference_debugging) { + std::string filename9 = base_filepath + "_devkproj"; + std::cout << "FILENAME: " << filename9 << std::endl; + save_tensor(C, num_tokens * (m->qProjSize * m->num_q_heads), filename9.c_str()); + } } // Step 7: perform rotary position embeddings (RoPE) bwd { @@ -1300,6 +1408,11 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename12 = base_filepath + "_attn_final_grad_in"; + std::cout << "FILENAME: " << filename12 << std::endl; + save_tensor(C, num_tokens * m->qSize, filename12.c_str()); + } } } } diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index 271a291b09..1624c0458d 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -290,10 +290,11 @@ __global__ void sparse_categorical_crossentropy_loss_peft_backward( int num_tokens, int num_classes) { CUDA_KERNEL_LOOP(i, num_tokens * num_classes) { - input_grad[i] = output_grad[i]; - if (i % num_classes == token_ids[i / num_classes]) { - input_grad[i] -= 1.0f; - } + input_grad[i] = 0.5; + // input_grad[i] = output_grad[i]; + // if (i % num_classes == token_ids[i / num_classes]) { + // input_grad[i] -= 1.0f; + // } } } @@ -345,14 +346,14 @@ void peft_bwd_kernel(SoftmaxMeta const *m, num_bwd_tokens, num_classes); // scale - scale_kernel<<>>(input_grad_ptr + - tokens_previous_requests * num_classes, - num_bwd_tokens * num_classes, - DT(0.0), - scale_factor); + // scale_kernel<<>>(input_grad_ptr + + // tokens_previous_requests * num_classes, + // num_bwd_tokens * num_classes, + // DT(0.0), + // scale_factor); tokens_previous_requests += num_bwd_tokens; } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 15789ae2e9..595b8d24e9 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -621,6 +621,8 @@ void Linear::inference_task(Task const *task, ctx, task->regions[0].region.get_index_space()); LinearMeta *m = *((LinearMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + std::string op_name_without_uid = Linear::get_op_name_without_uid(m); + printf("INF %s\n", op_name_without_uid.c_str()); if (bc->num_tokens == 0) { return; } @@ -757,6 +759,9 @@ void Linear::peft_bwd_task(Task const *task, int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; + std::string op_name_without_uid = Linear::get_op_name_without_uid(m); + std::cout << "BWD " << op_name_without_uid << std::endl; + int num_infr_tokens = bc->num_active_infr_tokens(); int num_peft_tokens = bc->num_active_peft_tokens(); if (m->inference_debugging) { diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index e39b444af4..fb13dc99cb 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -449,6 +449,8 @@ void LoraLinear::inference_task(Task const *task, Context ctx, Runtime *runtime) { LoraLinearMeta *m = *((LoraLinearMeta **)task->local_args); + std::string op_name_without_uid = LoraLinear::get_op_name_without_uid(m); + std::cout << "INF " << op_name_without_uid << std::endl; BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_active_tokens() == 0) { return; diff --git a/src/ops/residual_layer_norm.cc b/src/ops/residual_layer_norm.cc index 8563c299ab..61d021b6c0 100644 --- a/src/ops/residual_layer_norm.cc +++ b/src/ops/residual_layer_norm.cc @@ -823,6 +823,9 @@ void ResidualLayerNorm::peft_bwd_task( ctx, runtime); } + std::string op_name_without_uid = ResidualLayerNorm::get_op_name_without_uid(m); + std::cout << "BWD " << op_name_without_uid << " reset_in_grad[0]: " << m->reset_input_grads[0] << " reset_in_grad[1]: " << m->reset_input_grads[1] << std::endl; + ResidualLayerNorm::peft_bwd_kernel_wrapper( m, output_grad, input_grad, residual1_grad, residual2_grad, gamma); @@ -951,7 +954,10 @@ void ResidualLayerNorm::inference_task( assert(task->regions.size() == regions.size()); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); ResidualLayerNormMeta *m = *((ResidualLayerNormMeta **)task->local_args); + std::string op_name_without_uid = ResidualLayerNorm::get_op_name_without_uid(m); + std::cout << "INF " << op_name_without_uid << std::endl; if (bc->num_tokens == 0) { + printf("Zero tokens\n"); return; } diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index 1f87949234..329c996839 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -254,6 +254,7 @@ void ResidualLayerNorm::inference_kernel_wrapper( MemoryAllocator *allocator = m->handle.peft_activation_allocator; m->input_activation = allocator->allocate_instance_untyped( data_type_size(m->input_type[0]) * num_peft_tokens * in_dim); + printf("Allocating input_activation (%p) of size: %i*%i*%i=%i for %s...\n", m->input_activation, data_type_size(m->input_type[0]), num_peft_tokens,in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, m->op_name); // copy input activation if (m->input_type[0] == DT_FLOAT) { checkCUDA(cudaMemcpyAsync( diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index c2fbe11544..ff72b2273a 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -673,6 +673,37 @@ Legion::FutureMap return runtime->execute_index_space(ctx, launcher); } +template +void load_tensor_from_file(DT *ptr, size_t size, std::string filepath) { + std::ifstream in(filepath, std::ios::in | std::ios::binary); + if (!in.good()) { + std::cout << "Could not open file: " << filepath << std::endl; + } + assert(in.good() && "incorrect weight file path"); + std::vector
host_array(size); + size_t loaded_data_size = sizeof(DT) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + in.read((char *)host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) { + std::cout << "load weight data error " << in_get_size << ", " + << loaded_data_size << ", " << sizeof(DT) << std::endl; + assert(false); + } + assert(size == host_array.size()); + + copy_tensor_host_to_dev(ptr, host_array.data(), size); + + // // normal + // long data_index = 0; + // for (auto v : host_array) { + // ptr[data_index++] = v; + // } + in.close(); +} + /* regions[0](I): RMS output_grad regions[1](I/O): Residual input 0 grad @@ -710,7 +741,46 @@ void ResidualRMSNorm::peft_bwd_task(Task const *task, m->weight_type[0], regions[3], task->regions[3], FID_DATA, ctx, runtime); peft_bwd_kernel_wrapper( m, bc, output_grad, residual_input0_grad, residual_input1_grad, weight); - + + // get name + std::string op_name_without_uid = ResidualRMSNorm::get_op_name_without_uid(m); + std::cout << "BWD " << op_name_without_uid << " reset_in_grad[0]: " << m->reset_input_grads[0] << " reset_in_grad[1]: " << m->reset_input_grads[1] << std::endl; + // print shape + int numdims = residual_input0_grad.domain.get_dim(); + std::cout << "in grad dims: "; + for (int i=0; iinference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 1d062b552b..335c7f99d4 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -374,12 +374,16 @@ void Softmax::inference_task(Task const *task, assert(regions.size() == 3); assert(task->regions.size() == 3); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + SoftmaxMeta *m = *((SoftmaxMeta **)task->local_args); + + std::string op_name_without_uid = Softmax::get_op_name_without_uid(m); + std::cout << "INF " << op_name_without_uid << std::endl; if (bc->num_tokens == 0) { return; } Domain in_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - SoftmaxMeta *m = *((SoftmaxMeta **)task->local_args); + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( @@ -439,6 +443,7 @@ void Softmax::peft_bwd_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { + printf("BWD softmax\n"); assert(task->regions.size() == regions.size()); assert(regions.size() == 2); assert(task->regions.size() == 2); diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index 7836633b30..818e0b9085 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -72,6 +72,8 @@ def peft_backward_hook(module, grad_input, grad_output): print("\t", go.shape) print(f"\t\tSaving to {dst_filepath}") torch.save(go, dst_filepath) + if dst_filepath == "./hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.go_0": + go.detach().cpu().numpy().tofile(f"{dst_filepath}.flexflow") else: print(go) print("Backward GRAD Input:") @@ -81,6 +83,8 @@ def peft_backward_hook(module, grad_input, grad_output): print("\t", gi.shape) print(f"\t\tSaving to {dst_filepath}") torch.save(gi, dst_filepath) + if dst_filepath == "./hf_peft_tensors/bwd_step_0_layers.11.post_attention_layernorm.gi_0" or dst_filepath == "./hf_peft_tensors/bwd_step_0_norm.gi_0": + gi.detach().cpu().numpy().tofile(f"{dst_filepath}.flexflow") else: print(gi) @@ -225,6 +229,8 @@ def main(): torch.save(params, f"./hf_peft_tensors/{name}") if "lm_head" in name or "norm" in name: torch.save(params, f"./hf_peft_tensors/{name}") + if "down_proj" in name or "self_attn" in name: + torch.save(params, f"./hf_peft_tensors/{name}") # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes")