跳到主要内容

单元测试(Unit Tests)

单元测试是测试代码中最小可测试单元的自动化测试。在 Rust 中,单元测试通常与被测试的代码放在同一个文件中,使用 #[cfg(test)] 模块。

基本单元测试

测试函数结构

// src/lib.rs
pub fn add(a: i32, b: i32) -> i32 {
a + b
}

pub fn subtract(a: i32, b: i32) -> i32 {
a - b
}

pub fn multiply(a: i32, b: i32) -> i32 {
a * b
}

pub fn divide(a: i32, b: i32) -> Result<i32, String> {
if b == 0 {
Err("除零错误".to_string())
} else {
Ok(a / b)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_add() {
assert_eq!(add(2, 3), 5);
assert_eq!(add(-1, 1), 0);
assert_eq!(add(0, 0), 0);
}

#[test]
fn test_subtract() {
assert_eq!(subtract(5, 3), 2);
assert_eq!(subtract(0, 5), -5);
}

#[test]
fn test_multiply() {
assert_eq!(multiply(3, 4), 12);
assert_eq!(multiply(-2, 3), -6);
assert_eq!(multiply(0, 100), 0);
}

#[test]
fn test_divide_success() {
assert_eq!(divide(10, 2), Ok(5));
assert_eq!(divide(7, 3), Ok(2));
}

#[test]
fn test_divide_by_zero() {
assert_eq!(divide(10, 0), Err("除零错误".to_string()));
}
}

断言宏

基本断言

#[cfg(test)]
mod tests {
#[test]
fn test_assertions() {
// assert! - 检查布尔值
assert!(true);
assert!(2 + 2 == 4);

// assert_eq! - 检查相等性
assert_eq!(2 + 2, 4);
assert_eq!("hello".to_uppercase(), "HELLO");

// assert_ne! - 检查不相等性
assert_ne!(2 + 2, 5);
assert_ne!("hello", "world");
}

#[test]
fn test_custom_messages() {
let x = 10;
let y = 20;

// 自定义错误消息
assert!(x < y, "x 应该小于 y,但 x={}, y={}", x, y);
assert_eq!(x + y, 30, "加法计算错误");
assert_ne!(x, y, "x 和 y 不应该相等");
}
}

浮点数比较

#[cfg(test)]
mod tests {
const EPSILON: f64 = 1e-10;

fn approx_equal(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}

#[test]
fn test_floating_point() {
let result = 0.1 + 0.2;
let expected = 0.3;

// 不要直接比较浮点数
// assert_eq!(result, expected); // 这可能失败

// 使用近似比较
assert!(approx_equal(result, expected));

// 或者使用 epsilon 比较
assert!((result - expected).abs() < EPSILON);
}
}

测试 panic

should_panic 属性

pub fn divide_panic(a: i32, b: i32) -> i32 {
if b == 0 {
panic!("不能除以零!");
}
a / b
}

pub fn access_array(arr: &[i32], index: usize) -> i32 {
if index >= arr.len() {
panic!("索引超出范围:{}", index);
}
arr[index]
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[should_panic]
fn test_divide_panic() {
divide_panic(10, 0);
}

#[test]
#[should_panic(expected = "索引超出范围")]
fn test_array_access_panic() {
let arr = [1, 2, 3];
access_array(&arr, 5);
}

#[test]
fn test_no_panic() {
let result = divide_panic(10, 2);
assert_eq!(result, 5);
}
}

使用 catch_unwind 测试 panic

use std::panic;

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_panic_with_catch_unwind() {
let result = panic::catch_unwind(|| {
divide_panic(10, 0)
});

assert!(result.is_err());
}

#[test]
fn test_panic_message() {
let result = panic::catch_unwind(|| {
panic!("自定义错误消息");
});

assert!(result.is_err());

if let Err(panic_info) = result {
if let Some(message) = panic_info.downcast_ref::<&str>() {
assert_eq!(*message, "自定义错误消息");
}
}
}
}

测试 Result 类型

测试返回 Result 的函数

use std::fs::File;
use std::io::{self, Write};

pub fn write_to_file(filename: &str, content: &str) -> io::Result<()> {
let mut file = File::create(filename)?;
file.write_all(content.as_bytes())?;
Ok(())
}

pub fn parse_number(s: &str) -> Result<i32, std::num::ParseIntError> {
s.parse()
}

#[cfg(test)]
mod tests {
use super::*;
use std::fs;

#[test]
fn test_parse_number_success() {
assert_eq!(parse_number("42"), Ok(42));
assert_eq!(parse_number("-10"), Ok(-10));
assert_eq!(parse_number("0"), Ok(0));
}

#[test]
fn test_parse_number_error() {
assert!(parse_number("not_a_number").is_err());
assert!(parse_number("").is_err());
assert!(parse_number("12.34").is_err());
}

#[test]
fn test_write_to_file() -> io::Result<()> {
let filename = "test_output.txt";
let content = "Hello, test!";

// 写入文件
write_to_file(filename, content)?;

// 验证文件内容
let read_content = fs::read_to_string(filename)?;
assert_eq!(read_content, content);

// 清理测试文件
fs::remove_file(filename)?;

Ok(())
}
}

测试私有函数

在同一模块中测试

mod calculator {
// 私有函数
fn validate_input(a: i32, b: i32) -> Result<(), String> {
if a > 1000 || b > 1000 {
Err("输入值过大".to_string())
} else {
Ok(())
}
}

pub fn safe_add(a: i32, b: i32) -> Result<i32, String> {
validate_input(a, b)?;
Ok(a + b)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_validate_input() {
assert_eq!(validate_input(10, 20), Ok(()));
assert_eq!(validate_input(1001, 20), Err("输入值过大".to_string()));
assert_eq!(validate_input(10, 1001), Err("输入值过大".to_string()));
}

#[test]
fn test_safe_add() {
assert_eq!(safe_add(10, 20), Ok(30));
assert_eq!(safe_add(1001, 20), Err("输入值过大".to_string()));
}
}
}

测试组织和命名

测试模块组织

pub struct User {
pub name: String,
pub age: u32,
pub email: String,
}

impl User {
pub fn new(name: String, age: u32, email: String) -> Result<Self, String> {
if name.is_empty() {
return Err("姓名不能为空".to_string());
}
if age > 150 {
return Err("年龄不合理".to_string());
}
if !email.contains('@') {
return Err("邮箱格式不正确".to_string());
}

Ok(User { name, age, email })
}

pub fn is_adult(&self) -> bool {
self.age >= 18
}

pub fn update_email(&mut self, new_email: String) -> Result<(), String> {
if !new_email.contains('@') {
return Err("邮箱格式不正确".to_string());
}
self.email = new_email;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

mod user_creation {
use super::*;

#[test]
fn test_valid_user_creation() {
let user = User::new(
"Alice".to_string(),
25,
"alice@example.com".to_string()
);
assert!(user.is_ok());

let user = user.unwrap();
assert_eq!(user.name, "Alice");
assert_eq!(user.age, 25);
assert_eq!(user.email, "alice@example.com");
}

#[test]
fn test_empty_name_error() {
let result = User::new(
"".to_string(),
25,
"alice@example.com".to_string()
);
assert_eq!(result, Err("姓名不能为空".to_string()));
}

#[test]
fn test_invalid_age_error() {
let result = User::new(
"Alice".to_string(),
200,
"alice@example.com".to_string()
);
assert_eq!(result, Err("年龄不合理".to_string()));
}

#[test]
fn test_invalid_email_error() {
let result = User::new(
"Alice".to_string(),
25,
"invalid_email".to_string()
);
assert_eq!(result, Err("邮箱格式不正确".to_string()));
}
}

mod user_methods {
use super::*;

#[test]
fn test_is_adult() {
let adult = User::new(
"Adult".to_string(),
25,
"adult@example.com".to_string()
).unwrap();
assert!(adult.is_adult());

let minor = User::new(
"Minor".to_string(),
16,
"minor@example.com".to_string()
).unwrap();
assert!(!minor.is_adult());
}

#[test]
fn test_update_email_success() {
let mut user = User::new(
"Alice".to_string(),
25,
"alice@example.com".to_string()
).unwrap();

let result = user.update_email("newalice@example.com".to_string());
assert!(result.is_ok());
assert_eq!(user.email, "newalice@example.com");
}

#[test]
fn test_update_email_error() {
let mut user = User::new(
"Alice".to_string(),
25,
"alice@example.com".to_string()
).unwrap();

let result = user.update_email("invalid_email".to_string());
assert_eq!(result, Err("邮箱格式不正确".to_string()));
assert_eq!(user.email, "alice@example.com"); // 邮箱应该保持不变
}
}
}

运行测试

基本测试命令

# 运行所有测试
cargo test

# 运行特定测试
cargo test test_add

# 运行匹配模式的测试
cargo test user_creation

# 显示测试输出
cargo test -- --nocapture

# 运行被忽略的测试
cargo test -- --ignored

# 运行单个测试线程(用于调试)
cargo test -- --test-threads=1

# 显示测试统计信息
cargo test -- --show-output

测试配置

// 忽略耗时的测试
#[test]
#[ignore]
fn expensive_test() {
// 这个测试默认不会运行
std::thread::sleep(std::time::Duration::from_secs(1));
assert!(true);
}

// 只在特定条件下运行的测试
#[test]
#[cfg(target_os = "linux")]
fn linux_specific_test() {
// 只在 Linux 上运行
assert!(true);
}

#[test]
#[cfg(feature = "expensive_tests")]
fn feature_gated_test() {
// 只在启用 expensive_tests 特性时运行
// cargo test --features expensive_tests
assert!(true);
}

单元测试是确保代码质量的基础,应该覆盖所有的公共 API 和关键的私有函数。下一节我们将学习集成测试。