Unit Testing with pytest
Core Philosophy
Unit testing validates that individual pieces of code work correctly in isolation. The key insight: code that is easy to test is usually well-designed code. Writing tests forces you to think about function interfaces, inputs, outputs, and edge cases.
Why Test?
- •Catch bugs early - Before they compound into harder problems
- •Enable safe refactoring - Change code confidently knowing tests verify behavior
- •Document expected behavior - Tests serve as executable specifications
- •Improve code design - Testable code tends to be modular and well-structured
Pure Functions: The Foundation of Testable Code
Pure functions are the easiest to test because they:
- •Always return the same output for the same input
- •Have no side effects (don't modify external state)
- •Don't depend on external state (files, databases, time, random)
# ✅ Pure function - easy to test
def calculate_tax(income, rate):
return income * rate
# ❌ Impure function - hard to test
def calculate_tax_and_log(income, rate):
result = income * rate
with open("log.txt", "a") as f: # Side effect
f.write(f"{result}\n")
return result
Design principle: Extract the pure logic from impure operations. Test the pure part thoroughly.
pytest Mechanics
File and Function Discovery
pytest automatically discovers tests using these conventions:
| Pattern | What pytest finds |
|---|---|
test_*.py or *_test.py | Test files |
test_* functions | Test functions |
Test* classes | Test classes |
project/
├── src/
│ └── calculations.py
└── tests/
├── test_calculations.py # ✅ discovered
└── helper_utils.py # ❌ not discovered (no test_ prefix)
Running Tests
# Run all tests pixi run pytest # Verbose output (shows each test name) pixi run pytest -v # Run specific file pixi run pytest tests/test_calculations.py # Run specific test function pixi run pytest tests/test_calculations.py::test_add_positive_numbers # Stop on first failure pixi run pytest -x # Run with debugger on failure pixi run pytest --pdbp
Writing Tests
Basic Assertions
pytest uses plain assert statements. On failure, pytest shows detailed comparison:
def test_add_numbers():
result = add(2, 3)
assert result == 5
def test_list_contains_item():
items = get_items()
assert "apple" in items
assert len(items) == 3
def test_approximate_equality():
# For floats, use pytest.approx
assert calculate_pi() == pytest.approx(3.14159, rel=1e-5)
Testing DataFrames (pandas)
import pandas as pd
import pandas.testing as tm
def test_clean_data():
raw = pd.DataFrame({"value": ["-99", "10", "-77"]})
result = clean_missing_codes(raw)
expected = pd.DataFrame({"value": [pd.NA, 10, pd.NA]})
tm.assert_frame_equal(result, expected)
def test_column_types():
df = process_data(input_df)
assert df["category"].dtype == pd.CategoricalDtype()
Testing NumPy Arrays
import numpy as np
import numpy.testing as npt
def test_array_calculation():
result = normalize(np.array([1, 2, 3]))
expected = np.array([0.0, 0.5, 1.0])
npt.assert_array_almost_equal(result, expected)
Testing Exceptions
Use pytest.raises to verify code raises expected errors:
import pytest
def test_divide_by_zero_raises():
with pytest.raises(ZeroDivisionError):
divide(10, 0)
def test_invalid_input_message():
# Also verify the error message
with pytest.raises(ValueError, match="must be positive"):
calculate_sqrt(-5)
def test_type_error_for_wrong_input():
with pytest.raises(TypeError) as exc_info:
process_data("not a list")
assert "must be a list" in str(exc_info.value)
What to Test
Test Categories
- •Normal cases - Expected inputs produce expected outputs
- •Edge cases - Boundary conditions, empty inputs, single elements
- •Error cases - Invalid inputs raise appropriate exceptions
Example: Testing a Data Cleaning Function
def clean_agreement_scale(sr):
"""Convert survey responses to ordered categorical."""
sr = sr.replace({"-77": pd.NA, "-99": pd.NA})
categories = ["strongly disagree", "disagree", "neutral", "agree", "strongly agree"]
return sr.astype(pd.CategoricalDtype(categories=categories, ordered=True))
Tests to write:
def test_clean_agreement_normal_values():
"""Normal case: valid responses are preserved."""
sr = pd.Series(["agree", "disagree"])
result = clean_agreement_scale(sr)
assert list(result) == ["agree", "disagree"]
def test_clean_agreement_missing_codes():
"""Edge case: missing codes become NA."""
sr = pd.Series(["-77", "-99", "agree"])
result = clean_agreement_scale(sr)
assert pd.isna(result.iloc[0])
assert pd.isna(result.iloc[1])
assert result.iloc[2] == "agree"
def test_clean_agreement_returns_ordered_categorical():
"""Output type: categorical with correct ordering."""
sr = pd.Series(["agree"])
result = clean_agreement_scale(sr)
assert result.dtype.ordered
assert result.dtype.categories.tolist() == [
"strongly disagree", "disagree", "neutral", "agree", "strongly agree"
]
def test_clean_agreement_empty_series():
"""Edge case: empty input returns empty categorical."""
sr = pd.Series([], dtype=str)
result = clean_agreement_scale(sr)
assert len(result) == 0
Reusing Test Code
Fixtures: Shared Setup
Fixtures provide reusable test data and setup. Defined with @pytest.fixture:
import pytest
@pytest.fixture
def sample_survey_data():
"""Survey data with various response patterns."""
return pd.DataFrame({
"q1": ["agree", "-77", "disagree"],
"q2": ["-99", "strongly agree", "neutral"]
})
@pytest.fixture
def empty_dataframe():
return pd.DataFrame()
# Tests automatically receive fixtures by parameter name
def test_clean_survey(sample_survey_data):
result = clean_survey(sample_survey_data)
assert result.shape == sample_survey_data.shape
def test_handle_empty(empty_dataframe):
result = process(empty_dataframe)
assert len(result) == 0
Parametrization: Multiple Test Cases
Test the same logic with different inputs using @pytest.mark.parametrize:
import pytest
@pytest.mark.parametrize("input_val,expected", [
(0, 0),
(1, 1),
(5, 120), # 5! = 120
(10, 3628800), # 10! = 3628800
])
def test_factorial(input_val, expected):
assert factorial(input_val) == expected
@pytest.mark.parametrize("invalid_input", [-1, -5, -100])
def test_factorial_negative_raises(invalid_input):
with pytest.raises(ValueError, match="must be non-negative"):
factorial(invalid_input)
Combining Fixtures and Parametrization
@pytest.fixture
def calculator():
return Calculator()
@pytest.mark.parametrize("a,b,expected", [
(1, 2, 3),
(0, 0, 0),
(-1, 1, 0),
])
def test_calculator_add(calculator, a, b, expected):
assert calculator.add(a, b) == expected
Test Organization
File Structure
project/
├── src/
│ ├── data_cleaning.py
│ └── analysis.py
└── tests/
├── conftest.py # Shared fixtures
├── test_data_cleaning.py
└── test_analysis.py
conftest.py: Shared Fixtures
Place fixtures used across multiple test files in conftest.py:
# tests/conftest.py
import pytest
import pandas as pd
@pytest.fixture
def sample_data():
return pd.read_csv("tests/data/sample.csv")
@pytest.fixture
def database_connection():
conn = create_connection()
yield conn # Test runs here
conn.close() # Cleanup after test
Debugging Test Failures
Understanding pytest Output
FAILED tests/test_calc.py::test_add - AssertionError: assert 4 == 5
pytest shows:
- •Which test failed
- •The assertion that failed
- •Actual vs expected values
Using the Debugger
# Drop into debugger on failure pytest --pdb # Drop into debugger at test start pytest --pdb --trace
Inspecting Failures
def test_data_processing():
result = complex_processing(data)
# Add debug prints (shown on failure)
print(f"Result shape: {result.shape}")
print(f"Result columns: {result.columns.tolist()}")
assert result.shape == (100, 5)
Error Handling Design
Good error handling makes code easier to test and debug.
Pattern: Fail Early with Clear Messages
Validate inputs at function entry, before any processing. This makes debugging easier and prevents partial work.
- •Identify risky inputs - Those from users or not validated elsewhere
- •List failure modes - Start easy (wrong type) → specific (wrong structure)
- •Write
_fail_if_...functions - One per condition - •Call validators early - Before any processing
- •Test error messages - Ensure they're helpful
def create_markdown_table(data):
"""Create markdown table from list of dicts or dict of lists."""
_fail_if_neither_dict_nor_list(data)
if isinstance(data, dict):
_fail_if_dict_of_wrong_types(data)
_fail_if_dict_of_lists_with_different_lengths(data)
data = convert_dol_to_lod(data)
else:
_fail_if_list_of_wrong_types(data)
_fail_if_list_of_dicts_with_different_keys(data)
return _format_table(data)
def _fail_if_neither_dict_nor_list(data):
if not isinstance(data, (list, dict)):
raise TypeError(
f"data must be a list of dicts or dict of lists. Got {type(data)}"
)
def _fail_if_list_of_wrong_types(data):
invalid_rows = [i for i, row in enumerate(data) if not isinstance(row, dict)]
if invalid_rows:
report = "The following rows are not dictionaries:\n"
for i in invalid_rows:
report += f" Row {i} has type {type(data[i])}\n"
raise TypeError(report)
Testing Error Handling
def test_create_table_rejects_string():
with pytest.raises(TypeError, match="must be a list of dicts"):
create_markdown_table("not valid")
def test_create_table_reports_invalid_rows():
with pytest.raises(TypeError) as exc_info:
create_markdown_table([{"a": 1}, "invalid", {"a": 2}])
assert "Row 1" in str(exc_info.value)
Best Practices Summary
- •Test pure functions - Extract pure logic, test it thoroughly
- •One concept per test - Each test verifies one behavior
- •Descriptive names -
test_factorial_negative_raisesnottest_1 - •Test edge cases - Empty inputs, boundaries, single elements
- •Test error conditions - Verify exceptions and messages
- •Use fixtures - Share setup code, keep tests DRY
- •Use parametrization - Test multiple inputs without duplication
- •Assert specific expectations - Not just "no error", but correct values
- •Keep tests fast - Slow tests get run less often
- •Tests document behavior - Write tests that explain what code does
Quick Reference
# Basic test
def test_something():
assert function() == expected
# Testing exceptions
with pytest.raises(ValueError, match="pattern"):
bad_function()
# Fixture
@pytest.fixture
def data():
return setup_data()
# Parametrization
@pytest.mark.parametrize("input,expected", [(1, 2), (3, 4)])
def test_func(input, expected):
assert func(input) == expected
# DataFrame comparison
pd.testing.assert_frame_equal(result, expected)
# Array comparison
np.testing.assert_array_almost_equal(result, expected)
# Float comparison
assert result == pytest.approx(expected, rel=1e-5)