上記記事ではVRAM12GBのRTX3060を利用して、400Mパラメータ程度のMoE付きTransformerを構築して、ランダム初期値のパラメータから事前学習を行い、テキストを出力させるところまで実施しました。
とはいえ、自宅にGPUがない方もいらっしゃるかと思いますので、今回は無料版のGoogle Colabでどこまでできるかを試してみた記事になります
無料版のGoogle Colabで事前学習する上で課題となるのは下記の部分です。
例えば、保存領域が最大15GBしか利用できないため、学習済みモデルの重みのサイズを気をつける必要があります、最低限15GB以下にする必要があります。
また、重みのチェックポイントというのは、一定間隔で常に保存し続ける必要があります。
なぜなら、Google Colabの無料版は3時間程度でセッションが切れてしまうため、学習の途中経過を定期的に保存し、再度Google Colabの利用制限が解除されたら途中から学習を実施できるようにしなければなりません。
(3時間程度で学習は完了しません)
自宅PCでの処理であれば、チェックポイントが貯まってきたら古いものを自動消去することで、ストレージを圧迫しないようにすることは可能です。
しかし、Google Colabでは、削除データは自動的にゴミ箱に貯まってしまい、ゴミ箱から手動で消さなければストレージを圧迫し続けます。
また、このストレージの中には、キャッシュしたデータセットもおく必要があります。
(学習のたびにデータセットのダウンロードや、トークン化処理を実施していたら時間がもったいない)
前回の記事で使っていたデータセットのキャッシュサイズは約30GB、モデルのチェックポイントのサイズが約5GBであるため、このままでは無料版のGoogle Colabでは学習させることはできません。
したがって、学習データ量を小さくすること、モデルサイズを小さくすることは必須になります。
また、T4 GPUは非常に古いGPUのため、bfloat16
が利用できません。それも課題になってきます。
無料版Google Colab
GPU:T4 GPU
Google Driveストレージ:まっさらな状態を推奨(15GBの空き)
なお、本リポジトリやコードの詳細などは前回の記事をご覧ください。
git clone https://github.com/personabb/LightLM_public_repo.git
マイドライブ直下に「LightLM
」フォルダを作成し、リポジトリの中身のファイルを全てLightLM
フォルダにアップロードしてください。
学習したモデルをHuggingFaceにアップロードしたい場合は、HF tokenを取得してください。
Huggingfaceページの右上アイコンをクリックして、「Access Tokens」から作成することができます。
取得したtokenは、下記の部分に設定してください。
これにより、同じアカウントを利用している限り、ノートブックが変わったとしても利用可能です。
Google ColabにてLightLM_Colab_train.ipynb
を開いて、すべてのセルを実行してください。
実行すると、250stepごとに、モデルのチェックポイントの保存(model_testing-small
)と、検証データによる評価が保存(log/eval-small.txt
)されます。
また、このチェックポイントには学習中の全てのデータ(モデル重みだけでなく、OptimizerやScheduler、lr、dataset idxなど)が保存されます。従って途中からの学習の再開も可能です。
以下に学習設定や、モデル設定を記載します。
修正しても問題ないですが、無料版のGoogle Colab, Google Drive上で学習する都合上、大きく増加させることはできないかなと思います。
LightLM_Colab_train.ipynb
train_config = TrainerConfig(
・・・
use_dtype="float16" if device == 'cuda' else "float32",
・・・
checkpoints_frequency=250,
path_to_checkpoints="/content/drive/MyDrive/LightLM/model_testing-small",
max_checkpoints_to_keep=4,
tokenized_dataset_path = "HuggingFaceFW/fineweb-edu",
sub_target_files = [
"data/CC-MAIN-2025-26/000_00048.parquet",
"data/CC-MAIN-2025-26/000_00049.parquet"
],
eval_log_file="/content/drive/MyDrive/LightLM/log/eval-small.txt",
・・・
)
Google Colabの無料枠で利用できるGPUであるT4 GPUは、bfloat16を利用できないので、float16を利用しています。
また、学習用のデータセットに関しても、CC-MAIN-2025-26
フォルダすらすべてダウンロードしたキャッシュをすることは、無料版のGoogle Driveのストレージでは無理なので、2ファイル分(000_00048.parquet
、000_00049.parquet
)だけで学習することにします。
この2ファイルだけだと、215,939,584token(約2億token)の学習データになります。
LightLM_Colab_train.ipynb
config = ModelConfig(
vocab_size=tokenizer.vocab_size,
num_dims=512,
num_heads=16,
num_kv_heads=4,
num_layers=12,
ffn_hidden_dims=512 * 4,
rmsnorm_eps=1e-6,
rope_theta=1e5,
context_len=512,
use_cache=False,
use_flash=True,
use_moe=True,
moe_num_experts=3,
moe_active_experts=1,
moe_eps=1e-6,
moe_aux_loss_coef=0.01,
moe_shared_experts=1,
use_lossfreebalance=False,
)
前回の記事で利用したモデルよりも、Transformerブロックの数や、MoEのExpert数を減らしています。
モデルパラメータとしては、184.06M(Active:108.55M)のTransformerになります。
以下の4セル目を適切に修正したのちに、HF_Colab.ipynb
のすべてのセルを実行してください。
HF_Colab.ipynb
default_checkpoint = lightlm_path + "/model_testing-small/model.checkpoint.epoch0_step23500_global23500.pt"
model_dir = lightlm_path + "/hf_model-small"
repo_name = "your_username/your_repo_name"
private = False
default_checkpoint
は学習した上で、評価データの損失が最も低いチェックポイントを指定してください。
「model.checkpoint.epoch0_step23500_global23500.pt
」の部分は、学習の進捗によって変わります」model_dir
はHF形式に変換したデータを保存するディレクトリです。変更不要です。repo_name
は自身のユーザネームと保存したいリポジトリ名を指定してください。private
はHFのリポジトリが公開か非公開かを選択します。ストレージ圧迫するのもどうかと思うので公開で良いかと思います。
本ノートブックを実行後、Huggingfaceにアップロードがなされているかと思います。
学習途中のチェックポイント(例えばmodel.checkpoint.epoch0_step16000_global16000.pt
など)を利用して推論をする場合はLightLM_Colab_infer.ipynb
を実行します。
実行前に4セル名のモデルパラメータを学習時と同じものに設定し、5セル目のcheckpoint_path
を利用したいチェックポイントを指定します。
また、6セル目でプロンプトと生成パラメータ(temperature
など)を設定し実行してください。
Huggingfaceへアップロードしたモデルを推論に利用する場合は、HF_inference_Colab.ipynb
を実行します。
実行前に、2セル目でリポジトリ名の指定や、プロンプトの設定、3セル目のgenerate
メソッドでtemperature
などのパラメータを設定してください。
設定可能なパラメータはLightLM_Colab_infer.ipynb
と同様です。
今回は、学習データ量を減らしたため、1epochあたり3279stepのみでした。
自宅のPCであれば、丸一日学習できれば4epochの学習が完了します。
Global Step: 250, Epoch: 0, Step: 250, val_loss: 5.8105, norm: 0.6994, lr: 1.1594202899e-04, time: 5.91s, tok/s: 11067.0 | dataset idx: 421649/421757
Global Step: 500, Epoch: 0, Step: 500, val_loss: 4.9471, norm: 0.7153, lr: 2.0652173913e-04, time: 5.63s, tok/s: 11610.4 | dataset idx: 421541/421757
Global Step: 750, Epoch: 0, Step: 750, val_loss: 4.4959, norm: 0.5326, lr: 2.9710144928e-04, time: 5.57s, tok/s: 11751.1 | dataset idx: 421433/421757
Global Step: 1000, Epoch: 0, Step: 1000, val_loss: 4.1430, norm: 0.5438, lr: 3.8768115942e-04, time: 5.63s, tok/s: 11617.0 | dataset idx: 421325/421757
Global Step: 1250, Epoch: 0, Step: 1250, val_loss: 3.9000, norm: 0.4086, lr: 4.7826086957e-04, time: 5.61s, tok/s: 11659.7 | dataset idx: 421217/421757
Global Step: 1500, Epoch: 0, Step: 1500, val_loss: 3.6893, norm: 0.3595, lr: 4.9971243571e-04, time: 5.76s, tok/s: 11359.8 | dataset idx: 421109/421757
Global Step: 1750, Epoch: 0, Step: 1750, val_loss: 3.5402, norm: 0.3253, lr: 4.9845925999e-04, time: 5.65s, tok/s: 11570.8 | dataset idx: 421001/421757
Global Step: 2000, Epoch: 0, Step: 2000, val_loss: 3.4457, norm: 0.3139, lr: 4.9621733556e-04, time: 5.57s, tok/s: 11738.2 | dataset idx: 420893/421757
Global Step: 2250, Epoch: 0, Step: 2250, val_loss: 3.3438, norm: 0.3039, lr: 4.9299658233e-04, time: 5.58s, tok/s: 11713.8 | dataset idx: 420785/421757
Global Step: 2500, Epoch: 0, Step: 2500, val_loss: 3.2853, norm: 0.2956, lr: 4.8881125131e-04, time: 5.59s, tok/s: 11705.9 | dataset idx: 420677/421757
Global Step: 2750, Epoch: 0, Step: 2750, val_loss: 3.2458, norm: 0.2930, lr: 4.8367986147e-04, time: 5.65s, tok/s: 11580.7 | dataset idx: 420569/421757
Global Step: 3000, Epoch: 0, Step: 3000, val_loss: 3.1882, norm: 0.2930, lr: 4.7762511788e-04, time: 5.62s, tok/s: 11648.0 | dataset idx: 420461/421757
Global Step: 3250, Epoch: 0, Step: 3250, val_loss: 3.1629, norm: 0.3240, lr: 4.7067381120e-04, time: 5.61s, tok/s: 11654.2 | dataset idx: 420353/421757
Global Step: 3278, Epoch: 0, Step: 3278, val_loss: 3.1273, norm: 0.2951, lr: 4.6984074466e-04, time: 5.72s, tok/s: 11425.0 | dataset idx: 420245/421757
Global Step: 3500, Epoch: 1, Step: 221, val_loss: 3.1290, norm: 0.3036, lr: 4.6285669913e-04, time: 5.61s, tok/s: 11651.3 | dataset idx: 420137/421757
Global Step: 3750, Epoch: 1, Step: 471, val_loss: 3.0477, norm: 0.2877, lr: 4.5420837035e-04, time: 5.57s, tok/s: 11743.0 | dataset idx: 420029/421757
Global Step: 4000, Epoch: 1, Step: 721, val_loss: 3.0432, norm: 0.2976, lr: 4.4476709145e-04, time: 5.58s, tok/s: 11719.1 | dataset idx: 419921/421757
Global Step: 4250, Epoch: 1, Step: 971, val_loss: 3.0405, norm: 0.2892, lr: 4.3457463762e-04, time: 5.64s, tok/s: 11588.3 | dataset idx: 419813/421757
Global Step: 4500, Epoch: 1, Step: 1221, val_loss: 3.0247, norm: 0.2925, lr: 4.2367610780e-04, time: 5.59s, tok/s: 11693.3 | dataset idx: 419705/421757
Global Step: 4750, Epoch: 1, Step: 1471, val_loss: 2.9962, norm: 0.3036, lr: 4.1211972513e-04, time: 5.59s, tok/s: 11699.4 | dataset idx: 421705/421757
Global Step: 5000, Epoch: 1, Step: 1721, val_loss: 2.9384, norm: 0.2959, lr: 3.9995662357e-04, time: 5.58s, tok/s: 11729.6 | dataset idx: 421597/421757
Global Step: 5250, Epoch: 1, Step: 1971, val_loss: 2.9432, norm: 0.3000, lr: 3.8724062167e-04, time: 5.56s, tok/s: 11760.1 | dataset idx: 421489/421757
Global Step: 5500, Epoch: 1, Step: 2221, val_loss: 2.8901, norm: 0.2942, lr: 3.7402798440e-04, time: 5.61s, tok/s: 11654.3 | dataset idx: 421381/421757
Global Step: 5750, Epoch: 1, Step: 2471, val_loss: 2.8764, norm: 0.3046, lr: 3.6037717423e-04, time: 5.61s, tok/s: 11669.1 | dataset idx: 421273/421757
Global Step: 6000, Epoch: 1, Step: 2721, val_loss: 2.8534, norm: 0.3114, lr: 3.4634859242e-04, time: 5.56s, tok/s: 11755.0 | dataset idx: 421165/421757
Global Step: 6250, Epoch: 1, Step: 2971, val_loss: 2.8333, norm: 0.3161, lr: 3.3200431176e-04, time: 5.55s, tok/s: 11789.0 | dataset idx: 421057/421757
Global Step: 6500, Epoch: 1, Step: 3221, val_loss: 2.7851, norm: 0.3120, lr: 3.1740780195e-04, time: 5.57s, tok/s: 11732.9 | dataset idx: 420949/421757
Global Step: 6557, Epoch: 1, Step: 3278, val_loss: 2.8190, norm: 0.3156, lr: 3.1405118376e-04, time: 5.71s, tok/s: 11449.7 | dataset idx: 420841/421757
Global Step: 6750, Epoch: 2, Step: 192, val_loss: 2.7864, norm: 0.3142, lr: 3.0262364872e-04, time: 5.60s, tok/s: 11674.0 | dataset idx: 420733/421757
Global Step: 7000, Epoch: 2, Step: 442, val_loss: 2.7736, norm: 0.3155, lr: 2.8771726808e-04, time: 5.59s, tok/s: 11708.7 | dataset idx: 420625/421757
Global Step: 7250, Epoch: 2, Step: 692, val_loss: 2.7894, norm: 0.3266, lr: 2.7275461685e-04, time: 5.61s, tok/s: 11658.2 | dataset idx: 420517/421757
Global Step: 7500, Epoch: 2, Step: 942, val_loss: 2.7963, norm: 0.3200, lr: 2.5780190086e-04, time: 5.64s, tok/s: 11599.9 | dataset idx: 420409/421757
Global Step: 7750, Epoch: 2, Step: 1192, val_loss: 2.7613, norm: 0.3261, lr: 2.4292528196e-04, time: 5.54s, tok/s: 11809.7 | dataset idx: 420301/421757
Global Step: 8000, Epoch: 2, Step: 1442, val_loss: 2.7409, norm: 0.3304, lr: 2.2819058528e-04, time: 5.55s, tok/s: 11793.6 | dataset idx: 420193/421757
Global Step: 8250, Epoch: 2, Step: 1692, val_loss: 2.7475, norm: 0.3265, lr: 2.1366300801e-04, time: 5.60s, tok/s: 11676.8 | dataset idx: 420085/421757
Global Step: 8500, Epoch: 2, Step: 1942, val_loss: 2.7326, norm: 0.3301, lr: 1.9940683087e-04, time: 5.61s, tok/s: 11656.2 | dataset idx: 419977/421757
Global Step: 8750, Epoch: 2, Step: 2192, val_loss: 2.7202, norm: 0.3333, lr: 1.8548513371e-04, time: 5.56s, tok/s: 11754.6 | dataset idx: 419869/421757
Global Step: 9000, Epoch: 2, Step: 2442, val_loss: 2.6826, norm: 0.3494, lr: 1.7195951639e-04, time: 5.56s, tok/s: 11757.7 | dataset idx: 419761/421757
Global Step: 9250, Epoch: 2, Step: 2692, val_loss: 2.6671, norm: 0.3477, lr: 1.5888982624e-04, time: 5.58s, tok/s: 11718.8 | dataset idx: 419653/421757
Global Step: 9500, Epoch: 2, Step: 2942, val_loss: 2.6512, norm: 0.3331, lr: 1.4633389321e-04, time: 5.59s, tok/s: 11704.0 | dataset idx: 421653/421757
Global Step: 9750, Epoch: 2, Step: 3192, val_loss: 2.6311, norm: 0.3370, lr: 1.3434727402e-04, time: 5.55s, tok/s: 11784.4 | dataset idx: 421545/421757
Global Step: 9836, Epoch: 2, Step: 3278, val_loss: 2.6606, norm: 0.3489, lr: 1.3036514311e-04, time: 5.61s, tok/s: 11651.6 | dataset idx: 421437/421757
Global Step: 10000, Epoch: 3, Step: 163, val_loss: 2.6370, norm: 0.3509, lr: 1.2298300631e-04, time: 5.58s, tok/s: 11728.8 | dataset idx: 421329/421757
Global Step: 10250, Epoch: 3, Step: 413, val_loss: 2.6324, norm: 0.3576, lr: 1.1229137400e-04, time: 5.54s, tok/s: 11805.5 | dataset idx: 421221/421757
Global Step: 10500, Epoch: 3, Step: 663, val_loss: 2.6414, norm: 0.3553, lr: 1.0231968476e-04, time: 5.65s, tok/s: 11584.2 | dataset idx: 421113/421757
Global Step: 10750, Epoch: 3, Step: 913, val_loss: 2.6010, norm: 0.3533, lr: 9.3112060706e-05, time: 5.54s, tok/s: 11796.9 | dataset idx: 421005/421757
Global Step: 11000, Epoch: 3, Step: 1163, val_loss: 2.5890, norm: 0.3563, lr: 8.4709243161e-05, time: 5.53s, tok/s: 11836.0 | dataset idx: 420897/421757
Global Step: 11250, Epoch: 3, Step: 1413, val_loss: 2.5630, norm: 0.3537, lr: 7.7148412395e-05, time: 5.55s, tok/s: 11784.6 | dataset idx: 420789/421757
Global Step: 11500, Epoch: 3, Step: 1663, val_loss: 2.5723, norm: 0.3672, lr: 7.0463023103e-05, time: 5.59s, tok/s: 11704.5 | dataset idx: 420681/421757
Global Step: 11750, Epoch: 3, Step: 1913, val_loss: 2.5652, norm: 0.3678, lr: 6.4682656383e-05, time: 5.57s, tok/s: 11751.5 | dataset idx: 420573/421757
Global Step: 12000, Epoch: 3, Step: 2163, val_loss: 2.5753, norm: 0.3747, lr: 5.9832888844e-05, time: 5.53s, tok/s: 11832.4 | dataset idx: 420465/421757
Global Step: 12250, Epoch: 3, Step: 2413, val_loss: 2.5752, norm: 0.3672, lr: 5.5935179439e-05, time: 5.65s, tok/s: 11583.0 | dataset idx: 420357/421757
Global Step: 12500, Epoch: 3, Step: 2663, val_loss: 2.5639, norm: 0.3732, lr: 5.3006774510e-05, time: 5.69s, tok/s: 11495.3 | dataset idx: 420249/421757
Global Step: 12750, Epoch: 3, Step: 2913, val_loss: 2.5446, norm: 0.3675, lr: 5.1060631482e-05, time: 5.55s, tok/s: 11790.5 | dataset idx: 420141/421757
Global Step: 13000, Epoch: 3, Step: 3163, val_loss: 2.5736, norm: 0.3794, lr: 5.0105361530e-05, time: 5.58s, tok/s: 11719.2 | dataset idx: 420033/421757
Global Step: 13115, Epoch: 3, Step: 3278, val_loss: 2.5512, norm: 0.3927, lr: 5.0000000000e-05, time: 5.55s, tok/s: 11791.4 | dataset idx: 419925/421757
Google ColabでHF_inference_Colab.ipynb
を実行した際の出力結果になります。
意外と、無料版のGoogle Colabでも小さなTransformerの事前学習が可能なんだなというのがわかって少し驚きました。
小さめのLanguage Modelであれば個人でもどんどん作れそうなので、何か特化させて処理させるAIを作るのも面白そうですね。(自然言語処理ではなくとも)
Views: 0