Instruction fine-tuning (IFT) language models is the most effective technique for improving performance on domain specific tasks. Full-parameter IFT on the top open source models like Mistral and Llama 3 requires hundreds of GBs of GPU VRAM. Even with top end hardware, it’s impractical to train a model on a single machine simply because the model and training state won’t fit in VRAM.
That’s where provisioning compute nodes with a cloud service like Amazon SageMaker comes into play. When you start a training job, SageMaker will spin up a virtual cluster of nodes connected by a high-speed Infiniband interconnect. That interconnect practically allows you to treat all of your GPUs as one node in your training topology.
Unless you’re adventurous, you don’t want to hand roll code to partition your model and optimizer across all the GPUs in your cluster. I say leave that to big tech! Microsoft developed a PyTorch wrapper called DeepSpeed that distributes your training with minimal changes to your training loop and model architecture.
DeepSpeed’s ZeRO is a memory-efficient architecture for data parallel training that works by partitioning model parameters, gradients, and optimizer states across all GPUs. It recovers each layer as needed in the forward and backward passes. This increases the communication overhead slightly but makes it possible to fit an arbitrarily large model with enough compute nodes.
In this blog post, I want to share how you can deploy a customized version of DeepSpeed on SageMaker. I struggled through this in my first 2 weeks onboarding as an MLE at Lindy, and hope that this post will save you time.
The magic of SageMaker
SageMaker is an ensemble of MLOps tools by AWS. It offers an on-demand model training platform. A SageMaker training job will automatically provision compute nodes and orchestrate a training program per GPU.
The Huggingface Trainer has an integration with SageMaker, which can fine-tune language models. However, it’s very rigid in terms of allowing you to use extra dependencies in your code and to customizing your training loop.
So instead, we’ll use SageMaker to orchestrate a custom Docker container with DeepSpeed pre-installed. Throughout the entire process, SageMaker will:
- Provision the specified number of training nodes
- Grant the training nodes an execution IAM role
- Start the custom Docker image our training script runs in
- Download datasets from S3 and mount them into
/opt/ml/input/data
- Download the training script into
/opt/ml/code
- Set up the OpenMPI environment variables that tell DeepSpeed how to communicate with other nodes
- Set any environment variables and hyper-parameters provided
- Run your training script
- Copy the training output from
/opt/ml/checkpoints
and/opt/ml/model
and upload it to S3
Granting prerequisite permissions with an execution role
Your training script will run in a sandboxed environment on SageMaker nodes. The “execution role” specified at launch time defines their access to AWS services.
A training script requires access to the following services:
- AWS SageMaker – to control the training job
- CloudWatch – to stream training logs
- ECR – to pull the DeepSpeed image from an internal registry
- S3 – to download the dataset and model from S3
You can create this role manually in the AWS Console or apply my sagemaker.tf with Terraform. The role name will be passed to CreateTrainingJob later.
Docker image with pre-installed DeepSpeed
After SageMaker provisions the compute nodes for a training job, it will run your Docker image on each node to start the training loop. The entrypoint program is responsible for reading SageMaker’s environment variables to communicate across nodes and set up a unified distributed training process.
Fortunately, we can install the sagemaker-training-toolkit library and configure it to start an OpenMPI session when the training job starts. The toolkit will launch the session from the master node by connecting to each worker node through SSH. The worker nodes then poll the OpenMPI session until it terminates.
The DeepSpeed repository already has a Dockerfile for setting up an PyTorch+DeepSpeed environment. To use it on SageMaker, you can use a modified version that adds sagemaker-training-toolkit and passwordless SSH – openblitz/deepspeed.
Putting the model and dataset on S3
SageMaker will download your training / validation datasets and also, optionally, your model from S3. I recommend uploading Huggingface models in S3 to avoid rate limits and speeding up training setup time.
The S3 URIs for the dataset and model are copied into a read-only filesystem at /opt/ml/input/data
in each Docker container.
The training script
Now that we’ve squared away the setup for SageMaker, we can finally get to the fun stuff – writing the training script!
To initialize DeepSpeed, we need to specify the master host’s address and port:
import json
import os
training_env = json.loads(os.environ["SM_TRAINING_ENV"])
os.environ["MASTER_ADDR"] = training_env["master_addr"]
os.environ["MASTER_PORT"] = training_env["master_port"]
Then, the model can be loaded in training mode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model.config.use_cache = False
model.train()
Now finally, we initialize a DeepSpeed engine the wraps the model
import deepspeed
engine, _, _, _ = deepspeed.initialize(
model=model,
config="deepspeed_config.json",
)
Last thing before the training loop – loading datasets! A distributed sampler should be used to ensure that data parallel workers pick examples from different partitions in the datasets.
from torch.distributed import DistributedSampler
from torch.utils.data import DataLoader
from datasets import load_dataset
dataset = load_dataset(
"json",
data_files="/opt/ml/input/data/train/train.jsonl",
split="train",
)
dataset.set_transform(lambda x: tokenizer(
["text"],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
))
dataloader = DataLoader(
dataset=dataset,
sampler=DistributedSampler(
dataset=dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
),
batch_size=1,
)
After loading the model and datasets, we can run a training loop run full-text fine-tuning:
loss = torch.nn.CrossEntropyLoss()
for epoch in range(10):
for step, batch in enumerate(dataloader):
input_ids = batch["input_ids"].to(engine.device)
attention_mask = batch["attention_mask"].to(engine.device)
logits = engine(input_ids, attention_mask)["logits"]
# Shift labels by one to the left for autoregressive prediction
logits = logits[:, :-1, :]
labels = input_ids[:, 1:]
engine.backward(loss(logits, labels))
engine.step()
# engine.step() will automatically zero the gradients
At the end, all workers must save the model into a temporary directory since our Docker image uses a non-root user that doesn’t have access to /opt/ml/model
. This is mostly a relic of DeepSpeed’s original Dockerfile.
A sudo mv
command is used to copy the model to the output directory.
import subprocess
engine.save_16bit_model(save_dir="/home/deepspeed/tmp/model")
subprocess.run(
["bash", "-c", f"sudo mv /home/deepspeed/tmp/model/* {os.environ['SM_MODEL_DIR']}"],
check=True,
)
Here’s the sample program: fine_tune.py
Launching a training job
To start a training job, we’ll invoke the CreateTrainingJob endpoint using the SageMaker Python SDK.
import sagemaker
estimator = sagemaker.estimator.Estimator(
image_uri="https://registry.hub.docker.com/v2/openblitz/deepspeed",
role="YourSageMakerRole",
instance_count=1,
instance_type="ml.p4d.24xlarge",
output_path="s3://yourbucket/output",
base_job_name="test-fine-tune",
hyperparameters={
# This tells the training toolkit to use OpenMPI to launch the distributed training session
"sagemaker_mpi_enabled": True,
# This is the number of GPUs in ml.p4d.24xlarge instances
"sagemaker_mpi_num_of_processes_per_host": 8,
"sagemaker_mpi_custom_mpi_options": "--NCCL_DEBUG=WARN",
},
environment={
# This fixes an issue with /tmp becoming corrupted
"TMPDIR": "/home/deepspeed/tmp",
},
# SageMaker will upload fine_tune.py in your current directory
source_dir=path.dirname(path.abspath(__file__)),
entry_point="fine_tune.py",
)
estimator.fit(
{"train": "s3://yourbucket/train.jsonl"},
wait=True,
)
And there you go, now you can fine-tune any large language model on SageMaker! If you made it this far, I’d love to know what you’re training in the comments.
One reply on “How to use DeepSpeed on Amazon SageMaker”
I am training llama-3-70b thanks to this info – thank you for doing this wizardry dude!