Skip to content

Instantly share code, notes, and snippets.

@huseinzol05
Created July 23, 2025 04:06
Show Gist options
  • Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.
Save huseinzol05/c9373b2d0f80f4270d6a82398e47bf54 to your computer and use it in GitHub Desktop.
"""
https://www.oumi.ai/docs/en/latest/_modules/oumi/core/callbacks/hf_mfu_callback.html
"""
import torch
import wandb
from transformers import TrainerCallback, TrainerState, TrainerControl
# Theoretical Peak Tensor Core Performance (BF16 / FP16)
DEVICES = {
'NVIDIA GeForce RTX 3090': 142,
'NVIDIA GeForce RTX 3090 Ti': 160,
'NVIDIA H100 80GB HBM3': 990,
'NVIDIA GH200 480GB': 990,
}
class WandbMFUCallback(TrainerCallback):
def __init__(self, device_name = None):
self._time_of_second_step: Optional[float] = None
self._flops_at_second_step: Optional[float] = None
self._time_for_train_steps = 0.0
self._first_step_finished = False
self._device_name = device_name
if self._device_name is None:
self._device_name = torch.cuda.get_device_name()
if self._device_name not in DEVICES:
raise Exception('device name is not recognized.')
def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
self._step_start_time = time.time()
if not self._first_step_finished:
return
if self._time_of_second_step is None:
self._time_of_second_step = self._step_start_time
if state is not None:
self._flops_at_second_step = state.total_flos
def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
delta_time_seconds = time.time() - self._step_start_time
if not self._first_step_finished:
self._first_step_finished = True
return
self._time_for_train_steps += delta_time_seconds
def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self._time_of_second_step is None:
return
delta_time_seconds_train = time.time() - self._time_of_second_step
delta_time_seconds_step = self._time_for_train_steps
if self._flops_at_second_step is not None and (
state is not None and state.total_flos > 0.0
):
flops_since_second_step_on_all_devices = (
state.total_flos - self._flops_at_second_step
)
flops_step = flops_since_second_step_on_all_devices / delta_time_seconds_step
flops_train = flops_since_second_step_on_all_devices / delta_time_seconds_train
device_flops_per_second = DEVICES[self._device_name] * 1e12
train_step_mfu = flops_step / device_flops_per_second
train_mfu = flops_train / device_flops_per_second
wandb.log({
"train_step_mfu": train_step_mfu,
"train_mfu": train_mfu,
}, step=state.global_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment