Skip to content

Commit

Permalink
fix pylint and fix uts.
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 committed Dec 26, 2024
1 parent b4ad394 commit 68e2218
Show file tree
Hide file tree
Showing 12 changed files with 13 additions and 152 deletions.
41 changes: 1 addition & 40 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,6 @@
from chatlearn.launcher import dlc_utils


from pynvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetComputeRunningProcesses, nvmlShutdown

def get_gpu_usage():
nvmlInit()
device_count = nvmlDeviceGetCount()
gpu_usage = {}
for i in range(device_count):
handle = nvmlDeviceGetHandleByIndex(i)
processes = nvmlDeviceGetComputeRunningProcesses(handle)
for process in processes:
if process.pid == os.getpid():
gpu_usage[i] = process.pid
nvmlShutdown()
return gpu_usage


class BaseModule:
"""BaseModule is the base class for Base models.
Expand Down Expand Up @@ -146,7 +130,6 @@ def __init__(self, name, args=None, replica_id=0):
self._sync_buffer = defaultdict(list)
self._expert_sync_buffer = {}
self._synchronizer = None
self.col_groups = {}

def get_sync_buffer(self):
return self._sync_buffer
Expand Down Expand Up @@ -502,8 +485,7 @@ def setup_collective_group(self, rank, world_size, backend, group_name):
"""
self._group_names.append(group_name)
self._world_size = world_size
# breakpoint()
self.col_groups[group_name] = col.init_collective_group(
col.init_collective_group(
world_size, rank, backend=backend, group_name=group_name)

def _destroy_collective_group(self, group_name):
Expand Down Expand Up @@ -635,7 +617,6 @@ def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_t
if self._synchronizer is not None:
params_to_sync_list = self._synchronizer.transform_parameters(params_to_sync_list)
parameters_to_sync[pipe_stage] = params_to_sync_list
print(f"{self} params_to_sync_list: {len(params_to_sync_list)} {[name for name,_ in params_to_sync_list]}")
return parameters_to_sync

def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
Expand Down Expand Up @@ -790,39 +771,19 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
:meta private:
"""
tensors = []
# breakpoint()
print(f"debug from {src_rank} to {rank} self._parameters_to_sync[{pipe_stage}]: {len(self._parameters_to_sync[pipe_stage])} {self._parameters_to_sync[pipe_stage][0][0]}", flush=True)
print(f"{self} get_visible_gpus: {self.get_visible_gpus()}")
for name, param in self._parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
(self._synchronizer and self._synchronizer.is_parameter_changed):
tensors.append(self._expert_sync_buffer[name])
else:
tensors.append(param.data)
# if src_rank != rank:
# print(f"self.worker.device: {self.worker.device} current_device: {torch.cuda.current_device()} tensors[0].device: {tensors[0].device}")

# breakpoint()
print(f"debug from {src_rank} to {rank} tensors: {len(tensors)} {[tensors[0].shape]}", flush=True)
print(f"debug get_gpu_usage {self}: {get_gpu_usage()}", flush=True)
tensor_device = tensors[0].device
print(f"debug tensor is on device {self}: {tensor_device}")


assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
tensor_changed = rank != src_rank

print(f"debug from {src_rank} to {rank} dense_buckets: {len(dense_buckets)} sparse_bucket: {len(sparse_bucket)}", flush=True)
# breakpoint()
print(f"self.col_groups: {self.col_groups}")
for bucket in dense_buckets:
# for ele in bucket:
# breakpoint()
# col.broadcast(ele, src_rank, group_name)
# coalesced_comm_dense([ele], col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)

coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)

for param in sparse_bucket:
Expand Down
29 changes: 1 addition & 28 deletions chatlearn/models/vllm/hooks/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
get_ip, get_open_port, get_vllm_instance_id)

from chatlearn.utils.global_vars import get_vllm_actors
from chatlearn.models.vllm_module_v2 import VLLMModuleV2

logger = init_logger(__name__)

Expand Down Expand Up @@ -57,7 +56,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_ip = get_ip()
# driver_actor_id = ray.get_runtime_context().get_actor_id()
vllm_workers = get_vllm_actors()
print(f"debug 1111 self.workers 0: {[id(ele) for ele in vllm_workers]} {vllm_workers}")
worker_wrapper_kwargs = self._get_worker_wrapper_args()
if self.use_ray_spmd_worker:
self.workers = vllm_workers
Expand All @@ -70,15 +68,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
# self.driver_worker = worker
self.driver_worker = RayWorkerWrapper(
# worker.model.__class__,
**worker_wrapper_kwargs)
# ray.get(worker.set_driver_worker.remote(self.driver_worker))
else:
# Else, added to the list of workers.
self.workers.append(worker)
print(f"debug 1111 self.workers 1: {[id(ele) for ele in self.workers]} {self.workers}")
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
Expand Down Expand Up @@ -110,17 +104,11 @@ def sort_by_driver_then_worker_ip(worker):
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
# print(f"debug 1111 self.workers 2: {[id(ele) for ele in self.workers]} {self.workers}")
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# print(f"debug 1111 self.workers 3: {[id(ele) for ele in self.workers]} {self.workers}")

# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
# worker_node_and_gpu_ids = [ray.get(worker.get_node_and_gpu_ids.remote()) for worker in self.workers]
# print(f"debug hahahaha self.driver_dummy_worker: {self.driver_dummy_worker}")
# worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids")
print(f"debug worker_node_and_gpu_ids: {worker_node_and_gpu_ids}")

node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
Expand Down Expand Up @@ -169,7 +157,6 @@ def sort_by_driver_then_worker_ip(worker):
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
# use_dummy_driver=True,
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
Expand All @@ -195,21 +182,12 @@ def sort_by_driver_then_worker_ip(worker):
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

print(f"debug nanana init_device.....", flush=True)
print(f"debug 1111 self.workers 4: {[id(ele) for ele in self.workers]} {self.workers}")


self._run_workers("init_device")
from chatlearn.utils import future
refs = [self.workers[rank].init_device.remote() for rank in range(len(self.workers))]
future.wait(refs)


self._run_workers("load_model",
# use_dummy_driver=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
print(f"debug nanana load_model.....", flush=True)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
Expand All @@ -223,11 +201,6 @@ def sort_by_driver_then_worker_ip(worker):
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])

tp_pp_pairs = self._run_workers('get_tp_and_pp_rank')
print(f"debug tp_pp_pairs: {tp_pp_pairs}")
for worker, (tp_rank, pp_rank) in zip(self.workers, tp_pp_pairs):
ray.get(worker.set_tp_pp_ranks.remote(tp_rank, pp_rank))

# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
Expand Down
3 changes: 0 additions & 3 deletions chatlearn/models/vllm/hooks/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from vllm.worker.worker_base import logger

def get_tp_and_pp_rank(self):
print(f"debug get_tp_and_pp_rank...")
return parallel_state.get_tensor_model_parallel_rank(), \
parallel_state.get_pp_group().rank_in_group

Expand All @@ -39,8 +38,6 @@ def execute_method(self, method, *args, **kwargs):
target = self.worker
else:
target = self
if method == "init_device":
print(f"debug target: {target} self.worker: {self.worker}")
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
Expand Down
27 changes: 2 additions & 25 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def __init__(self, *args, **kwargs):
# common method is '__init__'
assert common_methods == {'__init__'}, \
f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
# TorchModule.__init__(self, *args)
self.local_rank = 0
# os.environ['LOCAL_RANK'] = '0'
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
Expand Down Expand Up @@ -81,14 +79,6 @@ def init(self):
initialize_vllm(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)
def init_device(self):
return
self.worker.device = torch.device(f"cuda:{torch.cuda.current_device()}")
torch.cuda.set_device(self.device)
init_worker_distributed_environment(self.worker.parallel_config, self.worker.rank,
self.worker.distributed_init_method,
self.worker.local_rank)
# return self.worker.init_device()

def setup(self):
"""Set up model and load checkpoint"""
Expand All @@ -100,7 +90,6 @@ def setup(self):
def setup_vllm(self, workers):
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
print(f"debug 1111 workers: {[id(ele) for ele in workers]} {workers}")
set_vllm_actors(workers)

dtype = self.model_args.get("dtype", "bfloat16")
Expand Down Expand Up @@ -196,9 +185,6 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids):
return inputs

def generate_vllm(self, query, is_eval):
print(f"debug aaaa tensor_parallel_rank: {id(self)} {self}")
print(f"tensor_parallel_rank: {self.tensor_parallel_rank()}", flush=True)
print(f"debug aaaa pipeline_parallel_rank: {id(self)} {self} {self.pipeline_parallel_rank()}", flush=True)
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")

Expand Down Expand Up @@ -233,8 +219,6 @@ def generate_vllm(self, query, is_eval):
sampling_params,
use_tqdm=True,
)
print(f"debug aaaa tensor_parallel_rank: {self.tensor_parallel_rank()}", flush=True)
print(f"debug aaaa pipeline_parallel_rank: {self.pipeline_parallel_rank()}", flush=True)
final_outputs = sorted(outputs, key=lambda x: int(x.request_id))
return final_outputs

Expand Down Expand Up @@ -275,22 +259,15 @@ def model(self):
f"please set env variables `VLLM_USE_RAY_SPMD_WORKER` and `VLLM_USE_RAY_COMPILED_DAG` first."
self._model = self.worker.model_runner.model
return self._model
def set_tp_pp_ranks(self, tp_rank, pp_rank):
"""
:meta private:
"""
print(f"aaaa debug set_tp_pp_ranks: {tp_rank} {pp_rank}", flush=True)
self._tp_rank = tp_rank
self._pp_rank = pp_rank

def tensor_parallel_rank(self):
"""
:meta private:
"""
return self._tp_rank
return parallel_state.get_tensor_model_parallel_rank()

def pipeline_parallel_rank(self):
"""
:meta private:
"""
return self._pp_rank
return get_pipeline_model_parallel_rank()
7 changes: 1 addition & 6 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,8 @@ def create_actor(self, num_gpus, placement_group, group_index):
"worker_class_fn": None,
"trust_remote_code": True,
}
# if self.vllm_engine is None:
# self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)
# self.model.engine = self.vllm_engine
# else:
# self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)

self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)

def create_engine_actor(self, num_gpus, placement_group, group_index):
self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
self.model.engine = self.vllm_engine
Expand Down
6 changes: 3 additions & 3 deletions chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def setup(self):
executor.update_models(self.remote_models)
if self.env:
self.env.set_dataset(self._dataset)
self.model_manager.build_parameter_group()
self.model_manager.start_error_monitor()

def set_dataset(self, dataset):
"""
Set prompt dataset.
Expand Down Expand Up @@ -279,8 +282,6 @@ def learn(self):
for executor in self._executors:
if executor:
executor.setup()
self.model_manager.build_parameter_group()
self.model_manager.start_error_monitor()
self.timers("setup").stop()
logger.info(f"{LOG_START} {self._name} setup summary {self.timers.log(names=['setup'])}")
self.logging_memory()
Expand All @@ -292,7 +293,6 @@ def learn(self):
self.runtime_args.max_relay_episode,
self.runtime_args.relay_episode_offset)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
# breakpoint()
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync'))
self._data_loader = data_loader
Expand Down
1 change: 0 additions & 1 deletion chatlearn/runtime/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def setup(self):

if isinstance(model.model, VLLMModuleV2):
for replica in model_node.model.replicas:
print(f"debug 1111 replica.all_actors: {[id(ele) for ele in replica.all_actors]} {replica.all_actors}")
ret = replica.vllm_engine.setup_vllm.remote(replica.all_actors)
future.wait(ret)

Expand Down
10 changes: 0 additions & 10 deletions chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,12 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False):
sync_group: ParameterSyncGroup = sync_group

src_model, dst_model = sync_group.src_model, sync_group.dst_model
print(f"sync from {src_model} to {dst_model}", flush=True)
# breakpoint()
refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False)
future.wait(refs)
print(f"having onload {src_model}", flush=True)
# breakpoint()

refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False)
future.wait(refs)
print(f"having onload {dst_model}", flush=True)
# breakpoint()

print(f"sync_group {sync_group} start to sync from {src_model} to {dst_model}", flush=True)
# breakpoint()
sync_group.sync(requires_grad, validate)
# breakpoint()

refs = src_model.offload()
future.wait(refs)
Expand Down
Loading

0 comments on commit 68e2218

Please sign in to comment.