Skip to content

Latest commit

 

History

History
355 lines (310 loc) · 12.6 KB

README_EN.md

File metadata and controls

355 lines (310 loc) · 12.6 KB

CoLLiE

CoLLiE (Collaborative Tuning of Large Language Models in an Efficient Way) is a comprehensive toolbox that helps you train large language models from scratch.

Github Repo Stars Doc HuggingFace badge

[ 简体中文 ] | [ English ]

Latest News

Table of Contents

Why CoLLiE

CoLLiE (Collaborative Tuning of Large Language Models in an Efficient Way) is a complete toolkit for training large models from scratch, providing data preprocessing, model fine-tuning, model saving and monitoring of training metrics, etc. CoLLiE integrates exisiting parallel strategies, efficient parameter fine-tuning methods and high-efficiency optimizers to speed up training, improve training quality and reduce training cost. CoLLiE supports a wide range of models (e.g. MOSS, InternLM, LLaMA, ChatGLM, etc.) In addition, CoLLiE provides rich documentation so that beginners can get started quickly. At the same time, CoLLiE offers high customisable features and flexible configuration options, allowing experienced users to personalise it to their needs. Whether you are a begginer or an experienced professional, CoLLiE has a solution for your needs.

Features

CoLLiE provides collaborative and efficient tuning methods for large language models based on DeepSpeed and PyTorch. It primarily includes the following four features:

Full Features

CoLLiE Supported Models

Evaluation

Memory Requirements

The memory requirements are profiled with tensor parallelism. The results with batch size 1, sequence length 2048 and gradient accumulation steps 2 are shown below:

Throughput

The throughput of using Adam optimizer with different batch sizes on A100 and RTX-3090 is shown below:

Installation

pip install git+https://github.com/OpenLMLab/collie.git

Use Docker

Usage

Quick Start

Here we will provide a sample of training "MOSS" using CoLLiE while using the LOMO optimizer and turning on ZeRO3 to reduce memory consumption.

So, follow the steps below to start your LLM training journey~

Step 1: Import the necessary packages

from transformers import AutoTokenizer
from collie.config import CollieConfig
from collie.data import CollieDatasetForTraining
from collie.data import CollieDataLoader
from collie.optim.lomo import Lomo
from collie.controller.trainer import Trainer
from collie.controller.evaluator import EvaluatorForPerplexity, EvaluatorForGeneration
from collie.models.moss_moon import Moss003MoonForCausalLM
from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, EvalMonitor
from collie.metrics import DecodeMetric, PPLMetric
from collie.module import GPTLMLoss
from collie.utils.data_provider import GradioProvider

Step 2: Set your path (Optional)

Here, we have chosen the pretrained model as MOSS.

pretrained_model = "fnlp/moss-moon-003-sft"

Step 3: Set the CoLLiE Configuration

config = CollieConfig.from_pretrained(pretrained_model, trust_remote_code=True)
# Note that tp_size * dp_size * pp_size = the number of GPUs
# Tensor Parallel
config.tp_size = 2
# Data Parallel
config.dp_size = 1
# Pipeline Parallel
config.pp_size = 1
# the number of training epochs
config.train_epochs = 1
# eval per {100} steps
config.eval_per_n_steps = 100
# eval per {1} epoch
config.eval_per_n_epochs = 1 
# The batch_size for each GPU is set to {16}
config.train_micro_batch_size = 16
# The batch_size for each eval is {1}
config.eval_batch_size = 1
# DeepSpeed Configuration
config.ds_config = {
        "fp16": {
            "enabled": True
        },
        "zero_allow_untested_optimizer": True,
        "zero_force_ds_cpu_optimizer": False,
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": False
            }
        },
        "monitor_config": {
            "enabled": True,
            "tag": "adan",
            "csv_monitor": {
                "enabled": True,
                "output_path": "./ds_logs/"
            }
        }
}

Step 4: Set Tokenizer

tokenizer = AutoTokenizer.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True)

Step 5: Load datasets

Here we customise a dataset, the data can be provided in two formats. You can refer to tutorials for more details.

train_dataset = [
    {
        'input': 'Collie is a python package for ',
        'output': 'finetuning large language models.'
    } for _ in range(10000)
]
train_dataset = CollieDatasetForTraining(train_dataset, tokenizer)
eval_dataset = train_dataset[:32]

Step 6: Load Pretrained Model

model = Moss003MoonForCausalLM.from_pretrained(pretrained_model, config=config)

Step 7: Set Optimizer

optimizer = Lomo(
    model,
    lr = 0.001,
    clip_grad_norm = 5.0
)

Step 8: Set Monitors

monitors = [
    # Time used per step
    StepTimeMonitor(config),
    # Tokens generated per gpu per second
    TGSMonitor(config),
    # Memory used
    MemoryMonitor(config),
    # Loss
    LossMonitor(config),
    # Evaluation Results
    EvalMonitor(config)
]

Step 9: Add Evaluators

Two evaluators are added here to calculate PPL (Perplexity) and to save Decode results.

evaluator_ppl = EvaluatorForPerplexity(
    model = model,
    config = config,
    dataset = eval_dataset,
    monitors = [
        EvalMonitor(config)
    ],
    metrics = {
        'ppl': PPLMetric()
    }
)
evaluator_decode = EvaluatorForGeneration(
    model = model,
    config = config,
    tokenizer = tokenizer,
    dataset = eval_dataset,
    monitors = [
        EvalMonitor(config)
    ],
    metrics = {
        'decode': DecodeMetric()
    }

)

Step 10: Instantiate the Trainer

trainer = Trainer(
    model = model,
    config = config,
    loss_fn = GPTLMLoss(-100),
    optimizer = optimizer,
    train_dataset = train_dataset,
    monitors = monitors,
    evaluators = [evaluator_ppl, evaluator_decode],
)
# 开始训练/验证
trainer.train()

Final step: launch the command line and start training! 👍

Command CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:29402 --nnodes=1 --nproc_per_node=4 finetune_moss_for_training.py

If the following progress bar appears on your command line, then congratulations, you have successfully started training your Large Language model!

Documentation and Examples

CoLLiE provides online documentation. More examples are available at examples.

You can find complete codes at examples/finetune_moss_for_training.py.

Fun Plugins

CoLLiE provides a number of plug-and-play plugins, the following will introduce the "Monitor" and "Asynchronous DataProvider", more plugins are waiting to be explored and developed...

Monitor

You can add monitor configuration to CollieConfig.ds_config and enable it in Trainer to turn on the monitor during training.

    "monitor_config": {
        # Turn on Monitor
        "enabled": True,
        # prefix of saved files
        "tag": "adan",
        # file format: csv
        "csv_monitor": {
            "enabled": True,
            # folders saved
            "output_path": "./ds_logs/"
        }
    }

After enabling the detector, you will get the relevant files in ds_logs folder, for example:

Asynchronous DataProvider

You just need to add data_provider in Trainer to open an asynchronous DataProvider during the training process, in which you can human eval at any time.

trainer = Trainer(
    model = model,
    config = config,
    loss_fn = GPTLMLoss(-100),
    optimizer = optimizer,
    train_dataset = train_dataset,
    monitors = monitors,
    evaluators = [evaluator_ppl, evaluator_decode],
    # Add
    data_provider = GradioProvider(tokenizer)
)

More successful examples and complete tutorials

CoLLiE provides complete Tutorials. More Examples can be found in examples.

Community

Contributors

Cite Us