泛型(Generics)
泛型(Generics)允许我们编写可以处理多种类型的代码,而不需要为每种类型重复编写相同的逻辑。这提高了代码的复用性和类型安全性。
函数中的泛型
基本泛型函数
fn largest<T: PartialOrd>(list: &[T]) -> &T {
let mut largest = &list[0];
for item in list {
if item > largest {
largest = item;
}
}
largest
}
fn main() {
let number_list = vec![34, 50, 25, 100, 65];
let result = largest(&number_list);
println!("最大的数字是 {}", result);
let char_list = vec!['y', 'm', 'a', 'q'];
let result = largest(&char_list);
println!("最大的字符是 {}", result);
}
多个泛型参数
fn compare<T, U>(x: T, y: U) -> String
where
T: std::fmt::Display,
U: std::fmt::Display,
{
format!("x = {}, y = {}", x, y)
}
fn main() {
let result = compare(5, "hello");
println!("{}", result);
}
结构体中的泛型
单个泛型参数
#[derive(Debug)]
struct Point<T> {
x: T,
y: T,
}
impl<T> Point<T> {
fn new(x: T, y: T) -> Self {
Point { x, y }
}
fn x(&self) -> &T {
&self.x
}
}
// 为特定类型实现方法
impl Point<f32> {
fn distance_from_origin(&self) -> f32 {
(self.x.powi(2) + self.y.powi(2)).sqrt()
}
}
fn main() {
let integer_point = Point::new(5, 10);
let float_point = Point::new(1.0, 4.0);
println!("integer_point: {:?}", integer_point);
println!("float_point: {:?}", float_point);
println!("distance: {}", float_point.distance_from_origin());
}
多个泛型参数
#[derive(Debug)]
struct Point<T, U> {
x: T,
y: U,
}
impl<T, U> Point<T, U> {
fn new(x: T, y: U) -> Self {
Point { x, y }
}
fn mixup<V, W>(self, other: Point<V, W>) -> Point<T, W> {
Point {
x: self.x,
y: other.y,
}
}
}
fn main() {
let p1 = Point::new(5, 10.4);
let p2 = Point::new("Hello", 'c');
let p3 = p1.mixup(p2);
println!("p3.x = {}, p3.y = {}", p3.x, p3.y);
}
枚举中的泛型
enum Option<T> {
Some(T),
None,
}
enum Result<T, E> {
Ok(T),
Err(E),
}
// 自定义泛型枚举
#[derive(Debug)]
enum Either<L, R> {
Left(L),
Right(R),
}
impl<L, R> Either<L, R> {
fn is_left(&self) -> bool {
matches!(self, Either::Left(_))
}
fn is_right(&self) -> bool {
matches!(self, Either::Right(_))
}
fn left(self) -> Option<L> {
match self {
Either::Left(l) => Some(l),
Either::Right(_) => None,
}
}
fn right(self) -> Option<R> {
match self {
Either::Left(_) => None,
Either::Right(r) => Some(r),
}
}
}
fn main() {
let left: Either<i32, String> = Either::Left(42);
let right: Either<i32, String> = Either::Right("hello".to_string());
println!("left is left: {}", left.is_left());
println!("right is right: {}", right.is_right());
}
Trait Bounds
基本 Trait Bounds
use std::fmt::Display;
fn print_and_return<T: Display>(value: T) -> T {
println!("Value: {}", value);
value
}
// 多个 trait bounds
fn compare_and_print<T: Display + PartialOrd>(x: T, y: T) {
if x >= y {
println!("{} >= {}", x, y);
} else {
println!("{} < {}", x, y);
}
}
fn main() {
let number = print_and_return(42);
let text = print_and_return("Hello");
compare_and_print(10, 20);
compare_and_print("apple", "banana");
}
where 子句
use std::fmt::Display;
fn some_function<T, U>(t: &T, u: &U) -> i32
where
T: Display + Clone,
U: Clone + std::fmt::Debug,
{
println!("t: {}", t);
println!("u: {:?}", u);
42
}
// 复杂的 where 子句
fn complex_function<T>() -> T
where
T: Default + Clone + std::fmt::Debug,
{
let value = T::default();
println!("Default value: {:?}", value);
value.clone()
}
fn main() {
let result = some_function(&"hello", &vec![1, 2, 3]);
println!("Result: {}", result);
let default_string: String = complex_function();
let default_vec: Vec<i32> = complex_function();
println!("Default string: {:?}", default_string);
println!("Default vec: {:?}", default_vec);
}
关联类型
trait Iterator {
type Item;
fn next(&mut self) -> Option<Self::Item>;
}
trait Collect<T> {
fn collect<I: Iterator<Item = T>>(iter: I) -> Self;
}
// 实现一个简单的迭代器
struct Counter {
current: usize,
max: usize,
}
impl Counter {
fn new(max: usize) -> Counter {
Counter { current: 0, max }
}
}
impl Iterator for Counter {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.current < self.max {
let current = self.current;
self.current += 1;
Some(current)
} else {
None
}
}
}
fn main() {
let mut counter = Counter::new(5);
while let Some(value) = counter.next() {
println!("Value: {}", value);
}
}
生命周期参数
// 泛型生命周期参数
fn longest<'a>(x: &'a str, y: &'a str) -> &'a str {
if x.len() > y.len() {
x
} else {
y
}
}
// 结构体中的生命周期
struct ImportantExcerpt<'a> {
part: &'a str,
}
impl<'a> ImportantExcerpt<'a> {
fn level(&self) -> i32 {
3
}
fn announce_and_return_part(&self, announcement: &str) -> &str {
println!("Attention please: {}", announcement);
self.part
}
}
fn main() {
let string1 = String::from("abcd");
let string2 = "xyz";
let result = longest(string1.as_str(), string2);
println!("The longest string is {}", result);
let novel = String::from("Call me Ishmael. Some years ago...");
let first_sentence = novel.split('.').next().expect("Could not find a '.'");
let i = ImportantExcerpt {
part: first_sentence,
};
println!("Important excerpt: {}", i.part);
}
高级泛型特性
高阶 trait bounds (HRTB)
fn call_with_one<F>(func: F) -> usize
where
F: for<'a> Fn(&'a str) -> usize,
{
func("hello")
}
fn main() {
let length_fn = |s: &str| s.len();
let result = call_with_one(length_fn);
println!("Length: {}", result);
}
泛型常量参数
// Rust 1.51+ 支持泛型常量参数
struct Array<T, const N: usize> {
data: [T; N],
}
impl<T, const N: usize> Array<T, N>
where
T: Default + Copy,
{
fn new() -> Self {
Array {
data: [T::default(); N],
}
}
fn len(&self) -> usize {
N
}
fn get(&self, index: usize) -> Option<&T> {
self.data.get(index)
}
fn set(&mut self, index: usize, value: T) -> Result<(), &'static str> {
if index < N {
self.data[index] = value;
Ok(())
} else {
Err("Index out of bounds")
}
}
}
fn main() {
let mut arr: Array<i32, 5> = Array::new();
arr.set(0, 10).unwrap();
arr.set(1, 20).unwrap();
println!("Length: {}", arr.len());
println!("arr[0]: {:?}", arr.get(0));
println!("arr[1]: {:?}", arr.get(1));
}
实际应用示例
泛型容器
use std::fmt::Debug;
#[derive(Debug)]
struct Container<T> {
items: Vec<T>,
}
impl<T> Container<T> {
fn new() -> Self {
Container { items: Vec::new() }
}
fn add(&mut self, item: T) {
self.items.push(item);
}
fn len(&self) -> usize {
self.items.len()
}
fn is_empty(&self) -> bool {
self.items.is_empty()
}
fn get(&self, index: usize) -> Option<&T> {
self.items.get(index)
}
fn iter(&self) -> std::slice::Iter<T> {
self.items.iter()
}
}
impl<T: Debug> Container<T> {
fn print_all(&self) {
for (i, item) in self.items.iter().enumerate() {
println!("Item {}: {:?}", i, item);
}
}
}
impl<T: PartialEq> Container<T> {
fn contains(&self, item: &T) -> bool {
self.items.contains(item)
}
fn remove(&mut self, item: &T) -> bool {
if let Some(pos) = self.items.iter().position(|x| x == item) {
self.items.remove(pos);
true
} else {
false
}
}
}
fn main() {
let mut numbers = Container::new();
numbers.add(1);
numbers.add(2);
numbers.add(3);
println!("Numbers container:");
numbers.print_all();
println!("Contains 2: {}", numbers.contains(&2));
let mut words = Container::new();
words.add("hello".to_string());
words.add("world".to_string());
println!("\nWords container:");
words.print_all();
}
泛型缓存
use std::collections::HashMap;
use std::hash::Hash;
trait Cache<K, V> {
fn get(&self, key: &K) -> Option<&V>;
fn insert(&mut self, key: K, value: V);
fn remove(&mut self, key: &K) -> Option<V>;
fn clear(&mut self);
}
struct SimpleCache<K, V> {
data: HashMap<K, V>,
}
impl<K, V> SimpleCache<K, V>
where
K: Eq + Hash,
{
fn new() -> Self {
SimpleCache {
data: HashMap::new(),
}
}
fn len(&self) -> usize {
self.data.len()
}
}
impl<K, V> Cache<K, V> for SimpleCache<K, V>
where
K: Eq + Hash,
{
fn get(&self, key: &K) -> Option<&V> {
self.data.get(key)
}
fn insert(&mut self, key: K, value: V) {
self.data.insert(key, value);
}
fn remove(&mut self, key: &K) -> Option<V> {
self.data.remove(key)
}
fn clear(&mut self) {
self.data.clear();
}
}
// LRU 缓存实现
use std::collections::VecDeque;
struct LRUCache<K, V> {
data: HashMap<K, V>,
order: VecDeque<K>,
capacity: usize,
}
impl<K, V> LRUCache<K, V>
where
K: Eq + Hash + Clone,
{
fn new(capacity: usize) -> Self {
LRUCache {
data: HashMap::new(),
order: VecDeque::new(),
capacity,
}
}
}
impl<K, V> Cache<K, V> for LRUCache<K, V>
where
K: Eq + Hash + Clone,
{
fn get(&self, key: &K) -> Option<&V> {
self.data.get(key)
}
fn insert(&mut self, key: K, value: V) {
if self.data.contains_key(&key) {
// 更新现有键
self.data.insert(key.clone(), value);
// 移动到前面
if let Some(pos) = self.order.iter().position(|k| k == &key) {
self.order.remove(pos);
}
self.order.push_front(key);
} else {
// 添加新键
if self.data.len() >= self.capacity {
// 移除最久未使用的
if let Some(old_key) = self.order.pop_back() {
self.data.remove(&old_key);
}
}
self.data.insert(key.clone(), value);
self.order.push_front(key);
}
}
fn remove(&mut self, key: &K) -> Option<V> {
if let Some(value) = self.data.remove(key) {
if let Some(pos) = self.order.iter().position(|k| k == key) {
self.order.remove(pos);
}
Some(value)
} else {
None
}
}
fn clear(&mut self) {
self.data.clear();
self.order.clear();
}
}
fn main() {
// 使用简单缓存
let mut simple_cache = SimpleCache::new();
simple_cache.insert("key1", "value1");
simple_cache.insert("key2", "value2");
println!("Simple cache get key1: {:?}", simple_cache.get(&"key1"));
// 使用 LRU 缓存
let mut lru_cache = LRUCache::new(2);
lru_cache.insert("a", 1);
lru_cache.insert("b", 2);
lru_cache.insert("c", 3); // 这会移除 "a"
println!("LRU cache get a: {:?}", lru_cache.get(&"a")); // None
println!("LRU cache get b: {:?}", lru_cache.get(&"b")); // Some(2)
println!("LRU cache get c: {:?}", lru_cache.get(&"c")); // Some(3)
}
性能考虑
单态化 (Monomorphization)
// 这个泛型函数
fn generic_function<T: std::fmt::Display>(x: T) {
println!("{}", x);
}
// 在编译时会被单态化为:
// fn generic_function_i32(x: i32) {
// println!("{}", x);
// }
//
// fn generic_function_str(x: &str) {
// println!("{}", x);
// }
fn main() {
generic_function(42); // 调用 i32 版本
generic_function("hello"); // 调用 &str 版本
}
最佳实践
- 合理使用泛型:不要过度泛化
- 使用有意义的类型参数名:
T
用于类型,E
用于错误,K
/V
用于键值 - 优先使用 trait bounds:而不是在函数体中检查
- 考虑使用关联类型:当类型关系明确时
- 避免过多的泛型参数:保持代码可读性
// 好的泛型设计示例
use std::fmt::Debug;
use std::hash::Hash;
trait Repository<Entity, Id> {
type Error;
fn find_by_id(&self, id: Id) -> Result<Option<Entity>, Self::Error>;
fn save(&mut self, entity: Entity) -> Result<Entity, Self::Error>;
fn delete(&mut self, id: Id) -> Result<bool, Self::Error>;
}
#[derive(Debug)]
struct User {
id: u32,
name: String,
}
#[derive(Debug)]
enum RepositoryError {
NotFound,
DatabaseError(String),
}
struct InMemoryUserRepository {
users: std::collections::HashMap<u32, User>,
next_id: u32,
}
impl InMemoryUserRepository {
fn new() -> Self {
InMemoryUserRepository {
users: std::collections::HashMap::new(),
next_id: 1,
}
}
}
impl Repository<User, u32> for InMemoryUserRepository {
type Error = RepositoryError;
fn find_by_id(&self, id: u32) -> Result<Option<User>, Self::Error> {
Ok(self.users.get(&id).cloned())
}
fn save(&mut self, mut user: User) -> Result<User, Self::Error> {
if user.id == 0 {
user.id = self.next_id;
self.next_id += 1;
}
self.users.insert(user.id, user.clone());
Ok(user)
}
fn delete(&mut self, id: u32) -> Result<bool, Self::Error> {
Ok(self.users.remove(&id).is_some())
}
}
fn main() {
let mut repo = InMemoryUserRepository::new();
let user = User {
id: 0,
name: "Alice".to_string(),
};
let saved_user = repo.save(user).unwrap();
println!("Saved user: {:?}", saved_user);
let found_user = repo.find_by_id(saved_user.id).unwrap();
println!("Found user: {:?}", found_user);
}
泛型是 Rust 中实现代码复用和类型安全的重要工具,通过合理使用泛型可以编写出既灵活又高效的代码。