Skip to content

Math Grader

Answer checker API that uses sympy to simplify expressions and check for equality.

Call grade_answer(given_answer: str, ground_truth: str).

are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str) -> bool

Check if two expressions are mathematically equivalent using sympy.

Subtracts the two expressions and simplifies the result. If the simplified difference equals zero, the expressions are equivalent.

Parameters:

Name Type Description Default
ground_truth_normalized str

The normalized ground truth expression.

required
given_normalized str

The normalized given expression to check.

required

Returns:

Type Description
bool

True if the expressions are mathematically equivalent, False otherwise.

Source code in pita/utils/grading_utils/math/math_grader.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str) -> bool:
    """
    Check if two expressions are mathematically equivalent using sympy.

    Subtracts the two expressions and simplifies the result. If the
    simplified difference equals zero, the expressions are equivalent.

    Args:
        ground_truth_normalized: The normalized ground truth expression.
        given_normalized: The normalized given expression to check.

    Returns:
        True if the expressions are mathematically equivalent, False otherwise.
    """
    are_equal = False
    try:
        expr = f"({ground_truth_normalized})-({given_normalized})"
        if should_allow_eval(expr):
            sympy_diff = _sympy_parse(expr)
            simplified = sympy.simplify(sympy_diff)
            if simplified == 0:
                are_equal = True
    except:
        pass
    return are_equal

count_unknown_letters_in_expr(expr: str) -> int

Count the number of unknown letters in an expression.

Removes known mathematical function names (sqrt, frac) and counts the remaining alphabetic characters.

Parameters:

Name Type Description Default
expr str

A mathematical expression string.

required

Returns:

Type Description
int

The number of distinct unknown alphabetic characters in the expression.

Source code in pita/utils/grading_utils/math/math_grader.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def count_unknown_letters_in_expr(expr: str) -> int:
    """
    Count the number of unknown letters in an expression.

    Removes known mathematical function names (sqrt, frac) and counts
    the remaining alphabetic characters.

    Args:
        expr: A mathematical expression string.

    Returns:
        The number of distinct unknown alphabetic characters in the expression.
    """
    expr = expr.replace("sqrt", "")
    expr = expr.replace("frac", "")
    letters_in_expr = set([x for x in expr if x.isalpha()])
    return len(letters_in_expr)

grade_answer(given_answer: str, ground_truth: str) -> bool

Grade a given answer against the ground truth.

The answer will be considered correct if: (a) it normalizes to the same string as the ground truth answer OR (b) sympy can simplify the difference between the expressions to 0

Special handling for: - Tuples and intervals (must match structure and order) - Fractions (must be in reduced form) - Integers (must be exact matches, no decimal equivalents)

Parameters:

Name Type Description Default
given_answer str

The answer to grade.

required
ground_truth str

The correct answer to compare against.

required

Returns:

Type Description
bool

True if the answer is correct, False otherwise.

Source code in pita/utils/grading_utils/math/math_grader.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def grade_answer(given_answer: str, ground_truth: str) -> bool:
    """
    Grade a given answer against the ground truth.

    The answer will be considered correct if:
    (a) it normalizes to the same string as the ground truth answer
    OR
    (b) sympy can simplify the difference between the expressions to 0

    Special handling for:
    - Tuples and intervals (must match structure and order)
    - Fractions (must be in reduced form)
    - Integers (must be exact matches, no decimal equivalents)

    Args:
        given_answer: The answer to grade.
        ground_truth: The correct answer to compare against.

    Returns:
        True if the answer is correct, False otherwise.
    """
    if given_answer is None:
        return False

    ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
    given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)

    # be at least as lenient as mathd
    if ground_truth_normalized_mathd == given_answer_normalized_mathd:
        return True

    ground_truth_normalized = _normalize(ground_truth)
    given_normalized = _normalize(given_answer)

    if ground_truth_normalized is None:
        return False

    if ground_truth_normalized == given_normalized:
        return True

    if len(given_normalized) == 0:
        return False

    ground_truth_elems = split_tuple(ground_truth_normalized)
    given_elems = split_tuple(given_normalized)

    if len(ground_truth_elems) > 1 and (
        ground_truth_normalized[0] != given_normalized[0]
        or ground_truth_normalized[-1] != given_normalized[-1]
    ):
        is_correct = False
    elif len(ground_truth_elems) != len(given_elems):
        is_correct = False
    else:
        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
            if _is_frac(ground_truth_elem) and _is_frac(given_elem):
                # if fractions aren't reduced, then shouldn't be marked as correct
                # so, we don't want to allow sympy.simplify in this case
                is_correct = ground_truth_elem == given_elem
            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
                # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
                is_correct = False
            else:
                is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
            if not is_correct:
                break

    return is_correct

should_allow_eval(expr: str) -> bool

Determine if an expression is safe to evaluate with sympy.

Checks for potentially problematic patterns that might cause sympy to hang or fail, including too many unknown variables and known bad patterns.

Parameters:

Name Type Description Default
expr str

A mathematical expression string.

required

Returns:

Type Description
bool

True if the expression is safe to evaluate, False otherwise.

Source code in pita/utils/grading_utils/math/math_grader.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def should_allow_eval(expr: str) -> bool:
    """
    Determine if an expression is safe to evaluate with sympy.

    Checks for potentially problematic patterns that might cause sympy to hang
    or fail, including too many unknown variables and known bad patterns.

    Args:
        expr: A mathematical expression string.

    Returns:
        True if the expression is safe to evaluate, False otherwise.
    """
    # we don't want to try parsing unknown text or functions of more than two variables
    if count_unknown_letters_in_expr(expr) > 2:
        return False

    for bad_string in BAD_SUBSTRINGS:
        if bad_string in expr:
            return False

    for bad_regex in BAD_REGEXES:
        if re.search(bad_regex, expr) is not None:
            return False

    return True

split_tuple(expr: str) -> List[str]

Split the elements in a tuple or interval.

Handles well-formatted commas in large numbers while splitting tuple elements. Recognizes tuples by their bracketing characters.

Parameters:

Name Type Description Default
expr str

A string representing a tuple, interval, or single value.

required

Returns:

Type Description
List[str]

A list of string elements. If the expression is a tuple/interval,

List[str]

returns the individual elements; otherwise returns a list containing

List[str]

the original expression.

Source code in pita/utils/grading_utils/math/math_grader.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def split_tuple(expr: str) -> List[str]:
    """
    Split the elements in a tuple or interval.

    Handles well-formatted commas in large numbers while splitting tuple
    elements. Recognizes tuples by their bracketing characters.

    Args:
        expr: A string representing a tuple, interval, or single value.

    Returns:
        A list of string elements. If the expression is a tuple/interval,
        returns the individual elements; otherwise returns a list containing
        the original expression.
    """
    expr = _strip_properly_formatted_commas(expr)
    if len(expr) == 0:
        return []
    if (
        len(expr) > 2
        and expr[0] in TUPLE_CHARS
        and expr[-1] in TUPLE_CHARS
        and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
    ):
        elems = [elem.strip() for elem in expr[1:-1].split(",")]
    else:
        elems = [expr]
    return elems