Skip to content

Sampling Strategy Examples

pita provides advanced sampling strategies to improve the quality and reasoning capabilities of models.

Power Sampling

Power Sampling uses Metropolis-Hastings MCMC to iteratively refine generated tokens. It operates at the token level, proposing and accepting/rejecting token replacements based on a decision metric.

from pita.inference.LLM_backend import AutoregressiveSampler

# Initialize sampler
sampler = AutoregressiveSampler(
    engine="vllm",
    model="Qwen/Qwen2.5-0.5B-Instruct",
    logits_processor=True  # Required for power sampling
)

# Enable power sampling
sampler.enable_power_sampling(
    block_size=250,          # Tokens generated per block
    MCMC_steps=3,            # Number of MCMC refinement steps
    token_metric="power_distribution"  # Metric for accept/reject decisions
)

# Use token sampling
prompt = "Solve the equation: 3x + 7 = 22"
output = sampler.token_sample(prompt)
generated_text = sampler.tokenizer.decode(output.output_ids)
print(generated_text)

Available Token Metrics for Power Sampling

  • "logprobs": Standard log probability scoring
  • "power_distribution": Temperature-scaled power distribution (recommended)
  • "entropy": Entropy-based metric
  • "likelihood_confidence": Combined probability and confidence

Sequential Monte Carlo (SMC)

SMC maintains multiple candidate sequences (particles) and selectively prunes/extends them based on quality metrics. It operates at the chain level.

from pita.inference.LLM_backend import AutoregressiveSampler

# Initialize sampler
sampler = AutoregressiveSampler(
    engine="vllm",
    model="Qwen/Qwen2.5-0.5B-Instruct",
    logits_processor=True  # Required for SMC metrics
)

# Enable SMC
sampler.enable_smc(
    num_particles=5,         # Number of candidate sequences to maintain
    tokens_per_step=50,      # Tokens generated per SMC step
    stop_on_eos=True,        # Stop when EOS token is generated
    token_metric="likelihood_confidence",  # Metric for particle scoring
    aggregation="last"       # How to aggregate token scores ("last", "minimum", "product")
)

# Use chain sampling
prompt = "Write a detailed explanation of photosynthesis."
output = sampler.chain_sample(prompt)
generated_text = sampler.tokenizer.decode(output.output_ids)
print(generated_text)

SMC Aggregation Methods

  • "last": Use only the last token's metric for scoring
  • "minimum": Use the minimum metric across all tokens
  • "product": Multiply metrics across all tokens
  • "model_aggregate": Custom model-based aggregation (WIP)

Combining Strategies (Advanced)

You can combine chain-level and token-level strategies for hybrid scaling.

from pita.inference.LLM_backend import AutoregressiveSampler

# Initialize sampler
sampler = AutoregressiveSampler(
    engine="vllm",
    model="Qwen/Qwen2.5-0.5B-Instruct",
    logits_processor=True
)

# Enable token-level (Power Sampling)
sampler.enable_power_sampling(
    block_size=200,
    MCMC_steps=2,
    token_metric="power_distribution"
)

# Use power sampling
output_power = sampler.token_sample(prompt)

Using Sampling Strategies via API

You can trigger sampling strategies via the API server using special system prompts:

import openai

client = openai.OpenAI(
    base_url="http://localhost:8001/v1",
    api_key="none"
)

# Power Sampling via API: ITS PS_<max_tokens>_<block_size>_<MCMC_steps>
response = client.chat.completions.create(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    messages=[
        {"role": "system", "content": "ITS PS_1000_250_3 You are a helpful assistant."},
        {"role": "user", "content": "Solve: 5x - 3 = 17"}
    ]
)

print(response.choices[0].message.content)

Disabling Sampling Strategies

To revert to standard sampling:

# Disable token sampling (if enabled)
if hasattr(sampler, 'token_sample_name'):
    sampler.token_sample_name = None
    sampler.token_sample_fn = None

# Disable chain sampling (if enabled)
if hasattr(sampler, 'chain_sample_name'):
    sampler.chain_sample_name = None
    sampler.chain_sample_fn = None

# Now sampler.sample() will use standard autoregressive sampling
output = sampler.sample(prompt)