diff --git a/diskann-quantization/src/multi_vector/matrix.rs b/diskann-quantization/src/multi_vector/matrix.rs index 7a755c16b..e8789b2c8 100644 --- a/diskann-quantization/src/multi_vector/matrix.rs +++ b/diskann-quantization/src/multi_vector/matrix.rs @@ -242,6 +242,14 @@ pub unsafe trait NewOwned: ReprOwned { #[derive(Debug, Clone, Copy)] pub struct Defaulted; +/// Create a new [`Mat`] cloned from a view. +pub trait NewCloned: ReprOwned { + /// Clone the contents behind `v`, returning a new owning [`Mat`]. + /// + /// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`. + fn new_cloned(v: MatRef<'_, Self>) -> Mat; +} + ////////////// // Standard // ////////////// @@ -312,6 +320,24 @@ impl Standard { Ok(()) } } + + /// Create a new [`Mat`] around the contents of `b` **without** any checks. + /// + /// # Safety + /// + /// The length of `b` must be exactly [`Standard::num_elements`]. + unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat { + debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated"); + + // SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw) + // the returned pointer is non-null. + let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::(); + + // SAFETY: `ptr` is properly aligned and points to a slice of the required length. + // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining + // it from `Box::into_raw`. + unsafe { Mat::from_raw_parts(self, ptr) } + } } /// Error for [`Standard::new`]. @@ -444,15 +470,10 @@ where { type Error = crate::error::Infallible; fn new_owned(self, value: T) -> Result, Self::Error> { - let b: Box<[T]> = (0..self.nrows() * self.ncols()).map(|_| value).collect(); - // SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw) - // the returned pointer is non-null. - let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::(); + let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect(); - // SAFETY: `ptr` is properly aligned and points to a slice of the required length. - // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining - // it from `Box::into_raw`. - Ok(unsafe { Mat::from_raw_parts(self, ptr) }) + // SAFETY: By construction, `b` has length `self.num_elements()`. + Ok(unsafe { self.box_to_mat(b) }) } } @@ -503,6 +524,18 @@ where } } +impl NewCloned for Standard +where + T: Copy, +{ + fn new_cloned(v: MatRef<'_, Self>) -> Mat { + let b: Box<[T]> = v.rows().flatten().copied().collect(); + + // SAFETY: By construction, `b` has length `v.repr().num_elements()`. + unsafe { v.repr().box_to_mat(b) } + } +} + ///////// // Mat // ///////// @@ -636,6 +669,12 @@ impl Drop for Mat { } } +impl Clone for Mat { + fn clone(&self) -> Self { + T::new_cloned(self.as_view()) + } +} + impl Mat> { /// Returns the raw dimension (columns) of the vectors in the matrix. #[inline] @@ -722,6 +761,14 @@ impl<'a, T: Repr> MatRef<'a, T> { Rows::new(*self) } + /// Return a [`Mat`] with the same contents as `self`. + pub fn to_owned(&self) -> Mat + where + T: NewCloned, + { + T::new_cloned(*self) + } + /// Construct a new [`MatRef`] over the raw pointer and representation without performing /// any validity checks. /// @@ -893,6 +940,14 @@ impl<'a, T: ReprMut> MatMut<'a, T> { RowsMut::new(self.reborrow_mut()) } + /// Return a [`Mat`] with the same contents as `self`. + pub fn to_owned(&self) -> Mat + where + T: NewCloned, + { + T::new_cloned(self.as_view()) + } + /// Construct a new [`MatMut`] over the raw pointer and representation without performing /// any validity checks. /// @@ -1435,6 +1490,65 @@ mod tests { } } + #[test] + fn test_mat_clone() { + for nrows in ROWS { + for ncols in COLS { + let repr = Standard::::new(*nrows, *ncols).unwrap(); + let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols); + + let mut mat = Mat::new(repr, Defaulted).unwrap(); + fill_mat(&mut mat, repr); + + // Clone via Mat::clone + { + let ctx = &lazy_format!("{ctx} - Mat::clone"); + let cloned = mat.clone(); + + assert_eq!(cloned.num_vectors(), *nrows); + assert_eq!(cloned.vector_dim(), *ncols); + + check_mat(&cloned, repr, ctx); + check_mat_ref(cloned.reborrow(), repr, ctx); + check_rows(cloned.rows(), repr, ctx); + + // Cloned allocation is independent. + if repr.num_elements() > 0 { + assert_ne!(mat.as_ptr(), cloned.as_ptr()); + } + } + + // Clone via MatRef::to_owned + { + let ctx = &lazy_format!("{ctx} - MatRef::to_owned"); + let owned = mat.as_view().to_owned(); + + check_mat(&owned, repr, ctx); + check_mat_ref(owned.reborrow(), repr, ctx); + check_rows(owned.rows(), repr, ctx); + + if repr.num_elements() > 0 { + assert_ne!(mat.as_ptr(), owned.as_ptr()); + } + } + + // Clone via MatMut::to_owned + { + let ctx = &lazy_format!("{ctx} - MatMut::to_owned"); + let owned = mat.as_view_mut().to_owned(); + + check_mat(&owned, repr, ctx); + check_mat_ref(owned.reborrow(), repr, ctx); + check_rows(owned.rows(), repr, ctx); + + if repr.num_elements() > 0 { + assert_ne!(mat.as_ptr(), owned.as_ptr()); + } + } + } + } + } + #[test] fn test_mat_refmut() { for nrows in ROWS {