Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ diskann-tools = { path = "diskann-tools", version = "0.45.0" }
anyhow = "1.0.98"
approx = "0.5.1"
arc-swap = "1.7.1"
bf-tree = "0.4.7"
bf-tree = "0.4.8"
bincode = "1.3.3"
bit-set = "0.8.0"
bytemuck = "1.23.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ impl<I: VectorId> NeighborProvider<I> {
}
}

/// Access the BfTree config
pub(crate) fn config(&self) -> &Config {
self.adjacency_list_index.config()
}

/// Create a snapshot of the adjacency list index
///
pub fn snapshot(&self) {
Expand Down
168 changes: 99 additions & 69 deletions diskann-providers/src/model/graph/provider/async_/bf_tree/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1778,11 +1778,29 @@ where
}
}

#[derive(Serialize, Deserialize, Clone)]
pub struct BfTreeParams {
pub bytes: usize,
pub max_record_size: usize,
pub leaf_page_size: usize,
}

impl BfTreeParams {
/// Build a BfTree Config from the saved parameters and a file path.
pub fn to_config(&self, path: &std::path::Path) -> Config {
let mut config = Config::new(path, self.bytes);
config.cb_max_record_size(self.max_record_size);
config.leaf_page_size(self.leaf_page_size);
config.storage_backend(bf_tree::StorageBackend::Std);
config
}
}

#[derive(Serialize, Deserialize, Clone)]
pub struct QuantParams {
pub num_pq_bytes: usize,
pub max_fp_vecs_per_fill: usize,
pub bytes_quant: usize,
pub params_quant: BfTreeParams,
}

#[derive(Serialize, Deserialize, Clone)]
Expand All @@ -1793,8 +1811,8 @@ pub struct SavedParams {
pub metric: String,
pub max_degree: u32,
pub prefix: String,
pub bytes_vector: usize,
pub bytes_neighbor: usize,
pub params_vector: BfTreeParams,
pub params_neighbor: BfTreeParams,
pub quant_params: Option<QuantParams>,
}

Expand Down Expand Up @@ -1836,25 +1854,42 @@ impl BfTreePaths {

// SaveWith/LoadWith for BfTreeProvider with TableDeleteProviderAsync

impl<T> SaveWith<SavedParams> for BfTreeProvider<T, NoStore, TableDeleteProviderAsync>
impl<T> SaveWith<String> for BfTreeProvider<T, NoStore, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Ok = usize;
type Error = ANNError;

async fn save_with<P>(
&self,
storage: &P,
saved_params: &SavedParams,
) -> Result<Self::Ok, Self::Error>
async fn save_with<P>(&self, storage: &P, prefix: &String) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let saved_params = SavedParams {
max_points: self.max_points(),
frozen_points: NonZeroUsize::new(self.num_start_points())
.ok_or_else(|| ANNError::log_index_error("num_start_points is zero"))?,
dim: self.dim(),
metric: self.metric().as_str().to_string(),
max_degree: self.max_degree(),
prefix: prefix.clone(),
params_vector: BfTreeParams {
bytes: self.full_vectors.config().get_cb_size_byte(),
max_record_size: self.full_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.full_vectors.config().get_leaf_page_size(),
},
params_neighbor: BfTreeParams {
bytes: self.neighbor_provider.config().get_cb_size_byte(),
max_record_size: self.neighbor_provider.config().get_cb_max_record_size(),
leaf_page_size: self.neighbor_provider.config().get_leaf_page_size(),
},
quant_params: None, // No quantization parameters
};

// Save only essential parameters as JSON
{
let params_filename = BfTreePaths::params_json(&saved_params.prefix);
let params_json = serde_json::to_string(saved_params).map_err(|e| {
let params_json = serde_json::to_string(&saved_params).map_err(|e| {
ANNError::log_index_error(format!("Failed to serialize params: {}", e))
})?;
let mut params_writer = storage.create_for_write(&params_filename)?;
Expand Down Expand Up @@ -1902,13 +1937,12 @@ where
let metric = Metric::from_str(&saved_params.metric)
.map_err(|e| ANNError::log_index_error(format!("Failed to parse metric: {}", e)))?;

let vector_path = BfTreePaths::vectors_bftree(&saved_params.prefix);
let mut vector_config = Config::new(&vector_path, saved_params.bytes_vector);
vector_config.storage_backend(bf_tree::StorageBackend::Std);

let neighbor_path = BfTreePaths::neighbors_bftree(&saved_params.prefix);
let mut neighbor_config = Config::new(&neighbor_path, saved_params.bytes_neighbor);
neighbor_config.storage_backend(bf_tree::StorageBackend::Std);
let vector_config = saved_params
.params_vector
.to_config(&BfTreePaths::vectors_bftree(&saved_params.prefix));
let neighbor_config = saved_params
.params_neighbor
.to_config(&BfTreePaths::neighbors_bftree(&saved_params.prefix));

let vector_index =
BfTree::new_from_snapshot(vector_config.clone(), None).map_err(super::ConfigError)?;
Expand Down Expand Up @@ -1950,25 +1984,50 @@ where
}
}

impl<T> SaveWith<SavedParams> for BfTreeProvider<T, QuantVectorProvider, TableDeleteProviderAsync>
impl<T> SaveWith<String> for BfTreeProvider<T, QuantVectorProvider, TableDeleteProviderAsync>
where
T: VectorRepr,
{
type Ok = usize;
type Error = ANNError;

async fn save_with<P>(
&self,
storage: &P,
saved_params: &SavedParams,
) -> Result<Self::Ok, Self::Error>
async fn save_with<P>(&self, storage: &P, prefix: &String) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
let saved_params = SavedParams {
max_points: self.max_points(),
frozen_points: NonZeroUsize::new(self.num_start_points())
.ok_or_else(|| ANNError::log_index_error("num_start_points is zero"))?,
dim: self.dim(),
metric: self.metric().as_str().to_string(),
max_degree: self.max_degree(),
prefix: prefix.clone(),
params_vector: BfTreeParams {
bytes: self.full_vectors.config().get_cb_size_byte(),
max_record_size: self.full_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.full_vectors.config().get_leaf_page_size(),
},
params_neighbor: BfTreeParams {
bytes: self.neighbor_provider.config().get_cb_size_byte(),
max_record_size: self.neighbor_provider.config().get_cb_max_record_size(),
leaf_page_size: self.neighbor_provider.config().get_leaf_page_size(),
},
quant_params: Some(QuantParams {
num_pq_bytes: self.quant_vectors.pq_chunks(),
max_fp_vecs_per_fill: self.max_fp_vecs_per_fill,
params_quant: BfTreeParams {
bytes: self.quant_vectors.config().get_cb_size_byte(),
max_record_size: self.quant_vectors.config().get_cb_max_record_size(),
leaf_page_size: self.quant_vectors.config().get_leaf_page_size(),
},
}),
};

// Save only essential parameters as JSON
{
let params_filename = BfTreePaths::params_json(&saved_params.prefix);
let params_json = serde_json::to_string(saved_params).map_err(|e| {
let params_json = serde_json::to_string(&saved_params).map_err(|e| {
ANNError::log_index_error(format!("Failed to serialize params: {}", e))
})?;
let mut params_writer = storage.create_for_write(&params_filename)?;
Expand Down Expand Up @@ -2035,17 +2094,15 @@ where
let metric = Metric::from_str(&saved_params.metric)
.map_err(|e| ANNError::log_index_error(format!("Failed to parse metric: {}", e)))?;

let vector_path = BfTreePaths::vectors_bftree(&saved_params.prefix);
let mut vector_config = Config::new(&vector_path, saved_params.bytes_vector);
vector_config.storage_backend(bf_tree::StorageBackend::Std);

let neighbor_path = BfTreePaths::neighbors_bftree(&saved_params.prefix);
let mut neighbor_config = Config::new(&neighbor_path, saved_params.bytes_neighbor);
neighbor_config.storage_backend(bf_tree::StorageBackend::Std);

let quant_path = BfTreePaths::quant_bftree(&saved_params.prefix);
let mut quant_config = Config::new(&quant_path, quant_params.bytes_quant);
quant_config.storage_backend(bf_tree::StorageBackend::Std);
let vector_config = saved_params
.params_vector
.to_config(&BfTreePaths::vectors_bftree(&saved_params.prefix));
let neighbor_config = saved_params
.params_neighbor
.to_config(&BfTreePaths::neighbors_bftree(&saved_params.prefix));
let quant_config = quant_params
.params_quant
.to_config(&BfTreePaths::quant_bftree(&saved_params.prefix));

let vector_index =
BfTree::new_from_snapshot(vector_config.clone(), None).map_err(super::ConfigError)?;
Expand Down Expand Up @@ -2333,6 +2390,8 @@ mod tests {

let bytes_vector = 1024 * 1024;
let mut vector_config = Config::new(&vector_path, bytes_vector);
vector_config.leaf_page_size(8192);
vector_config.cb_max_record_size(1024);
vector_config.storage_backend(bf_tree::StorageBackend::Std);

let bytes_neighbor = 1024 * 1024;
Expand Down Expand Up @@ -2391,22 +2450,12 @@ mod tests {
);
}

let storage = FileStorageProvider;
assert_eq!(vector_config.get_leaf_page_size(), 8192);
assert_eq!(vector_config.get_cb_max_record_size(), 1024);

let metric_str = params.metric.as_str();
let saved_params = SavedParams {
max_points: params.max_points,
frozen_points: params.num_start_points,
dim: params.dim,
metric: metric_str.to_string(),
max_degree: params.max_degree,
prefix: prefix.clone(),
bytes_vector,
bytes_neighbor,
quant_params: None,
};
let storage = FileStorageProvider;

provider.save_with(&storage, &saved_params).await.unwrap();
provider.save_with(&storage, &prefix).await.unwrap();

// Load using trait method (includes delete bitmap)
let loaded_provider = BfTreeProvider::<f32, NoStore, TableDeleteProviderAsync>::load_with(
Expand Down Expand Up @@ -2574,26 +2623,7 @@ mod tests {

let storage = FileStorageProvider;

// Create SavedParamsQuant outside of save_with
let metric_str = params.metric.as_str();
let num_pq_bytes = pq_table.get_num_chunks();
let saved_params = SavedParams {
max_points: params.max_points,
frozen_points: params.num_start_points,
dim: params.dim,
metric: metric_str.to_string(),
max_degree: params.max_degree,
prefix: prefix.clone(),
bytes_vector,
bytes_neighbor,
quant_params: Some(QuantParams {
num_pq_bytes,
max_fp_vecs_per_fill: params.max_fp_vecs_per_fill.unwrap_or(0),
bytes_quant,
}),
};

provider.save_with(&storage, &saved_params).await.unwrap();
provider.save_with(&storage, &prefix).await.unwrap();

// Load using trait method (includes delete bitmap and quantization)
let loaded_provider =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ impl QuantVectorProvider {
self.metric
}

/// Access the BfTree config
pub(crate) fn config(&self) -> &Config {
self.quant_vector_index.config()
}

/// Create a snapshot of the quant vector index
///
pub fn snapshot(&self) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ impl<T: VectorRepr, I: VectorId> VectorProvider<T, I> {
.collect()
}

/// Access the BfTree config
pub(crate) fn config(&self) -> &Config {
self.vector_index.config()
}

/// Create a snapshot of the vector index
///
#[inline(always)]
Expand Down
Loading