Created
July 23, 2025 04:06
-
-
Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
https://www.oumi.ai/docs/en/latest/_modules/oumi/core/callbacks/hf_mfu_callback.html | |
""" | |
import torch | |
import wandb | |
from transformers import TrainerCallback, TrainerState, TrainerControl | |
# Theoretical Peak Tensor Core Performance (BF16 / FP16) | |
DEVICES = { | |
'NVIDIA GeForce RTX 3090': 142, | |
'NVIDIA GeForce RTX 3090 Ti': 160, | |
'NVIDIA H100 80GB HBM3': 990, | |
'NVIDIA GH200 480GB': 990, | |
} | |
class WandbMFUCallback(TrainerCallback): | |
def __init__(self, device_name = None): | |
self._time_of_second_step: Optional[float] = None | |
self._flops_at_second_step: Optional[float] = None | |
self._time_for_train_steps = 0.0 | |
self._first_step_finished = False | |
self._device_name = device_name | |
if self._device_name is None: | |
self._device_name = torch.cuda.get_device_name() | |
if self._device_name not in DEVICES: | |
raise Exception('device name is not recognized.') | |
def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
self._step_start_time = time.time() | |
if not self._first_step_finished: | |
return | |
if self._time_of_second_step is None: | |
self._time_of_second_step = self._step_start_time | |
if state is not None: | |
self._flops_at_second_step = state.total_flos | |
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
delta_time_seconds = time.time() - self._step_start_time | |
if not self._first_step_finished: | |
self._first_step_finished = True | |
return | |
self._time_for_train_steps += delta_time_seconds | |
def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
if self._time_of_second_step is None: | |
return | |
delta_time_seconds_train = time.time() - self._time_of_second_step | |
delta_time_seconds_step = self._time_for_train_steps | |
if self._flops_at_second_step is not None and ( | |
state is not None and state.total_flos > 0.0 | |
): | |
flops_since_second_step_on_all_devices = ( | |
state.total_flos - self._flops_at_second_step | |
) | |
flops_step = flops_since_second_step_on_all_devices / delta_time_seconds_step | |
flops_train = flops_since_second_step_on_all_devices / delta_time_seconds_train | |
device_flops_per_second = DEVICES[self._device_name] * 1e12 | |
train_step_mfu = flops_step / device_flops_per_second | |
train_mfu = flops_train / device_flops_per_second | |
wandb.log({ | |
"train_step_mfu": train_step_mfu, | |
"train_mfu": train_mfu, | |
}, step=state.global_step) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment