Enforcing a Thinking Budget During GRPO in verl

December 20, 2025

I recently started a new project within my lab, where I was tasked with implementing a thinking budget in GRPO1Group Relative Policy Optimization. In other words, we needed to be able to control the number of thinking tokens a reasoning model was allowed to produce during the rollout phase.

Since I obviously wasn't going to write a training stack from scratch, I decided to use verl as a jumping off point. Despite being familiar with TRL 2Hugging Face's Transformer Reinforcement Learning library and its GRPO implementation, I found it unable to support longer sequence lengths. Since we discussed running experiments with 6K+ tokens, TRL was no longer a valid option without some hacky workarounds. Additionally, most of the papers I've read used verl for their experiments, so I decided it was time to see what the hype was all about.

Background

Model families like Qwen3 and Olmo3 have explicit thinking modes, where the LLM will reason and correct itself before outputting its final answer. Typically, this sequence is spanned by special tokens like <think> and </think>. We'll refer to the ladder as the thinking delimiter from here on out.

This thinking phase oftentimes lasts for thousands of tokens, eating up the entire max_tokens limit set by the user. Even though it would be extremely useful to control how long the model can think for, this isn't natively supported by inference or training libraries.

Lucky for us, this isn't a difficult fix at all!

Implementation

The only part of the code we need to change is the rollout phase, where we will now need to make 2 calls to vLLM instead of just 1. Let's say we have a thinking budget of $n$ tokens and a response budget of $m$ tokens, meaning the longest possible output will be $n+m$ tokens long. Our first call to vLLM will create $n$ tokens for the base prompt. If the thinking delimiter is NOT found in the output ids, then we'll append it and generate $m$ new tokens. However, if the thinking delimiter is in the list if tokens, then we only generate $m - (n - \text{delimeter_pos} + 1)$ tokens. The prompt will be the output tokens from the first call, and delimiter_pos = output_ids.index(delimiter_id).

The original function can be found here and my implentation here, but for the sake of this blog post, I'll write a simplified version of the code.

def generate(
    self,
    prompt_ids: list[int],
    sampling_params: dict[str, Any],
    request_id: str,
    image_data: Optional[list[Any]] = None,
) -> TokenOutput:
    """Generate sequence with token-in-token-out."""
    # TODO(@wuxibin): switch to `/generate` http endpoint once multi-modal support ready.
    max_tokens = self.config.max_model_len - len(prompt_ids)
    sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None
    sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
    sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
    prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
    prompt = TokensPrompt(
        prompt_token_ids=prompt_ids, 
        multi_modal_data={"image": image_data} if image_data else None
    )

    # Add lora request
    lora_request = None
    if self.model_config.lora_rank > 0:
        # Make sure we also check that the lora is already loaded in the engine
        lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras()
        if lora_loaded:
            lora_request = LoRARequest(
                lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH
            )

    generator = self.engine.generate(
        prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request
    )

    # Get final response
    final_res: Optional[RequestOutput] = None
    async for output in generator:
        final_res = output
    assert final_res is not None

    token_ids = final_res.outputs[0].token_ids
    log_probs = None
    if sampling_params.logprobs is not None:
        log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(final_res.outputs[0].logprobs)]

    routed_experts = None
    if self.config.enable_rollout_routing_replay:
        routed_experts = final_res.outputs[0].routed_experts

    # Determine stop reason from finish_reason
    finish_reason = final_res.outputs[0].finish_reason
    if finish_reason == "abort":
        stop_reason = "aborted"
    elif finish_reason in ("stop", "length"):
        stop_reason = "completed"
    else:
        stop_reason = finish_reason  # for more stop reason in the future

    return TokenOutput(
        token_ids=token_ids, log_probs=log_probs, routed_experts=routed_experts, stop_reason=stop_reason
        )

Here's another example with Python:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)