diff --git a/bin/export-model-arch/src/export-model-arch/main.cc b/bin/export-model-arch/src/export-model-arch/main.cc index 82aebd2b2c..29be28b0ef 100644 --- a/bin/export-model-arch/src/export-model-arch/main.cc +++ b/bin/export-model-arch/src/export-model-arch/main.cc @@ -1,6 +1,6 @@ #include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" -#include "export_model_arch/json_sp_model_export.dtg.h" +#include "export-model-arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/dlrm/dlrm.h" diff --git a/flake.lock b/flake.lock index fdde6ce02c..9cd1e4bbae 100644 --- a/flake.lock +++ b/flake.lock @@ -66,11 +66,11 @@ ] }, "locked": { - "lastModified": 1763685681, - "narHash": "sha256-VFtDhrXx49yQS2r5Oxz2mvw/60uIAZhy0Y0rDBMvEno=", + "lastModified": 1769666654, + "narHash": "sha256-YFbOVi+Se3KDGFAoofYwYPUpEkEhsvdGdlYDR2I2XmI=", "ref": "refs/heads/master", - "rev": "72f7bd4008671613237681e29c9c90403a421ce0", - "revCount": 138, + "rev": "64620d82f03478496eb00188184dbf48d56b560d", + "revCount": 143, "type": "git", "url": "https://git.sr.ht/~lockshaw/proj" }, diff --git a/lib/compiler/include/compiler/algorithm_config.dtg.toml b/lib/compiler/include/compiler/algorithm_config.dtg.toml new file mode 100644 index 0000000000..df08841384 --- /dev/null +++ b/lib/compiler/include/compiler/algorithm_config.dtg.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "AlgorithmConfig" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/data_parallelism/data_parallelism_config.dtg.h", + "compiler/unity_algorithm/unity_search_config.dtg.h", +] + +[[values]] +type = "::FlexFlow::DataParallelismConfig" + +[[values]] +type = "::FlexFlow::UnitySearchConfig" diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h new file mode 100644 index 0000000000..44a405f383 --- /dev/null +++ b/lib/compiler/include/compiler/compiler.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_COMPILER_COMPILER_H +#define _FLEXFLOW_COMPILER_COMPILER_H + +#include "compiler/algorithm_config.dtg.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/search_result.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +SearchResult optimize(ComputationGraph const &, + MachineSpecification const &, + CostEstimator const &, + AlgorithmConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.dtg.toml b/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.dtg.toml new file mode 100644 index 0000000000..862dba45e6 --- /dev/null +++ b/lib/compiler/include/compiler/data_parallelism/data_parallelism_config.dtg.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DataParallelismConfig" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "degree" +type = "int" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h deleted file mode 100644 index 62c9f97331..0000000000 --- a/lib/compiler/include/compiler/graph_optimize_state.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H -#define _FLEXFLOW_COMPILER_MCMC_STATE_H - -#include "compiler/graph_optimize_result.dtg.h" - -namespace FlexFlow { - -struct GraphOptimizeState { - explicit GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, - float runtime); - - GraphOptimizeResult graph_optimize_result; - float runtime; - - bool operator==(GraphOptimizeState const &other) const; - bool operator!=(GraphOptimizeState const &other) const; - bool operator<(GraphOptimizeState const &other) const; -}; - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::GraphOptimizeState> { - size_t operator()(::FlexFlow::GraphOptimizeState const &) const; -}; - -} // namespace std - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml index 1685fca931..d540f97207 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_device.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml index e36c8de82d..34f653a9ec 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml index 3658afe154..f8a8280e9f 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml index b030692260..7b1468f4c9 100644 --- a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h similarity index 57% rename from lib/compiler/include/compiler/allowed_machine_views.h rename to lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h index 2a3de47b0d..5201f7fa31 100644 --- a/lib/compiler/include/compiler/allowed_machine_views.h +++ b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h @@ -1,18 +1,19 @@ -#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H -#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ALLOWED_MACHINE_VIEWS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ALLOWED_MACHINE_VIEWS_H #include "compiler/machine_mapping/machine_view.dtg.h" #include "op-attrs/operator_task_space.dtg.h" +#include "pcg/machine_compute_resource_slice.dtg.h" #include "pcg/machine_compute_specification.dtg.h" namespace FlexFlow { bool is_valid_machine_view(MachineView const &mv, OperatorTaskSpace const &task, - MachineComputeSpecification const &ms); + MachineComputeResourceSlice const &ms); std::unordered_set - get_allowed_machine_views(MachineComputeSpecification const &machine_spec, + get_allowed_machine_views(MachineComputeResourceSlice const &machine_spec, OperatorTaskSpace const &task, DeviceType device_type); diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 3e49899003..89b7cedf7c 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H -#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" #include "compiler/machine_mapping/machine_mapping_cache.dtg.h" #include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" @@ -9,6 +8,7 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "pcg/machine_compute_resource_slice.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h b/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h deleted file mode 100644 index 99187999ec..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_COMPUTE_RESOURCE_SLICE_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_COMPUTE_RESOURCE_SLICE_H - -#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" -#include "pcg/machine_compute_specification.dtg.h" - -namespace FlexFlow { - -MachineComputeResourceSlice - compute_slice_from_specification(MachineComputeSpecification const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h index 3e5b9238dd..28d1fc943b 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -2,6 +2,8 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_H #include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_specification.dtg.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" @@ -18,6 +20,9 @@ MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapping(ParallelComputationGraph const &, MachineMapping const &); +std::optional get_machine_mapping_from_machine_mapping_result( + PCGBinarySPDecomposition const &, MachineMappingResult const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h index d314ab493b..15d02dd64a 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -14,8 +14,13 @@ MachineMappingConstraints get_unconstrained_solution_for_layers( std::unordered_set const &); std::unordered_set - get_all_layers(MachineMappingConstraints const &, - IncludeUnconstrained const &); + get_unconstrained_layers(MachineMappingConstraints const &); + +std::unordered_set + get_constrained_layers(MachineMappingConstraints const &); + +std::unordered_set + get_all_layers(MachineMappingConstraints const &); std::optional get_machine_view_for_layer(MachineMappingConstraints const &, diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml index ae5299ecdd..04d2eb1378 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.dtg.toml @@ -6,7 +6,7 @@ features = [] includes = [ "compiler/cost_estimator/runtime_only_cost_estimator.h", "compiler/machine_mapping/machine_view.dtg.h", - "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "pcg/machine_compute_resource_slice.dtg.h", "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", ] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h index be8a4b9afa..c3f6a7e1c2 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -9,6 +9,9 @@ namespace FlexFlow { +bool is_valid_machine_mapping_problem_tree( + MachineMappingProblemTree const &problem_tree); + MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, PCGBinarySPDecomposition const &sp); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index abd77bfa7b..62a4206d54 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H -#include "compiler/machine_mapping/abstracted_tensor_set_movement/machine_space_stencil.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" @@ -32,6 +31,9 @@ std::optional std::unordered_map mm_problem_tree_get_path_to_leaf_map(MachineMappingProblemTree const &); +std::string as_dot(MachineMappingProblemTree const &); +void debug_print_dot(MachineMappingProblemTree const &); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h index c1de7cb956..43f8373cf9 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h @@ -3,6 +3,7 @@ #include "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h" +#include "op-attrs/operator_task_space.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" @@ -17,6 +18,9 @@ RuntimeOnlyOpCostEstimateKey map_unmapped_runtime_only_op_cost_estimate_key( UnmappedRuntimeOnlyOpCostEstimateKey const &unmapped, MachineView const &machine_view); +OperatorTaskSpace get_operator_task_space_for_runtime_only_op_cost_estimate_key( + UnmappedRuntimeOnlyOpCostEstimateKey const &unmapped); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index f7b52ec574..fd48f1b02c 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -35,6 +35,9 @@ FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); make_singleton_machine_mapping_result(milliseconds_t runtime, MachineView const &machine_view); +[[nodiscard]] milliseconds_t + get_runtime_cost(MachineMappingResult const &mm_result); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml index 369cbfd851..fece560df6 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.dtg.toml @@ -9,7 +9,7 @@ features = [ includes = [ "compiler/machine_mapping/machine_mapping_constraints.dtg.h", - "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "pcg/machine_compute_resource_slice.dtg.h", "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", ] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h index 7573276b82..ce8c2029c6 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_resource_split.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_RESOURCE_SPLIT_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_RESOURCE_SPLIT_H -#include "compiler/machine_mapping/machine_compute_resource_slice.dtg.h" #include "compiler/machine_mapping/machine_resource_split.dtg.h" #include "compiler/machine_mapping/machine_view.dtg.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "pcg/machine_compute_resource_slice.dtg.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml index fc47dff0ba..7c31c0d16b 100644 --- a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_context.dtg.toml @@ -6,7 +6,7 @@ features = [] includes = [ "compiler/cost_estimator/cost_estimator.h", "compiler/machine_mapping/machine_view.dtg.h", - "compiler/machine_mapping/machine_compute_resource_slice.dtg.h", + "pcg/machine_compute_resource_slice.dtg.h", "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h", "pcg/optimizer_attrs.dtg.h", ] diff --git a/lib/compiler/include/compiler/search_result.dtg.toml b/lib/compiler/include/compiler/search_result.dtg.toml new file mode 100644 index 0000000000..36e516ce6e --- /dev/null +++ b/lib/compiler/include/compiler/search_result.dtg.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "SearchResult" +type = "struct" +features = [ +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/machine_mapping/machine_mapping.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/search_result.h b/lib/compiler/include/compiler/search_result.h new file mode 100644 index 0000000000..eb4fd2e95f --- /dev/null +++ b/lib/compiler/include/compiler/search_result.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SEARCH_RESULT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SEARCH_RESULT_H + +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { + +MappedParallelComputationGraph + get_mapped_pcg_from_search_result(SearchResult const &); + +std::string format_as(SearchResult const &); +std::ostream &operator<<(std::ostream &, SearchResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index 3ac0908659..c680644f30 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -22,11 +22,13 @@ GenericBinarySPDecompositionTreeImplementation - get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); std::unordered_multiset get_parallel_layers(PCGBinarySPDecomposition const &); +PCGBinarySPDecomposition + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + BinarySPDecompositionTree const &); + SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); std::unordered_set @@ -36,9 +38,6 @@ std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &, parallel_layer_guid_t const &); -PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree( - BinarySPDecompositionTree const &spd_tree); - std::unordered_map pcg_sp_tree_get_path_to_leaf_map(PCGBinarySPDecomposition const &); diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h deleted file mode 100644 index d8ba9158a6..0000000000 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_H - -#include "compiler/cost_estimator/cost_estimator.h" -#include "compiler/graph_optimize_result.dtg.h" -#include "optimizer_config.dtg.h" -#include "pcg/computation_graph.h" -#include "pcg/machine_specification.dtg.h" -#include "substitutions/sub_parallel_computation_graph.h" - -namespace FlexFlow { - -GraphOptimizeResult graph_optimize( - ParallelComputationGraph &pcg, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimizerConfig const &opt_config); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h new file mode 100644 index 0000000000..c0952c0684 --- /dev/null +++ b/lib/compiler/include/compiler/unity_algorithm/graph_optimize_state.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_GRAPH_OPTIMIZE_STATE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_GRAPH_OPTIMIZE_STATE_H + +#include "compiler/graph_optimize_result.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "utils/units/milliseconds_t.h" + +namespace FlexFlow { + +struct GraphOptimizeState { + GraphOptimizeState() = delete; + explicit GraphOptimizeState( + ParallelComputationGraph const ¶llel_computation_graph, + milliseconds_t runtime); + + bool operator==(GraphOptimizeState const &other) const; + bool operator!=(GraphOptimizeState const &other) const; + bool operator<(GraphOptimizeState const &other) const; + +public: + ParallelComputationGraph pcg; + milliseconds_t runtime; +}; + +std::string format_as(GraphOptimizeState const &); +std::ostream &operator<<(std::ostream &, GraphOptimizeState const &); + +// TODO(@lockshaw)(#pr): Delete this if still unused +// std::optional +// graph_optimize_state_from_machine_mapping_result(ParallelComputationGraph +// const &, +// PCGBinarySPDecomposition +// const &, +// MachineMappingResult const +// &); + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::GraphOptimizeState> { + size_t operator()(::FlexFlow::GraphOptimizeState const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h new file mode 100644 index 0000000000..9d071d1e0e --- /dev/null +++ b/lib/compiler/include/compiler/unity_algorithm/unity_algorithm.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_UNITY_ALGORITHM_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_UNITY_ALGORITHM_UNITY_ALGORITHM_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/runtime_only_cost_estimator.h" +#include "compiler/search_result.dtg.h" +#include "compiler/unity_algorithm/unity_search_config.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "substitutions/substitution.h" + +namespace FlexFlow { + +SearchResult graph_optimize(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineComputeSpecification const &resources, + UnitySearchConfig const &search_config); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/optimizer_config.dtg.toml b/lib/compiler/include/compiler/unity_algorithm/unity_search_config.dtg.toml similarity index 76% rename from lib/compiler/include/compiler/optimizer_config.dtg.toml rename to lib/compiler/include/compiler/unity_algorithm/unity_search_config.dtg.toml index 395b22f46b..f4a0640f73 100644 --- a/lib/compiler/include/compiler/optimizer_config.dtg.toml +++ b/lib/compiler/include/compiler/unity_algorithm/unity_search_config.dtg.toml @@ -1,10 +1,11 @@ namespace = "FlexFlow" -name = "OptimizerConfig" +name = "UnitySearchConfig" type = "struct" features = [ "eq", "hash", "fmt", + "json", ] includes = [ @@ -18,10 +19,6 @@ type = "float" name = "budget" type = "int" -[[fields]] -name = "threshold" -type = "float" - [[fields]] name = "max_num_ops" type = "int" diff --git a/lib/compiler/src/compiler/compiler.cc b/lib/compiler/src/compiler/compiler.cc new file mode 100644 index 0000000000..714cda3f86 --- /dev/null +++ b/lib/compiler/src/compiler/compiler.cc @@ -0,0 +1,30 @@ +#include "compiler/compiler.h" +#include "compiler/cost_estimator/runtime_only_cost_estimator_from_cost_estimator.h" +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +SearchResult optimize(ComputationGraph const &computation_graph, + MachineSpecification const &machine_specification, + CostEstimator const &cost_estimator, + AlgorithmConfig const &search_config) { + return search_config.visit(overload{ + [&](DataParallelismConfig const &config) -> SearchResult { + throw std::runtime_error( + "Data parallel search algorithm is not implemented yet"); + }, + [&](UnitySearchConfig const &config) { + ParallelComputationGraph pcg = + pcg_from_computation_graph(computation_graph); + return graph_optimize( + pcg, + runtime_only_cost_estimator_from_cost_estimator(cost_estimator), + machine_specification.compute_specification, + config); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc deleted file mode 100644 index bf40df2f11..0000000000 --- a/lib/compiler/src/compiler/graph_optimize_state.cc +++ /dev/null @@ -1,98 +0,0 @@ -#include "compiler/graph_optimize_state.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "utils/hash/tuple.h" -#include "utils/hash/unordered_map.h" -#include "utils/hash/unordered_multiset.h" - -namespace FlexFlow { - -GraphOptimizeState::GraphOptimizeState( - GraphOptimizeResult const &graph_optimize_result, float runtime) - : graph_optimize_result(graph_optimize_result), runtime(runtime) {} - -static std::unordered_multiset>, - std::unordered_map>> - get_layer_signature_set(MappedParallelComputationGraph const &mapped_pcg) { - - auto get_layer_signature = [&](parallel_layer_guid_t l) - -> std::tuple>, - std::unordered_map> { - ParallelLayerAttrs layer_attrs = - get_parallel_layer_attrs(mapped_pcg.pcg, l); - - std::unordered_map< - TensorSlotName, - std::tuple> - inputs = - map_values(get_incoming_tensors(mapped_pcg.pcg, l), - [&](parallel_tensor_guid_t const &i) { - parallel_layer_guid_t src = get_source_layer(i); - TensorSlotName src_slot = i.raw_graph_output.slot_name; - ParallelTensorAttrs tensor_attrs = - get_parallel_tensor_attrs(mapped_pcg.pcg, i); - - return std::tuple{ - get_parallel_layer_attrs(mapped_pcg.pcg, src), - src_slot, - tensor_attrs, - }; - }); - - std::unordered_map outputs = - map_values(get_layer_outputs(mapped_pcg.pcg, l), - [&](parallel_tensor_guid_t const &o) { - return get_parallel_tensor_attrs(mapped_pcg.pcg, o); - }); - - return { - layer_attrs, - mapped_pcg.mapped_tasks.at(l), - inputs, - outputs, - }; - }; - - return transform(unordered_multiset_of(get_parallel_layers(mapped_pcg.pcg)), - get_layer_signature); -} - -bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { - return get_layer_signature_set(this->graph_optimize_result.mapped_pcg) == - get_layer_signature_set(other.graph_optimize_result.mapped_pcg); -} - -bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { - return !(*this == other); -} - -bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { - return runtime < other.runtime; -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::GraphOptimizeState>::operator()( - ::FlexFlow::GraphOptimizeState const &state) const { - // TODO(@wmdi): Eventually it might be good to use a proper graph hash like - // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash - using namespace ::FlexFlow; - - auto layers = get_layer_signature_set(state.graph_optimize_result.mapped_pcg); - - return get_std_hash(layers); -} - -} // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc index 1088d02adb..1aeb83d202 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc @@ -66,6 +66,7 @@ TensorSetMovement concretize_abstracted_single_tensor_movement( std::unordered_map const &post_machine_stencils) { + ASSERT(contains_key(pre_machine_stencils, abstracted.src_op_tree_path)); MachineSpaceStencil pre_machine_stencil = pre_machine_stencils.at(abstracted.src_op_tree_path); diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc similarity index 72% rename from lib/compiler/src/compiler/allowed_machine_views.cc rename to lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc index 558f383adc..ec369f2f03 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/machine_mapping/allowed_machine_views.cc @@ -1,7 +1,8 @@ -#include "compiler/allowed_machine_views.h" +#include "compiler/machine_mapping/allowed_machine_views.h" #include "compiler/machine_mapping/machine_view.h" #include "compiler/machine_mapping/multi_dimensional_stride.dtg.h" #include "op-attrs/operator_task_space.h" +#include "pcg/machine_compute_resource_slice.h" #include "pcg/machine_compute_specification.h" #include "utils/containers/all_of.h" #include "utils/containers/cartesian_product.h" @@ -26,7 +27,7 @@ namespace FlexFlow { bool is_valid_machine_view(MachineView const &mv, OperatorTaskSpace const &task_space, - MachineComputeSpecification const &ms) { + MachineComputeResourceSlice const &ms) { if (mv_get_expected_task_space_num_dims(mv) != op_task_space_num_dims(task_space)) { return false; @@ -35,7 +36,7 @@ bool is_valid_machine_view(MachineView const &mv, MachineSpaceCoordinate maximum_device_coord = get_machine_space_coordinate( task_space, mv, get_task_space_maximum_coordinate(task_space)); - return is_valid_machine_space_coordinate(ms, maximum_device_coord); + return is_valid_machine_space_coordinate_in_slice(ms, maximum_device_coord); } /* @@ -47,7 +48,7 @@ bool is_valid_machine_view(MachineView const &mv, * the returned `MachineView`s to be invalid) */ static std::unordered_set - get_candidate_machine_views(MachineComputeSpecification const &machine_spec, + get_candidate_machine_views(MachineComputeResourceSlice const &machine_spec, OperatorTaskSpace const &task_space, DeviceType const &device_type) { @@ -62,8 +63,8 @@ static std::unordered_set positive_int{min_num_devices_with_full_stride_volume}); }; - auto candidate_strides = [&](std::vector const &tensor_dims, - positive_int total_devices) + auto get_candidate_strides = [&](std::vector const &tensor_dims, + positive_int total_devices) -> std::unordered_multiset { positive_int max_stride_upper_bound = get_max_stride_upper_bound(tensor_dims, total_devices); @@ -73,23 +74,29 @@ static std::unordered_set 1_n, max_stride_upper_bound.nonnegative_int_from_positive_int() + 1_n), [](nonnegative_int stride) { return stride_t{positive_int{stride}}; }); + std::unordered_multiset> raw_stride_vectors = cartesian_product( repeat_element(/*num_times=*/num_elements(tensor_dims), /*element=*/single_stride_range)); + std::unordered_multiset strides = transform(raw_stride_vectors, [](auto const &stride_vec) { return MultiDimensionalStride{stride_vec}; }); + return strides; }; - auto candidate_starts = [](MachineComputeSpecification const &ms, - DeviceType const &device_type) { + auto get_candidate_starts = [](MachineComputeResourceSlice const &slice, + DeviceType const &device_type) + -> std::unordered_set { + ASSERT(device_type == DeviceType::GPU); + std::unordered_set result; - for (nonnegative_int node_idx : nonnegative_range(ms.num_nodes)) { + for (nonnegative_int node_idx : nonnegative_range(slice.num_nodes)) { for (nonnegative_int device_idx : - nonnegative_range(get_num_devices_per_node(ms, device_type))) { + nonnegative_range(slice.num_gpus_per_node)) { result.insert( MachineSpaceCoordinate{node_idx, device_idx, device_type}); } @@ -97,7 +104,8 @@ static std::unordered_set return result; }; - auto candidate_dimensions = [](OperatorTaskSpace const &task_space) { + auto get_candidate_dimensions = [](OperatorTaskSpace const &task_space) + -> std::unordered_multiset> { std::unordered_set options = { MachineSpecificationDimension::INTER_NODE, MachineSpecificationDimension::INTRA_NODE}; @@ -110,16 +118,26 @@ static std::unordered_set return dim.positive_int_from_int_ge_two(); }); - positive_int total_devices = get_num_devices(machine_spec, device_type); + positive_int total_devices = get_total_num_devices_in_slice(machine_spec); + + std::unordered_multiset candidate_strides = + get_candidate_strides(tensor_dims, total_devices); + ASSERT(candidate_strides.size() > 0); + + std::unordered_set candidate_starts = + get_candidate_starts(machine_spec, device_type); + ASSERT(candidate_starts.size() > 0); + + std::unordered_multiset> + candidate_dimensions = get_candidate_dimensions(task_space); + ASSERT(candidate_dimensions.size() > 0); std::unordered_set machine_views; - for (MultiDimensionalStride const &strides : - candidate_strides(tensor_dims, total_devices)) { - for (MachineSpaceCoordinate start : - candidate_starts(machine_spec, device_type)) { + for (MultiDimensionalStride const &strides : candidate_strides) { + for (MachineSpaceCoordinate start : candidate_starts) { for (std::vector const &dims : - candidate_dimensions(task_space)) { + candidate_dimensions) { machine_views.insert( machine_view_from_strides_and_machine_spec_dimensions( start, strides.raw_strides, dims)); @@ -130,7 +148,7 @@ static std::unordered_set } std::unordered_set - get_allowed_machine_views(MachineComputeSpecification const &machine_spec, + get_allowed_machine_views(MachineComputeResourceSlice const &machine_spec, OperatorTaskSpace const &task_space, DeviceType device_type) { diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 2407297322..77e50740aa 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -21,9 +21,12 @@ #include "pcg/machine_specification.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_all_assignments.h" +#include "utils/containers/keys.h" +#include "utils/containers/set_minus.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/overload.h" @@ -37,6 +40,8 @@ MachineMappingResult MachineComputeResourceSlice const &resources, MachineMappingConstraints const &constraints) { + ASSERT(get_all_layers(constraints) == get_all_leaf_paths(problem_tree)); + MachineMappingState state = MachineMappingState{ problem_tree, resources, @@ -85,12 +90,21 @@ MachineMappingResult ¶llel_split_transformation) { auto get_boundary_machine_view_assignments = - [&](MachineMappingProblemTree const &root, - std::unordered_set const &boundary_layers) + [&](std::unordered_set const &boundary_layers, + MachineMappingProblemTree const &root, + BinaryTreePathEntry const &prefix) -> std::unordered_set { + MachineMappingConstraints sub_constraints = + restrict_to_child(constraints, prefix); + + ASSERT(get_all_layers(sub_constraints) == get_all_leaf_paths(root)); + + std::unordered_set unconstrained_boundary_layers = + set_minus(boundary_layers, get_constrained_layers(sub_constraints)); + std::unordered_map> allowed = generate_map( - boundary_layers, + unconstrained_boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = mm_problem_tree_get_subtree_at_path(root, l) @@ -98,8 +112,12 @@ MachineMappingResult .get(); return context.allowed_machine_views(leaf, resources); }); + + std::unordered_set> + assignments = get_all_assignments(allowed); + return transform( - get_all_assignments(allowed), + assignments, [](std::unordered_map const &m) { return ParallelLayerGuidObliviousMachineMapping{m}; }); @@ -143,29 +161,42 @@ MachineMappingResult for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : - get_boundary_machine_view_assignments(series_split.get_left_child(), - get_src_layers(tensor_movement))) { + get_boundary_machine_view_assignments(get_src_layers(tensor_movement), + series_split.get_left_child(), + BinaryTreePathEntry::LEFT_CHILD)) { MachineMappingResult pre_result = eval_pre_boundary_mapping(assigned_pre_machine_views); + if (is_infeasible(pre_result)) { + continue; + } + for (ParallelLayerGuidObliviousMachineMapping const &assigned_post_machine_views : get_boundary_machine_view_assignments( - series_split.get_right_child(), get_dst_layers(tensor_movement))) { + get_dst_layers(tensor_movement), + series_split.get_right_child(), + BinaryTreePathEntry::RIGHT_CHILD)) { MachineMappingResult post_result = eval_post_boundary_mapping(assigned_post_machine_views); + if (is_infeasible(post_result)) { + continue; + } + TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement( tensor_movement, /*pre_machine_stencils=*/ get_machine_stencils_for_partially_mapped_mm_problem_tree( - series_split.get_left_child(), assigned_pre_machine_views), + series_split.get_left_child(), + require_feasible(pre_result).machine_mapping), /*post_machine_stencils=*/ get_machine_stencils_for_partially_mapped_mm_problem_tree( - series_split.get_right_child(), assigned_post_machine_views)); + series_split.get_right_child(), + require_feasible(post_result).machine_mapping)); milliseconds_t cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc b/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc deleted file mode 100644 index 46614269fc..0000000000 --- a/lib/compiler/src/compiler/machine_mapping/machine_compute_resource_slice.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "compiler/machine_mapping/machine_compute_resource_slice.h" - -namespace FlexFlow { - -MachineComputeResourceSlice - compute_slice_from_specification(MachineComputeSpecification const &spec) { - - return MachineComputeResourceSlice{ - /*num_nodes=*/spec.num_nodes, - /*num_gpus_per_node=*/spec.num_gpus_per_node, - }; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 1676a6929d..8a16ff9dda 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,6 +1,8 @@ #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_view.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/keys.h" @@ -50,4 +52,27 @@ bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); } +std::optional get_machine_mapping_from_machine_mapping_result( + PCGBinarySPDecomposition const &sp_decomposition, + MachineMappingResult const &mm_result) { + + FeasibleMachineMappingResult feasible_mapping = ({ + if (is_infeasible(mm_result)) { + return std::nullopt; + } + + require_feasible(mm_result); + }); + + bidict path_to_leaf_map = + bidict_from_map(pcg_sp_tree_get_path_to_leaf_map(sp_decomposition)); + + return MachineMapping{ + map_keys(feasible_mapping.machine_mapping.raw_mapping, + [&](BinaryTreePath const &p) -> parallel_layer_guid_t { + return path_to_leaf_map.at_l(p); + }), + }; +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 20683777d5..8278d5511c 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -1,5 +1,6 @@ #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "utils/containers/filter.h" +#include "utils/containers/filter_values.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" #include "utils/containers/generate_map.h" @@ -21,18 +22,24 @@ MachineMappingConstraints get_unconstrained_solution_for_layers( } std::unordered_set - get_all_layers(MachineMappingConstraints const &partial_solution, - IncludeUnconstrained const &include_unconstrained) { - std::unordered_set with_unconstrained = - keys(partial_solution.machine_views); - - if (include_unconstrained.raw_bool) { - return with_unconstrained; - } else { - return filter(with_unconstrained, [&](BinaryTreePath const &l) { - return partial_solution.machine_views.at(l).has_value(); - }); - } + get_unconstrained_layers(MachineMappingConstraints const &constraints) { + + return keys(filter_values( + constraints.machine_views, + [](std::optional const &mv) { return !mv.has_value(); })); +} + +std::unordered_set + get_constrained_layers(MachineMappingConstraints const &constraints) { + + return keys(filter_values( + constraints.machine_views, + [](std::optional const &mv) { return mv.has_value(); })); +} + +std::unordered_set + get_all_layers(MachineMappingConstraints const &partial_solution) { + return keys(partial_solution.machine_views); } std::optional get_machine_view_for_layer( diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index da6b7b91e5..88d64f6359 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,14 +1,50 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/all_of.h" #include "utils/overload.h" namespace FlexFlow { +bool is_valid_machine_mapping_problem_tree( + MachineMappingProblemTree const &problem_tree) { + return problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + auto contains_paths = + [](MachineMappingProblemTree const &t, + std::unordered_set const &paths) { + return all_of(paths, [&](BinaryTreePath const &p) { + return mm_problem_tree_get_subtree_at_path(t, p).has_value(); + }); + }; + + return contains_paths(series_split.get_left_child(), + get_src_layers(tensor_movement)) && + contains_paths(series_split.get_right_child(), + get_dst_layers(tensor_movement)) && + is_valid_machine_mapping_problem_tree( + series_split.get_left_child()) && + is_valid_machine_mapping_problem_tree( + series_split.get_right_child()); + }, + [&](MMProblemTreeParallelSplit const ¶llel_split) { + return is_valid_machine_mapping_problem_tree( + parallel_split.get_left_child()) && + is_valid_machine_mapping_problem_tree( + parallel_split.get_right_child()); + }, + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &leaf) { return true; }, + }); +} + MachineMappingProblemTree get_machine_mapping_problem_tree( ParallelComputationGraph const &pcg, PCGBinarySPDecomposition const &sp_decomposition_tree) { @@ -23,32 +59,42 @@ MachineMappingProblemTree get_machine_mapping_problem_tree( [&](PCGBinarySeriesSplit const &series) { AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ MMProblemTreeSeriesSplit{ /*tensor_set_movement=*/tensor_movement, /*lhs=*/to_problem_tree(series.get_left_child()), /*rhs=*/to_problem_tree(series.get_right_child()), }, }; + ASSERT(is_valid_machine_mapping_problem_tree(result)); + return result; }, [&](PCGBinaryParallelSplit const ¶llel) { - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ MMProblemTreeParallelSplit{ to_problem_tree(parallel.get_left_child()), to_problem_tree(parallel.get_right_child()), }, }; + ASSERT(is_valid_machine_mapping_problem_tree(result)); + return result; }, [&](parallel_layer_guid_t const &leaf) { - return MachineMappingProblemTree{ + MachineMappingProblemTree result = MachineMappingProblemTree{ get_unmapped_runtime_only_op_cost_estimate_key_for_layer(pcg, leaf), }; + ASSERT(is_valid_machine_mapping_problem_tree(result)); + return result; }, }); }; - return to_problem_tree(sp_decomposition_tree); + MachineMappingProblemTree mm_tree = to_problem_tree(sp_decomposition_tree); + + ASSERT(is_valid_machine_mapping_problem_tree(mm_tree)); + + return mm_tree; } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 340c448275..8fc6d52c0f 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -1,4 +1,6 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h" @@ -97,4 +99,57 @@ std::unordered_map generic_binary_sp_impl_for_mm_problem_tree()); } +std::string as_dot(MachineMappingProblemTree const &tree) { + std::function + get_series_label = + [](MMProblemTreeSeriesSplit const &series) -> std::string { + auto path_as_dot = [](BinaryTreePath const &path) -> std::string { + return "(" + + join_strings(path.entries, + ", ", + [](BinaryTreePathEntry const &entry) -> std::string { + if (entry == BinaryTreePathEntry::LEFT_CHILD) { + return "l"; + } else { + assert(entry == BinaryTreePathEntry::RIGHT_CHILD); + return "r"; + } + }) + + ")"; + }; + + auto path_set_as_dot = + [&](std::unordered_set const &path_set) -> std::string { + return "(" + join_strings(path_set, ", ", path_as_dot) + ")"; + }; + + return fmt::format( + "srcs={} dsts={}", + path_set_as_dot(get_src_layers(series.tensor_set_movement)), + path_set_as_dot(get_dst_layers(series.tensor_set_movement))); + }; + + std::function + get_parallel_label = + [](MMProblemTreeParallelSplit const ¶llel) -> std::string { + return "P"; + }; + + std::function + get_leaf_label = + [](UnmappedRuntimeOnlyOpCostEstimateKey const &leaf) -> std::string { + return ""; + }; + + return as_dot(tree, + generic_binary_sp_impl_for_mm_problem_tree(), + get_series_label, + get_parallel_label, + get_leaf_label); +} + +void debug_print_dot(MachineMappingProblemTree const &tree) { + std::cout << as_dot(tree) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc index 9d84f2ca81..bf8d9fb70b 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.cc @@ -1,5 +1,9 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_operator_task_space.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/map_values.h" namespace FlexFlow { @@ -36,4 +40,16 @@ RuntimeOnlyOpCostEstimateKey map_unmapped_runtime_only_op_cost_estimate_key( }; } +OperatorTaskSpace get_operator_task_space_for_runtime_only_op_cost_estimate_key( + UnmappedRuntimeOnlyOpCostEstimateKey const &unmapped) { + + return get_operator_task_space( + assert_unwrap(compgraph_op_attrs_from_pcg_op_attrs(unmapped.op_attrs)), + map_values(unmapped.input_shapes, + [](ParallelTensorShape const &input_shape) + -> ParallelTensorDimDegrees { + return get_parallel_degrees(input_shape); + })); +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc index 7c9c7951eb..33d550474c 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -138,4 +138,14 @@ MachineMappingResult }; } +milliseconds_t get_runtime_cost(MachineMappingResult const &mm_result) { + if (mm_result.raw_result == std::nullopt) { + return milliseconds_t{ + std::numeric_limits::infinity(), + }; + } else { + return mm_result.raw_result.value().runtime; + } +} + } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/search_result.cc b/lib/compiler/src/compiler/search_result.cc new file mode 100644 index 0000000000..28eec7f247 --- /dev/null +++ b/lib/compiler/src/compiler/search_result.cc @@ -0,0 +1,21 @@ +#include "compiler/search_result.h" + +namespace FlexFlow { + +MappedParallelComputationGraph + get_mapped_pcg_from_search_result(SearchResult const &search_result) { + return mapped_pcg_from_pcg_and_mapping(search_result.pcg, + search_result.machine_mapping); +} + +std::string format_as(SearchResult const &r) { + return fmt::format("", + as_dot(r.pcg), + r.machine_mapping); +} + +std::ostream &operator<<(std::ostream &s, SearchResult const &r) { + return (s << fmt::to_string(r)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc index 4ab1fb152d..70fcbc644a 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.cc @@ -17,7 +17,7 @@ std::optional return std::nullopt; } - return pcg_binary_sp_decomposition_from_binary_sp_tree( + return pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( balanced_binary_sp_tree_from_nary(spd.value())); } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 003af5a2dc..cd8e634f2c 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,9 +1,11 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" #include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_path_to_leaf_map.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/overload.h" namespace FlexFlow { @@ -80,25 +82,42 @@ BinarySPDecompositionTree }); } -PCGBinarySPDecomposition pcg_binary_sp_decomposition_from_binary_sp_tree( - BinarySPDecompositionTree const &spd_tree) { - return spd_tree.visit(overload{ +PCGBinarySeriesSplit pcg_binary_series_split_from_binary_series_split( + BinarySeriesSplit const &split) { + return PCGBinarySeriesSplit{ + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_right_child()), + }; +} + +PCGBinaryParallelSplit pcg_binary_parallel_split_from_binary_parallel_split( + BinaryParallelSplit const &split) { + return PCGBinaryParallelSplit{ + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_left_child()), + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + split.get_right_child()), + }; +} + +PCGBinarySPDecomposition + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + BinarySPDecompositionTree const &sp_tree) { + + return sp_tree.visit(overload{ [](BinarySeriesSplit const &series) -> PCGBinarySPDecomposition { return PCGBinarySPDecomposition{ - PCGBinarySeriesSplit{ - pcg_binary_sp_decomposition_from_binary_sp_tree( - series.get_left_child()), - pcg_binary_sp_decomposition_from_binary_sp_tree( - series.get_right_child()), - }, + pcg_binary_series_split_from_binary_series_split(series), }; }, [](BinaryParallelSplit const ¶llel) -> PCGBinarySPDecomposition { return PCGBinarySPDecomposition{ PCGBinaryParallelSplit{ - pcg_binary_sp_decomposition_from_binary_sp_tree( + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( parallel.get_left_child()), - pcg_binary_sp_decomposition_from_binary_sp_tree( + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( parallel.get_right_child()), }, }; diff --git a/lib/compiler/src/compiler/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm.cc deleted file mode 100644 index 9ae824c62a..0000000000 --- a/lib/compiler/src/compiler/unity_algorithm.cc +++ /dev/null @@ -1,77 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "compiler/graph_optimize_state.h" -#include "compiler/machine_mapping/get_optimal_machine_mapping.h" -#include "pcg/machine_specification.dtg.h" -#include "substitutions/substitution.h" -#include "utils/deduplicated_priority_queue.h" -#include "utils/graph/node/algorithms.h" - -namespace FlexFlow { - -GraphOptimizeResult graph_optimize( - ParallelComputationGraph &pcg, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimizerConfig const &opt_config) { - NOT_IMPLEMENTED(); - - // std::vector substitutions = - // get_all_applicable_substitutions(pcg); - // - // MachineMappingCache cached_subgraph_costs; - // DeduplicatedPriorityQueue candidates; - // - // MachineMappingResult original_pcg_cost = - // get_optimal_machine_mapping(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // - // GraphOptimizeState initial_state = { - // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), - // original_pcg_cost.runtime}; - // - // GraphOptimizeState best_state = initial_state; - // candidates.push(initial_state); - // - // for (int iteration = 0; !candidates.empty() && iteration < - // opt_config.budget; - // ++iteration) { - // GraphOptimizeState current_state = candidates.top(); - // candidates.pop(); - // - // if (current_state.runtime < best_state.runtime) { - // best_state = current_state; - // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) - // { - // continue; - // } - // - // for (Substitution const &substitution : substitutions) { - // for (ParallelComputationGraph const &new_pcg : apply_substitution( - // current_state.graph_optimize_result.pcg, substitution)) { - // MachineMappingResult new_pcg_cost = - // get_optimal_machine_mapping(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // GraphOptimizeState new_state{ - // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), - // new_pcg_cost.runtime}; - // if (new_pcg_cost.runtime <= opt_config.threshold && - // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_state); - // } - // } - // } - // } - - // return best_state.graph_optimize_result; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc new file mode 100644 index 0000000000..6187ce7b60 --- /dev/null +++ b/lib/compiler/src/compiler/unity_algorithm/graph_optimize_state.cc @@ -0,0 +1,167 @@ +#include "compiler/unity_algorithm/graph_optimize_state.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_view.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/containers/zip_values_strict.h" +#include "utils/containers/zip_values_strict_with.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_multiset.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState(ParallelComputationGraph const &pcg, + milliseconds_t runtime) + : pcg(pcg), runtime(runtime) {} + +static std::unordered_multiset>, + std::unordered_map>> + get_layer_signature_set(ParallelComputationGraph const &pcg) { + + auto get_layer_signature = [&](parallel_layer_guid_t l) + -> std::tuple>, + std::unordered_map> { + ParallelLayerAttrs layer_attrs = get_parallel_layer_attrs(pcg, l); + + std::unordered_map< + TensorSlotName, + std::tuple> + inputs = map_values( + get_incoming_tensors(pcg, l), [&](parallel_tensor_guid_t const &i) { + parallel_layer_guid_t src = get_source_layer(i); + TensorSlotName src_slot = i.raw_graph_output.slot_name; + ParallelTensorAttrs tensor_attrs = + get_parallel_tensor_attrs(pcg, i); + + return std::tuple{ + get_parallel_layer_attrs(pcg, src), + src_slot, + tensor_attrs, + }; + }); + + std::unordered_map outputs = + map_values(get_layer_outputs(pcg, l), + [&](parallel_tensor_guid_t const &o) { + return get_parallel_tensor_attrs(pcg, o); + }); + + return { + layer_attrs, + inputs, + outputs, + }; + }; + + return transform(unordered_multiset_of(get_parallel_layers(pcg)), + get_layer_signature); +} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + return get_layer_signature_set(this->pcg) == + get_layer_signature_set(other.pcg); +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + +std::string format_as(GraphOptimizeState const &s) { + return fmt::format( + "", s.runtime, as_dot(s.pcg)); +} + +std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &x) { + return (s << fmt::to_string(x)); +} + +// TODO(@lockshaw)(#pr): Delete this if still unused +// std::optional +// graph_optimize_state_from_machine_mapping_result(ParallelComputationGraph +// const &pcg, +// PCGBinarySPDecomposition +// const +// &binary_sp_decomposition, +// MachineMappingResult const +// &machine_mapping_result) { +// +// FeasibleMachineMappingResult feasible_mapping = ({ +// if (is_infeasible(machine_mapping_result)) { +// return std::nullopt; +// } +// +// require_feasible(machine_mapping_result); +// }); +// +// bidict path_to_leaf_map = +// bidict_from_map(pcg_sp_tree_get_path_to_leaf_map(binary_sp_decomposition)); +// +// std::unordered_map +// mapped_tasks_by_path = zip_values_strict_with( +// path_to_leaf_map.as_unordered_map(), +// feasible_mapping.machine_mapping.raw_mapping, +// [&](parallel_layer_guid_t const &layer_guid, MachineView const &mv) +// -> MappedOperatorTaskGroup +// { +// ComputationGraphOpAttrs comp_graph_op_attrs = +// assert_unwrap(compgraph_op_attrs_from_pcg_op_attrs(pcg_get_op_attrs(pcg, +// layer_guid))); +// +// return mapped_operator_task_group_from_machine_view( +// comp_graph_op_attrs, +// get_incoming_input_degrees(pcg, layer_guid), +// mv); +// }); +// +// std::unordered_map +// mapped_tasks = map_keys(mapped_tasks_by_path, +// [&](BinaryTreePath const &path) -> +// parallel_layer_guid_t { +// return path_to_leaf_map.at_l(path); +// }); +// +// GraphOptimizeResult result = GraphOptimizeResult{ +// MappedParallelComputationGraph{ +// pcg, +// mapped_tasks, +// }, +// }; +// +// return GraphOptimizeState{ +// result, +// feasible_mapping.runtime, +// }; +// } + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + using namespace ::FlexFlow; + + auto layers = get_layer_signature_set(state.pcg); + + return get_std_hash(layers); +} + +} // namespace std diff --git a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc new file mode 100644 index 0000000000..be8c7c4f98 --- /dev/null +++ b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc @@ -0,0 +1,148 @@ +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "compiler/machine_mapping/allowed_machine_views.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" +#include "compiler/unity_algorithm/graph_optimize_state.h" +#include "op-attrs/operator_task_space.h" +#include "pcg/machine_compute_resource_slice.h" +#include "pcg/machine_specification.dtg.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution.h" +#include "substitutions/unity_substitution_set.h" +#include "utils/containers/generate_map.h" +#include "utils/deduplicated_priority_queue.h" +#include "utils/graph/node/algorithms.h" +#include "utils/optional.h" + +namespace FlexFlow { + +/* + * Applies a substitution to all possible positions in PCG + */ +std::vector + all_pcgs_obtained_by_applying_a_substitution( + ParallelComputationGraph const &pcg, + std::vector const &substitutions) { + std::vector results; + SubParallelComputationGraph subpcg = sub_pcg_from_full_pcg(pcg); + for (Substitution const &substitution : substitutions) { + for (PCGPatternMatch const &pattern_match : + find_pattern_matches(substitution.pcg_pattern, subpcg)) { + SubParallelComputationGraph subpcg_from_substitution = + apply_substitution(subpcg, substitution, pattern_match); + results.push_back( + pcg_from_sub_pcg_by_dropping_inputs(subpcg_from_substitution)); + } + } + return results; +} + +SearchResult graph_optimize(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineComputeSpecification const &resources, + UnitySearchConfig const &search_config) { + + std::vector substitutions = get_substitution_set(resources); + + MachineMappingCache cached_subgraph_costs = empty_machine_mapping_cache(); + DeduplicatedPriorityQueue candidates; + + MachineMappingContext context = MachineMappingContext{ + /*cost_estimator=*/cost_estimator, + /*allowed_machine_views=*/ + [&](UnmappedRuntimeOnlyOpCostEstimateKey const &key, + MachineComputeResourceSlice const &resources) + -> std::unordered_set { + OperatorTaskSpace op_task_space = + get_operator_task_space_for_runtime_only_op_cost_estimate_key(key); + + return get_allowed_machine_views( + resources, op_task_space, DeviceType::GPU); + }, + }; + + auto optimize_pcg = [&](ParallelComputationGraph const &pcg) + -> std::pair> { + PCGBinarySPDecomposition sp_decomp = + expect(get_pcg_balanced_binary_sp_decomposition(pcg), + "Failed to get SP decomposition of PCG"); + + MachineMappingProblemTree problem_tree = + get_machine_mapping_problem_tree(pcg, sp_decomp); + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers(get_all_leaf_paths(problem_tree)); + + MachineMappingResult mm_result = + get_optimal_machine_mapping(cached_subgraph_costs, + context, + problem_tree, + compute_slice_from_specification(resources), + constraints); + + return { + GraphOptimizeState{ + /*pcg=*/pcg, + /*runtime=*/get_runtime_cost(mm_result), + }, + get_machine_mapping_from_machine_mapping_result(sp_decomp, mm_result), + }; + }; + + GraphOptimizeState best_state = optimize_pcg(pcg).first; + candidates.push(best_state); + + for (int iteration = 0; + !candidates.empty() && iteration < search_config.budget; + ++iteration) { + GraphOptimizeState current_state = candidates.top(); + candidates.pop(); + + if (current_state < best_state) { + best_state = current_state; + } else if (current_state.runtime > + best_state.runtime * search_config.alpha) { + continue; + } + + for (ParallelComputationGraph const &new_pcg : + all_pcgs_obtained_by_applying_a_substitution(current_state.pcg, + substitutions)) { + + std::optional new_pcg_optimize_result = + optimize_pcg(new_pcg).first; + + if (new_pcg_optimize_result == std::nullopt) { + continue; + } + + GraphOptimizeState new_state = new_pcg_optimize_result.value(); + if (new_state.runtime <= best_state.runtime * search_config.alpha && + get_nodes(new_pcg.raw_graph).size() <= search_config.max_num_ops) { + candidates.push(new_state); + } + } + } + + std::optional best_mapping = + optimize_pcg(best_state.pcg).second; + + ASSERT(best_mapping != std::nullopt, "Failed to find any solutions"); + + return SearchResult{ + /*pcg=*/best_state.pcg, + /*machine_mapping=*/best_mapping.value(), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/allowed_machine_views.cc b/lib/compiler/test/src/compiler/allowed_machine_views.cc deleted file mode 100644 index e768d8540c..0000000000 --- a/lib/compiler/test/src/compiler/allowed_machine_views.cc +++ /dev/null @@ -1,107 +0,0 @@ -#include "compiler/allowed_machine_views.h" -#include "utils/containers/extend.h" -#include "utils/containers/range.h" -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/containers/zip.h" -#include "utils/fmt/unordered_set.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("get_allowed_machine_views") { - - SUBCASE("1 degree of parallelism") { - MachineComputeSpecification ms = MachineComputeSpecification{ - /*num_nodes=*/1_p, - /*num_cpus_per_node=*/5_p, - /*num_gpus_per_node=*/5_p, - }; - - OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{3_ge2}}}; - - std::unordered_set correct = { - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/1_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/2_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{1_p}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - MachineView{ - MachineSpaceCoordinate{ - /*node_idx=*/0_n, /*device_idx=*/0_n, DeviceType::GPU}, - {MachineViewDimension{stride_t{2_p}, - MachineSpecificationDimension::INTRA_NODE}}, - }, - }; - - std::unordered_set result = - get_allowed_machine_views(ms, task, DeviceType::GPU); - - CHECK(correct == result); - } - - SUBCASE("2 degrees of parallelism") { - - MachineComputeSpecification ms = MachineComputeSpecification{ - /*num_nodes=*/3_p, - /*num_cpus_per_node=*/3_p, - /*num_gpus_per_node=*/3_p, - }; - OperatorTaskSpace task = - OperatorTaskSpace{MinimalOrthotope{{2_ge2, 3_ge2}}}; - - auto make_2d_view = [&](nonnegative_int start_node_idx, - nonnegative_int start_device_idx, - positive_int stride1, - positive_int stride2, - MachineSpecificationDimension m1, - MachineSpecificationDimension m2) { - return MachineView{ - MachineSpaceCoordinate{ - start_node_idx, start_device_idx, DeviceType::GPU}, - {MachineViewDimension{stride_t{stride1}, m1}, - MachineViewDimension{stride_t{stride2}, m2}}, - }; - }; - - auto intra = MachineSpecificationDimension::INTRA_NODE; - auto inter = MachineSpecificationDimension::INTER_NODE; - std::unordered_set correct = { - make_2d_view( - 0_n, 0_n, /*stride1=*/1_p, /*stride2=*/1_p, inter, intra), - make_2d_view( - 1_n, 0_n, /*stride1=*/1_p, /*stride2=*/1_p, inter, intra), - make_2d_view( - 0_n, 0_n, /*stride1=*/2_p, /*stride2=*/1_p, inter, intra), - - make_2d_view( - 0_n, 0_n, /*stride1=*/1_p, /*stride2=*/1_p, intra, inter), - make_2d_view( - 0_n, 1_n, /*stride1=*/1_p, /*stride2=*/1_p, intra, inter), - make_2d_view( - 0_n, 0_n, /*stride1=*/2_p, /*stride2=*/1_p, intra, inter), - }; - - std::unordered_set result = - get_allowed_machine_views(ms, task, DeviceType::GPU); - - CHECK(correct == result); - } - } -} diff --git a/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc b/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc new file mode 100644 index 0000000000..2a0402a791 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/allowed_machine_views.cc @@ -0,0 +1,143 @@ +#include "compiler/machine_mapping/allowed_machine_views.h" +#include "doctest/doctest.h" +#include "utils/containers/extend.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_allowed_machine_views") { + + auto make_machine_view = + [&](nonnegative_int start_node_idx, + nonnegative_int start_device_idx, + std::optional stride_1 = std::nullopt, + std::optional m1 = std::nullopt, + std::optional stride_2 = std::nullopt, + std::optional m2 = std::nullopt) { + std::vector strides; + + if (stride_1.has_value()) { + ASSERT(m1.has_value()); + strides.push_back( + MachineViewDimension{stride_t{stride_1.value()}, m1.value()}); + } + + if (stride_2.has_value()) { + ASSERT(stride_1.has_value()); + ASSERT(m2.has_value()); + strides.push_back( + MachineViewDimension{stride_t{stride_2.value()}, m2.value()}); + } + + return MachineView{ + MachineSpaceCoordinate{ + start_node_idx, + start_device_idx, + DeviceType::GPU, + }, + strides, + }; + }; + + auto intra = MachineSpecificationDimension::INTRA_NODE; + auto inter = MachineSpecificationDimension::INTER_NODE; + + SUBCASE("1 degree of parallelism") { + MachineComputeResourceSlice ms = MachineComputeResourceSlice{ + /*num_nodes=*/1_p, + /*num_gpus_per_node=*/5_p, + }; + + OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{3_ge2}}}; + + std::unordered_set correct = { + make_machine_view(0_n, 0_n, 1_p, intra), + make_machine_view(0_n, 1_n, 1_p, intra), + make_machine_view(0_n, 2_n, 1_p, intra), + make_machine_view(0_n, 0_n, 2_p, intra), + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2 degrees of parallelism") { + + MachineComputeResourceSlice ms = MachineComputeResourceSlice{ + /*num_nodes=*/3_p, + /*num_gpus_per_node=*/3_p, + }; + OperatorTaskSpace task = + OperatorTaskSpace{MinimalOrthotope{{2_ge2, 3_ge2}}}; + + std::unordered_set correct = { + make_machine_view( + 0_n, 0_n, /*stride_1=*/1_p, inter, /*stride_2=*/1_p, intra), + make_machine_view( + 1_n, 0_n, /*stride_1=*/1_p, inter, /*stride_2=*/1_p, intra), + make_machine_view( + 0_n, 0_n, /*stride_1=*/2_p, inter, /*stride_2=*/1_p, intra), + + make_machine_view( + 0_n, 0_n, /*stride_1=*/1_p, intra, /*stride_2=*/1_p, inter), + make_machine_view( + 0_n, 1_n, /*stride_1=*/1_p, intra, /*stride_2=*/1_p, inter), + make_machine_view( + 0_n, 0_n, /*stride_1=*/2_p, intra, /*stride_2=*/1_p, inter), + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2D operator task space, dimensions (1,1)") { + MachineComputeResourceSlice full_machine_spec = + MachineComputeResourceSlice{ + /*num_nodes=*/2_p, + /*num_gpus_per_node=*/1_p, + }; + OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{}}}; + + std::unordered_set result = + get_allowed_machine_views(full_machine_spec, task, DeviceType::GPU); + + std::unordered_set correct = { + make_machine_view(0_n, 0_n), + make_machine_view(1_n, 0_n), + }; + + CHECK(correct == result); + } + + SUBCASE("2D operator task space, dimensions (2,1)") { + MachineComputeResourceSlice full_machine_spec = + MachineComputeResourceSlice{ + /*num_nodes=*/2_p, + /*num_gpus_per_node=*/2_p, + }; + OperatorTaskSpace task = OperatorTaskSpace{MinimalOrthotope{{2_ge2}}}; + + std::unordered_set result = + get_allowed_machine_views(full_machine_spec, task, DeviceType::GPU); + + std::unordered_set correct = { + make_machine_view(0_n, 0_n, /*stride_1=*/1_p, intra), + make_machine_view(0_n, 0_n, /*stride_1=*/1_p, inter), + make_machine_view(1_n, 0_n, /*stride_1=*/1_p, intra), + make_machine_view(0_n, 1_n, /*stride_1=*/1_p, inter)}; + + CHECK(correct == result); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index e70c0b75d2..392e16bec5 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -2,7 +2,7 @@ #include "compiler/cost_estimator/runtime_only_op_cost_estimate_key.dtg.h" #include "compiler/cost_estimator/runtime_only_op_cost_metrics.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" -#include "compiler/machine_mapping/machine_compute_resource_slice.h" +#include "compiler/machine_mapping/allowed_machine_views.h" #include "compiler/machine_mapping/machine_mapping_cache.h" #include "compiler/machine_mapping/machine_mapping_constraints.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" @@ -315,17 +315,28 @@ TEST_SUITE(FF_TEST_SUITE) { auto allowed_machine_views = [&](UnmappedRuntimeOnlyOpCostEstimateKey const &k, - MachineComputeResourceSlice const &resources) { - if (resources == four_nodes_resources) { - return std::unordered_set{mv_stride_1, mv_stride_2}; - } else if (resources == three_nodes_resources) { - return std::unordered_set{mv_stride_1, mv_stride_2}; - } else if (resources == two_nodes_resources) { - return std::unordered_set{mv_stride_1}; - } else { - return std::unordered_set{}; - } - }; + MachineComputeResourceSlice const &resources) + -> std::unordered_set { + std::unordered_set result; + + if (resources == four_nodes_resources) { + result = std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == three_nodes_resources) { + result = std::unordered_set{mv_stride_1, mv_stride_2}; + } else if (resources == two_nodes_resources) { + result = std::unordered_set{mv_stride_1}; + } else { + result = std::unordered_set{}; + } + + for (MachineView const &mv : result) { + OperatorTaskSpace op_task_space = + get_operator_task_space_for_runtime_only_op_cost_estimate_key(k); + ASSERT(is_valid_machine_view(mv, op_task_space, resources)); + } + + return result; + }; MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index e6af704cf1..b3901e08ca 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -222,8 +222,10 @@ TEST_SUITE(FF_TEST_SUITE) { TensorSlotName::LHS_INPUT, require_only_key(relu_1.outputs, TensorSlotName::OUTPUT), }, - {TensorSlotName::RHS_INPUT, - require_only_key(relu_2.outputs, TensorSlotName::OUTPUT)}, + { + TensorSlotName::RHS_INPUT, + require_only_key(relu_2.outputs, TensorSlotName::OUTPUT), + }, }, /*weights=*/{}); diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 26a643f327..ce01355f66 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -1,8 +1,13 @@ #include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_runtime_only_op_cost_estimate_key.dtg.h" +#include "compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h" #include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph_builder.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/containers/extend.h" #include "utils/containers/get_only.h" #include "utils/containers/require_only_key.h" #include "utils/full_binary_tree/binary_tree_path.h" @@ -368,4 +373,45 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("from pcg") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 32_p, + 64_p, + }, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/16_p, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/8_p, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + PCGBinarySPDecomposition sp_decomp = + expect(get_pcg_balanced_binary_sp_decomposition(pcg), + "Failed to get SP decomposition of PCG"); + + MachineMappingProblemTree problem_tree = + get_machine_mapping_problem_tree(pcg, sp_decomp); + } } diff --git a/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 8c1d3221d3..ced2634000 100644 --- a/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -6,7 +6,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_tree") { + TEST_CASE("pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree") { Node n1 = Node{1}; Node n2 = Node{2}; Node n3 = Node{3}; @@ -43,7 +43,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree input = make_binary_leaf(n1); PCGBinarySPDecomposition result = - pcg_binary_sp_decomposition_from_binary_sp_tree(input); + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree(input); PCGBinarySPDecomposition expected = make_pcg_leaf(n1); @@ -55,7 +55,7 @@ TEST_SUITE(FF_TEST_SUITE) { make_binary_series_split(make_binary_leaf(n1), make_binary_leaf(n2)); PCGBinarySPDecomposition result = - pcg_binary_sp_decomposition_from_binary_sp_tree(input); + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree(input); PCGBinarySPDecomposition expected = make_pcg_series_split(make_pcg_leaf(n1), make_pcg_leaf(n2)); @@ -68,7 +68,7 @@ TEST_SUITE(FF_TEST_SUITE) { make_binary_leaf(n1), make_binary_leaf(n2)); PCGBinarySPDecomposition result = - pcg_binary_sp_decomposition_from_binary_sp_tree(input); + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree(input); PCGBinarySPDecomposition expected = make_pcg_parallel_split(make_pcg_leaf(n1), make_pcg_leaf(n2)); @@ -82,7 +82,8 @@ TEST_SUITE(FF_TEST_SUITE) { make_binary_leaf(n3)); PCGBinarySPDecomposition pcg_tree = - pcg_binary_sp_decomposition_from_binary_sp_tree(original); + pcg_binary_sp_decomposition_from_binary_sp_decomposition_tree( + original); BinarySPDecompositionTree converted = binary_sp_tree_from_pcg_sp_tree(pcg_tree); diff --git a/lib/compiler/test/src/compiler/graph_optimize_state.cc b/lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc similarity index 70% rename from lib/compiler/test/src/compiler/graph_optimize_state.cc rename to lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc index d99f754609..bf7884df29 100644 --- a/lib/compiler/test/src/compiler/graph_optimize_state.cc +++ b/lib/compiler/test/src/compiler/unity_algorithm/graph_optimize_state.cc @@ -1,8 +1,9 @@ -#include "compiler/graph_optimize_state.h" +#include "compiler/unity_algorithm/graph_optimize_state.h" #include "compiler/machine_mapping/machine_mapping.dtg.h" #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_view.dtg.h" #include "compiler/machine_mapping/machine_view.h" +#include "doctest/doctest.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/doctest/check_without_stringify.h" @@ -52,41 +53,18 @@ TEST_SUITE(FF_TEST_SUITE) { return builder.pcg; }; - auto create_machine_mapping_for_pcg = - [](ParallelComputationGraph const &pcg) -> MachineMapping { - MachineSpaceCoordinate device = MachineSpaceCoordinate{ - /*node_idx=*/0_n, - /*device_idx=*/0_n, - /*device_type=*/DeviceType::GPU, - }; - - MachineView machine_view = make_single_device_machine_view(device); - - return MachineMapping{ - generate_map(get_parallel_layers(pcg), - [&](parallel_layer_guid_t) { return machine_view; }), - }; - }; - ParallelComputationGraph pcg1 = create_pcg(); - MachineMapping machine_mapping_1 = create_machine_mapping_for_pcg(pcg1); SUBCASE("returns true if the PCGs are isomorphic") { ParallelComputationGraph pcg2 = create_pcg(); - MachineMapping machine_mapping_2 = create_machine_mapping_for_pcg(pcg2); GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{ - mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1), - }, - 0, + pcg1, + 0_ms, }; - GraphOptimizeState state2 = GraphOptimizeState{ - GraphOptimizeResult{ - mapped_pcg_from_pcg_and_mapping(pcg2, machine_mapping_2), - }, - 0, + pcg2, + 0_ms, }; CHECK_WITHOUT_STRINGIFY(state1 == state2); @@ -109,24 +87,31 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelComputationGraph other_pcg = builder_.pcg; - MachineMapping other_machine_mapping = - create_machine_mapping_for_pcg(other_pcg); - GraphOptimizeState state1 = GraphOptimizeState{ - GraphOptimizeResult{ - mapped_pcg_from_pcg_and_mapping(pcg1, machine_mapping_1), - }, - 0, + pcg1, + 0_ms, }; GraphOptimizeState state_ = GraphOptimizeState{ - GraphOptimizeResult{ - mapped_pcg_from_pcg_and_mapping(other_pcg, other_machine_mapping), - }, - 0, + other_pcg, + 0_ms, }; CHECK_FALSE_WITHOUT_STRINGIFY(state1 == state_); } } + + TEST_CASE("GraphOptimizeState::operator<") { + ParallelComputationGraph pcg1 = empty_parallel_computation_graph(); + ParallelComputationGraph pcg2 = empty_parallel_computation_graph(); + GraphOptimizeState state1 = GraphOptimizeState{ + pcg1, + 1.0_ms, + }; + GraphOptimizeState state2 = GraphOptimizeState{ + pcg2, + 2.0_ms, + }; + CHECK(state1 < state2); + } } diff --git a/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc new file mode 100644 index 0000000000..f5278612aa --- /dev/null +++ b/lib/compiler/test/src/compiler/unity_algorithm/unity_algorithm.cc @@ -0,0 +1,91 @@ +#include "compiler/unity_algorithm/unity_algorithm.h" +#include "compiler/cost_estimator/runtime_only_cost_estimator_from_cost_estimator.h" +#include "doctest/doctest.h" +#include "internal/cost_estimator_for_test.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_type.dtg.h" +#include "op-attrs/shard_parallel_dim.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("graph_optimize") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{ + 32_p, + 64_p, + }, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/16_p, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/8_p, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + RuntimeOnlyCostEstimator cost_estimator = + runtime_only_cost_estimator_from_cost_estimator( + make_fake_cost_estimator( + [](OpCostEstimateKey const &k) -> OpCostMetrics { + return OpCostMetrics{ + /*forward_runtime=*/1.0_ms, + /*backward_runtime=*/2.0_ms, + /*memory=*/1_bytes, + }; + }, + [](TensorSetMovement const &) -> milliseconds_t { + return 1.0_ms; + })); + + MachineComputeSpecification full_machine_spec = MachineComputeSpecification{ + /*num_nodes=*/2_p, + /*num_cpus_per_node=*/1_p, + /*num_gpus_per_node=*/1_p, + }; + + SUBCASE("do not apply substitution") { + UnitySearchConfig search_config = UnitySearchConfig{ + /*alpha=*/1.0, + /*budget=*/0, + /*max_num_ops=*/100, + }; + SearchResult result = + graph_optimize(pcg, cost_estimator, full_machine_spec, search_config); + CHECK(pcgs_are_isomorphic(pcg, result.pcg)); + } + + SUBCASE("apply substitution") { + UnitySearchConfig search_config = UnitySearchConfig{ + /*alpha=*/1.0, + /*budget=*/1, + /*max_num_ops=*/100, + }; + SearchResult result = + graph_optimize(pcg, cost_estimator, full_machine_spec, search_config); + } + } +} diff --git a/lib/compiler/test/src/internal/cost_estimator_for_test.h b/lib/compiler/test/src/internal/cost_estimator_for_test.h index 6a0094839c..12708210f3 100644 --- a/lib/compiler/test/src/internal/cost_estimator_for_test.h +++ b/lib/compiler/test/src/internal/cost_estimator_for_test.h @@ -3,6 +3,7 @@ #include "compiler/cost_estimator/cost_estimator.h" #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/runtime_only_cost_estimator.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" #include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.dtg.toml b/lib/pcg/include/pcg/machine_compute_resource_slice.dtg.toml similarity index 100% rename from lib/compiler/include/compiler/machine_mapping/machine_compute_resource_slice.dtg.toml rename to lib/pcg/include/pcg/machine_compute_resource_slice.dtg.toml diff --git a/lib/pcg/include/pcg/machine_compute_resource_slice.h b/lib/pcg/include/pcg/machine_compute_resource_slice.h new file mode 100644 index 0000000000..f3dee01132 --- /dev/null +++ b/lib/pcg/include/pcg/machine_compute_resource_slice.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_COMPUTE_RESOURCE_SLICE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_COMPUTE_RESOURCE_SLICE_H + +#include "pcg/machine_compute_resource_slice.dtg.h" +#include "pcg/machine_compute_specification.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" + +namespace FlexFlow { + +MachineComputeResourceSlice + compute_slice_from_specification(MachineComputeSpecification const &); + +positive_int + get_total_num_devices_in_slice(MachineComputeResourceSlice const &); + +bool is_valid_machine_space_coordinate_in_slice( + MachineComputeResourceSlice const &slice, + MachineSpaceCoordinate const &coord); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/machine_compute_resource_slice.cc b/lib/pcg/src/pcg/machine_compute_resource_slice.cc new file mode 100644 index 0000000000..2cb2386926 --- /dev/null +++ b/lib/pcg/src/pcg/machine_compute_resource_slice.cc @@ -0,0 +1,29 @@ +#include "pcg/machine_compute_resource_slice.h" +#include + +namespace FlexFlow { + +MachineComputeResourceSlice + compute_slice_from_specification(MachineComputeSpecification const &spec) { + + return MachineComputeResourceSlice{ + /*num_nodes=*/spec.num_nodes, + /*num_gpus_per_node=*/spec.num_gpus_per_node, + }; +} + +positive_int + get_total_num_devices_in_slice(MachineComputeResourceSlice const &slice) { + return slice.num_nodes * slice.num_gpus_per_node; +} + +bool is_valid_machine_space_coordinate_in_slice( + MachineComputeResourceSlice const &slice, + MachineSpaceCoordinate const &coord) { + ASSERT(coord.device_type == DeviceType::GPU); + + return (coord.node_idx < slice.num_nodes) && + (coord.device_idx < slice.num_gpus_per_node); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml index f440a9f90b..b9043284d0 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.toml @@ -36,9 +36,6 @@ src_includes = [ [[values]] type = "::FlexFlow::nonnegative_int" -[[values]] -type = "::FlexFlow::positive_int" - [[values]] type = "bool" diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 19349823b7..9183278fe1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -62,7 +62,8 @@ std::optional get_attribute(CombineAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DIM: return OperatorAttributeValue{p.combine_dim}; case OperatorAttributeKey::PARALLEL_DIM: - return OperatorAttributeValue{p.combine_degree}; + return OperatorAttributeValue{ + p.combine_degree.nonnegative_int_from_positive_int()}; default: return std::nullopt; } @@ -84,23 +85,29 @@ std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: - return OperatorAttributeValue{p.out_channels}; + return OperatorAttributeValue{ + p.out_channels.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::KERNEL_H: - return OperatorAttributeValue{p.kernel_h}; + return OperatorAttributeValue{ + p.kernel_h.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::KERNEL_W: - return OperatorAttributeValue{p.kernel_w}; + return OperatorAttributeValue{ + p.kernel_w.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::STRIDE_H: - return OperatorAttributeValue{p.stride_h}; + return OperatorAttributeValue{ + p.stride_h.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::STRIDE_W: - return OperatorAttributeValue{p.stride_w}; + return OperatorAttributeValue{ + p.stride_w.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::PADDING_H: return OperatorAttributeValue{p.padding_h}; case OperatorAttributeKey::PADDING_W: return OperatorAttributeValue{p.padding_w}; case OperatorAttributeKey::GROUPS: - return OperatorAttributeValue{p.groups}; + return OperatorAttributeValue{ + p.groups.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::ACTIVATION: return OperatorAttributeValue{p.activation}; case OperatorAttributeKey::USE_BIAS: @@ -158,9 +165,11 @@ std::optional get_attribute(EmbeddingAttrs const &p, case OperatorAttributeKey::AGGR: return OperatorAttributeValue{p.aggr}; case OperatorAttributeKey::NUM_ENTRIES: - return OperatorAttributeValue{p.num_entries}; + return OperatorAttributeValue{ + p.num_entries.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::OUT_CHANNELS: - return OperatorAttributeValue{p.out_channels}; + return OperatorAttributeValue{ + p.out_channels.nonnegative_int_from_positive_int()}; default: return std::nullopt; } @@ -218,7 +227,8 @@ std::optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::OUT_CHANNELS: - return OperatorAttributeValue{p.out_channels}; + return OperatorAttributeValue{ + p.out_channels.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::USE_BIAS: return OperatorAttributeValue{p.use_bias}; case OperatorAttributeKey::DATA_TYPE: @@ -238,13 +248,15 @@ std::optional case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::EMBED_DIM: - return OperatorAttributeValue{p.embed_dim}; + return OperatorAttributeValue{ + p.embed_dim.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::KDIM: - return OperatorAttributeValue{p.kdim}; + return OperatorAttributeValue{p.kdim.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::VDIM: - return OperatorAttributeValue{p.vdim}; + return OperatorAttributeValue{p.vdim.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::NUM_HEADS: - return OperatorAttributeValue{p.num_heads}; + return OperatorAttributeValue{ + p.num_heads.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::BIAS: return OperatorAttributeValue{p.bias}; case OperatorAttributeKey::ADD_BIAS_KV: @@ -274,13 +286,17 @@ std::optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::KERNEL_H: - return OperatorAttributeValue{p.kernel_h}; + return OperatorAttributeValue{ + p.kernel_h.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::KERNEL_W: - return OperatorAttributeValue{p.kernel_w}; + return OperatorAttributeValue{ + p.kernel_w.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::STRIDE_H: - return OperatorAttributeValue{p.stride_h}; + return OperatorAttributeValue{ + p.stride_h.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::STRIDE_W: - return OperatorAttributeValue{p.stride_w}; + return OperatorAttributeValue{ + p.stride_w.nonnegative_int_from_positive_int()}; case OperatorAttributeKey::PADDING_H: return OperatorAttributeValue{p.padding_h}; case OperatorAttributeKey::PADDING_W: @@ -310,7 +326,8 @@ std::optional get_attribute(ReductionAttrs const &p, case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::PARALLEL_OP_DEGREE: - return OperatorAttributeValue{p.reduction_degree}; + return OperatorAttributeValue{ + p.reduction_degree.nonnegative_int_from_positive_int()}; default: return std::nullopt; } @@ -324,7 +341,8 @@ std::optional get_attribute(RepartitionAttrs const &p, case OperatorAttributeKey::PARALLEL_OP_DIM: return OperatorAttributeValue{p.repartition_dim}; case OperatorAttributeKey::PARALLEL_OP_DEGREE: - return OperatorAttributeValue{p.repartition_degree}; + return OperatorAttributeValue{ + p.repartition_degree.nonnegative_int_from_positive_int()}; default: return std::nullopt; } @@ -336,7 +354,8 @@ std::optional get_attribute(ReplicateAttrs const &p, case OperatorAttributeKey::OP_TYPE: return OperatorAttributeValue{get_op_type(p)}; case OperatorAttributeKey::PARALLEL_OP_DEGREE: - return OperatorAttributeValue{p.replicate_degree}; + return OperatorAttributeValue{ + p.replicate_degree.nonnegative_int_from_positive_int()}; default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc index a45af1e7d4..ac5687feb3 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc @@ -25,7 +25,7 @@ OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key, return OperatorAttributeConstraint{ ConstraintType::DIVISIBLE_BY, OperatorAttributeExpr{key}, - OperatorAttributeValue{denominator}, + OperatorAttributeValue{denominator.nonnegative_int_from_positive_int()}, }; } diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 96c33989fe..15d3a68ec7 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -17,6 +17,15 @@ bool operator_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + ASSERT(expr_val.value().has() && + constraint.attribute_value.has(), + "DIVISIBLE_BY constraint requires nonnegative_int values"); + + return expr_val.value().get() % + constraint.attribute_value.get() == + 0; + } default: PANIC("Unknown constraint type", constraint.constraint_type); } diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc index a765556c63..ce5094190c 100644 --- a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -13,13 +13,19 @@ struct Accessor { std::unordered_map const &m; template - T const &get(OperatorAttributeKey k) const { + T get(OperatorAttributeKey k) const { if (contains_key(this->m, k)) { return this->m.at(k).get(); } else { PANIC("Could not find key in attrs map", k, this->m); } } + + positive_int get_positive_int(OperatorAttributeKey k) const { + return positive_int{ + this->get(k), + }; + } }; PCGOperatorAttrs materialize_operator_from_attrs_map( @@ -33,11 +39,11 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( switch (op_type) { case OperatorType::MULTIHEAD_ATTENTION: return PCGOperatorAttrs{MultiHeadAttentionAttrs{ - /*embed_dim=*/acc.get(OperatorAttributeKey::EMBED_DIM), + /*embed_dim=*/acc.get_positive_int(OperatorAttributeKey::EMBED_DIM), /*num_heads=*/ - acc.get(OperatorAttributeKey::NUM_HEADS), - /*kdim=*/acc.get(OperatorAttributeKey::KDIM), - /*vdim=*/acc.get(OperatorAttributeKey::VDIM), + acc.get_positive_int(OperatorAttributeKey::NUM_HEADS), + /*kdim=*/acc.get_positive_int(OperatorAttributeKey::KDIM), + /*vdim=*/acc.get_positive_int(OperatorAttributeKey::VDIM), /*dropout=*/acc.get(OperatorAttributeKey::DROPOUT), /*bias=*/acc.get(OperatorAttributeKey::BIAS), /*add_bias_kv=*/acc.get(OperatorAttributeKey::ADD_BIAS_KV), @@ -45,10 +51,10 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( }}; case OperatorType::POOL2D: return PCGOperatorAttrs{Pool2DAttrs{ - /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), - /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), - /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), - /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*kernel_h=*/acc.get_positive_int(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get_positive_int(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get_positive_int(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get_positive_int(OperatorAttributeKey::STRIDE_W), /*padding_h=*/ acc.get(OperatorAttributeKey::PADDING_H), /*padding_w=*/ @@ -64,7 +70,7 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( case OperatorType::DROPOUT: case OperatorType::LINEAR: return PCGOperatorAttrs{LinearAttrs{ - /*out_channels=*/acc.get( + /*out_channels=*/acc.get_positive_int( OperatorAttributeKey::OUT_CHANNELS), /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), /*data_type=*/acc.get(OperatorAttributeKey::DATA_TYPE), @@ -76,17 +82,17 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( }}; case OperatorType::CONV2D: return PCGOperatorAttrs{Conv2DAttrs{ - /*out_channels=*/acc.get( + /*out_channels=*/acc.get_positive_int( OperatorAttributeKey::OUT_CHANNELS), - /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), - /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), - /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), - /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*kernel_h=*/acc.get_positive_int(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get_positive_int(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get_positive_int(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get_positive_int(OperatorAttributeKey::STRIDE_W), /*padding_h=*/ acc.get(OperatorAttributeKey::PADDING_H), /*padding_w=*/ acc.get(OperatorAttributeKey::PADDING_W), - /*groups=*/acc.get(OperatorAttributeKey::GROUPS), + /*groups=*/acc.get_positive_int(OperatorAttributeKey::GROUPS), /*activation=*/ acc.get>(OperatorAttributeKey::ACTIVATION), /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), @@ -109,7 +115,7 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( }}; case OperatorType::REPLICATE: return PCGOperatorAttrs{ReplicateAttrs{ - /*replicate_degree=*/acc.get( + /*replicate_degree=*/acc.get_positive_int( OperatorAttributeKey::PARALLEL_DEGREE), }}; case OperatorType::REPARTITION: @@ -117,17 +123,17 @@ PCGOperatorAttrs materialize_operator_from_attrs_map( /*repartition_dim=*/acc.get( OperatorAttributeKey::PARALLEL_DIM), /*repartition_Degree=*/ - acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + acc.get_positive_int(OperatorAttributeKey::PARALLEL_DEGREE), }}; case OperatorType::COMBINE: return PCGOperatorAttrs{CombineAttrs{ /*combine_dim=*/acc.get(OperatorAttributeKey::PARALLEL_DIM), /*combine_degree=*/ - acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + acc.get_positive_int(OperatorAttributeKey::PARALLEL_DEGREE), }}; case OperatorType::REDUCTION: return PCGOperatorAttrs{ReductionAttrs{ - acc.get(OperatorAttributeKey::PARALLEL_DEGREE), + acc.get_positive_int(OperatorAttributeKey::PARALLEL_DEGREE), }}; case OperatorType::BATCHMATMUL: case OperatorType::SCALAR_MULTIPLY: diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index e2f2e211fa..d3ca15a3fa 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -13,6 +13,15 @@ bool parallel_tensor_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + ASSERT(expr_val.has() && + constraint.attribute_value.has(), + "DIVISIBLE_BY constraint requires nonnegative_int values"); + + return expr_val.get() % + constraint.attribute_value.get() == + 0; + } default: PANIC("Unknown constraint type", constraint.constraint_type); } diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index f1d808c5fd..f1dabd9554 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -104,7 +104,8 @@ static OutputGraphExprValue { set_op_type_attr(op_type), set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), + OperatorAttributeValue{ + degree.nonnegative_int_from_positive_int()}), }}; return insert_single_output_op( @@ -139,8 +140,10 @@ static OutputGraphExprValue std::nullopt, { set_op_type_attr(op_type), - set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE, - OperatorAttributeValue{degree}), + set_attr_to_constant( + OperatorAttributeKey::PARALLEL_DEGREE, + OperatorAttributeValue{ + degree.nonnegative_int_from_positive_int()}), set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM, OperatorAttributeValue{dim}), }}; diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index 7d207d9c90..e9087b5718 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -64,7 +64,7 @@ static std::optional graph_node_inputs = get_incoming_open_kwarg_dataflow_values_for_node(graph, graph_node); - if (graph_node_inputs.size() != pattern_node_inputs.size()) { + if (keys(graph_node_inputs) != keys(pattern_node_inputs)) { return std::nullopt; } diff --git a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h index 6201845a64..d8160e2f6a 100644 --- a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h +++ b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h @@ -22,7 +22,7 @@ std::unordered_multiset> std::unordered_multiset> result; if (container.empty() || n == 0) { - return result; + return {{}}; } std::vector elements(std::begin(container), std::end(container)); diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 1690955286..084e55ea8a 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -7,8 +7,10 @@ namespace fmt { template -struct formatter, Char> - /* std::enable_if_t>::value>> */ +struct formatter< + ::std::variant, + Char, + std::enable_if_t>::value>> : formatter<::std::string> { template auto format(std::variant const &m, FormatContext &ctx) const @@ -25,8 +27,8 @@ struct formatter, Char> namespace FlexFlow { -template -std::ostream &operator<<(std::ostream &s, std::variant const &v) { +template +std::ostream &operator<<(std::ostream &s, std::variant const &v) { return s << fmt::to_string(v); } diff --git a/lib/utils/include/utils/full_binary_tree/as_dot.h b/lib/utils/include/utils/full_binary_tree/as_dot.h new file mode 100644 index 0000000000..e104d05e06 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/as_dot.h @@ -0,0 +1,81 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_AS_DOT_H + +#include "utils/containers/get_only.h" +#include "utils/dot_file.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include +#include +#include + +namespace FlexFlow { + +template +LabelledDataflowGraph as_labelled_dataflow_graph( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + std::function const &get_parent_label, + std::function const &get_leaf_label) { + auto g = LabelledDataflowGraph::template create< + UnorderedSetLabelledOpenDataflowGraph>(); + + FullBinaryTreeVisitor visitor = + FullBinaryTreeVisitor{ + [&](Parent const &parent) -> DataflowOutput { + DataflowOutput left_child_output = + visit(impl.get_left_child(parent), impl, visitor); + DataflowOutput right_child_output = + visit(impl.get_right_child(parent), impl, visitor); + NodeLabel parent_label = get_parent_label(parent); + NodeAddedResult parent_added = + g.add_node(parent_label, + {left_child_output, right_child_output}, + {std::monostate{}}); + return get_only(parent_added.outputs); + }, + [&](Leaf const &leaf) -> DataflowOutput { + NodeLabel leaf_label = get_leaf_label(leaf); + NodeAddedResult leaf_added = + g.add_node(leaf_label, {}, {std::monostate{}}); + return get_only(leaf_added.outputs); + }, + }; + + visit(tree, impl, visitor); + + return g; +} + +template +std::string + as_dot(Tree const &tree, + FullBinaryTreeImplementation const &impl, + std::function const &get_parent_label, + std::function const &get_leaf_label) { + + LabelledDataflowGraphView g = + as_labelled_dataflow_graph(tree, impl, get_parent_label, get_leaf_label); + + std::function get_node_label = + [](std::string const &s) { return s; }; + std::function get_input_label = + [](std::monostate const &) { return ""; }; + + return as_dot( + view_as_labelled_open_dataflow_graph(g), get_node_label, get_input_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index e49bad7cbf..34b77f4d37 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -1,12 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include "utils/nonnegative_int/nonnegative_int.h" +#include #include namespace FlexFlow { @@ -26,6 +28,10 @@ SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); nonnegative_int get_tree_height(BinarySPDecompositionTree const &); +std::optional + binary_sp_decomposition_tree_get_subtree_at_path( + BinarySPDecompositionTree const &, BinaryTreePath const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h new file mode 100644 index 0000000000..9c999d8f6e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_AS_DOT_H + +#include "utils/full_binary_tree/as_dot.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +std::string as_dot( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + std::function const &get_series_label, + std::function const &get_parallel_label, + std::function const &get_leaf_label) { + FullBinaryTreeImplementation, Leaf> + full_binary_tree_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + std::function const &)> + get_parent_label = + [&](std::variant const &parent) -> std::string { + return std::visit(overload{ + [&](Series const &series) -> std::string { + return get_series_label(series); + }, + [&](Parallel const ¶llel) -> std::string { + return get_parallel_label(parallel); + }, + }, + parent); + }; + + return as_dot(tree, full_binary_tree_impl, get_parent_label, get_leaf_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index cec3734907..5fb33e79c9 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -32,6 +32,12 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } +template +T expect(std::optional const &x, std::string const &err) { + ASSERT(x.has_value(), err); + return x.value(); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/units/milliseconds_t.h b/lib/utils/include/utils/units/milliseconds_t.h index ed3d5776a3..9eb6e3abf8 100644 --- a/lib/utils/include/utils/units/milliseconds_t.h +++ b/lib/utils/include/utils/units/milliseconds_t.h @@ -22,6 +22,9 @@ struct milliseconds_t { milliseconds_t operator+(milliseconds_t const &other) const; + milliseconds_t operator*(float rhs) const; + friend milliseconds_t operator*(float lhs, milliseconds_t const &rhs); + float unwrap_milliseconds() const; private: diff --git a/lib/utils/src/utils/fmt/variant.cc b/lib/utils/src/utils/fmt/variant.cc index e2d387eedb..f3b9094fbe 100644 --- a/lib/utils/src/utils/fmt/variant.cc +++ b/lib/utils/src/utils/fmt/variant.cc @@ -1 +1,10 @@ #include "utils/fmt/variant.h" + +namespace FlexFlow { + +template std::ostream &operator<<(std::ostream &, std::variant const &); + +template std::ostream &operator<<(std::ostream &, + std::variant const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/as_dot.cc b/lib/utils/src/utils/full_binary_tree/as_dot.cc new file mode 100644 index 0000000000..12a1ab5533 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/as_dot.cc @@ -0,0 +1,16 @@ +#include "utils/full_binary_tree/as_dot.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::string + as_dot(Tree const &, + FullBinaryTreeImplementation const &, + std::function const &, + std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 75a8bcf1ad..c11968e8b9 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,8 +1,10 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_tree_height.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation + binary_sp_decomposition_tree_get_subtree_at_path( + BinarySPDecompositionTree const &tree, BinaryTreePath const &path) { + return get_subtree_at_path(tree, generic_impl_for_binary_sp_tree(), path); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc new file mode 100644 index 0000000000..f557515c83 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.cc @@ -0,0 +1,21 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/as_dot.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::string + as_dot(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + std::function const &, + std::function const &, + std::function const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/units/milliseconds_t.cc b/lib/utils/src/utils/units/milliseconds_t.cc index fb0dd01d64..fc0e977d6b 100644 --- a/lib/utils/src/utils/units/milliseconds_t.cc +++ b/lib/utils/src/utils/units/milliseconds_t.cc @@ -38,6 +38,14 @@ milliseconds_t milliseconds_t::operator+(milliseconds_t const &other) const { }; } +milliseconds_t milliseconds_t::operator*(float rhs) const { + return milliseconds_t{this->value * rhs}; +} + +milliseconds_t operator*(float lhs, milliseconds_t const &rhs) { + return milliseconds_t{lhs * rhs.value}; +} + float milliseconds_t::unwrap_milliseconds() const { return this->value; } diff --git a/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc index 9fb4048691..3ec51ec2f6 100644 --- a/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc +++ b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc @@ -71,5 +71,15 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + + SUBCASE("n == 0") { + std::vector input = {1, 2, 3}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 0_n); + std::unordered_multiset> correct = {{}}; + + CHECK(result == correct); + } } }