Skip to content

vLLM Backend

check_token_metric_compatibility(sampler: AutoregressiveSampler, token_metric: str)

Check that the vLLM engine can support the given token metric with the given configuration.

Parameters:

Name Type Description Default
sampler AutoregressiveSampler

The sampler object containing sampling parameters and the LLM engine.

required
token_metric str

The token metric to check compatibility for.

required

Raises:

Type Description
ValueError

If logits_per_token is not set.

ValueError

If vLLM engine logprobs_mode is not 'raw_logits'.

ValueError

If 'req_id' is not in extra_args.

Source code in pita/inference/vllm_backend.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def check_token_metric_compatibility(
    sampler: AutoregressiveSampler, 
    token_metric: str):
    """
    Check that the vLLM engine can support the given token metric with the given configuration.

    Args:
        sampler: The sampler object containing sampling parameters and the LLM engine.
        token_metric: The token metric to check compatibility for.

    Raises:
        ValueError: If logits_per_token is not set.
        ValueError: If vLLM engine logprobs_mode is not 'raw_logits'.
        ValueError: If 'req_id' is not in extra_args.
    """
    if (token_metric == "logprobs" or token_metric == "power_distribution" or token_metric == "entropy" or token_metric == "likelihood_confidence"):
        # Make sure the user has actually set logits_per_token
        if(sampler.sampling_params.logits_per_token < 1):
            raise ValueError("LLM engine logits_per_token must be set to at least 1 to enable power sampling.")

        # For vLLM, make sure that logprobs_mode is set to 'raw_logits' to get unprocessed logits
        if(sampler.llm.llm_engine.model_config.logprobs_mode != 'raw_logits'):
            raise ValueError(
                f"vLLM engine logprobs_mode must be set to 'raw_logits' to enable power sampling."
                f"\nvLLM engine logprobs_mode is set to {sampler.llm.llm_engine.model_config.logprobs_mode}." 
                f"\nThis is done by setting logits=True when creating the LLM object."
                            )
        # Print all the extra_args of the vLLM SamplingParams
        print("vLLM SamplingParams extra_args:", sampler.sampling_params.engine_params.extra_args)  

        # Make sure the user has enabled the logits processor
        if('req_id' not in sampler.sampling_params.engine_params.extra_args):
            raise ValueError("req_id must be set to use power sampling with vLLM.")

        # Set the normalization constant in the extra_args of the vLLM SamplingParams to True
        if(token_metric == "logprobs" or token_metric == "power_distribution" or token_metric == "likelihood_confidence"):
            sampler.sampling_params.enable_normalization_constants = True
            print("Enabled normalization constants in vLLM SamplingParams for power sampling.")

        if(token_metric == "entropy" or token_metric == "likelihood_confidence"):
            sampler.sampling_params.enable_entropy = True
            print("Enabled entropy in vLLM SamplingParams for power sampling.")

create_LLM_object(model_name: str, model_type: str | None = None, dtype: str = 'auto', gpu_memory_utilization: float = 0.85, max_model_len: int = 2048, max_probs: int = 1000, logits_processor: bool = False, **kwargs: Any) -> LLM

Create the LLM object given the model name and engine parameters.

Parameters:

Name Type Description Default
model_name str

The name of the model to load.

required
model_type str

The type of model (e.g., 'safetensors', 'gguf'). Defaults to None.

None
dtype str

The data type to use. Defaults to "auto".

'auto'
gpu_memory_utilization float

The fraction of GPU memory to use. Defaults to 0.85.

0.85
max_model_len int

The maximum length of the model context. Defaults to 2048.

2048
max_probs int

Controls how many logprobs or logits are stored for each token. Defaults to 1000.

1000
logits_processor bool

Whether to enable the Redis logging logits processor. Defaults to False.

False
**kwargs Any

Additional keyword arguments passed to the LLM constructor.

{}

Returns:

Name Type Description
LLM LLM

The initialized vLLM LLM object.

Source code in pita/inference/vllm_backend.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def create_LLM_object(
        model_name: str,
        model_type: str | None = None,
        dtype: str = "auto",
        gpu_memory_utilization: float = 0.85,
        max_model_len: int = 2048,
        max_probs: int = 1000,
        logits_processor: bool = False,
        **kwargs: Any
    ) -> LLM:
    """
    Create the LLM object given the model name and engine parameters.

    Args:
        model_name (str): The name of the model to load.
        model_type (str, optional): The type of model (e.g., 'safetensors', 'gguf'). Defaults to None.
        dtype (str, optional): The data type to use. Defaults to "auto".
        gpu_memory_utilization (float, optional): The fraction of GPU memory to use. Defaults to 0.85.
        max_model_len (int, optional): The maximum length of the model context. Defaults to 2048.
        max_probs (int, optional): Controls how many logprobs or logits are stored for each token. Defaults to 1000.
        logits_processor (bool, optional): Whether to enable the Redis logging logits processor. Defaults to False.
        **kwargs: Additional keyword arguments passed to the LLM constructor.

    Returns:
        LLM: The initialized vLLM LLM object.
    """

    if(logits_processor):
        # Enable the Valkey logging logits processor by adding it to the kwargs
        kwargs["logits_processors"] = [LogitsLoggingProcessor]
        ValkeyManager.start()
        print("LogitsLoggingProcessor enabled. Logits will be logged.")
    else:
        print("LogitsLoggingProcessor not enabled. Logits will not be logged.")

    # Initialize VLLM locally for performance (as done in power_sample.py main)
    llm = LLM(model=model_name,
              dtype=dtype,
              gpu_memory_utilization=gpu_memory_utilization,
              max_model_len=max_model_len,
              max_logprobs=max_probs, # Controls how many logprobs or logits are stored for each token
              logprobs_mode='raw_logits',
              **kwargs)

    return llm

create_vllm_engine_params() -> SamplingParams

Create the vLLM SamplingParams object from the common Sampling_Params.

Returns:

Name Type Description
SamplingParams SamplingParams

A new instance of vLLM SamplingParams.

Source code in pita/inference/vllm_backend.py
142
143
144
145
146
147
148
149
150
151
def create_vllm_engine_params() -> SamplingParams:
    """
    Create the vLLM SamplingParams object from the common Sampling_Params.

    Returns:
        SamplingParams: A new instance of vLLM SamplingParams.
    """
    # Create the vLLM SamplingParams object from the common Sampling_Params
    vllm_params = SamplingParams()
    return vllm_params

sample(self, context: str | list[str], **kwargs: Any) -> Output

Generate text from the given context using the vLLM engine.

Parameters:

Name Type Description Default
context str | list[str]

The input context string to generate from.

required
**kwargs Any

Additional keyword arguments passed to the vLLM generate function.

{}

Returns:

Name Type Description
Output Output

An Output object containing: - tokens: The generated token IDs. - top_k_logits: The top_k logits (if logits_per_token is set). - top_k_logprobs: The top_k logprobs (if logprobs is set). - unprocessed_log_normalization_constant: The log(Normalization Constants - Unprocessed) for each token. - temp_processed_log_normalization_constant: The log(Normalization Constants - Temperature Processed) for each token. - entropy: The entropy for each token.

Source code in pita/inference/vllm_backend.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def sample(
        self,
        context: str | list[str],
        **kwargs: Any
    ) -> Output:
    """
    Generate text from the given context using the vLLM engine.

    Args:
        context (str | list[str]): The input context string to generate from.
        **kwargs: Additional keyword arguments passed to the vLLM generate function.

    Returns:
        Output: An Output object containing:
            - tokens: The generated token IDs.
            - top_k_logits: The top_k logits (if logits_per_token is set).
            - top_k_logprobs: The top_k logprobs (if logprobs is set).
            - unprocessed_log_normalization_constant: The log(Normalization Constants - Unprocessed) for each token.
            - temp_processed_log_normalization_constant: The log(Normalization Constants - Temperature Processed) for each token.
            - entropy: The entropy for each token.
    """

    # Generate a new response from the LLM
    llm_output = self.llm.generate(
        context, 
        sampling_params=self.sampling_params.engine_params, 
        **kwargs
    )

    # Get the generated tokens
    tokens = llm_output[0].outputs[0].token_ids

    # Create a 2D array of NaNs to hold the logits
    logits_expected = max(self.sampling_params.logprobs_per_token or 0, self.sampling_params.logits_per_token or 0)
    logits = np.full((len(tokens), 1 + logits_expected), np.nan, dtype=float)
    for token_idx in range(len(tokens)):
        for logit_idx, values in enumerate(llm_output[0].outputs[0].logprobs[token_idx].values()):
            logits[token_idx][logit_idx] = values.logprob

    # Get the Normalization Constants from Redis
    unprocessed_log_normalization_constant = []
    temp_processed_log_normalization_constant = []
    entropy = []
    if (hasattr(self.sampling_params.engine_params, 'extra_args') and 'req_id' in self.sampling_params.engine_params.extra_args):        
        # Set the req_id used to store the normalization constants in Redis
        req_id = self.sampling_params.engine_params.extra_args["req_id"]

        # Create a local Valkey client to retrieve the normalization constants
        valkey_client = valkey.Valkey(host=VALKEY_HOST, port=VALKEY_PORT, db=0, decode_responses=True)

        # Retrieve the normalization constants from Valkey using the req_id
        normalization_terms = valkey_client.lrange(req_id, 0, -1)

        # Clean up the Valkey key after retrieval
        valkey_client.delete(req_id)

        # Parse the normalization terms (format: "norm_val,norm_temp_val,max_val")
        for term in normalization_terms:
            parts = term.split(',')
            unprocessed_log_normalization_constant.append(float(parts[0]))
            temp_processed_log_normalization_constant.append(float(parts[1]))
            entropy.append(float(parts[2]))

    # Find the logprobs for each token with the logits and temp_processed_log_normalization_constant
    logprobs = (logits / self.sampling_params.engine_params.temperature) - np.array(temp_processed_log_normalization_constant)[:, np.newaxis]    

    # Create the output object
    output = Output(
        tokens=tokens,
        top_k_logits=logits[:, :self.sampling_params.logits_per_token],
        top_k_logprobs=logprobs[:, :self.sampling_params.logprobs_per_token],
        unprocessed_log_normalization_constant=unprocessed_log_normalization_constant,
        temp_processed_log_normalization_constant=temp_processed_log_normalization_constant,
        entropy=entropy
    )

    # Returns the output object
    return output