Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 122 additions & 8 deletions diskann-quantization/src/multi_vector/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ pub unsafe trait NewOwned<T>: ReprOwned {
#[derive(Debug, Clone, Copy)]
pub struct Defaulted;

/// Create a new [`Mat`] cloned from a view.
pub trait NewCloned: ReprOwned {
/// Clone the contents behind `v`, returning a new owning [`Mat`].
///
/// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`.
fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
}

//////////////
// Standard //
//////////////
Expand Down Expand Up @@ -312,6 +320,24 @@ impl<T: Copy> Standard<T> {
Ok(())
}
}

/// Create a new [`Mat`] around the contents of `b` **without** any checks.
///
/// # Safety
///
/// The length of `b` must be exactly [`Standard::num_elements`].
unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");

// SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw)
// the returned pointer is non-null.
let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();

// SAFETY: `ptr` is properly aligned and points to a slice of the required length.
// Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
// it from `Box::into_raw`.
unsafe { Mat::from_raw_parts(self, ptr) }
}
}

/// Error for [`Standard::new`].
Expand Down Expand Up @@ -444,15 +470,10 @@ where
{
type Error = crate::error::Infallible;
fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
let b: Box<[T]> = (0..self.nrows() * self.ncols()).map(|_| value).collect();
// SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw)
// the returned pointer is non-null.
let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();

// SAFETY: `ptr` is properly aligned and points to a slice of the required length.
// Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
// it from `Box::into_raw`.
Ok(unsafe { Mat::from_raw_parts(self, ptr) })
// SAFETY: By construction, `b` has length `self.num_elements()`.
Ok(unsafe { self.box_to_mat(b) })
}
}

Expand Down Expand Up @@ -503,6 +524,18 @@ where
}
}

impl<T> NewCloned for Standard<T>
where
T: Copy,
{
fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
let b: Box<[T]> = v.rows().flatten().copied().collect();

// SAFETY: By construction, `b` has length `v.repr().num_elements()`.
unsafe { v.repr().box_to_mat(b) }
}
}

/////////
// Mat //
/////////
Expand Down Expand Up @@ -636,6 +669,12 @@ impl<T: ReprOwned> Drop for Mat<T> {
}
}

impl<T: NewCloned> Clone for Mat<T> {
fn clone(&self) -> Self {
T::new_cloned(self.as_view())
}
}

impl<T: Copy> Mat<Standard<T>> {
/// Returns the raw dimension (columns) of the vectors in the matrix.
#[inline]
Expand Down Expand Up @@ -722,6 +761,14 @@ impl<'a, T: Repr> MatRef<'a, T> {
Rows::new(*self)
}

/// Return a [`Mat`] with the same contents as `self`.
pub fn to_owned(&self) -> Mat<T>
where
T: NewCloned,
{
T::new_cloned(*self)
}

/// Construct a new [`MatRef`] over the raw pointer and representation without performing
/// any validity checks.
///
Expand Down Expand Up @@ -893,6 +940,14 @@ impl<'a, T: ReprMut> MatMut<'a, T> {
RowsMut::new(self.reborrow_mut())
}

/// Return a [`Mat`] with the same contents as `self`.
pub fn to_owned(&self) -> Mat<T>
where
T: NewCloned,
{
T::new_cloned(self.as_view())
}

/// Construct a new [`MatMut`] over the raw pointer and representation without performing
/// any validity checks.
///
Expand Down Expand Up @@ -1435,6 +1490,65 @@ mod tests {
}
}

#[test]
fn test_mat_clone() {
for nrows in ROWS {
for ncols in COLS {
let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);

let mut mat = Mat::new(repr, Defaulted).unwrap();
fill_mat(&mut mat, repr);

// Clone via Mat::clone
{
let ctx = &lazy_format!("{ctx} - Mat::clone");
let cloned = mat.clone();

assert_eq!(cloned.num_vectors(), *nrows);
assert_eq!(cloned.vector_dim(), *ncols);

check_mat(&cloned, repr, ctx);
check_mat_ref(cloned.reborrow(), repr, ctx);
check_rows(cloned.rows(), repr, ctx);

// Cloned allocation is independent.
if repr.num_elements() > 0 {
assert_ne!(mat.as_ptr(), cloned.as_ptr());
}
}

// Clone via MatRef::to_owned
{
let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
let owned = mat.as_view().to_owned();

check_mat(&owned, repr, ctx);
check_mat_ref(owned.reborrow(), repr, ctx);
check_rows(owned.rows(), repr, ctx);

if repr.num_elements() > 0 {
assert_ne!(mat.as_ptr(), owned.as_ptr());
}
}

// Clone via MatMut::to_owned
{
let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
let owned = mat.as_view_mut().to_owned();

check_mat(&owned, repr, ctx);
check_mat_ref(owned.reborrow(), repr, ctx);
check_rows(owned.rows(), repr, ctx);

if repr.num_elements() > 0 {
assert_ne!(mat.as_ptr(), owned.as_ptr());
}
}
}
}
}

#[test]
fn test_mat_refmut() {
for nrows in ROWS {
Expand Down