-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix NetworkProfiler import and prepare for v0.2 release
- Loading branch information
Showing
10 changed files
with
456 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,80 @@ | ||
import functools | ||
from typing import Dict, Any | ||
import logging | ||
import asyncio | ||
from typing import Callable, Any, Optional | ||
from .config import config | ||
from .logging_config import setup_logging | ||
from .integration import get_framework_adapter | ||
from .analysis import Analyzer | ||
from .reporting import ReportGenerator | ||
from .exceptions import MemoraithError | ||
from .visualization.real_time_visualizer import RealTimeVisualizer | ||
from memoraith.data_collection.cpu_memory import CPUMemoryTracker | ||
from memoraith.data_collection.gpu_memory import GPUMemoryTracker | ||
from memoraith.data_collection.time_tracking import TimeTracker | ||
from memoraith.data_collection.network_profiler import NetworkProfiler | ||
|
||
def profile_model( | ||
memory: bool = True, | ||
computation: bool = True, | ||
gpu: bool = False, | ||
save_report: bool = True, | ||
report_format: str = 'html', | ||
real_time_viz: bool = False | ||
) -> Callable: | ||
""" | ||
Decorator to profile a model's training or inference function. | ||
class ModelProfiler: | ||
def __init__(self): | ||
self.cpu_tracker = CPUMemoryTracker() | ||
self.gpu_tracker = GPUMemoryTracker() | ||
self.time_tracker = TimeTracker() | ||
self.network_profiler = NetworkProfiler() | ||
self.logger = logging.getLogger(__name__) | ||
|
||
Args: | ||
memory (bool): Enable memory profiling | ||
computation (bool): Enable computation time profiling | ||
gpu (bool): Enable GPU profiling | ||
save_report (bool): Save the profiling report | ||
report_format (str): Format of the saved report ('html' or 'pdf') | ||
real_time_viz (bool): Enable real-time visualization | ||
async def start_profiling(self): | ||
self.logger.info("Starting model profiling") | ||
await self.cpu_tracker.start() | ||
if self.gpu_tracker: | ||
await self.gpu_tracker.start() | ||
self.time_tracker.start('training') | ||
self.network_profiler.start() | ||
|
||
Returns: | ||
Callable: Decorated function | ||
""" | ||
def decorator(func: Callable) -> Callable: | ||
@functools.wraps(func) | ||
async def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
setup_logging(config.log_level) | ||
logger = logging.getLogger('memoraith') | ||
logger.info("Starting Memoraith Profiler...") | ||
async def stop_profiling(self): | ||
self.logger.info("Stopping model profiling") | ||
cpu_memory = await self.cpu_tracker.get_peak_memory() | ||
gpu_memory = await self.gpu_tracker.get_peak_memory() if self.gpu_tracker else None | ||
duration = self.time_tracker.get_duration('training') | ||
network_usage = self.network_profiler.stop() | ||
|
||
config.enable_memory = memory | ||
config.enable_time = computation | ||
config.enable_gpu = gpu | ||
profiling_results = { | ||
'cpu_memory': cpu_memory, | ||
'gpu_memory': gpu_memory, | ||
'training_time': duration, | ||
'network_usage': network_usage | ||
} | ||
self.logger.info(f"Profiling results: {profiling_results}") | ||
return profiling_results | ||
|
||
try: | ||
model = kwargs.get('model') or args[0] | ||
adapter = get_framework_adapter(model) | ||
async def profile_step(self, step_name: str): | ||
self.time_tracker.start(step_name) | ||
cpu_memory_before = await self.cpu_tracker.get_current_memory() | ||
gpu_memory_before = await self.gpu_tracker.get_current_memory() if self.gpu_tracker else None | ||
network_usage_before = self.network_profiler.get_current_usage() | ||
|
||
visualizer = RealTimeVisualizer() if real_time_viz else None | ||
yield # Yield control to allow the step to execute | ||
|
||
async with adapter: | ||
if asyncio.iscoroutinefunction(func): | ||
result = await func(*args, **kwargs) | ||
else: | ||
result = await asyncio.to_thread(func, *args, **kwargs) | ||
cpu_memory_after = await self.cpu_tracker.get_current_memory() | ||
gpu_memory_after = await self.gpu_tracker.get_current_memory() if self.gpu_tracker else None | ||
network_usage_after = self.network_profiler.get_current_usage() | ||
duration = self.time_tracker.stop(step_name) | ||
|
||
if visualizer: | ||
await visualizer.update(adapter.data) | ||
step_profile = { | ||
'name': step_name, | ||
'duration': duration, | ||
'cpu_memory_used': cpu_memory_after - cpu_memory_before, | ||
'gpu_memory_used': gpu_memory_after - gpu_memory_before if gpu_memory_after and gpu_memory_before else None, | ||
'network_sent': network_usage_after['bytes_sent'] - network_usage_before['bytes_sent'], | ||
'network_recv': network_usage_after['bytes_recv'] - network_usage_before['bytes_recv'], | ||
} | ||
|
||
analysis_results = await Analyzer(adapter.data).run_analysis() | ||
self.logger.info(f"Step profile for {step_name}: {step_profile}") | ||
yield step_profile | ||
|
||
if save_report: | ||
await ReportGenerator(analysis_results).generate(format=report_format) | ||
def get_summary(self) -> Dict[str, Any]: | ||
return { | ||
'total_time': self.time_tracker.get_total_duration(), | ||
'peak_cpu_memory': self.cpu_tracker.get_peak_memory(), | ||
'peak_gpu_memory': self.gpu_tracker.get_peak_memory() if self.gpu_tracker else None, | ||
'average_network_usage': self.network_profiler.get_average_usage(), | ||
} | ||
|
||
logger.info("Memoraith Profiling Completed.") | ||
return result | ||
|
||
except MemoraithError as e: | ||
logger.error(f"MemoraithError: {e}") | ||
raise | ||
except Exception as e: | ||
logger.exception("An unexpected error occurred during profiling.") | ||
raise | ||
|
||
return wrapper | ||
return decorator | ||
|
||
def set_output_path(path: str) -> None: | ||
"""Set the output path for profiling reports.""" | ||
config.set_output_path(path) | ||
def reset(self): | ||
self.cpu_tracker.reset() | ||
if self.gpu_tracker: | ||
self.gpu_tracker.reset() | ||
self.time_tracker.reset() | ||
self.network_profiler.reset() | ||
self.logger.info("All profilers reset") |
Oops, something went wrong.