Skip to content

vLLM Logits Processor

LogitsLoggingProcessor

Bases: LogitsProcessor

Custom vLLM logits processor that logs normalization constants and entropy to Valkey.

This processor intercepts logits during generation to calculate and store normalization constants and entropy values. These values are stored in Valkey for retrieval by the main process after generation completes.

Attributes:

Name Type Description
active_req_ids Dict[int, sampling_params]

Dictionary mapping request indices to their sampling parameters.

valkey_client

Valkey client for storing computed values.

temperature

Default temperature value.

Source code in pita/inference/vllm_logits_processor.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class LogitsLoggingProcessor(LogitsProcessor):
    """
    Custom vLLM logits processor that logs normalization constants and entropy to Valkey.

    This processor intercepts logits during generation to calculate and store normalization
    constants and entropy values. These values are stored in Valkey for retrieval by the
    main process after generation completes.

    Attributes:
        active_req_ids: Dictionary mapping request indices to their sampling parameters.
        valkey_client: Valkey client for storing computed values.
        temperature: Default temperature value.
    """
    def __init__(
         self,
        vllm_config: VllmConfig,
        device: torch.device,
        is_pin_memory: bool
    ) -> None:
        """
        Initialize the LogitsLoggingProcessor.

        Args:
            vllm_config: vLLM configuration object.
            device: PyTorch device for tensor operations.
            is_pin_memory: Whether to use pinned memory for tensors.
        """
        self.active_req_ids: Dict[int, sampling_params] = {}
        self.valkey_client = None
        self.temperature = 1.0  # Default temperature, can be configured per request

    def _ensure_valkey(self) -> None:
        """
        Ensure Valkey client is initialized and connected.

        Lazily initializes the Valkey connection on first use to avoid connection
        issues during processor instantiation.
        """
        if self.valkey_client is None:
            try:
                self.valkey_client = valkey.Valkey(
                    host=VALKEY_HOST, port=VALKEY_PORT, db=0, decode_responses=True
                )
            except Exception as e:
                # If we can't log to Valkey, we are flying blind, but try printing just in case
                print(f"CRITICAL WORKER ERROR: Valkey connect failed: {e}")

    def is_argmax_invariant(self) -> bool:
        """
        Indicate whether this processor changes which token has the highest probability.

        Returns:
            False to ensure apply() is always called, even when it doesn't change argmax.
        """
        return False  # Must be False to ensure apply() is called

    def update_state(self, batch_update: Optional[BatchUpdate]) -> None:
        """
        Update processor state when requests are added, removed, or moved in the batch.

        This method is called by vLLM to notify the processor of batch changes. It tracks
        request IDs and their associated sampling parameters.

        Args:
            batch_update: Information about requests added, removed, or moved in the batch.
                Can be None if no updates occurred.
        """
        self._ensure_valkey()

        if batch_update is None:
            return

        for req_index, params, _, _ in batch_update.added:
            # Debug: Check if extra_args survived the trip
            args_str = str(params.extra_args) if params.extra_args else "None"

            # Update the req_id map
            if params.extra_args and "req_id" in params.extra_args:
                req_id = params.extra_args["req_id"]
                self.active_req_ids[req_index] = sampling_params(
                    req_id, 
                    params.extra_args.get("normalization_constants", False), 
                    params.temperature, 
                    params.extra_args.get("entropy", False), 
                    params.extra_args.get("entropy_inference", False), 
                    params.extra_args.get("gradient_steps", 0), 
                    params.extra_args.get("learning_rate", 0.0), 
                    params.extra_args.get("delta", 0.0)
                )
            else:
                print(f"WARNING: No req_id found in extra_args for req_index {req_index}. extra_args: {args_str}. Logits logging will be skipped for this request.")

        # Handle removals to keep map clean
        for req_index in batch_update.removed:
            if req_index in self.active_req_ids:
                self.active_req_ids.pop(req_index)

        # Handle index movements 
        for from_idx, to_idx, direction in batch_update.moved:
            if direction == MoveDirectionality.SWAP:
                self.active_req_ids[to_idx], self.active_req_ids[from_idx] = (
                    self.active_req_ids[from_idx], self.active_req_ids[to_idx]
                )
            else:
                if from_idx in self.active_req_ids:
                    self.active_req_ids[to_idx] = self.active_req_ids[from_idx]
                    del self.active_req_ids[from_idx]

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Process logits to calculate and log normalization constants and entropy.

        This method is called by vLLM for each token generation step. It calculates
        normalization constants (logsumexp) and entropy values, then stores them in
        Redis for later retrieval.

        Args:
            logits: Raw logits tensor of shape (batch_size, vocab_size).

        Returns:
            The unmodified logits tensor (this processor only observes, doesn't modify).
        """
        self._ensure_valkey()

        if not self.active_req_ids:
            print("WARNING: active_req_ids is empty in apply()!")
            return logits

        # Store the max_logits and shift_logits of each request
        log_norm_constant = torch.zeros(len(self.active_req_ids), device=logits.device)
        log_norm_constant_temp_scaled = torch.zeros(len(self.active_req_ids), device=logits.device)
        entropy = torch.zeros(len(self.active_req_ids), device=logits.device)

        # Calculate the Normalization Constants if normalization_constants = True or entropy = True
        for row_idx, params in self.active_req_ids.items():
            if params.normalization_constants:
                # Calculate the Normalization Constants if required
                log_norm_constant[row_idx] = torch.logsumexp(logits[row_idx], dim=-1)
                log_norm_constant_temp_scaled[row_idx] = torch.logsumexp(logits[row_idx] / params.temperature, dim=-1)
            # If entropy = True, calculate the entropy
            if params.entropy:
                entropy[row_idx] = Categorical(logits=logits[row_idx]).entropy()


        # Prepare pipeline for batch Valkey operations
        pipe = self.valkey_client.pipeline()

        found_any = False
        for row_idx, params in self.active_req_ids.items():
            req_id = params.req_id
            if row_idx < logits.size(0):
                # Store as JSON-like string with all normalization info
                data = f"{log_norm_constant[row_idx]},{log_norm_constant_temp_scaled[row_idx]},{entropy[row_idx]}"
                pipe.rpush(req_id, data)
                found_any = True
            else:
                print(f"row_idx {row_idx} >= batch size {logits.size(0)}, skipping")

        # Push values to Valkey if we found any valid ones
        if found_any:
            pipe.execute()

        return logits

__init__(vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool) -> None

Initialize the LogitsLoggingProcessor.

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM configuration object.

required
device device

PyTorch device for tensor operations.

required
is_pin_memory bool

Whether to use pinned memory for tensors.

required
Source code in pita/inference/vllm_logits_processor.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
     self,
    vllm_config: VllmConfig,
    device: torch.device,
    is_pin_memory: bool
) -> None:
    """
    Initialize the LogitsLoggingProcessor.

    Args:
        vllm_config: vLLM configuration object.
        device: PyTorch device for tensor operations.
        is_pin_memory: Whether to use pinned memory for tensors.
    """
    self.active_req_ids: Dict[int, sampling_params] = {}
    self.valkey_client = None
    self.temperature = 1.0  # Default temperature, can be configured per request

apply(logits: torch.Tensor) -> torch.Tensor

Process logits to calculate and log normalization constants and entropy.

This method is called by vLLM for each token generation step. It calculates normalization constants (logsumexp) and entropy values, then stores them in Redis for later retrieval.

Parameters:

Name Type Description Default
logits Tensor

Raw logits tensor of shape (batch_size, vocab_size).

required

Returns:

Type Description
Tensor

The unmodified logits tensor (this processor only observes, doesn't modify).

Source code in pita/inference/vllm_logits_processor.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    """
    Process logits to calculate and log normalization constants and entropy.

    This method is called by vLLM for each token generation step. It calculates
    normalization constants (logsumexp) and entropy values, then stores them in
    Redis for later retrieval.

    Args:
        logits: Raw logits tensor of shape (batch_size, vocab_size).

    Returns:
        The unmodified logits tensor (this processor only observes, doesn't modify).
    """
    self._ensure_valkey()

    if not self.active_req_ids:
        print("WARNING: active_req_ids is empty in apply()!")
        return logits

    # Store the max_logits and shift_logits of each request
    log_norm_constant = torch.zeros(len(self.active_req_ids), device=logits.device)
    log_norm_constant_temp_scaled = torch.zeros(len(self.active_req_ids), device=logits.device)
    entropy = torch.zeros(len(self.active_req_ids), device=logits.device)

    # Calculate the Normalization Constants if normalization_constants = True or entropy = True
    for row_idx, params in self.active_req_ids.items():
        if params.normalization_constants:
            # Calculate the Normalization Constants if required
            log_norm_constant[row_idx] = torch.logsumexp(logits[row_idx], dim=-1)
            log_norm_constant_temp_scaled[row_idx] = torch.logsumexp(logits[row_idx] / params.temperature, dim=-1)
        # If entropy = True, calculate the entropy
        if params.entropy:
            entropy[row_idx] = Categorical(logits=logits[row_idx]).entropy()


    # Prepare pipeline for batch Valkey operations
    pipe = self.valkey_client.pipeline()

    found_any = False
    for row_idx, params in self.active_req_ids.items():
        req_id = params.req_id
        if row_idx < logits.size(0):
            # Store as JSON-like string with all normalization info
            data = f"{log_norm_constant[row_idx]},{log_norm_constant_temp_scaled[row_idx]},{entropy[row_idx]}"
            pipe.rpush(req_id, data)
            found_any = True
        else:
            print(f"row_idx {row_idx} >= batch size {logits.size(0)}, skipping")

    # Push values to Valkey if we found any valid ones
    if found_any:
        pipe.execute()

    return logits

is_argmax_invariant() -> bool

Indicate whether this processor changes which token has the highest probability.

Returns:

Type Description
bool

False to ensure apply() is always called, even when it doesn't change argmax.

Source code in pita/inference/vllm_logits_processor.py
86
87
88
89
90
91
92
93
def is_argmax_invariant(self) -> bool:
    """
    Indicate whether this processor changes which token has the highest probability.

    Returns:
        False to ensure apply() is always called, even when it doesn't change argmax.
    """
    return False  # Must be False to ensure apply() is called

update_state(batch_update: Optional[BatchUpdate]) -> None

Update processor state when requests are added, removed, or moved in the batch.

This method is called by vLLM to notify the processor of batch changes. It tracks request IDs and their associated sampling parameters.

Parameters:

Name Type Description Default
batch_update Optional[BatchUpdate]

Information about requests added, removed, or moved in the batch. Can be None if no updates occurred.

required
Source code in pita/inference/vllm_logits_processor.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def update_state(self, batch_update: Optional[BatchUpdate]) -> None:
    """
    Update processor state when requests are added, removed, or moved in the batch.

    This method is called by vLLM to notify the processor of batch changes. It tracks
    request IDs and their associated sampling parameters.

    Args:
        batch_update: Information about requests added, removed, or moved in the batch.
            Can be None if no updates occurred.
    """
    self._ensure_valkey()

    if batch_update is None:
        return

    for req_index, params, _, _ in batch_update.added:
        # Debug: Check if extra_args survived the trip
        args_str = str(params.extra_args) if params.extra_args else "None"

        # Update the req_id map
        if params.extra_args and "req_id" in params.extra_args:
            req_id = params.extra_args["req_id"]
            self.active_req_ids[req_index] = sampling_params(
                req_id, 
                params.extra_args.get("normalization_constants", False), 
                params.temperature, 
                params.extra_args.get("entropy", False), 
                params.extra_args.get("entropy_inference", False), 
                params.extra_args.get("gradient_steps", 0), 
                params.extra_args.get("learning_rate", 0.0), 
                params.extra_args.get("delta", 0.0)
            )
        else:
            print(f"WARNING: No req_id found in extra_args for req_index {req_index}. extra_args: {args_str}. Logits logging will be skipped for this request.")

    # Handle removals to keep map clean
    for req_index in batch_update.removed:
        if req_index in self.active_req_ids:
            self.active_req_ids.pop(req_index)

    # Handle index movements 
    for from_idx, to_idx, direction in batch_update.moved:
        if direction == MoveDirectionality.SWAP:
            self.active_req_ids[to_idx], self.active_req_ids[from_idx] = (
                self.active_req_ids[from_idx], self.active_req_ids[to_idx]
            )
        else:
            if from_idx in self.active_req_ids:
                self.active_req_ids[to_idx] = self.active_req_ids[from_idx]
                del self.active_req_ids[from_idx]

sampling_params dataclass

Sampling parameters for logits processing requests.

Attributes:

Name Type Description
req_id str

Unique identifier for the request.

normalization_constants bool

Whether to calculate normalization constants.

temperature float

Sampling temperature value.

entropy bool

Whether to calculate entropy.

entropy_inference bool

Whether entropy is used for inference decisions.

gradient_steps int

Number of gradient steps for optimization.

learning_rate float

Learning rate for optimization.

delta float

Delta value for optimization adjustments.

Source code in pita/inference/vllm_logits_processor.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class sampling_params:
    """
    Sampling parameters for logits processing requests.

    Attributes:
        req_id: Unique identifier for the request.
        normalization_constants: Whether to calculate normalization constants.
        temperature: Sampling temperature value.
        entropy: Whether to calculate entropy.
        entropy_inference: Whether entropy is used for inference decisions.
        gradient_steps: Number of gradient steps for optimization.
        learning_rate: Learning rate for optimization.
        delta: Delta value for optimization adjustments.
    """
    req_id: str
    normalization_constants: bool
    temperature: float
    entropy: bool
    entropy_inference: bool
    gradient_steps: int
    learning_rate: float
    delta: float