Train a model to solve coding problems using GRPO and TRL
This example demonstrates how to run GRPO on Modal using the TRL GRPO trainer GRPO is a reinforcement learning algorithm introduced by DeepSeek, and was used to train DeepSeek R1. TRL is a reinforcement learning training library by Huggingface.
First we perform the imports and then define the app.
from __future__ import annotations
import os
import re
import subprocess
from pathlib import Path
from typing import Iterable, Sequence
import modal
app: modal.App = modal.App("grpo-trl-example")
We define an image where we install the TRL library. We also install vLLM for the next part of this example. We also use Weights & Biases for logging.
image: modal.Image = modal.Image.debian_slim().pip_install(
"trl[vllm]==0.19.1", "datasets==3.5.1", "wandb==0.17.6"
)
We import the necessary libraries needed in the context of the image.
with image.imports():
from datasets import Dataset, load_dataset
from trl import GRPOConfig, GRPOTrainer
We also define a Modal Volume for storing model checkpoints.
MODELS_DIR = Path("/models")
checkpoints_volume: modal.Volume = modal.Volume.from_name(
"grpo-trl-example-checkpoints", create_if_missing=True
)
Defining the reward function
In this example, we use the OpenCoder-LLM/opc-sft-stage2 dataset to train a model to solve coding problems.
In reinforcement learning, we define a reward function for the model. Since we are evaluating code that is generated by a model, we use Modal Sandboxes to evaluate the code securely.
For each completion from the model and a test case to test the completion, we define a simple reward function. The function returns 1 if there are no errors, and 0 otherwise. You might want to adjust this reward function as the model is unlikely to learn well with this function.
@app.function()
def compute_reward(completion: str, testcase: Sequence[str]) -> int:
sb, score = None, 0
sb: modal.Sandbox = modal.Sandbox.create(app=app)
code_to_execute: str = get_generated_code_and_test_cases(completion, testcase)
try:
p = sb.exec("python", "-c", code_to_execute, timeout=30)
p.wait()
return_code = p.returncode
if return_code == 0:
score = 1
except Exception as e:
print(e)
finally:
sb.terminate()
return score
We write a function that constructs a program from the model completion. This is determined based on the format of the data. The completions are supposed to follow the format “```python …“. The test cases are a list of assert statements. More details here.
def get_generated_code_and_test_cases(completion: str, testcase: Sequence[str]) -> str:
if "```python" in completion:
# Find the start and end of the code block
start_idx: int = completion.find("```python") + len("```python")
end_idx: int = completion.find("```", start_idx)
if end_idx != -1:
code: str = completion[start_idx:end_idx].strip()
else:
code: str = completion[start_idx:].strip()
else:
code: str = completion.strip()
test_cases: str = "\n".join(testcase)
full_code: str = f"{code}\n\n{test_cases}"
return full_code
Finally, we define the function that is passed into the GRPOTrainer, which takes in a list of completions. Custom reward functions must conform to a specific signature.
def reward_helper_function(
completions: Sequence[str], testcases: Sequence[Sequence[str]], **kwargs: object
) -> Iterable[int]:
return compute_reward.starmap(zip(completions, testcases))
Kicking off a training run
Preprocess the data, preparing the columns that GRPOTrainer
expects.
We use the OpenCoder-LLM educational instruct dataset, which has (instruction, code, test case) triples validated through a Python compiler.
More details here.
def start_grpo_trainer(use_vllm=False, vllm_mode=None):
dataset: Dataset = load_dataset(
"OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train"
)
dataset = dataset.rename_column(
"instruction", "prompt"
) # Needed for the GRPO trainer
dataset = dataset.rename_column("testcase", "testcases")
dataset = dataset.select(range(128)) # To simplify testing.
training_args: GRPOConfig = GRPOConfig(
output_dir=str(MODELS_DIR),
report_to="wandb",
use_vllm=use_vllm,
vllm_mode=vllm_mode,
save_steps=1,
max_steps=5, # To simplify testing. Remove for production use cases.
)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_helper_function,
args=training_args,
train_dataset=dataset,
)
trainer.train()
We use Weights & Biases for logging, hence we use a Modal Secret with wandb credentials.
@app.function(
image=image,
gpu="H100",
timeout=60 * 60 * 24, # 24 hours
secrets=[modal.Secret.from_name("wandb-secret")],
volumes={"/models": checkpoints_volume},
)
def train() -> None:
start_grpo_trainer()
To run: modal run --detach grpo_trl.py::train
.
Speeding up training with vLLM
vLLM can be used either in server mode (run vLLM server on separate gpu) or colocate mode (within the training process). In server mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference. More details here. Here, we use 2 GPUs. We run the GRPOTrainer on 1 of them, and the vLLM process on another.
@app.function(
image=image,
gpu="H100:2",
timeout=60 * 60 * 24, # 24 hours
secrets=[modal.Secret.from_name("wandb-secret")],
volumes={str(MODELS_DIR): checkpoints_volume},
)
def train_vllm_server_mode() -> None:
env_copy = os.environ.copy()
env_copy["CUDA_VISIBLE_DEVICES"] = "0" # Run serve vLLM process on GPU 0
# Start vllm-serve in the background
subprocess.Popen(
["trl", "vllm-serve", "--model", "Qwen/Qwen2-0.5B-Instruct"],
env=env_copy,
)
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Run training process on GPU 1
start_grpo_trainer(use_vllm=True, vllm_mode="server")
You can execute this using modal run --detach grpo_trl.py::train_vllm_server_mode
.
In colocate mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. More details here.
@app.function(
image=image,
gpu="H100",
timeout=60 * 60 * 24, # 24 hours
secrets=[modal.Secret.from_name("wandb-secret")],
volumes={"/models": checkpoints_volume},
)
def train_vllm_colocate_mode() -> None:
# Rank of the current process (0 for single-process training)
os.environ["RANK"] = "0"
# Local rank of the process on the node (0 for single-process training)
os.environ["LOCAL_RANK"] = "0"
# Total number of processes (1 for single-process training)
os.environ["WORLD_SIZE"] = "1"
# Address of the master node (localhost for single node)
os.environ["MASTER_ADDR"] = "localhost"
# Port for communication between processes
os.environ["MASTER_PORT"] = "12355"
start_grpo_trainer(use_vllm=True, vllm_mode="colocate")
You can execute this using modal run --detach grpo_trl.py::train_vllm_colocate_mode
.
Performing inference on the trained model
We use vLLM to perform inference on the trained model.
VLLM_PORT: int = 8000
Once you have the model checkpoints in your Modal Volume, you can load the weights and perform inference using vLLM.
The weights path is as follows: global_step_n/actor/huggingface
where n is the checkpoint you want (eg global_step_5/actor/huggingface
).
The latest_checkpointed_iteration.txt
file stores the most recent checkpoint index.
def get_latest_checkpoint_file_path():
checkpoint_dirs = [
d.name
for d in MODELS_DIR.iterdir()
if d.is_dir() and re.match(r"^checkpoint-(\d+)$", d.name)
]
if not checkpoint_dirs:
raise FileNotFoundError("No checkpoint directories found in models dir")
latest_checkpoint_index = max(
int(re.match(r"^checkpoint-(\d+)$", d).group(1)) for d in checkpoint_dirs
)
return str(MODELS_DIR / f"checkpoint-{latest_checkpoint_index}")
We provide the code for setting up an OpenAI compatible inference endpoint here. For more details re. serving models on vLLM, check out this example.
vllm_image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"vllm==0.9.1",
"flashinfer-python==0.2.6.post1",
extra_index_url="https://download.pytorch.org/whl/cu128",
)
.env({"VLLM_USE_V1": "1"})
)
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
@app.function(
image=vllm_image,
gpu="H100",
scaledown_window=15 * 60, # How long should we stay up with no requests?
timeout=10 * 60, # How long should we wait for container start?
volumes={"/root/.cache/vllm": vllm_cache_vol, MODELS_DIR: checkpoints_volume},
)
@modal.concurrent(
max_inputs=32
) # How many requests can one replica handle? tune carefully!
@modal.web_server(port=VLLM_PORT, startup_timeout=10 * 60)
def serve():
latest_checkpoint_file_path = get_latest_checkpoint_file_path()
cmd = [
"vllm",
"serve",
"--uvicorn-log-level=info",
latest_checkpoint_file_path,
"--host",
"0.0.0.0",
"--port",
str(VLLM_PORT),
]
subprocess.Popen(" ".join(cmd), shell=True)
You can then deploy the server using modal deploy grpo_trl.py
, which gives you a custom url. You can then query it using the following curl command:
curl -X POST <url>/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"messages": [
{"role": "system", "content": "You are a helpful assistant for solving math problems."},
{"role": "user", "content": "James had 4 apples. Mary gave him 2 and he ate 1. How many does he have left?"}
],
"temperature": 0.7
}'
or in the following ways.