Skip to content

Test Time Coding

decode(system_string: str) -> tuple[Optional[Sequential_Monte_Carlo], Optional[Power_Sampling]]

Decode test-time scaling parameters from a system prompt string.

Parameters:

Name Type Description Default
system_string str

The encoded string containing test-time scaling parameters. Format: "ITS__"

required

Returns:

Type Description
Optional[Sequential_Monte_Carlo]

A tuple of (chain_sampling, token_sampling) where each element is either

Optional[Power_Sampling]

a parameter object or None if that sampling technique was not specified.

Raises:

Type Description
ValueError

If the format is invalid or parameter values are non-numeric.

Source code in pita/api/test_time_coding.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def decode(system_string: str) -> tuple[Optional[Sequential_Monte_Carlo], Optional[Power_Sampling]]:
    """
    Decode test-time scaling parameters from a system prompt string.

    Args:
        system_string: The encoded string containing test-time scaling parameters.
            Format: "ITS_<chain>_<chain_params...>_<token>_<token_params...>"

    Returns:
        A tuple of (chain_sampling, token_sampling) where each element is either
        a parameter object or None if that sampling technique was not specified.

    Raises:
        ValueError: If the format is invalid or parameter values are non-numeric.
    """
    # Split the string into parts (only first token before space)
    parts = system_string.split(" ")[0].split("_")

    if len(parts) < 2 or parts[0] != "ITS":
        raise ValueError("Invalid system string format. Must start with 'ITS'.")

    chain_sampling = None
    token_sampling = None

    i = 1  # Start after "ITS"

    # Parse chain sampling method
    if parts[i] == "NONE":
        i += 1
    elif parts[i] == "SMC":
        # SMC requires 3 params: num_particles, tokens_per_step, stop_on_eos
        if i + 3 >= len(parts):
            raise ValueError("SMC requires 3 parameters: num_particles, tokens_per_step, stop_on_eos")
        if not all(parts[i+j].isdigit() for j in range(1, 4)):
            raise ValueError(f"Invalid SMC parameters: expected 3 integers after 'SMC'")
        chain_sampling = Sequential_Monte_Carlo(
            num_particles=int(parts[i+1]),
            tokens_per_step=int(parts[i+2]),
            stop_on_eos=bool(int(parts[i+3]))
        )
        i += 4
    else:
        raise ValueError(f"Unknown chain sampling method: '{parts[i]}'. Expected 'SMC' or 'NONE'.")

    # Parse token sampling method
    if i >= len(parts):
        raise ValueError("Missing token sampling specification. Expected 'PS' or 'NONE'.")

    if parts[i] == "NONE":
        pass  # token_sampling stays None
    elif parts[i] == "PS":
        # PS requires 2 params: block_size, MCMC_steps
        if i + 2 >= len(parts):
            raise ValueError("PS requires 2 parameters: block_size, MCMC_steps")
        if not all(parts[i+j].isdigit() for j in range(1, 3)):
            raise ValueError(f"Invalid PS parameters: expected 2 integers after 'PS'")
        token_sampling = Power_Sampling(
            block_size=int(parts[i+1]),
            MCMC_steps=int(parts[i+2])
        )
    else:
        raise ValueError(f"Unknown token sampling method: '{parts[i]}'. Expected 'PS' or 'NONE'.")

    return chain_sampling, token_sampling

encode(chain_sampling: Optional[Sequential_Monte_Carlo] = None, token_sampling: Optional[Power_Sampling] = None) -> str

Encode test-time scaling parameters into a string for embedding in system prompts.

Parameters:

Name Type Description Default
chain_sampling Optional[Sequential_Monte_Carlo]

Chain sampling configuration (SMC).

None
token_sampling Optional[Power_Sampling]

Token sampling configuration (Power Sampling).

None

Returns:

Type Description
str

A formatted string: "ITS__"

str

Returns empty string if no parameters are provided.

Source code in pita/api/test_time_coding.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def encode(
    chain_sampling: Optional[Sequential_Monte_Carlo] = None,
    token_sampling: Optional[Power_Sampling] = None
) -> str:
    """
    Encode test-time scaling parameters into a string for embedding in system prompts.

    Args:
        chain_sampling: Chain sampling configuration (SMC).
        token_sampling: Token sampling configuration (Power Sampling).

    Returns:
        A formatted string: "ITS_<chain>_<chain_params>_<token>_<token_params>"
        Returns empty string if no parameters are provided.
    """
    if chain_sampling is None and token_sampling is None:
        return ""

    parts = ["ITS"]

    # Encode chain sampling method
    if chain_sampling is None:
        parts.append("NONE")
    elif isinstance(chain_sampling, Sequential_Monte_Carlo):
        parts.extend(["SMC", str(chain_sampling.num_particles), 
                      str(chain_sampling.tokens_per_step), 
                      str(int(chain_sampling.stop_on_eos))])
    else:
        raise ValueError(f"Unknown chain sampling type: {type(chain_sampling)}")

    # Encode token sampling method
    if token_sampling is None:
        parts.append("NONE")
    elif isinstance(token_sampling, Power_Sampling):
        parts.extend(["PS", str(token_sampling.block_size), 
                      str(token_sampling.MCMC_steps)])
    else:
        raise ValueError(f"Unknown token sampling type: {type(token_sampling)}")

    return "_".join(parts)