Skip to content

Commit

Permalink
Support torch_bfloat16 (#17)
Browse files Browse the repository at this point in the history
* Support torch_bfloat16

* Cast scores properly

* Fix ipmorts
  • Loading branch information
ljvmiranda921 authored Aug 5, 2024
1 parent b57e690 commit 41a8292
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions scripts/run_rewardbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@

from scripts.utils import load_multilingual_eval_dataset

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


def torch_dtype_mapping(dtype_str):
"""
Helper function for argparse to map string to torch dtype.
"""
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64,
}
if dtype_str not in dtype_map:
raise argparse.ArgumentTypeError(f"Invalid torch dtype: {dtype_str}")
return dtype_map[dtype_str]


def main():
parser = argparse.ArgumentParser(description="Evaluate a reward model.")
Expand All @@ -62,8 +80,10 @@ def main():
parser.add_argument("--output_dir", type=str, default="results/", help="the output directory to save results")
parser.add_argument("--save_all", action="store_true", default=False, help="save all results (include scores per instance)")
parser.add_argument("--force_truncation", action="store_true", default=False, help="force truncation (if model errors)")
parser.add_argument("--torch_dtype", type=str, default="float16", choices=["float16", "bfloat16", "float32", "float64"], help="set PyTorch dtype (default: float16)")
# fmt: on
args = parser.parse_args()
args.torch_dtype = torch_dtype_mapping(args.torch_dtype)

###############
# Setup logging
Expand Down Expand Up @@ -111,6 +131,14 @@ def main():
config = MODEL_CONFIGS["default"]
logger.info(f"Using reward model config: {config}")

torch_dtype = config.get("torch_dtype", None)
if torch_dtype is None:
# if datatype is bfloat16, then manually turn off quantizaiton (done with bitsandbytes)
if args.torch_dtype == torch.bfloat16:
quantized = False
logger.info("Disabling quantization for bfloat16 datatype")
torch_dtype = args.torch_dtype

# Default entries
# "model_builder": AutoModelForSequenceClassification.from_pretrained,
# "pipeline_builder": pipeline,
Expand All @@ -126,6 +154,7 @@ def main():
or ("Llama3" in args.model)
or ("Llama-3" in args.model)
or ("LLaMA3" in args.model)
or ("llama3" in args.model)
or args.not_quantized
):
quantized = False
Expand Down Expand Up @@ -184,7 +213,7 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
model = model_builder(
args.model,
Expand Down Expand Up @@ -247,11 +276,14 @@ def main():
model_kwargs = {
"load_in_8bit": True,
"device_map": {"": current_device},
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
"torch_dtype": torch_dtype if torch.cuda.is_available() else None,
}
else:
# note, device map auto does not work for quantized models
model_kwargs = {"device_map": "auto"}
model_kwargs = {
"device_map": "auto",
"torch_dtype": torch_dtype,
}

model = model_builder(args.model, **model_kwargs, trust_remote_code=args.trust_remote_code)
reward_pipe = pipeline_builder(
Expand Down Expand Up @@ -306,8 +338,9 @@ def main():
score_rejected_batch = [result["score"] for result in rewards_rejected]
# for classes that directly output scores (custom code)
else:
score_chosen_batch = rewards_chosen.cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.cpu().numpy().tolist()
# Cast to float in case of bfloat16
score_chosen_batch = rewards_chosen.float().cpu().numpy().tolist()
score_rejected_batch = rewards_rejected.float().cpu().numpy().tolist()

# log results
[
Expand Down

0 comments on commit 41a8292

Please sign in to comment.