Spaces:
Runtime error
Runtime error
Commit
·
7b48c38
1
Parent(s):
edcba35
calc v0
Browse files- .github/workflows/sync_to_hub.yaml +1 -1
- README.md +3 -3
- app.py +39 -33
- dashboard_utils/main_metrics.py +0 -33
- dashboard_utils/time_tracker.py +0 -32
- mem_calc.py +237 -0
- models.py +97 -0
.github/workflows/sync_to_hub.yaml
CHANGED
|
@@ -17,4 +17,4 @@ jobs:
|
|
| 17 |
- name: Push to hub
|
| 18 |
env:
|
| 19 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 20 |
-
run: git push https://training-transformers-together:[email protected]/spaces/training-transformers-together/
|
|
|
|
| 17 |
- name: Push to hub
|
| 18 |
env:
|
| 19 |
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 20 |
+
run: git push https://training-transformers-together:[email protected]/spaces/training-transformers-together/calc main --force
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: ⚡
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: streamlit
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Memory calculator
|
| 3 |
emoji: ⚡
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: streamlit
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
app.py
CHANGED
|
@@ -5,41 +5,47 @@ If you're not a hedgehog, you shouldn't reuse this code. Use this instead: https
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
st.set_page_config(page_title="
|
| 11 |
st.markdown("""<style>
|
| 12 |
.reportview-container {
|
| 13 |
top: -80px;
|
| 14 |
}
|
| 15 |
</style>""", unsafe_allow_html=True)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
|
| 8 |
+
import mem_calc
|
| 9 |
+
from models import models
|
| 10 |
+
st.set_page_config(page_title="Memory calculator", layout="centered")
|
| 11 |
st.markdown("""<style>
|
| 12 |
.reportview-container {
|
| 13 |
top: -80px;
|
| 14 |
}
|
| 15 |
</style>""", unsafe_allow_html=True)
|
| 16 |
+
|
| 17 |
+
models = list(models.keys()) # respect the original order because py37
|
| 18 |
+
model = st.selectbox('Model architecture', models, index=models.index("gpt2-l"))
|
| 19 |
+
|
| 20 |
+
optimizers_names = ('32-bit', '16-bit', '8-bit', 'factorized')
|
| 21 |
+
optimizers_values = ['adam', '16-bit-adam', '8-bit-adam', 'adafactor']
|
| 22 |
+
optimizer = st.radio('Adam / LAMB states', optimizers_names)
|
| 23 |
+
checkpoint = st.checkbox("Gradient checkpointing", value=True)
|
| 24 |
+
offload = st.checkbox("Offload optimizer", value=False)
|
| 25 |
+
share_params = st.checkbox("Share parameters", value=False)
|
| 26 |
+
|
| 27 |
+
with st.expander("More options"):
|
| 28 |
+
|
| 29 |
+
precisions_names = ('Full', 'Mixed ("O1")', 'Pure 16-bit')
|
| 30 |
+
precisions_values = ('O0', 'O1', 'O3')
|
| 31 |
+
precision = st.selectbox('Precision', precisions_names, index=1)
|
| 32 |
+
|
| 33 |
+
vocab_size = int(st.number_input('Vocabulary size', min_value=1, step=1, value=50257, format="%i"))
|
| 34 |
+
|
| 35 |
+
args = mem_calc.parse_args(f"""
|
| 36 |
+
--model {model} --vocab_size {vocab_size} --optimizer {optimizers_values[optimizers_names.index(optimizer)]}
|
| 37 |
+
{'--checkpoint' if checkpoint else ''} {'--offload' if offload else ''} {'--albert' if share_params else ''}
|
| 38 |
+
--fp16-level {precisions_values[precisions_names.index(precision)]}
|
| 39 |
+
""".split())
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
memory = mem_calc.calculate_memory(args)
|
| 43 |
+
|
| 44 |
+
cols = st.columns(3)
|
| 45 |
+
cols[0].metric("Parameters (GPU)", f"{memory['model']:.2f} GB", f"{memory['model']/memory['total_mem'] * 100:.2f} %", delta_color="off")
|
| 46 |
+
cols[1].metric(f"Optimizer ({'GPU' if offload else 'CPU'})", f"{memory['optim']:.2f} GB", f"{memory['optim']/memory['total_mem'] * 100:.2f} %", delta_color="off")
|
| 47 |
+
cols[2].metric("Activations (GPU)", f"{memory['grad']:.2f} GB", f"{memory['grad']/memory['total_mem'] * 100:.2f} %", delta_color="off")
|
| 48 |
+
cols = st.columns(3)
|
| 49 |
+
cols[0].metric("GPU total", f"{memory['total_mem']:.2f} GB")
|
| 50 |
+
cols[1].metric("Offloaded to RAM", f"{memory['cpu_mem']:.2f} GB")
|
| 51 |
+
cols[2].metric("Communication overhead", f"{memory['overhead'] * 1000:.2f} ms")
|
dashboard_utils/main_metrics.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
import datetime
|
| 2 |
-
|
| 3 |
-
import streamlit as st
|
| 4 |
-
import pandas as pd
|
| 5 |
-
|
| 6 |
-
import wandb
|
| 7 |
-
|
| 8 |
-
from dashboard_utils.time_tracker import _log, simple_time_tracker
|
| 9 |
-
|
| 10 |
-
WANDB_REPO = "learning-at-home/dalle-hivemind"
|
| 11 |
-
CACHE_TTL = 120 # note: in the text, we claim that this plot is updated every few minutes
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
@st.cache(ttl=CACHE_TTL)
|
| 15 |
-
@simple_time_tracker(_log)
|
| 16 |
-
def get_main_metrics():
|
| 17 |
-
wandb.login(anonymous="must")
|
| 18 |
-
api = wandb.Api()
|
| 19 |
-
runs = api.runs(WANDB_REPO)
|
| 20 |
-
run = runs[0]
|
| 21 |
-
history = run.history(keys=["step", "loss", "alive peers", "_timestamp"])
|
| 22 |
-
|
| 23 |
-
steps = []
|
| 24 |
-
losses = []
|
| 25 |
-
alive_peers = []
|
| 26 |
-
dates = []
|
| 27 |
-
for _, row in history.iterrows():
|
| 28 |
-
steps.append(row["step"])
|
| 29 |
-
losses.append(row["loss"])
|
| 30 |
-
alive_peers.append(row["alive peers"])
|
| 31 |
-
dates.append(datetime.datetime.utcfromtimestamp(row["_timestamp"]))
|
| 32 |
-
|
| 33 |
-
return pd.DataFrame({"steps": steps, "training loss": losses, "active participants": alive_peers, "wall time": dates})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dashboard_utils/time_tracker.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
from time import time
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def simple_time_tracker(log_fun):
|
| 6 |
-
def _simple_time_tracker(fn):
|
| 7 |
-
@wraps(fn)
|
| 8 |
-
def wrapped_fn(*args, **kwargs):
|
| 9 |
-
start_time = time()
|
| 10 |
-
|
| 11 |
-
try:
|
| 12 |
-
result = fn(*args, **kwargs)
|
| 13 |
-
finally:
|
| 14 |
-
elapsed_time = time() - start_time
|
| 15 |
-
|
| 16 |
-
# log the result
|
| 17 |
-
log_fun(
|
| 18 |
-
{
|
| 19 |
-
"function_name": fn.__name__,
|
| 20 |
-
"total_time": elapsed_time,
|
| 21 |
-
}
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
return result
|
| 25 |
-
|
| 26 |
-
return wrapped_fn
|
| 27 |
-
|
| 28 |
-
return _simple_time_tracker
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _log(message):
|
| 32 |
-
print("[SimpleTimeTracker] {function_name} {total_time:.3f}".format(**message))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mem_calc.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
from models import models
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_GB(nbytes):
|
| 7 |
+
return nbytes/(1024**3)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def vocab(bsz, seqlen, dmodel, vocab_dim):
|
| 11 |
+
# assumes tied embeddings
|
| 12 |
+
|
| 13 |
+
w = vocab_dim*dmodel
|
| 14 |
+
emb = seqlen*bsz*dmodel
|
| 15 |
+
emb_norm = seqlen*bsz*dmodel
|
| 16 |
+
pos_emb = seqlen*bsz*dmodel
|
| 17 |
+
out_emb = seqlen*bsz*vocab_dim
|
| 18 |
+
softmax_emb = seqlen*bsz*vocab_dim
|
| 19 |
+
|
| 20 |
+
model = w + dmodel
|
| 21 |
+
grad = emb + emb_norm + pos_emb + out_emb + softmax_emb
|
| 22 |
+
grad *= 1
|
| 23 |
+
return model, grad
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def transformer(bsz, seqlen, dmodel, nlayers, vocab_type, dhid=None,
|
| 27 |
+
checkpoint=False, albert=False):
|
| 28 |
+
if dhid is None: dhid = 4*dmodel
|
| 29 |
+
model = 0
|
| 30 |
+
grad = 0
|
| 31 |
+
for i in range(nlayers):
|
| 32 |
+
m, g = transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=checkpoint)
|
| 33 |
+
model += m
|
| 34 |
+
grad += g
|
| 35 |
+
|
| 36 |
+
if albert:
|
| 37 |
+
model = model / nlayers
|
| 38 |
+
|
| 39 |
+
m, g = vocab(bsz, seqlen, dmodel, vocab_type)
|
| 40 |
+
model += m
|
| 41 |
+
grad += g
|
| 42 |
+
|
| 43 |
+
return model, grad
|
| 44 |
+
|
| 45 |
+
def layer_norm(bsz, seqlen, dmodel):
|
| 46 |
+
w = dmodel
|
| 47 |
+
x_grad = bsz*seqlen*dmodel
|
| 48 |
+
return w, x_grad
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def transformer_layer(bsz, seqlen, dmodel, dhid, checkpoint=False):
|
| 52 |
+
model = 0
|
| 53 |
+
grad = 0
|
| 54 |
+
|
| 55 |
+
m, g = ffn(bsz, seqlen, dmodel, dhid, 'gelu')
|
| 56 |
+
model += m
|
| 57 |
+
grad += g*3
|
| 58 |
+
|
| 59 |
+
m, g = attention_layer(bsz, seqlen, dmodel)
|
| 60 |
+
model += m
|
| 61 |
+
grad += g*5.0
|
| 62 |
+
|
| 63 |
+
m, g = layer_norm(bsz, seqlen, dmodel)
|
| 64 |
+
model += m
|
| 65 |
+
grad += g*1.0
|
| 66 |
+
|
| 67 |
+
if checkpoint:
|
| 68 |
+
grad = bsz * seqlen * dmodel
|
| 69 |
+
|
| 70 |
+
return model, grad
|
| 71 |
+
|
| 72 |
+
def attention_layer(bsz, seqlen, dmodel):
|
| 73 |
+
w_proj = dmodel*3*dmodel
|
| 74 |
+
w_out = dmodel*dmodel
|
| 75 |
+
|
| 76 |
+
x_residual = bsz*seqlen*dmodel
|
| 77 |
+
x_proj = bsz*seqlen*dmodel*3
|
| 78 |
+
#x_proj_contiguous = bsz*seqlen*dmodel*3
|
| 79 |
+
x_proj_contiguous = 0
|
| 80 |
+
|
| 81 |
+
x_qscaled = bsz*seqlen*dmodel
|
| 82 |
+
x_qk = bsz*seqlen*seqlen*2 # we need to store both input sequence directions for gradient computation
|
| 83 |
+
x_softmax = bsz*seqlen*seqlen
|
| 84 |
+
x_softmax_v = bsz*seqlen*dmodel*2 # we need to store both input sequence directions for gradient computation
|
| 85 |
+
#x_out_contiguous = bsz*seqlen*dmodel
|
| 86 |
+
x_out_contiguous = 0
|
| 87 |
+
x_out = bsz*seqlen*dmodel
|
| 88 |
+
|
| 89 |
+
model = w_proj + w_out
|
| 90 |
+
grad = x_residual + x_proj + x_proj_contiguous + x_qscaled + x_qk + x_softmax + x_softmax_v + x_out_contiguous + x_out
|
| 91 |
+
return model, grad
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def ffn(bsz, seqlen, dmodel, dhid, func='relu'):
|
| 96 |
+
# out = linear(relu(linear(x), inplace=True)) + x
|
| 97 |
+
w1 = dmodel*dhid
|
| 98 |
+
w2 = dhid*dmodel
|
| 99 |
+
model = w1 + w2
|
| 100 |
+
wgrad = model
|
| 101 |
+
x1 = bsz*seqlen*dhid
|
| 102 |
+
if func != 'relu': x1 *= 2 # inplace not possible with most other functions
|
| 103 |
+
x2 = bsz*seqlen*dmodel
|
| 104 |
+
residual = bsz*seqlen*dmodel
|
| 105 |
+
grad = x1 + x2 + residual
|
| 106 |
+
|
| 107 |
+
return model, grad
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
OPTIMIZERS = ['adam', 'adafactor', 'adafactor-fac-only', '8-bit-adam', '16-bit-adam']
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def parse_args(args=None):
|
| 114 |
+
parser = argparse.ArgumentParser('Memory calculator')
|
| 115 |
+
|
| 116 |
+
parser.add_argument('--nlayers', type=int, help='The number of transformer layers.')
|
| 117 |
+
parser.add_argument('--bsz', type=int, default=1, help='The batch size. Default: 2')
|
| 118 |
+
parser.add_argument('--seqlen', type=int, help='The sequence length.')
|
| 119 |
+
parser.add_argument('--dmodel', type=int, help='The core model size.')
|
| 120 |
+
parser.add_argument('--dhid', type=int, default=None,
|
| 121 |
+
help='The hidden size of the FFN layer. Default: 4x model size.')
|
| 122 |
+
parser.add_argument('--fp16-level', type=str, default='O1',
|
| 123 |
+
help='FP16-level to use. O0 = FP32; O1 = mixed-precision (16+32); O3 = fp16. Default: O1.')
|
| 124 |
+
parser.add_argument('--model', default='', choices=list(models.keys()), help='Predefined NLP transformer models')
|
| 125 |
+
parser.add_argument('--optimizer', default='adam', choices=OPTIMIZERS, help='The optimizer to use.')
|
| 126 |
+
parser.add_argument('--vocab_size', type=int, default=50257, help='The vocabulary to use.')
|
| 127 |
+
parser.add_argument('--offload', action='store_true', help='Whether to use optimizer offload.')
|
| 128 |
+
parser.add_argument('--ngpus', type=int, default=1, help='The number of gpus. Default: 1')
|
| 129 |
+
parser.add_argument('--zero', type=int, default=0,
|
| 130 |
+
help='The ZeRO level (1 optimizer, 2 optimizer+weights, 3 everything. Default: 1')
|
| 131 |
+
parser.add_argument('--albert', action='store_true', help='Use parameter sharing.')
|
| 132 |
+
parser.add_argument('--checkpoint', action='store_true', help='Use gradient checkpointing.')
|
| 133 |
+
|
| 134 |
+
return parser.parse_args(args)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def calculate_memory(args):
|
| 138 |
+
if args.model != '':
|
| 139 |
+
if args.model not in models:
|
| 140 |
+
raise ValueError(f'{args.model} is not supported')
|
| 141 |
+
else:
|
| 142 |
+
for key, value in models[args.model].items():
|
| 143 |
+
if getattr(args, key, None) is None:
|
| 144 |
+
setattr(args, key, value)
|
| 145 |
+
|
| 146 |
+
model, grad = transformer(args.bsz, args.seqlen, args.dmodel, args.nlayers, args.vocab_size, args.dhid, args.checkpoint, args.albert)
|
| 147 |
+
parameters = model
|
| 148 |
+
|
| 149 |
+
if args.optimizer == 'adam':
|
| 150 |
+
optim = 8*model
|
| 151 |
+
elif args.optimizer == '8-bit-adam':
|
| 152 |
+
optim = 2*model
|
| 153 |
+
elif args.optimizer in ['16-bit-adam', 'adafactor']:
|
| 154 |
+
optim = 4*model
|
| 155 |
+
elif args.optimizer in ['adafactor-fac-only']:
|
| 156 |
+
optim = math.log(model)
|
| 157 |
+
|
| 158 |
+
if args.fp16_level == 'O0':
|
| 159 |
+
# fp32 weights
|
| 160 |
+
wgrad = 4*model
|
| 161 |
+
model = 4*model
|
| 162 |
+
grad = 4*grad # fp32
|
| 163 |
+
elif args.fp16_level in ['O1', 'O2']:
|
| 164 |
+
# fp16 weights + fp32 master weights
|
| 165 |
+
wgrad = 2*model
|
| 166 |
+
model = 4*model + (2*model)
|
| 167 |
+
grad = 2*grad # fp16
|
| 168 |
+
elif args.fp16_level == 'O3':
|
| 169 |
+
wgrad = 2*model
|
| 170 |
+
model = 2*model #fp16
|
| 171 |
+
grad = 2*grad # fp32
|
| 172 |
+
|
| 173 |
+
model = get_GB(model)
|
| 174 |
+
grad = get_GB(grad)
|
| 175 |
+
optim = get_GB(optim)
|
| 176 |
+
wgrad = get_GB(wgrad)
|
| 177 |
+
|
| 178 |
+
cpu_mem = 0
|
| 179 |
+
overhead = 0
|
| 180 |
+
|
| 181 |
+
if args.zero == 1:
|
| 182 |
+
if not args.offload:
|
| 183 |
+
# assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
|
| 184 |
+
overhead += optim/25
|
| 185 |
+
|
| 186 |
+
optim = optim / args.ngpus
|
| 187 |
+
elif args.zero == 2:
|
| 188 |
+
if not args.offload:
|
| 189 |
+
# assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
|
| 190 |
+
overhead += optim/25
|
| 191 |
+
overhead += wgrad/25
|
| 192 |
+
|
| 193 |
+
optim = optim / args.ngpus
|
| 194 |
+
wgrad = wgrad / args.ngpus
|
| 195 |
+
elif args.zero == 3:
|
| 196 |
+
if not args.offload:
|
| 197 |
+
# assumes PCIe 4.0 infiniband (200 Gbit/s = 25 GB/s)
|
| 198 |
+
overhead += optim/25
|
| 199 |
+
overhead += model/25
|
| 200 |
+
overhead += wgrad/25
|
| 201 |
+
|
| 202 |
+
optim = optim / args.ngpus
|
| 203 |
+
model = model / args.ngpus
|
| 204 |
+
wgrad = wgrad / args.ngpus
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if args.offload:
|
| 208 |
+
cpu_mem = optim + wgrad
|
| 209 |
+
optim = 0
|
| 210 |
+
wgrad = 0
|
| 211 |
+
if args.ngpus <= 2:
|
| 212 |
+
# 12 GB/s for PCIe 3.0 and 1-2x GPU setup (16 lanes, 16 GB/s theoretical)
|
| 213 |
+
overhead = cpu_mem/12
|
| 214 |
+
else:
|
| 215 |
+
# 6 GB/s for PCIe 3.0 and 4x GPU setup
|
| 216 |
+
overhead = cpu_mem/6
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
total_mem = model + grad + optim + wgrad
|
| 220 |
+
return locals()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == '__main__':
|
| 224 |
+
args = parse_args()
|
| 225 |
+
mem = calculate_memory(args)
|
| 226 |
+
print('')
|
| 227 |
+
print(f'Model: {args.model} with batch size {args.bsz} and sequence length {args.seqlen} and a total of {mem["parameters"]/1e9:.4f}B parameters.')
|
| 228 |
+
print('='*80)
|
| 229 |
+
print('Weight memory: {0:.2f} GB ({1:.2f}%)'.format(mem['model'], 100*mem['model']/mem['total_mem']))
|
| 230 |
+
print('Weight gradient memory: {0:.2f} GB ({1:.2f}%)'.format(mem['wgrad'], 100*mem['wgrad']/mem['total_mem']))
|
| 231 |
+
print('Input gradient memory: {0:.2f} GB ({1:.2f}%)'.format(mem['grad'], 100*mem['grad']/mem['total_mem']))
|
| 232 |
+
print('Optimizer memory: {0:.2f} GB ({1:.2f}%)'.format(mem['optim'], 100*mem['optim']/mem['total_mem']))
|
| 233 |
+
print('Total GPU memory: {0:.2f} GB'.format(mem['total_mem']))
|
| 234 |
+
if mem['cpu_mem'] > 0:
|
| 235 |
+
print('Total CPU memory: {0:.2f} GB'.format(mem['cpu_mem']))
|
| 236 |
+
if mem['overhead'] > 0:
|
| 237 |
+
print('Overhead: {0:.2f} seconds per update (can be partially overlapped with compute)'.format(mem['overhead']))
|
models.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models = {}
|
| 2 |
+
models['bert-s'] = {}
|
| 3 |
+
models['bert-s']['seqlen'] = 512
|
| 4 |
+
models['bert-s']['dmodel'] = 768
|
| 5 |
+
models['bert-s']['dhidden'] = 3072
|
| 6 |
+
models['bert-s']['nlayers'] = 12
|
| 7 |
+
|
| 8 |
+
models['bert-l'] = {}
|
| 9 |
+
models['bert-l']['seqlen'] = 512
|
| 10 |
+
models['bert-l']['dmodel'] = 1024
|
| 11 |
+
models['bert-l']['dhidden'] = 4096
|
| 12 |
+
models['bert-l']['nlayers'] = 24
|
| 13 |
+
|
| 14 |
+
models['t5-3b'] = {}
|
| 15 |
+
models['t5-3b']['seqlen'] = 512
|
| 16 |
+
models['t5-3b']['dmodel'] = 1024
|
| 17 |
+
models['t5-3b']['dhidden'] = 16384
|
| 18 |
+
models['t5-3b']['nlayers'] = 48
|
| 19 |
+
|
| 20 |
+
models['t5-11b'] = {}
|
| 21 |
+
models['t5-11b']['seqlen'] = 512
|
| 22 |
+
models['t5-11b']['dmodel'] = 1024
|
| 23 |
+
models['t5-11b']['dhidden'] = 64*1024
|
| 24 |
+
models['t5-11b']['nlayers'] = 48
|
| 25 |
+
|
| 26 |
+
models['gpt2-s'] = {}
|
| 27 |
+
models['gpt2-s']['seqlen'] = 1024
|
| 28 |
+
models['gpt2-s']['dmodel'] = 768
|
| 29 |
+
models['gpt2-s']['dhidden'] = 768*4
|
| 30 |
+
models['gpt2-s']['nlayers'] = 12
|
| 31 |
+
|
| 32 |
+
models['gpt2-m'] = {}
|
| 33 |
+
models['gpt2-m']['seqlen'] = 1024
|
| 34 |
+
models['gpt2-m']['dmodel'] = 1024
|
| 35 |
+
models['gpt2-m']['dhidden'] = 1024*4
|
| 36 |
+
models['gpt2-m']['nlayers'] = 24
|
| 37 |
+
|
| 38 |
+
models['gpt2-l'] = {}
|
| 39 |
+
models['gpt2-l']['seqlen'] = 1024
|
| 40 |
+
models['gpt2-l']['dmodel'] = 1280
|
| 41 |
+
models['gpt2-l']['dhidden'] = 1280*4
|
| 42 |
+
models['gpt2-l']['nlayers'] = 36
|
| 43 |
+
|
| 44 |
+
models['gpt2-xl'] = {}
|
| 45 |
+
models['gpt2-xl']['seqlen'] = 1024
|
| 46 |
+
models['gpt2-xl']['dmodel'] = 1600
|
| 47 |
+
models['gpt2-xl']['dhidden'] = 1600*4
|
| 48 |
+
models['gpt2-xl']['nlayers'] = 48
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
models['gpt3-s'] = {}
|
| 52 |
+
models['gpt3-s']['seqlen'] = 2048
|
| 53 |
+
models['gpt3-s']['dmodel'] = 768
|
| 54 |
+
models['gpt3-s']['dhidden'] = 768*4
|
| 55 |
+
models['gpt3-s']['nlayers'] = 12
|
| 56 |
+
|
| 57 |
+
models['gpt3-m'] = {}
|
| 58 |
+
models['gpt3-m']['seqlen'] = 2048
|
| 59 |
+
models['gpt3-m']['dmodel'] = 1024
|
| 60 |
+
models['gpt3-m']['dhidden'] = 1024*4
|
| 61 |
+
models['gpt3-m']['nlayers'] = 24
|
| 62 |
+
|
| 63 |
+
models['gpt3-l'] = {}
|
| 64 |
+
models['gpt3-l']['seqlen'] = 2048
|
| 65 |
+
models['gpt3-l']['dmodel'] = 1536
|
| 66 |
+
models['gpt3-l']['dhidden'] = 1536*4
|
| 67 |
+
models['gpt3-l']['nlayers'] = 24
|
| 68 |
+
|
| 69 |
+
models['gpt3-xl'] = {}
|
| 70 |
+
models['gpt3-xl']['seqlen'] = 2048
|
| 71 |
+
models['gpt3-xl']['dmodel'] = 2560
|
| 72 |
+
models['gpt3-xl']['dhidden'] = 2560*4
|
| 73 |
+
models['gpt3-xl']['nlayers'] = 24
|
| 74 |
+
|
| 75 |
+
models['gpt3-3b'] = {}
|
| 76 |
+
models['gpt3-3b']['seqlen'] = 2048
|
| 77 |
+
models['gpt3-3b']['dmodel'] = 2560
|
| 78 |
+
models['gpt3-3b']['dhidden'] = 2560*4
|
| 79 |
+
models['gpt3-3b']['nlayers'] = 32
|
| 80 |
+
|
| 81 |
+
models['gpt3-7b'] = {}
|
| 82 |
+
models['gpt3-7b']['seqlen'] = 2048
|
| 83 |
+
models['gpt3-7b']['dmodel'] = 4096
|
| 84 |
+
models['gpt3-7b']['dhidden'] = 4096*4
|
| 85 |
+
models['gpt3-7b']['nlayers'] = 32
|
| 86 |
+
|
| 87 |
+
models['gpt3-13b'] = {}
|
| 88 |
+
models['gpt3-13b']['seqlen'] = 2048
|
| 89 |
+
models['gpt3-13b']['dmodel'] = 5120
|
| 90 |
+
models['gpt3-13b']['dhidden'] = 5120*4
|
| 91 |
+
models['gpt3-13b']['nlayers'] = 40
|
| 92 |
+
|
| 93 |
+
models['gpt3-175b'] = {}
|
| 94 |
+
models['gpt3-175b']['seqlen'] = 2048
|
| 95 |
+
models['gpt3-175b']['dmodel'] = 12288
|
| 96 |
+
models['gpt3-175b']['dhidden'] = 12288*4
|
| 97 |
+
models['gpt3-175b']['nlayers'] = 96
|