Skip to content

Commit

Permalink
[BE] add integration test for the generation script
Browse files Browse the repository at this point in the history
ghstack-source-id: 42013cd4bc32a60e5c79c4b3ed976584eb17efad
Pull Request resolved: #741
  • Loading branch information
tianyu-l committed Dec 17, 2024
1 parent 915da67 commit 896b7ca
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 23 deletions.
21 changes: 12 additions & 9 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh):
},
)

for layer_id, transformer_block in model.layers.items():
for _, transformer_block in model.layers.items():
layer_plan = {
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
Expand Down Expand Up @@ -81,6 +81,7 @@ def test_generate(
batch_size: int = 1,
top_k: Optional[int] = None,
seed: Optional[int] = None,
deterministic: bool = False,
):
init_logger()
color = utils.Color
Expand All @@ -95,13 +96,6 @@ def test_generate(
"The input prompt is empty, model will respond from a empty sequence."
)

utils.set_determinism(seed)

if seed is None:
logger.info("Deterministic sampling off")
else:
logger.info(f"Deterministic sampling on. Using seed: {seed}")

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"{device_type}:{local_rank}")
Expand All @@ -128,6 +122,7 @@ def test_generate(
logger.info(f"Init model on init_device: {init_device}")
model = model_cls.from_model_args(model_config)

world_mesh = None
# Init distributed env
if world_size > 1:
utils.init_distributed(config)
Expand All @@ -147,6 +142,8 @@ def test_generate(
# sequences would require https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, world_mesh["tp"])

utils.set_determinism(world_mesh, device, seed, deterministic)

# materalize model
model.to_empty(device=device_type)
model.eval()
Expand Down Expand Up @@ -276,8 +273,13 @@ def test_generate(
"--top_k", type=int, help="Prune to select from top_k probabilities. Optional"
)
parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
parser.add_argument(
"--deterministic",
action="store_true",
help="Use deterministic algorithms wherever possible, may be slower",
)

parser.add_argument("--prompt", type=str, help="Input prompt")
parser.add_argument("--prompt", type=str, default="", help="Input prompt")

parser.add_argument(
"--out",
Expand All @@ -297,6 +299,7 @@ def test_generate(
batch_size=args.batch_size,
top_k=args.top_k,
seed=args.seed,
deterministic=args.deterministic,
)

if torch.distributed.is_initialized():
Expand Down
37 changes: 30 additions & 7 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ def build_test_list():
"fsdp+tp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--optimizer.early_step_in_backward",
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
),
OverrideDefinitions(
[
[
Expand All @@ -376,14 +389,14 @@ def build_test_list():
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--optimizer.early_step_in_backward",
],
[
# placeholder for the generation script's generate step
],
],
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
"Generation script test",
"test_generate",
ngpu=2,
),
]
return integration_tests_flavors
Expand All @@ -406,7 +419,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
model_flavor_arg = f"--model.flavor {test_flavor.model_flavor}"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))

for override_arg in test_flavor.override_args:
for idx, override_arg in enumerate(test_flavor.override_args):
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
# dump compile trace for debugging purpose
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
Expand All @@ -422,6 +435,16 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
logger.info(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)

# save checkpoint (idx == 0) and load it for generation (idx == 1)
if test_name == "test_generate" and idx == 1:
cmd = (
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 "
"PROMPT='What is the meaning of life?' "
f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json"
)

result = _run_cmd(cmd)
logger.info(result.stdout)
if result.returncode != 0:
Expand Down
18 changes: 12 additions & 6 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger


Expand Down Expand Up @@ -54,9 +53,10 @@ def _warn_overwrite_env(env, val):


def set_determinism(
world_mesh: DeviceMesh,
world_mesh: Optional[DeviceMesh],
device: torch.device,
job_config: JobConfig,
seed: Optional[int] = None,
deterministic: bool = False,
) -> None:
"""
Set the same DTensor manual seed for all ranks within the same DTensor SPMD group, but different
Expand All @@ -67,18 +67,24 @@ def set_determinism(
Set Determinism flags for increased reproducibility with loss of performance.
"""
if job_config.training.deterministic:
logger.info("Deterministic training enabled (expect perf degradation).")
if deterministic:
logger.info("Deterministic algorithm enabled (expect perf degradation).")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# env var for deterministic CuBLAS
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

if not world_mesh:
if seed is not None:
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
logger.debug(f"Single-process job using seed: {seed}")
return

# to ensure we can control which ranks have same or different seeds, all ranks agree on a starting seed.
# if user provides one, we use this. Otherwise rank 0 rolls the dice and everyone else uses that.
seed = job_config.training.seed
if seed is None:
# Extract the seed for torch's main generator on rank 0 and standardizes on using that to build
# seeds for unique SPMD groups
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def main(job_config: JobConfig):
pp_mesh = world_mesh["pp"]

# Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
utils.set_determinism(world_mesh, device, job_config)
utils.set_determinism(
world_mesh, device, job_config.training.seed, job_config.training.deterministic
)
model_name = job_config.model.name

# build tokenizer
Expand Down

0 comments on commit 896b7ca

Please sign in to comment.