From 2be849a61a3ceaab3da67f2fb4f9dd14f79fedff Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 21:18:53 -0700 Subject: [PATCH 01/16] Graph Testing initial cleanup --- lib/utils/include/utils/graph/README.md | 168 ++++------- .../graph/dataflow_graph/dataflow_graph.h | 2 - .../graph/digraph/algorithms/get_dominators.h | 15 + .../include/utils/graph/digraph/digraph.h | 2 - .../utils/graph/digraph/digraph_view.h | 2 - .../instances/hashmap_undirected_graph.h | 0 .../utils/graph/multidigraph/multidigraph.h | 2 - lib/utils/include/utils/graph/node/graph.h | 2 - .../include/utils/graph/node/graph_view.h | 2 - .../include/utils/graph/node/node.struct.toml | 1 + .../utils/graph/undirected/undirected_edge.h | 27 +- .../undirected/undirected_edge.struct.toml | 18 ++ .../utils/graph/undirected/undirected_graph.h | 2 - .../graph/undirected/undirected_graph_view.h | 2 - lib/utils/src/utils/graph/algorithms.cc | 5 +- .../instances/hashmap_undirected_graph.cc | 26 +- .../unordered_set_undirected_graph.cc | 4 +- .../algorithms/get_neighboring_nodes.cc | 2 +- .../utils/graph/undirected/undirected_edge.cc | 33 +-- .../graph/undirected/undirected_edge_query.cc | 3 +- lib/utils/src/utils/graph/views/views.cc | 35 ++- .../graph/digraph/algorithms/algorithms.cc | 106 +++++++ .../utils/graph/digraph/algorithms/digraph.cc | 85 ++++++ .../digraph/algorithms/directed_edge_query.cc | 70 +++++ .../digraph/algorithms/get_dominators.cc | 68 +++++ .../algorithms/get_topological_ordering.cc | 36 +++ .../graph/digraph/algorithms/traversal.cc | 112 ++++++++ .../algorithms/get_incoming_edges.cc | 36 +++ .../algorithms/get_outgoing_edges.cc | 40 +++ .../algorithms/get_connected_components.cc | 26 ++ .../src/utils/graph/undirected/undirected.cc | 75 +++++ lib/utils/test/src/utils/graph/views/views.cc | 262 ++++++++++++++++++ 32 files changed, 1049 insertions(+), 220 deletions(-) rename lib/utils/{src => include}/utils/graph/instances/hashmap_undirected_graph.h (100%) create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc create mode 100644 lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc create mode 100644 lib/utils/test/src/utils/graph/undirected/undirected.cc create mode 100644 lib/utils/test/src/utils/graph/views/views.cc diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 25b0103f9c..f3d31e7bc8 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,10 +15,16 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected -- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. - +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. +- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. +- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: + - The edges entering, exiting a given nodes now have a well-defined order. + - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. + - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. + +Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. + Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -37,7 +43,7 @@ flowchart TD D --- B ``` -Example of `DirectedGraph`: +Example of `DiGraph`: ```mermaid flowchart TD A(" ") @@ -58,98 +64,34 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A("A") - B("B") - C("C") - D("D") - E("E") - F("F") - - A -->|"(■, ★)"| B - B -->|"(●, ★)"| C - C -->|"(♥, ▲)"| D - D -->|"(●, ■)"| A - B -->|"(★, ●)"| E - E -->|"(■, ■)"| B - D -->|"(●, ●)"| A - A -->|"(●, ■)"| E - D -->|"(■, ●)"| D - E -->|"(■, ■)"| E -``` -or visualized a different way, -```mermaid -flowchart TD - Acirc("●") - Asqua("■") - Bcirc("●") - Bstar("★") - Bsqua("■") - Chear("♥") - Cstar("★") - Dsqua("■") - Dcirc("●") - Dtria("▲") - Ecirc("●") - Esqua("■") - Fplaceholder(" ") - - style Fplaceholder fill:#0000,stroke:#0000 - - subgraph "A" - Acirc - Asqua - end - - subgraph "B" - Bsqua - Bcirc - Bstar - end - - subgraph "C" - Chear - Cstar - end - - subgraph "D" - Dsqua - Dcirc - Dtria - end - - subgraph "E" - Ecirc - Esqua - end - - subgraph "F" - Fplaceholder - end - - Asqua --> Bstar - Bcirc --> Cstar - Chear --> Dtria - Dcirc --> Asqua - Bstar --> Ecirc - Esqua --> Bsqua - Dcirc --> Acirc - Acirc --> Esqua - Dsqua --> Dcirc - Esqua --> Esqua + A + B + C + D + E + F + + A --> B + B --> C + C --> D + D --> A + B --> E + E --> B + D --> A + A --> E + D --> D + E --> E ``` -Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. +Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. -`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. -Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. -`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -158,11 +100,13 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. +They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). +You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -171,40 +115,35 @@ This may seem wasteful (oftentimes graphs are large objects that are passed arou At this point, however, we still have not discussed how to create a graph. The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: +For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open, Upward, Downward +### Open DataFlow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). - -![Open graphs inheritance diagram](docs/open.svg) +This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. -Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) - - -### Labelled Graphs +### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). -While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. -Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. +Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. +- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). + +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) +Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram](docs/labelled.svg) - - ## Internals @@ -236,12 +175,7 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. -To create graphs within the library, we thus use the following syntax: -`BaseGraph obj = BaseGraph::create();` - -Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` - ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). -In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 6a1898dd13..d73175c7dd 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 1e4d09d3ae..96e8864bc1 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,7 +5,22 @@ namespace FlexFlow { +/** + * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) + * + * @note By definition, the root node dominates every node and every node + * dominates itself. + * + */ std::unordered_set get_dominators(DiGraphView const &, Node const &); + +/** + * @brief Returns the intersection of the dominators of the given set of nodes. + * @note This is conceptually equivalent to merging the given set of nodes and + * then finding the set of dominators of the new merged node (where merged means + * that all edges belonging to the set of nodes now pass through a single + * unified node). + */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index e36b90d4bf..3d320b1c06 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 54f84f8d2c..0380751c55 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 69080b9348..692ee33783 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index bddefdacb3..1d94d1a65e 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,8 +31,6 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index fce3177ef1..8d904e05f2 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,8 +22,6 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index d5c22e5d3d..46e0255de3 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -6,6 +6,7 @@ features = [ "hash", "fmt", "json", + "rapidcheck", ] includes = [ diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 33d50192cb..d051413faa 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,33 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); +#include "utils/graph/undirected/undirected_edge.dtg.h" - bool operator==(UndirectedEdge const &) const; - bool operator!=(UndirectedEdge const &) const; - bool operator<(UndirectedEdge const &) const; - -public: - Node smaller; - Node bigger; -}; +namespace FlexFlow { -bool is_connected_to(UndirectedEdge const &, Node const &); +bool is_connected_to(UndirectedEdge const &e, Node const &n); } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::UndirectedEdge> { - size_t operator()(::FlexFlow::UndirectedEdge const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml new file mode 100644 index 0000000000..f5258b0bfd --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck" +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 69975991ce..09b6495699 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; - friend struct GraphInternal; - private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index c2df96abc0..90dd5dd5d8 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; - friend struct GraphInternal; - private: IUndirectedGraphView const &get_ptr() const; }; diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 6ed41daf43..79c4fc9964 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,7 +184,8 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = + UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; return contains(g.query_edges(q), e); } @@ -212,7 +213,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.smaller, e.bigger}; + return {e.endpoints.min(), e.endpoints.max()}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 61c4f80763..df84683a6b 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -22,23 +22,25 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.bigger)); + if (!contains_key(this->adjacency, e.endpoints.max())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.max())); } - if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.smaller)); + if (!contains_key(this->adjacency, e.endpoints.min())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.min())); } - this->adjacency.at(e.bigger).insert(e.smaller); - this->adjacency.at(e.smaller).insert(e.bigger); + this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); + this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.bigger); - m.erase(e.smaller); - m.erase(e.bigger); + std::unordered_set &m = this->adjacency.at(e.endpoints.max()); + m.erase(e.endpoints.min()); + m.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( @@ -46,7 +48,7 @@ std::unordered_set HashmapUndirectedGraph::query_edges( std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { for (auto const &dst : src_kv.second) { - result.insert({src_kv.first, dst}); + result.insert(UndirectedEdge{{src_kv.first, dst}}); } } return result; diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc index 6f6722f635..cb44f4636d 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -27,8 +27,8 @@ void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { } void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { - assert(contains(this->nodes, e.bigger)); - assert(contains(this->nodes, e.smaller)); + assert(contains(this->nodes, e.endpoints.min())); + assert(contains(this->nodes, e.endpoints.max())); this->edges.insert(e); } diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc index 3c05b9d5d5..726fda8af7 100644 --- a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -10,7 +10,7 @@ std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, std::unordered_set result = set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { - return std::unordered_set{e.bigger, e.smaller}; + return std::unordered_set{e.endpoints.max(), e.endpoints.max()}; })); result.erase(n); return result; diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 0a575e115c..4cfc6aaaa8 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,40 +1,11 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" +#include namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -static std::tuple tie(UndirectedEdge const &e) { - return std::tie(e.smaller, e.bigger); -} - -bool UndirectedEdge::operator==(UndirectedEdge const &other) const { - return tie(*this) == tie(other); -} - -bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { - return tie(*this) != tie(other); -} - -bool UndirectedEdge::operator<(UndirectedEdge const &other) const { - return tie(*this) < tie(other); -} - bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; + return e.endpoints.min() == n || e.endpoints.max() == n; } } // namespace FlexFlow - -namespace std { - -using namespace FlexFlow; - -size_t hash::operator()(UndirectedEdge const &e) const { - std::tuple members = ::FlexFlow::tie(e); - return std::hash{}(members); -} - -} // namespace std diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 3cccf1c6eb..e9e948aa40 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -7,7 +7,8 @@ UndirectedEdgeQuery undirected_edge_query_all() { } bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { - return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); + return includes(q.nodes, e.endpoints.max()) && + includes(q.nodes, e.endpoints.min()); } UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 9b5353de9f..c29d478f1e 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,4 +1,5 @@ #include "utils/graph/views/views.h" +#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" #include "utils/disjoint_set.h" @@ -7,8 +8,8 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" +#include "utils/graph/query_set.h" #include "utils/graph/undirected/undirected_edge_query.h" - namespace FlexFlow { UndirectedSubgraphView::UndirectedSubgraphView( @@ -78,9 +79,13 @@ JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { std::unordered_set JoinedNodeView::query_nodes(NodeQuery const &query) const { - // TODO @lockshaw this is going to be reimplemented in 984, so don't bother - // fixing it for now - NOT_IMPLEMENTED(); + std::unordered_set nodes = right_entries(this->mapping); + if (query == node_query_all()) { + return nodes; + } + return filter(nodes, [&](Node const &n) { + return contains(allowed_values(query.nodes), n); + }); } std::pair, std::unordered_set> @@ -146,17 +151,18 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdge JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return { - this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key( + JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::LEFT})}}; } UndirectedEdge JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key( - JoinNodeKey{e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key( - JoinNodeKey{e.bigger, LRDirection::RIGHT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.min(), LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::RIGHT})}}; } JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, @@ -208,7 +214,7 @@ DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return {e.src, e.dst}; + return UndirectedEdge{{e.src, e.dst}}; } std::unordered_set to_undirected_edges( @@ -218,8 +224,9 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, - DirectedEdge{e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.endpoints.min(), e.endpoints.max()}, + DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; } std::unordered_set to_directed_edges( diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc new file mode 100644 index 0000000000..0817c69e06 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -0,0 +1,106 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DiGraph - algorithms.cc") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + }; + add_edges(g, e); + + SUBCASE("get_edges") { + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[3], n[1]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n[0], n[3]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + SUBCASE("get_sinks") { + SUBCASE("Base") { + std::unordered_set correct = {n[2], n[3]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a sink") { + g.add_edge(DirectedEdge{n[3], n[2]}); + std::unordered_set correct = {n[2]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sinks(g); + std::unordered_set correct = {n[3]}; + CHECK(result == correct); + } + } + + SUBCASE("get_sources") { + SUBCASE("Base") { + std::unordered_set correct = {n[0]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set correct = {}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n[0], n[1]}); + std::unordered_set correct = {n[0], n[1]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sources(g); + std::unordered_set correct = {}; + CHECK(result.empty()); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc new file mode 100644 index 0000000000..3a3648eec8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -0,0 +1,85 @@ +#include "utils/graph/digraph/digraph.h" +#include "utils/containers/repeat.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { + /* + graph TD + + n0 --> n1 + n0 --> n2 + n1 --> n2 + n2 --> n4 + n1 --> n3 + */ + + DiGraph g = DiGraph::create(); + std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector e = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}; + for (DirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[1], n[2], n[3], n[4]}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[3], e[4]}); + + g.remove_node_unsafe(n[1]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[2], n[3], n[4]}); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[3]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[1], e[2], e[3], e[4]}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + g.remove_edge(e[1]); + g.remove_edge(e[3]); + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[4]}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc new file mode 100644 index 0000000000..1dde5c8f69 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -0,0 +1,70 @@ +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("directed_edge_query") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}}); + + SUBCASE("directed_edge_query_all") { + + DirectedEdgeQuery result = directed_edge_query_all(); + + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); + } + + SUBCASE("matches_edge") { + DirectedEdgeQuery q = + DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; + + CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); + CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); + } + + SUBCASE("query_intersection") { + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; + DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, + query_set{n.at(2), n.at(3)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1)}, + query_set{n.at(2)}, + }; + + CHECK(result == correct); + } + SUBCASE("intersection with std::nullopt") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc new file mode 100644 index 0000000000..e9151b53e5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -0,0 +1,68 @@ +#include "utils/graph/digraph/algorithms/get_dominators.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dominators") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("single node") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } + + SUBCASE("multiple nodes") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } + + SUBCASE("graph with cycles") { + // example from + // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 + + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(1)}, + }); + + SUBCASE("node 1") { + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + CHECK(result == correct); + } + + SUBCASE("node 3") { + std::unordered_set result = get_dominators(g, n.at(3)); + std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc new file mode 100644 index 0000000000..de6953fad4 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -0,0 +1,36 @@ +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/containers.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc new file mode 100644 index 0000000000..0d8e7ca53a --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -0,0 +1,112 @@ +#include "utils/graph/traversal.h" +#include "utils/fmt/vector.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/hash/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unchecked_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + } + + TEST_CASE("get_bfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}}); + + SUBCASE("branching path") { + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3], n[4], n[5]}, + {n[0], n[2], n[1], n[3], n[4], n[5]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("isolated node") { + std::vector correct = {n[5]}; + std::vector result = get_bfs_ordering(g, {n[5]}); + CHECK(correct == result); + } + + SUBCASE("graph with cycle") { + g = DiGraph::create(); + n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[0]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[2], n[1]}}); + std::unordered_set> corrects = {{n[0], n[1], n[2]}, + {n[0], n[2], n[1]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + } + + TEST_CASE("get_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("with cycle") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("branching") { + g.add_edge(DirectedEdge{n[1], n[3]}); + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("disconnected") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[0], n[1], n[2]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("isolated node") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[3]}; + std::vector result = get_dfs_ordering(g, {n[3]}); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..b5943cd99f --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -0,0 +1,36 @@ +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector edges = add_edges(g, + {{n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}}); + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..d4748e8422 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,40 @@ +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector> input = { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}, + }; + + std::vector edges = add_edges(g, input); + + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc new file mode 100644 index 0000000000..7f6f0dd064 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -0,0 +1,26 @@ +#include "utils/graph/undirected/algorithms/get_connected_components.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/undirected_graph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc new file mode 100644 index 0000000000..7973cf8af5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -0,0 +1,75 @@ +#include "test/utils/rapidcheck.h" +#include "test/utils/rapidcheck/visitable.h" +#include "utils/commutative_pair.h" +#include "utils/containers/repeat.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/undirected/undirected_graph.h" + +using namespace FlexFlow; + +using namespace rc; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} +/* static_assert(is_fmtable::value, ""); */ + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc new file mode 100644 index 0000000000..58f2e35cb5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -0,0 +1,262 @@ +#include "utils/graph/views/views.h" +#include "utils/containers/set_union.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + // TEST_CASE("UndirectedSubgraphView") { + // UndirectedGraph g = UndirectedGraph::create(); + // std::vector n = add_nodes(g, 5); + // add_edges(g, + // {UndirectedEdge{{n.at(0), n.at(3)}}, + // UndirectedEdge{{n.at(1), n.at(1)}}, + // UndirectedEdge{{n.at(1), n.at(2)}}, + // UndirectedEdge{{n.at(1), n.at(3)}}, + // UndirectedEdge{{n.at(2), n.at(3)}}, + // UndirectedEdge{{n.at(2), n.at(4)}}}); + // std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + // UndirectedGraphView view = get_subgraph(g, sub_nodes); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // UndirectedEdge{{n.at(0), n.at(3)}}, + // UndirectedEdge{{n.at(1), n.at(1)}}, + // UndirectedEdge{{n.at(1), n.at(3)}}, + // }; + + // std::unordered_set result = get_edges(view); + + // // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also + // CHECK(result == expected); + // } + // } + + TEST_CASE("DiSubgraphView") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + DiGraphView view = get_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("JoinedNodeView") { + UndirectedGraph g1 = UndirectedGraph::create(); + UndirectedGraph g2 = UndirectedGraph::create(); + + std::vector n1 = add_nodes(g1, 3); + std::vector n2 = add_nodes(g2, 2); + std::unordered_set joined_nodes = + set_union(unordered_set_of(n1), unordered_set_of(n2)); + add_edges(g1, + {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); + add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); + + JoinedNodeView joined_view(g1, g2); + + SUBCASE("trace_nodes") { + std::pair, std::unordered_set> result = + joined_view.trace_nodes(joined_nodes); + std::pair, std::unordered_set> correct = { + {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; + + CHECK(result == correct); + } + + SUBCASE("query_nodes") { + SUBCASE("matchall") {} + SUBCASE("subset") {} + } + } + + // TEST_CASE("JoinedUndirectedGraphView") { + // UndirectedGraph g1 = UndirectedGraph::create(); + // UndirectedGraph g2 = UndirectedGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges(g1, + // {UndirectedEdge{{n1.at(0), n1.at(1)}}, + // UndirectedEdge{{n1.at(1), n1.at(2)}}}); + // add_edges(g2, + // {UndirectedEdge{{n2.at(0), n2.at(2)}}, + // UndirectedEdge{{n2.at(1), n2.at(2)}}}); + + // UndirectedGraphView view = join(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // UndirectedEdge{{n1.at(0), n1.at(1)}}, + // UndirectedEdge{{n1.at(1), n1.at(2)}}, + // UndirectedEdge{{n2.at(0), n2.at(2)}}, + // UndirectedEdge{{n2.at(1), n2.at(2)}}}; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + // TEST_CASE("JoinedDigraphView") { + // DiGraph g1 = DiGraph::create(); + // DiGraph g2 = DiGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges( + // g1, + // {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), + // n1.at(2)}}); + // add_edges( + // g2, + // {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), + // n2.at(2)}}); + + // DiGraphView view = join(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // DirectedEdge{n1.at(0), n1.at(1)}, + // DirectedEdge{n1.at(1), n1.at(2)}, + // DirectedEdge{n2.at(0), n2.at(2)}, + // DirectedEdge{n2.at(1), n2.at(2)}}; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + TEST_CASE("ViewDiGraphAsUndirectedGraph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}); + + UndirectedGraphView view = as_undirected(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewUndirectedGraphAsDiGraph") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(0)}}, + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}); + + DiGraphView view = as_digraph(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(0)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } +} From 0e8c9625195c322b29960a1fa074191bf6dfbc2d Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 21:21:57 -0700 Subject: [PATCH 02/16] fmt --- lib/utils/test/src/utils/graph/views/views.cc | 56 ++++++++++--------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index 58f2e35cb5..fb5374bd6a 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -87,34 +87,36 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("JoinedNodeView") { - UndirectedGraph g1 = UndirectedGraph::create(); - UndirectedGraph g2 = UndirectedGraph::create(); - - std::vector n1 = add_nodes(g1, 3); - std::vector n2 = add_nodes(g2, 2); - std::unordered_set joined_nodes = - set_union(unordered_set_of(n1), unordered_set_of(n2)); - add_edges(g1, - {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); - add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); - - JoinedNodeView joined_view(g1, g2); - - SUBCASE("trace_nodes") { - std::pair, std::unordered_set> result = - joined_view.trace_nodes(joined_nodes); - std::pair, std::unordered_set> correct = { - {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; - - CHECK(result == correct); - } + // TEST_CASE("JoinedNodeView") { + // UndirectedGraph g1 = UndirectedGraph::create(); + // UndirectedGraph g2 = UndirectedGraph::create(); - SUBCASE("query_nodes") { - SUBCASE("matchall") {} - SUBCASE("subset") {} - } - } + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 2); + // std::unordered_set joined_nodes = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + // add_edges(g1, + // {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], + // n1[2]}}}); + // add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); + + // JoinedNodeView joined_view(g1, g2); + + // SUBCASE("trace_nodes") { + // std::pair, std::unordered_set> result = + // joined_view.trace_nodes(joined_nodes); + // std::pair, std::unordered_set> correct = + // { + // {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; + + // CHECK(result == correct); + // } + + // SUBCASE("query_nodes") { + // SUBCASE("matchall") {} + // SUBCASE("subset") {} + // } + // } // TEST_CASE("JoinedUndirectedGraphView") { // UndirectedGraph g1 = UndirectedGraph::create(); From 8e11b0b2e0db8777bfc75f526fa57b8945a8575f Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 17:09:00 -0700 Subject: [PATCH 03/16] removed unneccesary views, fixed views adjacent bugs --- lib/utils/include/utils/graph/views/views.h | 126 +------------ lib/utils/src/utils/graph/algorithms.cc | 9 - .../instances/hashmap_undirected_graph.cc | 9 +- lib/utils/src/utils/graph/views/views.cc | 155 +--------------- .../algorithms/get_connected_components.cc | 41 ++++- lib/utils/test/src/utils/graph/views/views.cc | 167 +++--------------- 6 files changed, 77 insertions(+), 430 deletions(-) diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index aaa1e033f4..5e0109ed5b 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -41,104 +41,11 @@ struct DiSubgraphView : public IDiGraphView { std::unordered_set subgraph_nodes; }; -struct JoinedNodeView { -public: - JoinedNodeView() = delete; - explicit JoinedNodeView(GraphView const &lhs, GraphView const &rhs); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::pair, std::unordered_set> - trace_nodes(std::unordered_set const &) const; - - Node at_join_key(JoinNodeKey const &) const; - JoinNodeKey at_node(Node const &) const; - -private: - bidict mapping; - NodeSource node_source; -}; - -struct JoinedUndirectedGraphView : public IUndirectedGraphView { -public: - JoinedUndirectedGraphView() = delete; - explicit JoinedUndirectedGraphView(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs); - - std::unordered_set - query_edges(UndirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedUndirectedGraphView *clone() const override; - -private: - UndirectedEdge fix_lhs_edge(UndirectedEdge const &) const; - UndirectedEdge fix_rhs_edge(UndirectedEdge const &) const; - -private: - UndirectedGraphView lhs; - UndirectedGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct JoinedDigraphView : virtual public IDiGraphView { -public: - JoinedDigraphView() = delete; - explicit JoinedDigraphView(DiGraphView const &lhs, DiGraphView const &rhs); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedNodeView const &joined_nodes_view() const; +UndirectedGraphView view_subgraph(UndirectedGraphView const &, + std::unordered_set const &); - JoinedDigraphView *clone() const override; - -private: - DirectedEdge fix_lhs_edge(DirectedEdge const &) const; - DirectedEdge fix_rhs_edge(DirectedEdge const &) const; - -private: - DiGraphView lhs; - DiGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct AddDirectedEdgesView : public IDiGraphView { -public: - AddDirectedEdgesView() = delete; - - explicit AddDirectedEdgesView(DiGraphView const &g, - std::unordered_set const &edges); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - AddDirectedEdgesView *clone() const override; - -private: - DiGraphView g; - std::unordered_set edges; -}; - -struct SingleSourceNodeView : public IDiGraphView { -public: - SingleSourceNodeView() = delete; - - explicit SingleSourceNodeView(DiGraphView const &g) : g(g) {} - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - SingleSourceNodeView *clone() const override; - -private: - DiGraphView g; - std::optional singleton_src; - std::optional joined_view; - std::unique_ptr added_edges_view; -}; +DiGraphView view_subgraph(DiGraphView const &, + std::unordered_set const &); UndirectedEdge to_undirected_edge(DirectedEdge const &); std::unordered_set @@ -176,31 +83,6 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView { UndirectedGraphView g; }; -std::unordered_map - flatten_contraction(std::unordered_map const &); - -template -Impl materialize_view(View const &g) { - Impl result; - for (Node const &n : get_nodes(g)) { - result.add_node_unsafe(n); - } - for (auto const &e : get_edges(g)) { - result.add_edge(e); - } - return result; -} - -template -Impl materialize_undirected_graph_view(IUndirectedGraphView const &g) { - return materialize_view(g); -} - -template -Impl materialize_digraph_view(IDiGraphView const &g) { - return materialize_view(g); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 79c4fc9964..d7cd979f14 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -481,15 +481,6 @@ DiGraphView get_subgraph(DiGraphView const &g, // return MultiDiGraphView::create(lhs, rhs); // } -DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) { - return DiGraphView::create(lhs, rhs); -} - -UndirectedGraphView join(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs) { - return UndirectedGraphView::create(lhs, rhs); -} - UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index df84683a6b..5d16304701 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -38,16 +38,17 @@ void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.endpoints.max()); - m.erase(e.endpoints.min()); - m.erase(e.endpoints.max()); + std::unordered_set &max_map = this->adjacency.at(e.endpoints.max()); + max_map.erase(e.endpoints.min()); + std::unordered_set &min_map = this->adjacency.at(e.endpoints.min()); + min_map.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : src_kv.second) { + for (auto const &dst : apply_query(query.nodes, src_kv.second)) { result.insert(UndirectedEdge{{src_kv.first, dst}}); } } diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index c29d478f1e..7bb039d314 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -66,153 +66,6 @@ DiGraphView view_subgraph(DiGraphView const &g, return DiGraphView::create(g, subgraph_nodes); } -JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { - for (Node const &n : get_nodes(lhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::LEFT}, - this->node_source.new_node()); - } - for (Node const &n : get_nodes(rhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::RIGHT}, - this->node_source.new_node()); - } -} - -std::unordered_set - JoinedNodeView::query_nodes(NodeQuery const &query) const { - std::unordered_set nodes = right_entries(this->mapping); - if (query == node_query_all()) { - return nodes; - } - return filter(nodes, [&](Node const &n) { - return contains(allowed_values(query.nodes), n); - }); -} - -std::pair, std::unordered_set> - JoinedNodeView::trace_nodes(std::unordered_set const &nodes) const { - std::unordered_set left_nodes, right_nodes; - - for (Node const &n : nodes) { - JoinNodeKey k = this->at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - return {left_nodes, right_nodes}; -} - -Node JoinedNodeView::at_join_key(JoinNodeKey const &k) const { - return this->mapping.at_l(k); -} - -JoinNodeKey JoinedNodeView::at_node(Node const &n) const { - return this->mapping.at_r(n); -} - -JoinedUndirectedGraphView::JoinedUndirectedGraphView( - UndirectedGraphView const &lhs, UndirectedGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -std::unordered_set - JoinedUndirectedGraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set JoinedUndirectedGraphView::query_edges( - UndirectedEdgeQuery const &query) const { - std::unordered_set nodes = this->query_nodes(NodeQuery{query.nodes}); - std::unordered_set left_nodes, right_nodes; - for (Node const &n : nodes) { - JoinNodeKey k = this->joined_nodes.at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - std::unordered_set result; - for (UndirectedEdge const &e : - this->lhs.query_edges(UndirectedEdgeQuery{left_nodes})) { - result.insert(this->fix_lhs_edge(e)); - } - for (UndirectedEdge const &e : - this->rhs.query_edges(UndirectedEdgeQuery{right_nodes})) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key( - JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::LEFT})}}; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.min(), LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::RIGHT})}}; -} - -JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, - DiGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -JoinedDigraphView *JoinedDigraphView::clone() const { - return new JoinedDigraphView(lhs, rhs); -} - -std::unordered_set - JoinedDigraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set - JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { - - std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); - std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); - auto traced_srcs = this->joined_nodes.trace_nodes(srcs); - auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - DirectedEdgeQuery left_query = - DirectedEdgeQuery{traced_srcs.first, traced_dsts.first}; - DirectedEdgeQuery right_query = - DirectedEdgeQuery{traced_srcs.second, traced_dsts.second}; - - std::unordered_set result; - for (DirectedEdge const &e : this->lhs.query_edges(left_query)) { - result.insert(this->fix_lhs_edge(e)); - } - for (DirectedEdge const &e : this->rhs.query_edges(right_query)) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -DirectedEdge JoinedDigraphView::fix_lhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT})}; -} - -DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; -} - UndirectedEdge to_undirected_edge(DirectedEdge const &e) { return UndirectedEdge{{e.src, e.dst}}; } @@ -265,8 +118,8 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), + g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); @@ -279,8 +132,4 @@ std::unordered_set return g.query_nodes(q); } -JoinedUndirectedGraphView *JoinedUndirectedGraphView::clone() const { - return new JoinedUndirectedGraphView(lhs, rhs); -} - } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index 7f6f0dd064..179cce7db7 100644 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -7,10 +7,24 @@ using namespace FlexFlow; -TEST_SUITE(FF_TEST_SUITE) { +TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); + + std::unordered_set> correct = { + {n[0]}, + {n[1]}, + {n[2]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("2 components") { std::vector n = add_nodes(g, 4); add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); @@ -23,4 +37,25 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(correct == result); } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + UndirectedEdge{{n[0], n[1]}}, + UndirectedEdge{{n[0], n[2]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[3], n[4]}}, + }); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3], n[4]}, + {n[5]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } } diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index fb5374bd6a..8a6a44d1cc 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -14,41 +14,39 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UndirectedSubgraphView") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(4)}}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + UndirectedGraphView view = view_subgraph(g, sub_nodes); - // TEST_CASE("UndirectedSubgraphView") { - // UndirectedGraph g = UndirectedGraph::create(); - // std::vector n = add_nodes(g, 5); - // add_edges(g, - // {UndirectedEdge{{n.at(0), n.at(3)}}, - // UndirectedEdge{{n.at(1), n.at(1)}}, - // UndirectedEdge{{n.at(1), n.at(2)}}, - // UndirectedEdge{{n.at(1), n.at(3)}}, - // UndirectedEdge{{n.at(2), n.at(3)}}, - // UndirectedEdge{{n.at(2), n.at(4)}}}); - // std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - // UndirectedGraphView view = get_subgraph(g, sub_nodes); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; - // std::unordered_set result = get_nodes(view); + std::unordered_set result = get_nodes(view); - // CHECK(result == expected); - // } + CHECK(result == expected); + } - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // UndirectedEdge{{n.at(0), n.at(3)}}, - // UndirectedEdge{{n.at(1), n.at(1)}}, - // UndirectedEdge{{n.at(1), n.at(3)}}, - // }; + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + }; - // std::unordered_set result = get_edges(view); + std::unordered_set result = get_edges(view); - // // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also - // CHECK(result == expected); - // } - // } + CHECK(result == expected); + } + } TEST_CASE("DiSubgraphView") { DiGraph g = DiGraph::create(); @@ -63,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(3), n.at(2)}, DirectedEdge{n.at(2), n.at(4)}}); std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - DiGraphView view = get_subgraph(g, sub_nodes); + DiGraphView view = view_subgraph(g, sub_nodes); SUBCASE("get_nodes") { std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; @@ -87,115 +85,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - // TEST_CASE("JoinedNodeView") { - // UndirectedGraph g1 = UndirectedGraph::create(); - // UndirectedGraph g2 = UndirectedGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 2); - // std::unordered_set joined_nodes = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - // add_edges(g1, - // {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], - // n1[2]}}}); - // add_edges(g2, {UndirectedEdge{{n2[0], n2[1]}}}); - - // JoinedNodeView joined_view(g1, g2); - - // SUBCASE("trace_nodes") { - // std::pair, std::unordered_set> result = - // joined_view.trace_nodes(joined_nodes); - // std::pair, std::unordered_set> correct = - // { - // {n1[0], n1[1], n1[2]}, {n2[0], n2[1]}}; - - // CHECK(result == correct); - // } - - // SUBCASE("query_nodes") { - // SUBCASE("matchall") {} - // SUBCASE("subset") {} - // } - // } - - // TEST_CASE("JoinedUndirectedGraphView") { - // UndirectedGraph g1 = UndirectedGraph::create(); - // UndirectedGraph g2 = UndirectedGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); - - // add_edges(g1, - // {UndirectedEdge{{n1.at(0), n1.at(1)}}, - // UndirectedEdge{{n1.at(1), n1.at(2)}}}); - // add_edges(g2, - // {UndirectedEdge{{n2.at(0), n2.at(2)}}, - // UndirectedEdge{{n2.at(1), n2.at(2)}}}); - - // UndirectedGraphView view = join(g1, g2); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - - // std::unordered_set result = get_nodes(view); - - // CHECK(result == expected); - // } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // UndirectedEdge{{n1.at(0), n1.at(1)}}, - // UndirectedEdge{{n1.at(1), n1.at(2)}}, - // UndirectedEdge{{n2.at(0), n2.at(2)}}, - // UndirectedEdge{{n2.at(1), n2.at(2)}}}; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - // } - - // TEST_CASE("JoinedDigraphView") { - // DiGraph g1 = DiGraph::create(); - // DiGraph g2 = DiGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); - - // add_edges( - // g1, - // {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), - // n1.at(2)}}); - // add_edges( - // g2, - // {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), - // n2.at(2)}}); - - // DiGraphView view = join(g1, g2); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - - // std::unordered_set result = get_nodes(view); - - // CHECK(result == expected); - // } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // DirectedEdge{n1.at(0), n1.at(1)}, - // DirectedEdge{n1.at(1), n1.at(2)}, - // DirectedEdge{n2.at(0), n2.at(2)}, - // DirectedEdge{n2.at(1), n2.at(2)}}; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - // } - TEST_CASE("ViewDiGraphAsUndirectedGraph") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); From 87e767cdafed8d520180c73e49a3cb692ff87200 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 18:54:44 -0700 Subject: [PATCH 04/16] minor optimization --- .../series_parallel/parallel_reduction.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 12a6630bf0..609d065660 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,5 +1,9 @@ #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" namespace FlexFlow { @@ -10,13 +14,17 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 != e2 && g.get_multidiedge_src(e1) == g.get_multidiedge_src(e2) && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); + for (auto const &[directed_edge, count] : get_edge_counts(g)) { + + if (count <= 1) {continue;} + + std::unordered_set const &outgoing_edges = get_outgoing_edges(g, directed_edge.src); + for (MultiDiEdge const &e1 : outgoing_edges) { + for (MultiDiEdge const &e2 : outgoing_edges) { + if (e1 != e2 && g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { + return make_parallel_reduction(e1, e2); + } } } } From 44b32f88c47b1a10eeaab6d42843afb9d090d7c0 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 19:04:17 -0700 Subject: [PATCH 05/16] fmt --- .../graph/series_parallel/parallel_reduction.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 609d065660..78265f6856 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,9 +1,9 @@ #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" -#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -17,12 +17,16 @@ std::optional for (auto const &[directed_edge, count] : get_edge_counts(g)) { - if (count <= 1) {continue;} + if (count <= 1) { + continue; + } - std::unordered_set const &outgoing_edges = get_outgoing_edges(g, directed_edge.src); + std::unordered_set const &outgoing_edges = + get_outgoing_edges(g, directed_edge.src); for (MultiDiEdge const &e1 : outgoing_edges) { for (MultiDiEdge const &e2 : outgoing_edges) { - if (e1 != e2 && g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { + if (e1 != e2 && + g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { return make_parallel_reduction(e1, e2); } } From e7055ad8ac1da6dce17e5691600bb4aa45cf8b7d Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 14 Oct 2024 19:40:37 -0700 Subject: [PATCH 06/16] small fix --- .../utils/graph/digraph/algorithms/get_topological_ordering.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index de6953fad4..5adc0cc4df 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" -#include "utils/containers.h" +#include "utils/containers/index_of.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" From 8564b8b979419ced5a67774c89777d380061e300 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 27 Nov 2024 11:38:54 -0800 Subject: [PATCH 07/16] get_series_parallel_decomposition fix --- lib/utils/include/utils/containers/find.h | 7 +++ .../series_parallel/parallel_reduction.h | 7 +++ .../series_parallel_decomposition.h | 19 ++++++ .../get_series_parallel_decomposition.cc | 61 +++++++++--------- .../series_parallel/parallel_reduction.cc | 63 +++++++++++++------ .../series_parallel_decomposition.cc | 61 ++++++++++++++++++ .../graph/series_parallel/series_reduction.cc | 33 ++++------ .../test/src/utils/containers/contains.cc | 15 ++++- 8 files changed, 192 insertions(+), 74 deletions(-) diff --git a/lib/utils/include/utils/containers/find.h b/lib/utils/include/utils/containers/find.h index eed5f8453c..7b103fed16 100644 --- a/lib/utils/include/utils/containers/find.h +++ b/lib/utils/include/utils/containers/find.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H #include +#include namespace FlexFlow { @@ -11,6 +12,12 @@ typename Container::const_iterator return std::find(c.cbegin(), c.cend(), e); } +template +typename std::unordered_set::const_iterator + find(std::unordered_set const &c, V const &e) { + return c.find(e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 3fc1347ee5..0b3c7f3619 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -12,8 +12,15 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &, std::optional find_parallel_reduction(MultiDiGraphView const &); +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &); + MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +MultiDiEdge + apply_extended_parallel_reduction(MultiDiGraph &, + std::unordered_set const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index 52d2cb7236..d56d4a55f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -17,6 +17,25 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SeriesSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SeriesParallelDecomposition const &sp); + +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, + Node const &node); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions); +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index cd29af59a0..7a5cb1ea82 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -2,14 +2,18 @@ #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/series_reduction.h" @@ -26,39 +30,18 @@ std::optional if (!maybe_line_graph.has_value()) { return std::nullopt; } - maybe_line_graph.value(); }); MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict .as_unordered_map(), - [](Node const &n) { return BinarySPDecompositionTree{n}; }); + [](Node const &n) { return SeriesParallelDecomposition{n}; }); while (true) { - assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); - std::optional maybe_parallel_reduction = - find_parallel_reduction(ttsp); - if (maybe_parallel_reduction.has_value()) { - ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); - auto [e1, e2] = parallel_reduction.edges.ordered(); - MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinaryParallelSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; - ttsp_edge_to_sp_tree.erase(e1); - ttsp_edge_to_sp_tree.erase(e2); - ttsp_edge_to_sp_tree.insert({merged, new_tree}); - - continue; - } - std::optional maybe_series_reduction = find_series_reduction(ttsp); if (maybe_series_reduction.has_value()) { @@ -66,15 +49,33 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinarySeriesSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, - }; + + SeriesParallelDecomposition new_tree = serial_composition({ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }); + ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + continue; + } + std::unordered_map> + parallel_reductions = find_all_extended_parallel_reductions(ttsp); + if (!parallel_reductions.empty()) { + for (auto const &[_, parallel_reduction] : parallel_reductions) { + MultiDiEdge merged = + apply_extended_parallel_reduction(ttsp, parallel_reduction); + + SeriesParallelDecomposition new_tree = parallel_composition(transform( + unordered_multiset_of(parallel_reduction), + [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + for (MultiDiEdge const &e : parallel_reduction) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + } continue; } @@ -87,7 +88,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); + return ttsp_edge_to_sp_tree.at(e); } } } diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 78265f6856..c7eb866b62 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,9 +1,17 @@ #include "utils/graph/series_parallel/parallel_reduction.h" -#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" +#include "utils/containers/get_one_of.h" +#include "utils/containers/group_by.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" -#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" -#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" +#include +#include namespace FlexFlow { @@ -15,31 +23,48 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - for (auto const &[directed_edge, count] : get_edge_counts(g)) { - - if (count <= 1) { - continue; - } - - std::unordered_set const &outgoing_edges = - get_outgoing_edges(g, directed_edge.src); - for (MultiDiEdge const &e1 : outgoing_edges) { - for (MultiDiEdge const &e2 : outgoing_edges) { - if (e1 != e2 && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); - } - } + std::unordered_map seen; + for (MultiDiEdge const &edge : get_edges(g)) { + DirectedEdge diedge = get_directed_edge(g, edge); + if (seen.find(diedge) != seen.end()) { + return make_parallel_reduction(seen.at(diedge), edge); } + seen.emplace(diedge, edge); } - return std::nullopt; } +std::unordered_map> + find_all_extended_parallel_reductions(MultiDiGraphView const &g) { + std::unordered_map> + parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) { + return get_directed_edge(g, edge); + }); + + return filter( + parallel_groups, + [](std::pair> const + &group) { return group.second.size() > 1; }); +} + MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, ParallelReduction const &r) { g.remove_edge(r.edges.max()); return r.edges.min(); } +MultiDiEdge apply_extended_parallel_reduction( + MultiDiGraph &g, std::unordered_set const ¶llel_edges) { + + MultiDiEdge keep_edge = get_one_of(parallel_edges); + + for (MultiDiEdge const ¶llel_edge : parallel_edges) { + if (parallel_edge != keep_edge) { + g.remove_edge(parallel_edge); + } + } + + return keep_edge; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index b7a84b871a..dc99ef6c5a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,12 +1,17 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" +#include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -74,4 +79,60 @@ std::unordered_multiset get_nodes(Node const &node) { return {node}; } +bool is_empty(Node const &node) { + return false; +} + +bool is_empty(SeriesSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(ParallelSplit const ¶llel) { + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(SeriesParallelDecomposition const &sp) { + return sp.visit([](auto const &t) { return is_empty(t); }); +} + +SeriesParallelDecomposition serial_composition( + std::vector const &sp_compositions) { + std::vector> composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.push_back(sp_comp.get()); + } + } + return SeriesParallelDecomposition{SeriesSplit{composition}}; +} + +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions) { + std::unordered_multiset< + std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> + composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition = multiset_union(composition, + sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.insert(sp_comp.get()); + } + } + return SeriesParallelDecomposition(ParallelSplit{composition}); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 7300c93fb0..c312bb4a6b 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,8 +1,14 @@ #include "utils/graph/series_parallel/series_reduction.h" +#include "utils/containers/contains.h" +#include "utils/containers/get_only.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { @@ -26,30 +32,13 @@ SeriesReduction make_series_reduction(MultiDiEdge const &e1, std::optional find_series_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 == e2) { - continue; - } - Node e1_dst = g.get_multidiedge_dst(e1); - Node e2_src = g.get_multidiedge_src(e2); - if (e1_dst != e2_src) { - continue; - } - - std::unordered_set outgoing = get_outgoing_edges(g, e1_dst); - std::unordered_set incoming = get_incoming_edges(g, e1_dst); - - if (outgoing.size() > 1 || incoming.size() > 1) { - continue; - } - - return SeriesReduction{e1, e2}; + for (Node const &node : get_nodes(g)) { + if (get_incoming_edges(g, node).size() == 1 && + get_outgoing_edges(g, node).size() == 1) { + return make_series_reduction(get_only(get_incoming_edges(g, node)), + get_only(get_outgoing_edges(g, node))); } } - return std::nullopt; } diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index 6e0a84c7ab..fc42d25eea 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -1,13 +1,22 @@ #include "utils/containers/contains.h" #include +#include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + SUBCASE("std::vector") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } + + SUBCASE("std::unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + CHECK(contains(s, 3)); + CHECK(!contains(s, 6)); + } } } From 8a04a2ee3bcde144165c269dbfed4417d8df94d9 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 20 Dec 2024 14:55:52 -0800 Subject: [PATCH 08/16] updates to graph documentation + get_series_parallel_decomposition speedup --- lib/utils/include/utils/graph/README.md | 114 +++++++++++-- .../algorithms/get_incoming_edges.h | 3 + .../algorithms/get_outgoing_edges.h | 3 + .../graph/series_parallel/series_reduction.h | 7 + .../algorithms/get_incoming_edges.cc | 17 ++ .../algorithms/get_outgoing_edges.cc | 17 +- .../get_series_parallel_decomposition.cc | 51 +++--- .../graph/series_parallel/series_reduction.cc | 51 ++++++ .../graph/series_parallel/series_reduction.cc | 152 ++++++++++++++++++ 9 files changed, 382 insertions(+), 33 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index f3d31e7bc8..41777b9b9a 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -18,13 +18,8 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. - `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) - `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. -- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: - - The edges entering, exiting a given nodes now have a well-defined order. - - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. - - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. +- `DataflowGraph`: used to model computation graphs. See the [DataflowGraph](#dataflowgraph) section for a detailed explanation. -Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. - Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -89,7 +84,9 @@ Nodes are of type `Node`, and from a user perspective are simply opaque handles, In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. -To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). +To add a node to an `UndirectedGraph g`, simply call `g.add_node()`, which will return a `Node` object. +For semantics closer to `networkx`'s method of adding nodes, `g.add_node_unsafe(my_node)` can be used. This is useful when constructing a modified copy of an existing graph (given that it maintains node bijection), though it is not generally recommended. +The interface for node addition is identical for `DiGraph` and `MultiDiGraph`. To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. @@ -114,8 +111,8 @@ Both `Graph` and `GraphView` types follow normal value semantics. This may seem wasteful (oftentimes graphs are large objects that are passed around via reference to avoid making additional copies), but the `Graph` and `GraphView` types internally implement copy-on-write optimizations to only perform the minimum number of actual copies while maintaining immutability and lifetime safety (if you allocate a `DiGraph` use for example `get_subgraph` to get a `DiGraphView` representing a part of this graph, modifications to the underlying `DiGraph` will not be mirrored in the `DiGraphView` and the `DiGraphView` will remain valid even after the base `DiGraph` leaves scope. At this point, however, we still have not discussed how to create a graph. -The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: +The user-facing graph interface is intentionally separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. +For example, to construct a `DiGraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp DiGraph g = DiGraph::create(); ``` @@ -124,7 +121,104 @@ Generally users will use underlying representations provided by the graph librar [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open DataFlow Variant +### DataflowGraph + +The primary abstraction for representing computation graphs / task graphs is the `DataflowGraph` interface (along with its variants, `OpenDataflowGraph`, `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`). +At a high level, nodes represent multivariate functions (from tuples of inputs to tuple of outputs), while edges represent value uses of such functions. + +`DataflowGraph` is similar to `MultiDiGraph`, but with the following important differences: + - The edges entering, exiting a given nodes have a well-defined order. + - `DataflowGraph`s are directed acyclic graphs. This is enforced by the interface used to construct them, since a node can only be added to the graph after all of its predecessor nodes have already been added. + +The main components of `DataflowGraph` are as follows: +- `DataflowInput`: used to represent the ordered sequence of incoming dependencies (arguments) of a given node (operator). +- `DataflowOutput`: used to represent the ordered sequence of outgoing results (value uses) from a given node (operator). +- `DataflowEdge`: wrapper around a `DataflowInput`, `DataflowOutput` pair between 2 nodes. +- `NodeAddedResult`: returned upon adding a new node. Contains the newly generated `Node` and the vector of `DataflowOutput`s for the given node. + +`DataflowGraph`s are constructed as follows: + +```cpp + auto g = DataflowGraph::create(); + + // Node with no inputs and 2 outputs + NodeAddedResult n1_result = g.add_node({}, 2); + Node n1 = n1_result.node; + DataflowOutput n1_o1 = n1_result.outputs[0]; + DataflowOutput n1_o2 = n1_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n2_result = g.add_node({n1_o1, n1_o2}, 1); + Node n2 = n2_result.node; + DataflowOutput n2_o1 = n2_result.outputs[0]; + + // Node with 1 input and 2 outputs + NodeAddedResult n3_result = g.add_node({n1_o2}, 1); + Node n3 = n3_result.node; + DataflowOutput n3_o1 = n3_result.outputs[0]; + DataflowOutput n3_o2 = n3_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n4_result = g.add_node({n2_o1, n3_o1}, 1); + Node n4 = n4_result.node; + DataflowOutput n4_o1 = n4_result.outputs[0]; +``` + +which generates the following graph + +```mermaid +flowchart TD + subgraph Node1[ ] + direction TB + N1Process[n1] + n1_o1((n1_o1)) + n1_o2((n1_o2)) + N1Process --> n1_o1 + N1Process --> n1_o2 + end + + subgraph Node2[ ] + direction TB + n2_i1((n2_i1)) + n2_i2((n2_i2)) + N2Process[n2] + n2_o1((o1)) + n2_i1 --> N2Process + n2_i2 --> N2Process + N2Process --> n2_o1 + end + + subgraph Node3[ ] + direction TB + n3_i1((n3_i1)) + N3Process[n3] + n3_o1((n3_o1)) + n3_o2((n3_o2)) + n3_i1 --> N3Process + N3Process --> n3_o1 + N3Process --> n3_o2 + end + + subgraph Node4[ ] + direction TB + n4_i1((n4_i1)) + n4_i2((n4_i2)) + N4Process[n4] + n4_o1((n4_o1)) + n4_i1 --> N4Process + n4_i2 --> N4Process + N4Process --> n4_o1 + end + + n1_o1 --> n2_i1 + n1_o2 --> n2_i2 + n1_o2 --> n3_i1 + n2_o1 --> n4_i1 + n3_o1 --> n4_i2 +``` + + +### Open Dataflow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h index df5662804a..76be999b54 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h @@ -8,6 +8,9 @@ namespace FlexFlow { std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h index 6bc73533e7..6a8474673e 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h @@ -8,6 +8,9 @@ namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index a7d53fecfc..0de8aecc19 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -4,6 +4,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/series_parallel/series_reduction.dtg.h" +#include "utils/hash/vector.h" namespace FlexFlow { @@ -14,8 +15,14 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &); SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &); std::optional find_series_reduction(MultiDiGraphView const &); +std::unordered_set> + find_all_extended_series_reductions(MultiDiGraphView const &g); + MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &); +MultiDiEdge apply_extended_series_reduction( + MultiDiGraph &g, std::vector const &series_edges); + } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index be39dc158f..50818dea2f 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -1,4 +1,8 @@ #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -7,4 +11,17 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{query_set::matchall(), {n}}); } +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g) { + std::unordered_map> result = + group_by(get_edges(g), + [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }); + + for (Node const &n : get_nodes(g)) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index f98c599614..55847cf2af 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -1,5 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" - +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, @@ -7,4 +9,17 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{{n}, query_set::matchall()}); } +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g) { + std::unordered_map> result = + group_by(get_edges(g), + [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }); + + for (Node const &n : get_nodes(g)) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 7a5cb1ea82..908743fae1 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -35,6 +35,7 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); + std::unordered_map ttsp_edge_to_sp_tree = map_values( inverse_line_graph_result.inverse_edge_to_line_node_bidict @@ -42,27 +43,11 @@ std::optional [](Node const &n) { return SeriesParallelDecomposition{n}; }); while (true) { - std::optional maybe_series_reduction = - find_series_reduction(ttsp); - if (maybe_series_reduction.has_value()) { - SeriesReduction series_reduction = maybe_series_reduction.value(); - MultiDiEdge e1 = series_reduction.first; - MultiDiEdge e2 = series_reduction.second; - MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - - SeriesParallelDecomposition new_tree = serial_composition({ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }); - - ttsp_edge_to_sp_tree.erase(e1); - ttsp_edge_to_sp_tree.erase(e2); - ttsp_edge_to_sp_tree.insert({merged, new_tree}); + int reductions = 0; - continue; - } std::unordered_map> parallel_reductions = find_all_extended_parallel_reductions(ttsp); + if (!parallel_reductions.empty()) { for (auto const &[_, parallel_reduction] : parallel_reductions) { MultiDiEdge merged = @@ -71,18 +56,40 @@ std::optional SeriesParallelDecomposition new_tree = parallel_composition(transform( unordered_multiset_of(parallel_reduction), [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + for (MultiDiEdge const &e : parallel_reduction) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); } - continue; + reductions++; } - if (get_nodes(ttsp).size() != 2) { - return std::nullopt; + std::unordered_set> series_reductions = + find_all_extended_series_reductions(ttsp); + if (!series_reductions.empty()) { + for (std::vector series_reduction : series_reductions) { + MultiDiEdge merged = + apply_extended_series_reduction(ttsp, series_reduction); + + SeriesParallelDecomposition new_tree = serial_composition( + transform(series_reduction, [&](MultiDiEdge const &e) { + return ttsp_edge_to_sp_tree.at(e); + })); + + for (MultiDiEdge const &e : series_reduction) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + } + reductions++; } - if (get_edges(ttsp).size() != 1) { + + if (reductions > 0) { + continue; + } + + if (get_nodes(ttsp).size() != 2 || get_edges(ttsp).size() != 1) { return std::nullopt; } diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index c312bb4a6b..26fabe593c 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,13 +1,21 @@ #include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/require_same.h" +#include "utils/containers/subvec.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/multidigraph/multidigraph_view.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" #include namespace FlexFlow { @@ -42,6 +50,34 @@ std::optional return std::nullopt; } +std::unordered_set> + find_all_extended_series_reductions(MultiDiGraphView const &g) { + std::unordered_map> incoming_edges = + get_incoming_edges(g); + std::unordered_map> outgoing_edges = + get_outgoing_edges(g); + std::unordered_map> strands; + std::unordered_map node_to_head_of_strand; + for (Node const &n : get_topological_ordering(g)) { + if ((incoming_edges.at(n).size() == 1) && + (outgoing_edges.at(n).size() == 1)) { + MultiDiEdge incoming = get_only(incoming_edges.at(n)); + MultiDiEdge outgoing = get_only(outgoing_edges.at(n)); + Node pre = g.get_multidiedge_src(incoming); + if (contains_key(node_to_head_of_strand, pre)) { + Node head = node_to_head_of_strand.at(pre); + node_to_head_of_strand.emplace(n, head); + strands.at(head).push_back(outgoing); + } else { + node_to_head_of_strand.emplace(n, n); + strands[n].push_back(incoming); + strands[n].push_back(outgoing); + } + } + } + return unordered_set_of(values(strands)); +} + MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { Node pre_node = get_pre_node(g, r); Node center_node = get_center_node(g, r); @@ -51,4 +87,19 @@ MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { return g.add_edge(pre_node, post_node); } +MultiDiEdge apply_extended_series_reduction( + MultiDiGraph &g, std::vector const &series_edges) { + + Node first = g.get_multidiedge_src(series_edges.at(0)); + Node last = g.get_multidiedge_dst(series_edges.at(series_edges.size() - 1)); + + std::vector internal_nodes; + for (MultiDiEdge const &e : subvec(series_edges, std::nullopt, -1)) { + internal_nodes.push_back(g.get_multidiedge_dst(e)); + } + for (Node const &n : internal_nodes) { + g.remove_node(n); + } + return g.add_edge(first, last); +} } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index c6b45ec6ce..3a8a5e9a60 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,9 +1,12 @@ #include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include @@ -234,6 +237,155 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(returned_edge_src == correct_src); } + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(5); + CHECK(returned_edge_dst == correct_dst); + } + } + } + } + TEST_CASE("find_all_extended_series_reductions") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("linear graph") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges(g, + { + {n.at(0), n.at(1)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(3)}, + }); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = { + {e[0], e[1], e[2]}}; + CHECK(result == correct); + } + + SUBCASE("2 linear strands") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(3)}, + {n.at(2), n.at(3)}}); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = {{e[0], e[2]}, + {e[1], e[3]}}; + CHECK(result == correct); + } + + SUBCASE("graph with multiple separate serial strands") { + std::vector n = add_nodes(g, 9); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(4)}, + {n.at(2), n.at(3)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(6)}, + {n.at(3), n.at(5)}, + {n.at(4), n.at(7)}, + {n.at(5), n.at(7)}, + {n.at(6), n.at(8)}, + {n.at(7), n.at(8)}}); + + std::unordered_set> result = + find_all_extended_series_reductions(g); + std::unordered_set> correct = { + {e[0], e[2], e[7]}, {e[3], e[6]}, {e[5], e[9]}}; + CHECK(result == correct); + } + } + + TEST_CASE("apply_extended_series_reduction") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("base case") { + std::vector n = add_nodes(g, 4); + std::vector e = add_edges( + g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); + + std::vector reduction = {e.at(0), e.at(1), e.at(2)}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = {n.at(0), n.at(3)}; + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = {returned_edge}; + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(0); + CHECK(returned_edge_src == correct_src); + } + + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(3); + CHECK(returned_edge_dst == correct_dst); + } + } + } + + SUBCASE("in larger graph") { + std::vector n = add_nodes(g, 8); + std::vector e = add_edges(g, + { + {n.at(0), n.at(2)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(3)}, + {n.at(3), n.at(4)}, + {n.at(4), n.at(5)}, + {n.at(5), n.at(6)}, + {n.at(5), n.at(7)}, + }); + + std::vector reduction = {e.at(3), e.at(4), e.at(5)}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = + set_minus(unordered_set_of(n), {n.at(4), n.at(3)}); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = [&] { + std::unordered_set new_edges = unordered_set_of(e); + new_edges.erase(e.at(3)); + new_edges.erase(e.at(4)); + new_edges.erase(e.at(5)); + new_edges.insert(returned_edge); + return new_edges; + }(); + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(2); + CHECK(returned_edge_src == correct_src); + } + SUBCASE("dst") { Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); Node correct_dst = n.at(5); From 3cc3407dc6d05a84ef845888cffe12adba60cb30 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 11 Jan 2025 11:33:52 -0800 Subject: [PATCH 09/16] PR fixes --- lib/utils/include/utils/graph/algorithms.h | 4 +- .../include/utils/graph/digraph/algorithms.h | 4 +- .../algorithms/get_incoming_edges.h | 3 +- .../algorithms/get_outgoing_edges.h | 4 +- .../include/utils/graph/node/node.struct.toml | 1 - .../extended_parallel_reduction.struct.toml | 21 ++ .../extended_series_reduction.struct.toml | 21 ++ .../get_series_parallel_decomposition.h | 2 - .../series_parallel/parallel_reduction.h | 19 +- .../series_parallel_decomposition.h | 2 +- .../graph/series_parallel/series_reduction.h | 44 ++++- .../undirected/undirected_edge.struct.toml | 1 - .../src/utils/graph/digraph/algorithms.cc | 6 +- .../get_cbc_decomposition.cc | 4 +- .../is_complete_bipartite_digraph.cc | 2 +- .../digraph/algorithms/get_dominators_map.cc | 4 +- .../algorithms/get_topological_ordering.cc | 2 +- .../get_inverse_line_graph.cc | 8 +- .../graph/digraph/algorithms/is_acyclic.cc | 6 +- .../algorithms/get_incoming_edges.cc | 10 +- .../algorithms/get_outgoing_edges.cc | 9 +- .../algorithms/find_isomorphisms.cc | 8 +- .../get_series_parallel_decomposition.cc | 21 +- .../series_parallel/parallel_reduction.cc | 29 +-- .../series_parallel_decomposition.cc | 2 +- .../graph/series_parallel/series_reduction.cc | 40 ++-- lib/utils/src/utils/graph/views/views.cc | 7 +- .../test/src/utils/containers/contains.cc | 4 +- .../graph/digraph/algorithms/algorithms.cc | 184 ++++++++++-------- .../digraph/algorithms/directed_edge_query.cc | 98 +++++----- .../digraph/algorithms/get_dominators.cc | 41 ++-- .../graph/digraph/algorithms/traversal.cc | 111 ++++++----- .../algorithms/get_incoming_edges.cc | 34 +++- .../algorithms/get_outgoing_edges.cc | 48 +++-- .../graph/series_parallel/series_reduction.cc | 32 +-- .../algorithms/get_connected_components.cc | 117 +++++++---- .../src/utils/graph/undirected/undirected.cc | 32 --- 37 files changed, 587 insertions(+), 398 deletions(-) create mode 100644 lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 3f170b5652..ff7a7dcad2 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -139,10 +139,10 @@ std::unordered_set get_neighbors(DiGraphView const &, Node const &); // &); // return the set of nodes without incoming edges -std::unordered_set get_sources(DiGraphView const &); +std::unordered_set get_initial_nodes(DiGraphView const &); // return the set of nodes without outgoing edges -std::unordered_set get_sinks(DiGraphView const &); +std::unordered_set get_terminal_nodes(DiGraphView const &); // std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index 370f181c37..fdced8a05c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -6,8 +6,8 @@ namespace FlexFlow { std::unordered_set get_edges(DiGraphView const &); -std::unordered_set get_sources(DiGraphView const &); -std::unordered_set get_sinks(DiGraphView const &); +std::unordered_set get_initial_nodes(DiGraphView const &); +std::unordered_set get_terminal_nodes(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h index 76be999b54..471a12a44b 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h @@ -9,7 +9,8 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); std::unordered_map> - get_incoming_edges(MultiDiGraphView const &g); + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &nodes); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h index 6a8474673e..bd8c364f7e 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H #include "utils/graph/multidigraph/multidigraph_view.h" +#include namespace FlexFlow { @@ -9,7 +10,8 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); std::unordered_map> - get_outgoing_edges(MultiDiGraphView const &g); + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 46e0255de3..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -6,7 +6,6 @@ features = [ "hash", "fmt", "json", - "rapidcheck", ] includes = [ diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml new file mode 100644 index 0000000000..9c1ed68730 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ExtendedParallelReduction" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml new file mode 100644 index 0000000000..f1cf0ccde3 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ExtendedSeriesReduction" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "edges" +type = "std::vector<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h index f2a006d899..5f492c1aeb 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -4,8 +4,6 @@ #include "utils/graph/digraph/digraph.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/optional.h" -#include -#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 0b3c7f3619..7a3a7a021c 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -2,24 +2,39 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include +#include namespace FlexFlow { ParallelReduction make_parallel_reduction(MultiDiEdge const &, MultiDiEdge const &); + std::optional find_parallel_reduction(MultiDiGraphView const &); -std::unordered_map> +/** + * @brief Finds all ExtendedParallelReduction for a given MultiDiGraph + * @details An ExtendedParallelReduction is a unordered collection of + * `MultiDiEdge`s such that they share a common source and destination node. + */ +std::unordered_set find_all_extended_parallel_reductions(MultiDiGraphView const &); MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +/** + * @brief Applies a given ExtendedParallelReduction in place to a given + * MultiDiGraph + * @details The reduction removes all but one `MultiDiEdge`, so that the source, + * destination nodes associated with the reduction become connected by a single + * edge. + */ MultiDiEdge apply_extended_parallel_reduction(MultiDiGraph &, - std::unordered_set const &); + ExtendedParallelReduction const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index d56d4a55f7..b3fc201ca5 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -30,7 +30,7 @@ SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, // duplicate nodes within `sp` are counted multiple times size_t num_nodes(SeriesParallelDecomposition const &sp); -SeriesParallelDecomposition serial_composition( +SeriesParallelDecomposition series_composition( std::vector const &sp_compositions); SeriesParallelDecomposition parallel_composition( std::unordered_multiset const diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index 0de8aecc19..3e281066d4 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -3,6 +3,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/series_reduction.dtg.h" #include "utils/hash/vector.h" @@ -15,13 +16,50 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &); SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &); std::optional find_series_reduction(MultiDiGraphView const &); -std::unordered_set> +/** + * @brief Finds all the ExtendedSeriesReduction structures in a given graph. + * + * @details An `ExtendedSeriesReduction` is an ordered collection of + * `MultiDiEdges` such that: + * - The destination node of the nth edge is the same as the source node of the + * (n+1)th edge. + * - Such a node (intermediate node) has exactly two edges: one incoming (nth + * edge) and one outgoing ((n+1)th edge). + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * We have that [(A,B), (B,D), (D,E)] and [(A,C), (C,E)] both constitute + * `ExtendedSeriesReduction`. + */ +std::unordered_set find_all_extended_series_reductions(MultiDiGraphView const &g); MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &); -MultiDiEdge apply_extended_series_reduction( - MultiDiGraph &g, std::vector const &series_edges); +/** + * @brief Applies a given ExtendedSeriesReduction in-place to a given graph. + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * Given the ExtendedSeriesReduction [(A,B), (B,D), (D,E)], the intermediate + *nodes B, D, will be deleted, and the resulting graph will be: + * + * A ----> E + * \ / + * -> C -> + * + **/ +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml index f5258b0bfd..0ad8232339 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -5,7 +5,6 @@ features = [ "ord", "hash", "fmt", - "rapidcheck" ] includes = [ diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index 8cd685e5c6..84798b2f62 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -15,11 +15,11 @@ std::unordered_set get_edges(DiGraphView const &g) { return g.query_edges(directed_edge_query_all()); } -std::unordered_set get_sinks(DiGraphView const &g) { - return get_sources(flipped(g)); +std::unordered_set get_terminal_nodes(DiGraphView const &g) { + return get_initial_nodes(flipped(g)); } -std::unordered_set get_sources(DiGraphView const &g) { +std::unordered_set get_initial_nodes(DiGraphView const &g) { std::unordered_set all_nodes = get_nodes(g); std::unordered_set with_incoming_edge = transform(get_edges(g), [](DirectedEdge const &e) { return e.dst; }); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 92bd1e32ca..9a2f9cb019 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -71,8 +71,8 @@ std::optional extend(already_in_a_tail, tail); } - assert(already_in_a_head == set_minus(get_nodes(g), get_sinks(g))); - assert(already_in_a_tail == set_minus(get_nodes(g), get_sources(g))); + assert(already_in_a_head == set_minus(get_nodes(g), get_terminal_nodes(g))); + assert(already_in_a_tail == set_minus(get_nodes(g), get_initial_nodes(g))); return result; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index ccd2808603..bf428ed26b 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -6,7 +6,7 @@ namespace FlexFlow { bool is_complete_bipartite_digraph(DiGraphView const &g) { - return is_complete_bipartite_digraph(g, get_sources(g)); + return is_complete_bipartite_digraph(g, get_initial_nodes(g)); } bool is_complete_bipartite_digraph(DiGraphView const &g, diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 3dd9de73f0..1d909150cc 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -15,11 +15,11 @@ namespace FlexFlow { std::unordered_map> get_dominators_map(DiGraphView const &g) { - std::unordered_set sources = get_sources(g); + std::unordered_set initial_nodes = get_initial_nodes(g); std::queue queue; - for (Node src : get_sources(g)) { + for (Node src : get_initial_nodes(g)) { queue.push(src); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index 41fe3b67d5..fea799b3e9 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -9,7 +9,7 @@ namespace FlexFlow { static std::vector get_unchecked_topological_ordering(DiGraphView const &g) { - auto dfs_view = unchecked_dfs(g, get_sources(g)); + auto dfs_view = unchecked_dfs(g, get_initial_nodes(g)); std::vector order; std::unordered_set seen; std::unordered_map> predecessors = diff --git a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index 77f04f2efd..ccf943c4d3 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -42,11 +42,11 @@ std::optional return get_component_containing_node_in_tail(cbc_decomposition, n).value(); }; - std::unordered_set sources = get_sources(view); - std::unordered_set sinks = get_sinks(view); + std::unordered_set initial_nodes = get_initial_nodes(view); + std::unordered_set terminal_nodes = get_terminal_nodes(view); auto src_for_node = [&](Node const &v) -> Node { - if (contains(sources, v)) { + if (contains(initial_nodes, v)) { return alpha; } else { return component_nodes.at_l(t(v)); @@ -54,7 +54,7 @@ std::optional }; auto dst_for_node = [&](Node const &v) -> Node { - if (contains(sinks, v)) { + if (contains(terminal_nodes, v)) { return omega; } else { return component_nodes.at_l(h(v)); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index dd660f193d..018d07163d 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -9,11 +9,11 @@ std::optional is_acyclic(DiGraphView const &g) { if (num_nodes(g) == 0) { return std::nullopt; } - std::unordered_set sources = get_sources(g); - if (sources.size() == 0) { + std::unordered_set initial_nodes = get_initial_nodes(g); + if (initial_nodes.size() == 0) { return false; } - auto dfs_view = unchecked_dfs(g, sources); + auto dfs_view = unchecked_dfs(g, initial_nodes); std::unordered_set seen; for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); it++) { diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index 50818dea2f..7a5ba695f9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -2,7 +2,9 @@ #include "utils/containers/group_by.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidiedge_query.dtg.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/query_set.h" namespace FlexFlow { @@ -12,12 +14,14 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &g, } std::unordered_map> - get_incoming_edges(MultiDiGraphView const &g) { + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { std::unordered_map> result = - group_by(get_edges(g), + group_by(g.query_edges(MultiDiEdgeQuery{query_set::matchall(), + query_set{ns}}), [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }); - for (Node const &n : get_nodes(g)) { + for (Node const &n : ns) { result[n]; } diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 55847cf2af..d183b44137 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -2,6 +2,7 @@ #include "utils/containers/group_by.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, @@ -10,12 +11,14 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, } std::unordered_map> - get_outgoing_edges(MultiDiGraphView const &g) { + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { std::unordered_map> result = - group_by(get_edges(g), + group_by(g.query_edges(MultiDiEdgeQuery{query_set{ns}, + query_set::matchall()}), [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }); - for (Node const &n : get_nodes(g)) { + for (Node const &n : ns) { result[n]; } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index fa17678943..375aaa3762 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -34,14 +34,14 @@ static std::optional { std::unordered_set already_mapped_src_nodes = left_entries(sink_node_mapping); - std::unordered_set src_g_sink_nodes = get_sinks(src_g); + std::unordered_set src_g_sink_nodes = get_terminal_nodes(src_g); assert(already_mapped_src_nodes == src_g_sink_nodes); } { std::unordered_set already_mapped_dst_nodes = right_entries(sink_node_mapping); - std::unordered_set dst_g_sink_nodes = get_sinks(dst_g); + std::unordered_set dst_g_sink_nodes = get_terminal_nodes(dst_g); assert(already_mapped_dst_nodes == dst_g_sink_nodes); } @@ -201,8 +201,8 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = vector_of(get_sinks(src)); - std::unordered_set dst_sink_nodes = get_sinks(dst); + std::vector src_sink_nodes = vector_of(get_terminal_nodes(src)); + std::unordered_set dst_sink_nodes = get_terminal_nodes(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { return {}; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 908743fae1..b45e62eae7 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -12,6 +12,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" @@ -45,19 +46,19 @@ std::optional while (true) { int reductions = 0; - std::unordered_map> - parallel_reductions = find_all_extended_parallel_reductions(ttsp); + std::unordered_set parallel_reductions = + find_all_extended_parallel_reductions(ttsp); if (!parallel_reductions.empty()) { - for (auto const &[_, parallel_reduction] : parallel_reductions) { + for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { MultiDiEdge merged = apply_extended_parallel_reduction(ttsp, parallel_reduction); SeriesParallelDecomposition new_tree = parallel_composition(transform( - unordered_multiset_of(parallel_reduction), + unordered_multiset_of(parallel_reduction.edges), [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); - for (MultiDiEdge const &e : parallel_reduction) { + for (MultiDiEdge const &e : parallel_reduction.edges) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -65,19 +66,19 @@ std::optional reductions++; } - std::unordered_set> series_reductions = + std::unordered_set series_reductions = find_all_extended_series_reductions(ttsp); if (!series_reductions.empty()) { - for (std::vector series_reduction : series_reductions) { + for (ExtendedSeriesReduction series_reduction : series_reductions) { MultiDiEdge merged = apply_extended_series_reduction(ttsp, series_reduction); - SeriesParallelDecomposition new_tree = serial_composition( - transform(series_reduction, [&](MultiDiEdge const &e) { + SeriesParallelDecomposition new_tree = series_composition( + transform(series_reduction.edges, [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); - for (MultiDiEdge const &e : series_reduction) { + for (MultiDiEdge const &e : series_reduction.edges) { ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index c7eb866b62..3aa677a2f7 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,6 +1,7 @@ #include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/containers/get_one_of.h" #include "utils/containers/group_by.h" +#include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/graph/digraph/directed_edge.dtg.h" @@ -9,6 +10,7 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" #include "utils/hash/unordered_set.h" #include #include @@ -34,17 +36,22 @@ std::optional return std::nullopt; } -std::unordered_map> +std::unordered_set find_all_extended_parallel_reductions(MultiDiGraphView const &g) { std::unordered_map> - parallel_groups = group_by(get_edges(g), [&](MultiDiEdge const &edge) { - return get_directed_edge(g, edge); - }); + reduction_groups; + for (MultiDiEdge const &edge : get_edges(g)) { + reduction_groups[get_directed_edge(g, edge)].insert(edge); + } + + std::unordered_set> reductions = filter( + unordered_set_of(values(reduction_groups)), + [](std::unordered_set const &s) { return s.size() > 1; }); - return filter( - parallel_groups, - [](std::pair> const - &group) { return group.second.size() > 1; }); + return transform(reductions, + [&](std::unordered_set const &edges) { + return ExtendedParallelReduction{edges}; + }); } MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, @@ -54,11 +61,11 @@ MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, } MultiDiEdge apply_extended_parallel_reduction( - MultiDiGraph &g, std::unordered_set const ¶llel_edges) { + MultiDiGraph &g, ExtendedParallelReduction const &reduction) { - MultiDiEdge keep_edge = get_one_of(parallel_edges); + MultiDiEdge keep_edge = get_one_of(reduction.edges); - for (MultiDiEdge const ¶llel_edge : parallel_edges) { + for (MultiDiEdge const ¶llel_edge : reduction.edges) { if (parallel_edge != keep_edge) { g.remove_edge(parallel_edge); } diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index dc99ef6c5a..937fc1254e 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -99,7 +99,7 @@ bool is_empty(SeriesParallelDecomposition const &sp) { return sp.visit([](auto const &t) { return is_empty(t); }); } -SeriesParallelDecomposition serial_composition( +SeriesParallelDecomposition series_composition( std::vector const &sp_compositions) { std::vector> composition{}; for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 26fabe593c..5b9b592444 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -15,6 +15,7 @@ #include "utils/graph/multidigraph/multidigraph.h" #include "utils/graph/multidigraph/multidigraph_view.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/hash/unordered_set.h" #include @@ -50,24 +51,28 @@ std::optional return std::nullopt; } -std::unordered_set> +std::unordered_set find_all_extended_series_reductions(MultiDiGraphView const &g) { - std::unordered_map> incoming_edges = - get_incoming_edges(g); - std::unordered_map> outgoing_edges = - get_outgoing_edges(g); + + auto incoming_edges_map = get_incoming_edges(g, get_nodes(g)); + auto outgoing_edges_map = get_outgoing_edges(g, get_nodes(g)); + std::unordered_map> strands; std::unordered_map node_to_head_of_strand; + for (Node const &n : get_topological_ordering(g)) { - if ((incoming_edges.at(n).size() == 1) && - (outgoing_edges.at(n).size() == 1)) { - MultiDiEdge incoming = get_only(incoming_edges.at(n)); - MultiDiEdge outgoing = get_only(outgoing_edges.at(n)); + if ((incoming_edges_map.at(n).size() == 1) && + (outgoing_edges_map.at(n).size() == 1)) { + + MultiDiEdge incoming = get_only(incoming_edges_map.at(n)); + MultiDiEdge outgoing = get_only(outgoing_edges_map.at(n)); Node pre = g.get_multidiedge_src(incoming); + if (contains_key(node_to_head_of_strand, pre)) { Node head = node_to_head_of_strand.at(pre); node_to_head_of_strand.emplace(n, head); strands.at(head).push_back(outgoing); + } else { node_to_head_of_strand.emplace(n, n); strands[n].push_back(incoming); @@ -75,7 +80,10 @@ std::unordered_set> } } } - return unordered_set_of(values(strands)); + + return transform(unordered_set_of(values(strands)), [&](auto const &edges) { + return ExtendedSeriesReduction{edges}; + }); } MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { @@ -87,16 +95,18 @@ MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { return g.add_edge(pre_node, post_node); } -MultiDiEdge apply_extended_series_reduction( - MultiDiGraph &g, std::vector const &series_edges) { +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction) { - Node first = g.get_multidiedge_src(series_edges.at(0)); - Node last = g.get_multidiedge_dst(series_edges.at(series_edges.size() - 1)); + Node first = g.get_multidiedge_src(reduction.edges.at(0)); + Node last = g.get_multidiedge_dst(reduction.edges.back()); std::vector internal_nodes; - for (MultiDiEdge const &e : subvec(series_edges, std::nullopt, -1)) { + for (MultiDiEdge const &e : subvec(reduction.edges, std::nullopt, -1)) { internal_nodes.push_back(g.get_multidiedge_dst(e)); } + for (Node const &n : internal_nodes) { g.remove_node(n); } diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 7bb039d314..e8f0a443c4 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,12 +1,8 @@ #include "utils/graph/views/views.h" -#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" -#include "utils/disjoint_set.h" -#include "utils/exception.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" #include "utils/graph/query_set.h" #include "utils/graph/undirected/undirected_edge_query.h" @@ -118,8 +114,7 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + g.query_edges(UndirectedEdgeQuery{query_union(q.srcs, q.dsts)}); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index fc42d25eea..9d686ab814 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -10,13 +10,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("std::vector") { std::vector v = {1, 2, 3, 4, 5}; CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + CHECK_FALSE(contains(v, 6)); } SUBCASE("std::unordered_set") { std::unordered_set s = {1, 2, 3, 4, 5}; CHECK(contains(s, 3)); - CHECK(!contains(s, 6)); + CHECK_FALSE(contains(s, 6)); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc index 0817c69e06..fd39449c2c 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -8,99 +8,119 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("DiGraph - algorithms.cc") { + TEST_CASE("get_edges(DiGraph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); std::vector e = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, }; add_edges(g, e); - SUBCASE("get_edges") { - SUBCASE("Base") { - std::unordered_set correct = unordered_set_of(e); - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[3], n[1]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge") { - g.remove_edge(DirectedEdge{n[0], n[3]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); } - SUBCASE("get_sinks") { - SUBCASE("Base") { - std::unordered_set correct = {n[2], n[3]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a sink") { - g.add_edge(DirectedEdge{n[3], n[2]}); - std::unordered_set correct = {n[2]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sinks(g); - std::unordered_set correct = {n[3]}; - CHECK(result == correct); - } + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(3), n.at(1)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); } - SUBCASE("get_sources") { - SUBCASE("Base") { - std::unordered_set correct = {n[0]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a source") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set correct = {}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge to create a new source") { - g.remove_edge(DirectedEdge{n[0], n[1]}); - std::unordered_set correct = {n[0], n[1]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sources(g); - std::unordered_set correct = {}; - CHECK(result.empty()); - } + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n.at(0), n.at(3)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + TEST_CASE("get_terminal_nodes(DiGraph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(2), n.at(3)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a terminal node") { + g.add_edge(DirectedEdge{n.at(3), n.at(2)}); + std::unordered_set correct = {n.at(2)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_terminal_nodes(g); + std::unordered_set correct = {n.at(3)}; + CHECK(result == correct); + } + } + + TEST_CASE("get_initial_nodes(DiGraph)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(0)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set correct = {}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n.at(0), n.at(1)}); + std::unordered_set correct = {n.at(0), n.at(1)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_initial_nodes(g); + std::unordered_set correct = {}; + CHECK(result.empty()); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc index 1dde5c8f69..ee7ead009e 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -1,70 +1,68 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/algorithms.h" -#include "utils/graph/digraph/algorithms/get_successors.h" -#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("directed_edge_query") { - DiGraph g = DiGraph::create(); + TEST_CASE("directed_edge_query_all") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; - std::vector n = add_nodes(g, 5); + DirectedEdgeQuery result = directed_edge_query_all(); - add_edges(g, - {DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(1), n.at(3)}}); + CHECK(matches_edge(result, e1)); + CHECK(matches_edge(result, e2)); + } - SUBCASE("directed_edge_query_all") { + TEST_CASE("matches_edge") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; - DirectedEdgeQuery result = directed_edge_query_all(); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{n1}, query_set{n2}}; - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); - } + CHECK(matches_edge(query, e1)); + CHECK_FALSE(matches_edge(query, e2)); + + DirectedEdge flipped_edge = DirectedEdge{n2, n1}; + CHECK_FALSE(matches_edge(query, flipped_edge)); + } - SUBCASE("matches_edge") { - DirectedEdgeQuery q = - DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; + TEST_CASE("query_intersection") { + Node n1{0}, n2{1}, n3{2}, n4{3}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; + DirectedEdge e3 = DirectedEdge{n3, n4}; - CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); - CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n2, n3}}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{query_set{n2, n3}, query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n2}, query_set{n3}}; + + CHECK(result == expected); } - SUBCASE("query_intersection") { - SUBCASE("standard intersection") { - DirectedEdgeQuery q1 = DirectedEdgeQuery{ - query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; - DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, - query_set{n.at(2), n.at(3)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1)}, - query_set{n.at(2)}, - }; - - CHECK(result == correct); - } - SUBCASE("intersection with std::nullopt") { - DirectedEdgeQuery q1 = - DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; - DirectedEdgeQuery q2 = - DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; - CHECK(result == correct); - } + SUBCASE("intersection with matchall") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n3, n4}}; + + CHECK(result == expected); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc index e9151b53e5..17bea2210f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -9,28 +9,29 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_dominators") { DiGraph g = DiGraph::create(); + SUBCASE("acyclic graph") { + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); - std::vector n = add_nodes(g, 4); - std::vector e = { - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - }; - add_edges(g, e); - - SUBCASE("single node") { - Node node = n.at(2); - std::unordered_set correct = {n.at(0), n.at(2)}; - std::unordered_set result = get_dominators(g, node); - CHECK(correct == result); - } + SUBCASE("get_dominators(DiGraph, Node)") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } - SUBCASE("multiple nodes") { - std::unordered_set nodes = {n.at(1), n.at(3)}; - std::unordered_set result = get_dominators(g, nodes); - std::unordered_set correct = {n.at(0)}; - CHECK(correct == result); + SUBCASE("get_dominators(DiGraph, std::unordered_set)") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } } SUBCASE("graph with cycles") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc index 0d8e7ca53a..f778cfbd22 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -1,4 +1,5 @@ #include "utils/graph/traversal.h" +#include "utils/containers/contains.h" #include "utils/fmt/vector.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" @@ -12,40 +13,55 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_unchecked_dfs_ordering") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); - SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + SUBCASE("linear path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } + + SUBCASE("diamond path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(3), n.at(2), n.at(3)}, + {n.at(0), n.at(2), n.at(3), n.at(1), n.at(3)}}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } } TEST_CASE("get_bfs_ordering") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[3]}, - DirectedEdge{n[2], n[3]}, - DirectedEdge{n[3], n[4]}, - DirectedEdge{n[4], n[5]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}); SUBCASE("branching path") { std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3], n[4], n[5]}, - {n[0], n[2], n[1], n[3], n[4], n[5]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); + {n.at(0), n.at(1), n.at(2), n.at(3), n.at(4), n.at(5)}, + {n.at(0), n.at(2), n.at(1), n.at(3), n.at(4), n.at(5)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } SUBCASE("isolated node") { - std::vector correct = {n[5]}; - std::vector result = get_bfs_ordering(g, {n[5]}); + std::vector correct = {n.at(5)}; + std::vector result = get_bfs_ordering(g, {n.at(5)}); CHECK(correct == result); } @@ -53,15 +69,15 @@ TEST_SUITE(FF_TEST_SUITE) { g = DiGraph::create(); n = add_nodes(g, 3); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[0]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[0]}, - DirectedEdge{n[2], n[1]}}); - std::unordered_set> corrects = {{n[0], n[1], n[2]}, - {n[0], n[2], n[1]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(2), n.at(1)}}); + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(2)}, {n.at(0), n.at(2), n.at(1)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } } @@ -70,42 +86,49 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("start from non-initial node") { + std::vector correct = {n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(1)}); CHECK(correct == result); } SUBCASE("with cycle") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } SUBCASE("branching") { - g.add_edge(DirectedEdge{n[1], n[3]}); + g.add_edge(DirectedEdge{n.at(1), n.at(3)}); std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; - std::vector result = get_dfs_ordering(g, {n[0]}); + {n.at(0), n.at(1), n.at(2), n.at(3)}, + {n.at(0), n.at(1), n.at(3), n.at(2)}}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(contains(corrects, result)); } SUBCASE("disconnected") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[0], n[1], n[2]}; - std::vector result = get_dfs_ordering(g, {n[0]}); + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(0), n.at(1), n.at(2)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); CHECK(correct == result); } SUBCASE("isolated node") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[3]}; - std::vector result = get_dfs_ordering(g, {n[3]}); + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(3)}); CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index b5943cd99f..b15b8a9d7d 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -11,7 +11,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { + TEST_CASE("get_incoming_edges") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); @@ -19,17 +19,33 @@ TEST_SUITE(FF_TEST_SUITE) { {{n.at(0), n.at(0)}, {n.at(0), n.at(1)}, {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}}); + {n.at(1), n.at(0)}, + {n.at(2), n.at(0)}}); - SUBCASE("node has incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(1)); - std::unordered_set correct = {edges.at(1), edges.at(2)}; - CHECK(result == correct); + SUBCASE("get_incoming_edges(MultiDiGraphView, Node)") { + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } - SUBCASE("node has no incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(2)); - std::unordered_set correct = {}; + SUBCASE("get_incoming_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(2)}; + std::unordered_map> result = + get_incoming_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(3), edges.at(4)}}, {n.at(2), {}}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index d4748e8422..69b38090d3 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -11,29 +11,45 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { + TEST_CASE("get_outgoing_edges") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); - std::vector> input = { - {n.at(0), n.at(0)}, - {n.at(0), n.at(1)}, - {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}, - }; + std::vector edges = add_edges(g, + { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(0)}, + }); - std::vector edges = add_edges(g, input); + SUBCASE("get_outgoing_edges(MultiDiGraphView, Node)") { - SUBCASE("node has outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(0)); - std::unordered_set correct = { - edges.at(0), edges.at(1), edges.at(2)}; - CHECK(result == correct); + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2), edges.at(3)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } - SUBCASE("node has no outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(2)); - std::unordered_set correct = {}; + SUBCASE("get_outgoing_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(1)}; + std::unordered_map> result = + get_outgoing_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(1), edges.at(2), edges.at(3)}}, + {n.at(1), {edges.at(4)}}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index 3a8a5e9a60..51606bc9d6 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -245,6 +245,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } } + TEST_CASE("find_all_extended_series_reductions") { MultiDiGraph g = MultiDiGraph::create(); @@ -257,14 +258,14 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(2), n.at(3)}, }); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = { - {e[0], e[1], e[2]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(1), e.at(2)})}; CHECK(result == correct); } - SUBCASE("2 linear strands") { + SUBCASE("2 linear strands with a common terminal node") { std::vector n = add_nodes(g, 4); std::vector e = add_edges(g, {{n.at(0), n.at(1)}, @@ -272,10 +273,11 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(1), n.at(3)}, {n.at(2), n.at(3)}}); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = {{e[0], e[2]}, - {e[1], e[3]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(2)}), + ExtendedSeriesReduction({e.at(1), e.at(3)})}; CHECK(result == correct); } @@ -294,10 +296,12 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(6), n.at(8)}, {n.at(7), n.at(8)}}); - std::unordered_set> result = + std::unordered_set result = find_all_extended_series_reductions(g); - std::unordered_set> correct = { - {e[0], e[2], e[7]}, {e[3], e[6]}, {e[5], e[9]}}; + std::unordered_set correct = { + ExtendedSeriesReduction({e.at(0), e.at(2), e.at(7)}), + ExtendedSeriesReduction({e.at(3), e.at(6)}), + ExtendedSeriesReduction({e.at(5), e.at(9)})}; CHECK(result == correct); } } @@ -310,7 +314,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector e = add_edges( g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); - std::vector reduction = {e.at(0), e.at(1), e.at(2)}; + ExtendedSeriesReduction reduction({e.at(0), e.at(1), e.at(2)}); MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); @@ -355,7 +359,7 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(5), n.at(7)}, }); - std::vector reduction = {e.at(3), e.at(4), e.at(5)}; + ExtendedSeriesReduction reduction({e.at(3), e.at(4), e.at(5)}); MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); @@ -370,9 +374,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set result_edges = get_edges(g); std::unordered_set correct_edges = [&] { std::unordered_set new_edges = unordered_set_of(e); - new_edges.erase(e.at(3)); - new_edges.erase(e.at(4)); - new_edges.erase(e.at(5)); + new_edges = set_minus(new_edges, {e.at(3), e.at(4), e.at(5)}); new_edges.insert(returned_edge); return new_edges; }(); diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index 179cce7db7..e6b0575ff5 100644 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -7,55 +7,86 @@ using namespace FlexFlow; -TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); - SUBCASE("disjoint nodes") { - std::vector n = add_nodes(g, 3); + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); - std::unordered_set> correct = { - {n[0]}, - {n[1]}, - {n[2]}, - }; - std::unordered_set> result = - get_connected_components(g); + std::unordered_set> correct = { + {n.at(0)}, + {n.at(1)}, + {n.at(2)}, + }; + std::unordered_set> result = + get_connected_components(g); - CHECK(correct == result); - } + CHECK(correct == result); + } - SUBCASE("2 components") { - std::vector n = add_nodes(g, 4); - add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); + SUBCASE("1 component") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(3), n.at(0)}}, + }); - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - std::unordered_set> result = - get_connected_components(g); + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2), n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); - CHECK(correct == result); - } + CHECK(correct == result); + } + + SUBCASE("2 components") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(2), n.at(1)}}}); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(0), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(3), n.at(4)}}, + }); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3), n.at(4)}, + {n.at(5)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("empty graph") { + std::unordered_set> correct = {}; + std::unordered_set> result = + get_connected_components(g); - SUBCASE("3 components") { - std::vector n = add_nodes(g, 6); - add_edges(g, - { - UndirectedEdge{{n[0], n[1]}}, - UndirectedEdge{{n[0], n[2]}}, - UndirectedEdge{{n[1], n[2]}}, - UndirectedEdge{{n[3], n[4]}}, - }); - - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3], n[4]}, - {n[5]}, - }; - std::unordered_set> result = - get_connected_components(g); - - CHECK(correct == result); + CHECK(correct == result); + } } } diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc index 7973cf8af5..6454379118 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -41,35 +41,3 @@ TEST_SUITE(FF_TEST_SUITE) { }); } } -/* static_assert(is_fmtable::value, ""); */ - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE( - "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct( - gen::construct>(gen::elementOf(n), - gen::elementOf(n)))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); - }); - } -} From a17c76e65506a3d92e9d0a30e8202c960a43dbc9 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 11 Jan 2025 15:42:31 -0800 Subject: [PATCH 10/16] added strict=true to zip --- lib/utils/include/utils/containers/zip.h | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 0f6dbed1d3..78e449494d 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -4,12 +4,21 @@ #include #include #include +#include "utils/exception.h" +#include "fmt/format.h" namespace FlexFlow { template std::vector> zip(std::vector const &l, - std::vector const &r) { + std::vector const &r, + bool strict = false) { + if (strict && l.size() != r.size()) { + throw mk_runtime_error(fmt::format( + "When strict = true, vector sizes must match. Got vectors of length {} and {}", + l.size(), r.size())); + } + std::vector> result; for (int i = 0; i < std::min(l.size(), r.size()); i++) { result.push_back(std::make_pair(l.at(i), r.at(i))); @@ -19,8 +28,15 @@ std::vector> zip(std::vector const &l, template std::vector> zip(std::vector const &a, - std::vector const &b, - std::vector const &c) { + std::vector const &b, + std::vector const &c, + bool strict = false) { + if (strict && (a.size() != b.size() || b.size() != c.size())) { + throw std::runtime_error(fmt::format( + "When strict = true, vectors sizes must match. Got vectors of length {}, {} and {}", + a.size(), b.size(), c.size())); + } + std::vector> result; for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); From a56201c2d3e88a72e9b04bbb273035df87743883 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 11 Jan 2025 15:45:04 -0800 Subject: [PATCH 11/16] fmt --- lib/utils/include/utils/containers/zip.h | 33 +++++++++++++----------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 78e449494d..65722e049f 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,22 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H +#include "fmt/format.h" +#include "utils/exception.h" #include #include #include -#include "utils/exception.h" -#include "fmt/format.h" namespace FlexFlow { template -std::vector> zip(std::vector const &l, - std::vector const &r, - bool strict = false) { +std::vector> + zip(std::vector const &l, std::vector const &r, bool strict = false) { if (strict && l.size() != r.size()) { - throw mk_runtime_error(fmt::format( - "When strict = true, vector sizes must match. Got vectors of length {} and {}", - l.size(), r.size())); + throw mk_runtime_error(fmt::format("When strict = true, vector sizes must " + "match. Got vectors of length {} and {}", + l.size(), + r.size())); } std::vector> result; @@ -28,15 +28,18 @@ std::vector> zip(std::vector const &l, template std::vector> zip(std::vector const &a, - std::vector const &b, - std::vector const &c, - bool strict = false) { + std::vector const &b, + std::vector const &c, + bool strict = false) { if (strict && (a.size() != b.size() || b.size() != c.size())) { - throw std::runtime_error(fmt::format( - "When strict = true, vectors sizes must match. Got vectors of length {}, {} and {}", - a.size(), b.size(), c.size())); + throw std::runtime_error( + fmt::format("When strict = true, vectors sizes must match. Got vectors " + "of length {}, {} and {}", + a.size(), + b.size(), + c.size())); } - + std::vector> result; for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); From cc41dbb4a8fceeae0e48efec3b5c46b6e76f942b Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Tue, 21 Jan 2025 16:43:50 -0800 Subject: [PATCH 12/16] minor fixes --- lib/utils/include/utils/graph/README.md | 14 ++++++++------ lib/utils/include/utils/graph/algorithms.h | 5 ----- lib/utils/include/utils/graph/digraph/algorithms.h | 12 ++++++++++-- .../extended_parallel_reduction.struct.toml | 7 +++++++ .../extended_series_reduction.struct.toml | 8 ++++++++ .../graph/series_parallel/parallel_reduction.h | 2 -- 6 files changed, 33 insertions(+), 15 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 41777b9b9a..5cf0c88015 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -78,7 +78,7 @@ flowchart TD E --> E ``` -Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +Note that the node names are completely arbitrary: they have no apparent ordering or other meaning besides representing the topology of the graph. This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. @@ -127,12 +127,13 @@ The primary abstraction for representing computation graphs / task graphs is the At a high level, nodes represent multivariate functions (from tuples of inputs to tuple of outputs), while edges represent value uses of such functions. `DataflowGraph` is similar to `MultiDiGraph`, but with the following important differences: - - The edges entering, exiting a given nodes have a well-defined order. + - The edges entering, exiting a given nodes have a well-defined order. + - The outputs of a given node also have a well-defined order. - `DataflowGraph`s are directed acyclic graphs. This is enforced by the interface used to construct them, since a node can only be added to the graph after all of its predecessor nodes have already been added. The main components of `DataflowGraph` are as follows: -- `DataflowInput`: used to represent the ordered sequence of incoming dependencies (arguments) of a given node (operator). -- `DataflowOutput`: used to represent the ordered sequence of outgoing results (value uses) from a given node (operator). +- `DataflowInput`: used to denote an entry in the ordered sequence of incoming dependencies (arguments) of a given node (operator). +- `DataflowOutput`: used to denote an entry in the ordered sequence of outgoing results (value uses) from a given node (operator). - `DataflowEdge`: wrapper around a `DataflowInput`, `DataflowOutput` pair between 2 nodes. - `NodeAddedResult`: returned upon adding a new node. Contains the newly generated `Node` and the vector of `DataflowOutput`s for the given node. @@ -220,15 +221,16 @@ flowchart TD ### Open Dataflow Variant -`Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. +`Open` should be interpreted in the topological sense: that is, a graph that contains some edges where one of the edge's 2 nodes is not present in the graph itself. This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. +`DataflowGraphInput` is used to represent the open (incoming) inputs to the graph. Note that, unlike `DataFlowInput`, `DataflowGraphInput`s are unordered (given that they are inputs to possibly several different nodes within the graph). ### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. - `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. -- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s. While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index ff7a7dcad2..c1ebd7b534 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -138,11 +138,6 @@ std::unordered_set get_neighbors(DiGraphView const &, Node const &); // std::unordered_set get_neighbors(MultiDiGraphView const &, Node const // &); -// return the set of nodes without incoming edges -std::unordered_set get_initial_nodes(DiGraphView const &); - -// return the set of nodes without outgoing edges -std::unordered_set get_terminal_nodes(DiGraphView const &); // std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index fdced8a05c..67cfba13ff 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -6,8 +6,16 @@ namespace FlexFlow { std::unordered_set get_edges(DiGraphView const &); -std::unordered_set get_initial_nodes(DiGraphView const &); -std::unordered_set get_terminal_nodes(DiGraphView const &); + +/** + * @brief Returns the set of nodes in the graph with no incoming edges. + */ +std::unordered_set get_initial_nodes(DiGraphView const &graph); + +/** + * @brief Returns the set of nodes in the graph with no outgoing edges. + */ +std::unordered_set get_terminal_nodes(DiGraphView const &graph); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml index 9c1ed68730..8e30b75009 100644 --- a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml @@ -6,6 +6,13 @@ features = [ "fmt", ] +docstring = """\ +/** + * @brief An ExtendedParallelReduction is a unordered collection of + * `MultiDiEdge`s such that they share a common source and destination node. + */ +""" + includes = [ "utils/graph/multidigraph/multidiedge.dtg.h", "" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml index f1cf0ccde3..b58a0d5068 100644 --- a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -1,5 +1,13 @@ namespace = "FlexFlow" name = "ExtendedSeriesReduction" + +docstring = """\ +/** + * @brief An ExtendedParallelReduction is a unordered collection of + * `MultiDiEdge`s such that they share a common source and destination node. + */ +""" + features = [ "eq", "hash", diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 7a3a7a021c..598548bec1 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -17,8 +17,6 @@ std::optional /** * @brief Finds all ExtendedParallelReduction for a given MultiDiGraph - * @details An ExtendedParallelReduction is a unordered collection of - * `MultiDiEdge`s such that they share a common source and destination node. */ std::unordered_set find_all_extended_parallel_reductions(MultiDiGraphView const &); From 070e961f58ce533db431499979dab2ba92355501 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Tue, 21 Jan 2025 16:46:24 -0800 Subject: [PATCH 13/16] fix --- .../series_parallel/extended_parallel_reduction.struct.toml | 6 ++---- .../series_parallel/extended_series_reduction.struct.toml | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml index 8e30b75009..ca43a987e2 100644 --- a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml @@ -7,10 +7,8 @@ features = [ ] docstring = """\ -/** - * @brief An ExtendedParallelReduction is a unordered collection of - * `MultiDiEdge`s such that they share a common source and destination node. - */ +@brief An ExtendedParallelReduction is a unordered collection of +`MultiDiEdge`s such that they share a common source and destination node. """ includes = [ diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml index b58a0d5068..98b5b4c40f 100644 --- a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -2,10 +2,8 @@ namespace = "FlexFlow" name = "ExtendedSeriesReduction" docstring = """\ -/** - * @brief An ExtendedParallelReduction is a unordered collection of - * `MultiDiEdge`s such that they share a common source and destination node. - */ +@brief An ExtendedParallelReduction is a unordered collection of +`MultiDiEdge`s such that they share a common source and destination node. """ features = [ From 2afc768c78c1456fc3d019f53cb75e36521683cc Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Tue, 21 Jan 2025 19:23:13 -0800 Subject: [PATCH 14/16] fixes --- lib/utils/include/utils/commutative_pair.h | 2 +- lib/utils/include/utils/graph/algorithms.h | 1 - .../extended_series_reduction.struct.toml | 8 +- .../get_series_parallel_decomposition.h | 7 + .../graph/series_parallel/series_reduction.h | 7 - .../algorithms/make_undirected_edge.h | 12 ++ .../instances/hashmap_undirected_graph.cc | 7 +- .../get_series_parallel_decomposition.cc | 125 +++++++++++++++--- .../series_parallel/parallel_reduction.cc | 6 +- .../algorithms/make_undirected_edge.cc | 10 ++ lib/utils/src/utils/graph/views/views.cc | 3 +- lib/utils/test/src/utils/commutative_pair.cc | 6 +- .../graph/digraph/algorithms/algorithms.cc | 6 +- .../utils/graph/digraph/algorithms/digraph.cc | 22 +-- .../graph/series_parallel/series_reduction.cc | 22 +-- .../algorithms/get_connected_components.cc | 23 ++-- .../src/utils/graph/undirected/undirected.cc | 98 ++++++++++---- lib/utils/test/src/utils/graph/views/views.cc | 33 ++--- 18 files changed, 286 insertions(+), 112 deletions(-) create mode 100644 lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h create mode 100644 lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc diff --git a/lib/utils/include/utils/commutative_pair.h b/lib/utils/include/utils/commutative_pair.h index 12cc16f90e..1a4009fedb 100644 --- a/lib/utils/include/utils/commutative_pair.h +++ b/lib/utils/include/utils/commutative_pair.h @@ -13,7 +13,7 @@ template struct commutative_pair { public: commutative_pair() = delete; - commutative_pair(T const &x, T const &y) : first(x), second(y) {} + explicit commutative_pair(T const &x, T const &y) : first(x), second(y) {} bool operator==(commutative_pair const &other) const { return this->tie() == other.tie() || this->rtie() == other.tie(); diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index c1ebd7b534..ca59f997c7 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -138,7 +138,6 @@ std::unordered_set get_neighbors(DiGraphView const &, Node const &); // std::unordered_set get_neighbors(MultiDiGraphView const &, Node const // &); - // std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); // std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml index 98b5b4c40f..ed999a22df 100644 --- a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -2,8 +2,12 @@ namespace = "FlexFlow" name = "ExtendedSeriesReduction" docstring = """\ -@brief An ExtendedParallelReduction is a unordered collection of -`MultiDiEdge`s such that they share a common source and destination node. +@details An `ExtendedSeriesReduction` is an ordered collection of +`MultiDiEdges` such that: +- The destination node of the nth edge is the same as the source node of the + (n+1)th edge. +- Such a node (intermediate node) has exactly two edges: one incoming (nth + edge) and one outgoing ((n+1)th edge). """ features = [ diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h index 5f492c1aeb..7e9aa5606d 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -10,6 +10,13 @@ namespace FlexFlow { std::optional get_series_parallel_decomposition(DiGraphView const &); +/** + * @brief Unoptimized version of get_series_parallel_decomposition, used for + * reference. + */ +std::optional + get_series_parallel_decomposition_unoptimized(DiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index 3e281066d4..9d11e2bdfb 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -19,13 +19,6 @@ std::optional find_series_reduction(MultiDiGraphView const &); /** * @brief Finds all the ExtendedSeriesReduction structures in a given graph. * - * @details An `ExtendedSeriesReduction` is an ordered collection of - * `MultiDiEdges` such that: - * - The destination node of the nth edge is the same as the source node of the - * (n+1)th edge. - * - Such a node (intermediate node) has exactly two edges: one incoming (nth - * edge) and one outgoing ((n+1)th edge). - * * For example, in the following graph: * * A -> B -> D -> E diff --git a/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h b/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h new file mode 100644 index 0000000000..e6c834f4fb --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_MAKE_UNDIRECTED_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_MAKE_UNDIRECTED_EDGE_H + +#include "utils/graph/undirected/undirected_edge.dtg.h" + +namespace FlexFlow { + +UndirectedEdge make_undirected_edge(Node const &, Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 5d16304701..6713fafe41 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -1,7 +1,9 @@ #include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/commutative_pair.h" #include "utils/containers/contains_key.h" #include "utils/containers/keys.h" #include "utils/exception.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" namespace FlexFlow { @@ -18,6 +20,9 @@ void HashmapUndirectedGraph::add_node_unsafe(Node const &node) { } void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { + for (Node const &neighbor : this->adjacency.at(n)) { + this->adjacency.at(neighbor).erase(n); + } this->adjacency.erase(n); } @@ -49,7 +54,7 @@ std::unordered_set HashmapUndirectedGraph::query_edges( std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { for (auto const &dst : apply_query(query.nodes, src_kv.second)) { - result.insert(UndirectedEdge{{src_kv.first, dst}}); + result.insert(make_undirected_edge(src_kv.first, dst)); } } return result; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index b45e62eae7..a01c94a708 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -43,14 +43,8 @@ std::optional .as_unordered_map(), [](Node const &n) { return SeriesParallelDecomposition{n}; }); - while (true) { - int reductions = 0; - - std::unordered_set parallel_reductions = - find_all_extended_parallel_reductions(ttsp); - - if (!parallel_reductions.empty()) { - for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { + auto handle_parallel_reduction = + [&](ExtendedParallelReduction const ¶llel_reduction) { MultiDiEdge merged = apply_extended_parallel_reduction(ttsp, parallel_reduction); @@ -62,14 +56,12 @@ std::optional ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); - } - reductions++; - } - std::unordered_set series_reductions = - find_all_extended_series_reductions(ttsp); - if (!series_reductions.empty()) { - for (ExtendedSeriesReduction series_reduction : series_reductions) { + return new_tree; + }; + + auto handle_series_reduction = + [&](ExtendedSeriesReduction const &series_reduction) { MultiDiEdge merged = apply_extended_series_reduction(ttsp, series_reduction); @@ -82,11 +74,33 @@ std::optional ttsp_edge_to_sp_tree.erase(e); } ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + return new_tree; + }; + + while (true) { + bool reduction_has_happened = false; + + std::unordered_set parallel_reductions = + find_all_extended_parallel_reductions(ttsp); + + if (!parallel_reductions.empty()) { + for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { + handle_parallel_reduction(parallel_reduction); } - reductions++; + reduction_has_happened = true; } - if (reductions > 0) { + std::unordered_set series_reductions = + find_all_extended_series_reductions(ttsp); + if (!series_reductions.empty()) { + for (ExtendedSeriesReduction series_reduction : series_reductions) { + handle_series_reduction(series_reduction); + } + reduction_has_happened = true; + } + + if (reduction_has_happened) { continue; } @@ -101,4 +115,81 @@ std::optional } } +std::optional + get_series_parallel_decomposition_unoptimized(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); + + InverseLineGraphResult inverse_line_graph_result = ({ + std::optional maybe_line_graph = + get_inverse_line_graph(transitively_reduced); + if (!maybe_line_graph.has_value()) { + return std::nullopt; + } + + maybe_line_graph.value(); + }); + + MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( + inverse_line_graph_result.graph); + std::unordered_map + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return BinarySPDecompositionTree{n}; }); + + while (true) { + assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); + std::optional maybe_parallel_reduction = + find_parallel_reduction(ttsp); + if (maybe_parallel_reduction.has_value()) { + ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); + auto [e1, e2] = parallel_reduction.edges.ordered(); + MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; + ttsp_edge_to_sp_tree.erase(e1); + ttsp_edge_to_sp_tree.erase(e2); + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + continue; + } + + std::optional maybe_series_reduction = + find_series_reduction(ttsp); + if (maybe_series_reduction.has_value()) { + SeriesReduction series_reduction = maybe_series_reduction.value(); + MultiDiEdge e1 = series_reduction.first; + MultiDiEdge e2 = series_reduction.second; + MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; + ttsp_edge_to_sp_tree.erase(e1); + ttsp_edge_to_sp_tree.erase(e2); + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + continue; + } + + if (get_nodes(ttsp).size() != 2) { + return std::nullopt; + } + if (get_edges(ttsp).size() != 1) { + return std::nullopt; + } + + MultiDiEdge e = get_only(get_edges(ttsp)); + if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { + return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); + } + } +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 3aa677a2f7..cf03db0e8a 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,6 @@ #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/commutative_pair.h" +#include "utils/containers/contains_key.h" #include "utils/containers/get_one_of.h" #include "utils/containers/group_by.h" #include "utils/containers/transform.h" @@ -19,7 +21,7 @@ namespace FlexFlow { ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, MultiDiEdge const &e2) { - return ParallelReduction{{e1, e2}}; + return ParallelReduction{commutative_pair{e1, e2}}; } std::optional @@ -28,7 +30,7 @@ std::optional std::unordered_map seen; for (MultiDiEdge const &edge : get_edges(g)) { DirectedEdge diedge = get_directed_edge(g, edge); - if (seen.find(diedge) != seen.end()) { + if (contains_key(seen, diedge)) { return make_parallel_reduction(seen.at(diedge), edge); } seen.emplace(diedge, edge); diff --git a/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc b/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc new file mode 100644 index 0000000000..1c1eb4ae07 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" +#include "utils/commutative_pair.h" + +namespace FlexFlow { + +UndirectedEdge make_undirected_edge(Node const &n1, Node const &n2) { + return UndirectedEdge{commutative_pair{n1, n2}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index e8f0a443c4..74234033b3 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -5,6 +5,7 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/node_query.h" #include "utils/graph/query_set.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.h" namespace FlexFlow { @@ -63,7 +64,7 @@ DiGraphView view_subgraph(DiGraphView const &g, } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return UndirectedEdge{{e.src, e.dst}}; + return make_undirected_edge(e.src, e.dst); } std::unordered_set to_undirected_edges( diff --git a/lib/utils/test/src/utils/commutative_pair.cc b/lib/utils/test/src/utils/commutative_pair.cc index 2b91c8b843..af015e2b8c 100644 --- a/lib/utils/test/src/utils/commutative_pair.cc +++ b/lib/utils/test/src/utils/commutative_pair.cc @@ -7,9 +7,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("commutative_pair") { - commutative_pair x = {1, 2}; - commutative_pair y = {2, 1}; - commutative_pair z = {1, 1}; + commutative_pair x = commutative_pair{1, 2}; + commutative_pair y = commutative_pair{2, 1}; + commutative_pair z = commutative_pair{1, 1}; SUBCASE("max and min") { SUBCASE("max") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc index fd39449c2c..f17f0cb106 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -8,7 +8,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_edges(DiGraph)") { + TEST_CASE("get_edges(DiGraphView)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -51,7 +51,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("get_terminal_nodes(DiGraph)") { + TEST_CASE("get_terminal_nodes(DiGraphView)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -84,7 +84,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("get_initial_nodes(DiGraph)") { + TEST_CASE("get_initial_nodes(DiGraphView)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc index 3a3648eec8..e3471a3031 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -31,23 +31,25 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("query_nodes") { - CHECK(g.query_nodes(node_query_all()) == std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == std::unordered_set{n[0], n[2]}); - std::unordered_set queried_edges = - g.query_edges(directed_edge_query_all()); - std::unordered_set expected = { - e[0], e[1], e[2], e[3], e[4]}; - CHECK(queried_edges == expected); + SUBCASE("query_edges") { + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); - queried_edges = g.query_edges( - DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); - expected = std::unordered_set{e[0]}; - CHECK(queried_edges == expected); + queried_edges = g.query_edges(DirectedEdgeQuery{ + query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } } SUBCASE("remove_node_unsafe") { g.remove_node_unsafe(n[0]); diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index 51606bc9d6..50a3ea0fc9 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -198,8 +198,8 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(1), n.at(2)}, {n.at(2), n.at(3)}, {n.at(2), n.at(3)}, - {n.at(3), n.at(4)}, //* - {n.at(4), n.at(5)}, //* + {n.at(3), n.at(4)}, + {n.at(4), n.at(5)}, {n.at(5), n.at(6)}, {n.at(5), n.at(7)}, }); @@ -261,7 +261,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set result = find_all_extended_series_reductions(g); std::unordered_set correct = { - ExtendedSeriesReduction({e.at(0), e.at(1), e.at(2)})}; + ExtendedSeriesReduction{{e.at(0), e.at(1), e.at(2)}}}; CHECK(result == correct); } @@ -276,8 +276,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set result = find_all_extended_series_reductions(g); std::unordered_set correct = { - ExtendedSeriesReduction({e.at(0), e.at(2)}), - ExtendedSeriesReduction({e.at(1), e.at(3)})}; + ExtendedSeriesReduction{{e.at(0), e.at(2)}}, + ExtendedSeriesReduction{{e.at(1), e.at(3)}}}; CHECK(result == correct); } @@ -299,9 +299,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set result = find_all_extended_series_reductions(g); std::unordered_set correct = { - ExtendedSeriesReduction({e.at(0), e.at(2), e.at(7)}), - ExtendedSeriesReduction({e.at(3), e.at(6)}), - ExtendedSeriesReduction({e.at(5), e.at(9)})}; + ExtendedSeriesReduction{{e.at(0), e.at(2), e.at(7)}}, + ExtendedSeriesReduction{{e.at(3), e.at(6)}}, + ExtendedSeriesReduction{{e.at(5), e.at(9)}}}; CHECK(result == correct); } } @@ -314,7 +314,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector e = add_edges( g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); - ExtendedSeriesReduction reduction({e.at(0), e.at(1), e.at(2)}); + ExtendedSeriesReduction reduction = + ExtendedSeriesReduction{{e.at(0), e.at(1), e.at(2)}}; MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); @@ -359,7 +360,8 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(5), n.at(7)}, }); - ExtendedSeriesReduction reduction({e.at(3), e.at(4), e.at(5)}); + ExtendedSeriesReduction reduction = + ExtendedSeriesReduction{{e.at(3), e.at(4), e.at(5)}}; MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index e6b0575ff5..20b3eaa74a 100644 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -2,6 +2,7 @@ #include "utils/fmt/unordered_set.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" #include "utils/graph/undirected/undirected_graph.h" #include @@ -29,10 +30,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 4); add_edges(g, { - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(2), n.at(3)}}, - UndirectedEdge{{n.at(3), n.at(0)}}, + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(3)), + make_undirected_edge(n.at(3), n.at(0)), }); std::unordered_set> correct = { @@ -47,8 +48,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("2 components") { std::vector n = add_nodes(g, 4); add_edges(g, - {UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(2), n.at(1)}}}); + { + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(2), n.at(1)), + }); std::unordered_set> correct = { {n.at(0), n.at(1), n.at(2)}, @@ -64,10 +67,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 6); add_edges(g, { - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(0), n.at(2)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(3), n.at(4)}}, + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(0), n.at(2)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(3), n.at(4)), }); std::unordered_set> correct = { diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc index 6454379118..10575b069d 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -1,43 +1,85 @@ -#include "test/utils/rapidcheck.h" -#include "test/utils/rapidcheck/visitable.h" #include "utils/commutative_pair.h" #include "utils/containers/repeat.h" #include "utils/graph/instances/hashmap_undirected_graph.h" #include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.h" #include "utils/graph/undirected/undirected_graph.h" +#include using namespace FlexFlow; -using namespace rc; - TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE( "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct( - gen::construct>(gen::elementOf(n), - gen::elementOf(n)))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); - }); + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector e = {make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(0), n.at(2)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(4)), + make_undirected_edge(n.at(1), n.at(3))}; + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{ + n.at(0), n.at(1), n.at(2), n.at(3), n.at(4)}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n.at(0), n.at(2)}}}) == + std::unordered_set{n.at(0), n.at(2)}); + } + + SUBCASE("query_edges") { + + std::unordered_set queried_edges = + g.query_edges(undirected_edge_query_all()); + std::unordered_set expected = { + e.at(0), e.at(1), e.at(2), e.at(3), e.at(4)}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + UndirectedEdgeQuery{query_set{{n.at(0), n.at(1)}}}); + expected = std::unordered_set{e.at(0)}; + CHECK(queried_edges == expected); + } + + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n.at(0)); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n.at(1), n.at(2), n.at(3), n.at(4)}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(2), e.at(3), e.at(4)}); + + g.remove_node_unsafe(n.at(1)); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n.at(2), n.at(3), n.at(4)}); + + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(3)}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e.at(0)); + + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{ + e.at(1), e.at(2), e.at(3), e.at(4)}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{ + n.at(0), n.at(1), n.at(2), n.at(3), n.at(4)}); + + g.remove_edge(e.at(1)); + g.remove_edge(e.at(3)); + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(2), e.at(4)}); + } } } diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index 8a6a44d1cc..183bab4cc2 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -7,6 +7,7 @@ #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/instances/hashmap_undirected_graph.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" #include "utils/graph/undirected/undirected_graph.h" #include "utils/graph/undirected/undirected_graph_view.h" #include @@ -18,12 +19,12 @@ TEST_SUITE(FF_TEST_SUITE) { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 5); add_edges(g, - {UndirectedEdge{{n.at(0), n.at(3)}}, - UndirectedEdge{{n.at(1), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(1), n.at(3)}}, - UndirectedEdge{{n.at(2), n.at(3)}}, - UndirectedEdge{{n.at(2), n.at(4)}}}); + {make_undirected_edge(n.at(0), n.at(3)), + make_undirected_edge(n.at(1), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(1), n.at(3)), + make_undirected_edge(n.at(2), n.at(3)), + make_undirected_edge(n.at(2), n.at(4))}); std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; UndirectedGraphView view = view_subgraph(g, sub_nodes); @@ -37,9 +38,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - UndirectedEdge{{n.at(0), n.at(3)}}, - UndirectedEdge{{n.at(1), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(3)}}, + make_undirected_edge(n.at(0), n.at(3)), + make_undirected_edge(n.at(1), n.at(1)), + make_undirected_edge(n.at(1), n.at(3)), }; std::unordered_set result = get_edges(view); @@ -106,9 +107,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(2), n.at(0)}}}; + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(0))}; std::unordered_set result = get_edges(view); @@ -120,10 +121,10 @@ TEST_SUITE(FF_TEST_SUITE) { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, - {UndirectedEdge{{n.at(0), n.at(0)}}, - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(2), n.at(0)}}}); + {make_undirected_edge(n.at(0), n.at(0)), + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(0))}); DiGraphView view = as_digraph(g); From 0af8b3b18324ad3bc3da0dc9fbd200ca251954f3 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 9 Feb 2025 15:35:09 -0800 Subject: [PATCH 15/16] CI fixes and PR comment fixes --- .../simulate_task_graph_execution.cc | 2 +- lib/pcg/src/pcg/machine_view.cc | 9 +- .../parallel_computation_graph.cc | 2 +- lib/utils/include/utils/containers/zip.h | 36 +------- lib/utils/include/utils/containers/zip3.h | 24 +++++ .../include/utils/containers/zip3_strict.h | 23 +++++ .../include/utils/containers/zip_strict.h | 22 +++++ lib/utils/include/utils/containers/zip_with.h | 20 ++++ .../utils/containers/zip_with_strict.h | 22 +++++ lib/utils/include/utils/fmt/tuple.h | 40 ++++++++ lib/utils/include/utils/tuple.h | 16 +--- lib/utils/include/utils/tuple/visit.h | 21 +++++ lib/utils/src/utils/containers/zip.cc | 12 +++ lib/utils/src/utils/containers/zip3.cc | 15 +++ lib/utils/src/utils/containers/zip3_strict.cc | 15 +++ lib/utils/src/utils/containers/zip_strict.cc | 12 +++ lib/utils/src/utils/containers/zip_with.cc | 14 +++ .../src/utils/containers/zip_with_strict.cc | 14 +++ lib/utils/src/utils/fmt/tuple.cc | 9 ++ .../get_series_parallel_decomposition.cc | 8 +- lib/utils/src/utils/tuple/visit.cc | 15 +++ .../include/test/utils/doctest/fmt/tuple.h | 18 ++++ .../common/src/test/utils/doctest/fmt/pair.cc | 13 +++ .../src/test/utils/doctest/fmt/tuple.cc | 15 +++ lib/utils/test/src/utils/containers/zip.cc | 81 ++++++++++++++++ lib/utils/test/src/utils/containers/zip3.cc | 92 +++++++++++++++++++ .../test/src/utils/containers/zip3_strict.cc | 80 ++++++++++++++++ .../test/src/utils/containers/zip_strict.cc | 28 ++++++ .../test/src/utils/containers/zip_with.cc | 77 ++++++++++++++++ .../src/utils/containers/zip_with_strict.cc | 53 +++++++++++ lib/utils/test/src/utils/fmt/tuple.cc | 70 ++++++++++++++ .../utils/graph/digraph/algorithms/digraph.cc | 2 +- .../algorithms/get_incoming_edges.cc | 2 +- .../algorithms/get_outgoing_edges.cc | 2 +- .../graph/series_parallel/series_reduction.cc | 10 +- .../src/utils/graph/undirected/undirected.cc | 2 +- lib/utils/test/src/utils/tuple/visit.cc | 40 ++++++++ 37 files changed, 869 insertions(+), 67 deletions(-) create mode 100644 lib/utils/include/utils/containers/zip3.h create mode 100644 lib/utils/include/utils/containers/zip3_strict.h create mode 100644 lib/utils/include/utils/containers/zip_strict.h create mode 100644 lib/utils/include/utils/containers/zip_with.h create mode 100644 lib/utils/include/utils/containers/zip_with_strict.h create mode 100644 lib/utils/include/utils/fmt/tuple.h create mode 100644 lib/utils/include/utils/tuple/visit.h create mode 100644 lib/utils/src/utils/containers/zip3.cc create mode 100644 lib/utils/src/utils/containers/zip3_strict.cc create mode 100644 lib/utils/src/utils/containers/zip_strict.cc create mode 100644 lib/utils/src/utils/containers/zip_with.cc create mode 100644 lib/utils/src/utils/containers/zip_with_strict.cc create mode 100644 lib/utils/src/utils/fmt/tuple.cc create mode 100644 lib/utils/src/utils/tuple/visit.cc create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc create mode 100644 lib/utils/test/src/utils/containers/zip.cc create mode 100644 lib/utils/test/src/utils/containers/zip3.cc create mode 100644 lib/utils/test/src/utils/containers/zip3_strict.cc create mode 100644 lib/utils/test/src/utils/containers/zip_strict.cc create mode 100644 lib/utils/test/src/utils/containers/zip_with.cc create mode 100644 lib/utils/test/src/utils/containers/zip_with_strict.cc create mode 100644 lib/utils/test/src/utils/fmt/tuple.cc create mode 100644 lib/utils/test/src/utils/tuple/visit.cc diff --git a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc index 974a70ddf5..63509f3534 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc @@ -30,7 +30,7 @@ TaskGraphExecutionTrace simulate_task_graph_execution( } TaskGraphExecutionState execution_state = - TaskGraphExecutionState{/*ready_tasks=*/set_of(get_sources(task_graph)), + TaskGraphExecutionState{/*ready_tasks=*/set_of(get_initial_nodes(task_graph)), /*in_progress_tasks=*/{}, /*finished_tasks=*/{}, /*current_time=*/0.0}; diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index fe319dc63c..c958fbb9b8 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -14,7 +14,8 @@ #include "utils/containers/scanl.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" -#include "utils/containers/zip.h" +#include "utils/containers/zip_with_strict.h" +#include "utils/containers/zip3_strict.h" #include "utils/exception.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/nonnegative_int/num_elements.h" @@ -53,8 +54,8 @@ MachineView machine_view_from_strides_and_machine_spec_dimensions( strides)); } std::vector dimensions = - transform(zip(strides, dims), [&](auto const &p) { - return MachineViewDimension{p.first, p.second}; + zip_with_strict(strides, dims, [](stride_t s, MachineSpecificationDimension d) { + return MachineViewDimension{s, d}; }); return MachineView{start, dimensions}; } @@ -109,7 +110,7 @@ std::optional get_machine_space_coordinate( nonnegative_int index = start_idx; for (auto [coeff, coord_point, stride] : - zip(coeffs, coord_points, strides)) { + zip3(coeffs, coord_points, strides)) { index += coeff * coord_point * stride; } return index; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 4cc0500fa2..376cb0c19a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -117,7 +117,7 @@ std::unordered_set std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { - std::unordered_set raw_sources = get_sources(pcg.raw_graph); + std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); return transform(raw_sources, [](Node const &n) { return parallel_layer_guid_t{n}; }); } diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 65722e049f..7bfca5e8b1 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,24 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H -#include "fmt/format.h" -#include "utils/exception.h" -#include #include #include +#include namespace FlexFlow { template -std::vector> - zip(std::vector const &l, std::vector const &r, bool strict = false) { - if (strict && l.size() != r.size()) { - throw mk_runtime_error(fmt::format("When strict = true, vector sizes must " - "match. Got vectors of length {} and {}", - l.size(), - r.size())); - } - +std::vector> zip(std::vector const &l, + std::vector const &r) { std::vector> result; for (int i = 0; i < std::min(l.size(), r.size()); i++) { result.push_back(std::make_pair(l.at(i), r.at(i))); @@ -26,27 +17,6 @@ std::vector> return result; } -template -std::vector> zip(std::vector const &a, - std::vector const &b, - std::vector const &c, - bool strict = false) { - if (strict && (a.size() != b.size() || b.size() != c.size())) { - throw std::runtime_error( - fmt::format("When strict = true, vectors sizes must match. Got vectors " - "of length {}, {} and {}", - a.size(), - b.size(), - c.size())); - } - - std::vector> result; - for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { - result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/zip3.h b/lib/utils/include/utils/containers/zip3.h new file mode 100644 index 0000000000..18fcb28d03 --- /dev/null +++ b/lib/utils/include/utils/containers/zip3.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H + +#include +#include +#include +#include + +namespace FlexFlow { + +template +std::vector> zip3(std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector> result; + for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { + result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip3_strict.h b/lib/utils/include/utils/containers/zip3_strict.h new file mode 100644 index 0000000000..92620e29ac --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_strict.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H + +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include "utils/containers/zip3.h" + +namespace FlexFlow { + +template +std::vector> zip3_strict(std::vector const &as, + std::vector const &bs, + std::vector const &cs) { + if (!(as.size() == bs.size() && bs.size() == cs.size())) { + throw mk_runtime_error(fmt::format("zip3_strict requires as, bs, and cs to have the same length, but received as={} (length {}), bs={} (length {}), and cs={} (length {})", as, as.size(), bs, bs.size(), cs, cs.size())); + } + + return zip3(as, bs, cs); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_strict.h b/lib/utils/include/utils/containers/zip_strict.h new file mode 100644 index 0000000000..42f32e64d2 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_strict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H + +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include "utils/containers/zip.h" + +namespace FlexFlow { + +template +std::vector> zip_strict(std::vector const &lhs, + std::vector const &rhs) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error(fmt::format("zip_strict requires lhs and rhs to have the same length, but received lhs={} (length {}), rhs={} (length {})", lhs, lhs.size(), rhs, rhs.size())); + } + + return zip(lhs, rhs); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with.h b/lib/utils/include/utils/containers/zip_with.h new file mode 100644 index 0000000000..fb10f2a89e --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H + +#include + +namespace FlexFlow { + +template > +std::vector zip_with(std::vector const &l, std::vector const &r, F &&f) { + std::vector result; + for (int i = 0; i < l.size() && i < r.size(); i++) { + result.push_back(f(l.at(i), r.at(i))); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with_strict.h b/lib/utils/include/utils/containers/zip_with_strict.h new file mode 100644 index 0000000000..357d0e94c6 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with_strict.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H + +#include "utils/exception.h" +#include +#include "utils/containers/zip_with.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +template > +std::vector zip_with_strict(std::vector const &lhs, std::vector const &rhs, F &&f) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error(fmt::format("zip_with_strict requires inputs to have the same length, but received lhs = {} (length {}) and rhs = {} (length {})", lhs, lhs.size(), rhs, rhs.size())); + } + + return zip_with(lhs, rhs, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/tuple.h b/lib/utils/include/utils/fmt/tuple.h new file mode 100644 index 0000000000..80054f8e5e --- /dev/null +++ b/lib/utils/include/utils/fmt/tuple.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H + +#include "utils/check_fmtable.h" +#include +#include +#include +#include +#include "utils/join_strings.h" +#include "utils/tuple/visit.h" + +namespace fmt { + +template +struct formatter, Char> + : formatter { + + template + auto format(std::tuple const &t, FormatContext &ctx) const + -> decltype(ctx.out()) { + + std::vector stringified_elements; + ::FlexFlow::visit_tuple(t, [&](auto const &element) -> void { stringified_elements.push_back(fmt::to_string(element)); }); + + return formatter::format("{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::tuple const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 0296e365a3..543d9fb6a8 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -8,6 +8,7 @@ #include #include #include +#include "utils/tuple/visit.h" // Adapted from // https://github.com/bitwizeshift/BackportCpp/blob/4f33a7f9b219f169e60d8ed2fd5731a3a23288e4/include/bpstd/tuple.hpp @@ -32,21 +33,6 @@ struct index_of : index_of_impl {}; } // namespace TupleUtils -template -void visit_tuple_impl(Visitor &v, std::tuple const &tup) { - v(Idx, std::get(tup)); - if (Idx >= std::tuple_size::value) { - return; - } else { - visit_tuple_impl<(Idx + 1)>(v, tup); - } -} - -template -void visit_tuple(Visitor &v, std::tuple const &tup) { - visit_tuple_impl<0>(v, tup); -} - struct tuple_get_visitor { tuple_get_visitor() = delete; tuple_get_visitor(int requested_idx, std::any &result) diff --git a/lib/utils/include/utils/tuple/visit.h b/lib/utils/include/utils/tuple/visit.h new file mode 100644 index 0000000000..741eac1e88 --- /dev/null +++ b/lib/utils/include/utils/tuple/visit.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H + +#include +#include + +namespace FlexFlow { + +template +void visit_tuple_impl(Tuple const &tuple, Visitor &&v, std::index_sequence) { + (v(std::get(tuple)), ...); +} + +template +void visit_tuple(std::tuple const &tuple, Visitor &&v) { + visit_tuple_impl(tuple, v, std::index_sequence_for{}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/zip.cc b/lib/utils/src/utils/containers/zip.cc index a02c9c8a35..7569640802 100644 --- a/lib/utils/src/utils/containers/zip.cc +++ b/lib/utils/src/utils/containers/zip.cc @@ -1 +1,13 @@ #include "utils/containers/zip.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L1 = value_type<0>; +using R1 = value_type<1>; + +template + std::vector> zip(std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3.cc b/lib/utils/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..6f1d9e46ae --- /dev/null +++ b/lib/utils/src/utils/containers/zip3.cc @@ -0,0 +1,15 @@ +#include "utils/containers/zip3.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A1 = value_type<0>; +using B1 = value_type<1>; +using C1 = value_type<2>; + +template + std::vector> zip3(std::vector const &, + std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3_strict.cc b/lib/utils/src/utils/containers/zip3_strict.cc new file mode 100644 index 0000000000..b3e3047547 --- /dev/null +++ b/lib/utils/src/utils/containers/zip3_strict.cc @@ -0,0 +1,15 @@ +#include "utils/containers/zip3_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A1 = value_type<0>; +using B1 = value_type<1>; +using C1 = value_type<2>; + +template + std::vector> zip3_strict(std::vector const &, + std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_strict.cc b/lib/utils/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..bbc31c708e --- /dev/null +++ b/lib/utils/src/utils/containers/zip_strict.cc @@ -0,0 +1,12 @@ +#include "utils/containers/zip_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template + std::vector> zip_strict(std::vector const &, std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with.cc b/lib/utils/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..499d6ac8b2 --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + std::vector zip_with(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with_strict.cc b/lib/utils/src/utils/containers/zip_with_strict.cc new file mode 100644 index 0000000000..349ee9a37c --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with_strict.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template + std::vector zip_with_strict(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/tuple.cc b/lib/utils/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..3ba2c53d89 --- /dev/null +++ b/lib/utils/src/utils/fmt/tuple.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/tuple.h" + +namespace FlexFlow { + + +template + std::ostream &operator<<(std::ostream &s, std::tuple const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index a01c94a708..33bdd74787 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -43,7 +43,7 @@ std::optional .as_unordered_map(), [](Node const &n) { return SeriesParallelDecomposition{n}; }); - auto handle_parallel_reduction = + auto perform_extended_parallel_reduction = [&](ExtendedParallelReduction const ¶llel_reduction) { MultiDiEdge merged = apply_extended_parallel_reduction(ttsp, parallel_reduction); @@ -60,7 +60,7 @@ std::optional return new_tree; }; - auto handle_series_reduction = + auto perform_extended_series_reduction = [&](ExtendedSeriesReduction const &series_reduction) { MultiDiEdge merged = apply_extended_series_reduction(ttsp, series_reduction); @@ -86,7 +86,7 @@ std::optional if (!parallel_reductions.empty()) { for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { - handle_parallel_reduction(parallel_reduction); + perform_extended_parallel_reduction(parallel_reduction); } reduction_has_happened = true; } @@ -95,7 +95,7 @@ std::optional find_all_extended_series_reductions(ttsp); if (!series_reductions.empty()) { for (ExtendedSeriesReduction series_reduction : series_reductions) { - handle_series_reduction(series_reduction); + perform_extended_series_reduction(series_reduction); } reduction_has_happened = true; } diff --git a/lib/utils/src/utils/tuple/visit.cc b/lib/utils/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..f0d218b207 --- /dev/null +++ b/lib/utils/src/utils/tuple/visit.cc @@ -0,0 +1,15 @@ +#include "utils/tuple/visit.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; +using Visitor = std::function const &)>; + +template + void visit_tuple(std::tuple const &, Visitor &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h new file mode 100644 index 0000000000..6a11e7abcb --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H + +#include "utils/fmt/tuple.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::tuple const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc index 106fb1c900..4ef1e451ac 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -1 +1,14 @@ #include "test/utils/doctest/fmt/pair.h" +#include "utils/archetypes/value_type.h" + +using ::FlexFlow::value_type; + +using L = value_type<0>; +using R = value_type<1>; + +namespace doctest { + +template + struct StringMaker>; + +} // namespace doctest diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc new file mode 100644 index 0000000000..f832d41e84 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc @@ -0,0 +1,15 @@ +#include "test/utils/doctest/fmt/tuple.h" +#include "utils/archetypes/value_type.h" + +using ::FlexFlow::value_type; + +using A = value_type<0>; +using B = value_type<1>; +using C = value_type<2>; + +namespace doctest { + +template + struct StringMaker>; + +} // namespace doctest diff --git a/lib/utils/test/src/utils/containers/zip.cc b/lib/utils/test/src/utils/containers/zip.cc new file mode 100644 index 0000000000..c305e53f69 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip.cc @@ -0,0 +1,81 @@ +#include +#include "utils/containers/zip.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip(std::vector, std::vector)") { + SUBCASE("L and R types are the same") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}, {2, 8}}; + + CHECK(result == correct); + } + + SUBCASE("L and R types are different") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("left is longer than right") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}}; + + CHECK(result == correct); + } + + SUBCASE("right is longer than left") { + std::vector lhs = {2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}}; + + CHECK(result == correct); + } + + SUBCASE("left is empty") { + std::vector lhs = {}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("right is empty") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("both are empty") { + std::vector lhs = {}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip3.cc b/lib/utils/test/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..f1613105ee --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip3.cc @@ -0,0 +1,92 @@ +#include +#include "utils/containers/zip3.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/tuple.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip3(std::vector, std::vector, std::vector)") { + SUBCASE("types are same") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("types are different") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {"a", "d", "d"}; + std::vector> input_c = {{1, 2}, {}, {3, 1}}; + + std::vector>> result = zip3(input_a, input_b, input_c); + std::vector>> correct = { + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, + }; + + CHECK(result == correct); + } + + SUBCASE("A list is shortest") { + std::vector input_a = {2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("B list is shortest") { + std::vector input_a = {2, 1, 2, 4}; + std::vector input_b = {5, 4}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}}; + + CHECK(result == correct); + } + + SUBCASE("C list is shortest") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 3}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 3}}; + + CHECK(result == correct); + } + + SUBCASE("one list is empty") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("all lists are empty") { + std::vector input_a = {}; + std::vector input_b = {}; + std::vector input_c = {}; + + std::vector> result = zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip3_strict.cc b/lib/utils/test/src/utils/containers/zip3_strict.cc new file mode 100644 index 0000000000..abfc4576d5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip3_strict.cc @@ -0,0 +1,80 @@ +#include +#include "utils/containers/zip3_strict.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/tuple.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip3_strict(std::vector, std::vector, std::vector)") { + SUBCASE("types are same") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = zip3_strict(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("types are different") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {"a", "d", "d"}; + std::vector> input_c = {{1, 2}, {}, {3, 1}}; + + std::vector>> result = zip3_strict(input_a, input_b, input_c); + std::vector>> correct = { + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, + }; + + CHECK(result == correct); + } + + SUBCASE("A list is shortest") { + std::vector input_a = {2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("B list is shortest") { + std::vector input_a = {2, 1, 2, 4}; + std::vector input_b = {5, 4}; + std::vector input_c = {3, 4, 3}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("C list is shortest") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 3}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("one list is empty") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("all lists are empty") { + std::vector input_a = {}; + std::vector input_b = {}; + std::vector input_c = {}; + + std::vector> result = zip3_strict(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_strict.cc b/lib/utils/test/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..0b7d35e0f4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_strict.cc @@ -0,0 +1,28 @@ +#include +#include "utils/containers/zip_strict.h" +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_strict(std::vector, std::vector)") { + SUBCASE("input lengths are the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip_strict(lhs, rhs); + std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("input lengths are not the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4}; + + CHECK_THROWS(zip_strict(lhs, rhs)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_with.cc b/lib/utils/test/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..fe306bbe9e --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_with.cc @@ -0,0 +1,77 @@ +#include "utils/containers/zip_with.h" +#include +#include +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_with(std::vector, std::vector, F)") { + SUBCASE("result types and input types are all different") { + std::vector v1 = {1, 3, 4, 3}; + std::vector v2 = {"aa", "cc", "bb", "dd"}; + + std::vector> result = zip_with(v1, v2, [](int x1, std::string const &x2) { return std::make_pair(x1, x2); }); + std::vector> correct = { + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, + }; + + CHECK(result == correct); + } + + + SUBCASE("input lengths don't match") { + auto add = [](int x1, int x2) { return x1 + x2; }; + + std::vector shorter = {1, 2}; + std::vector longer = {1, 3, 5, 7}; + + SUBCASE("first input is shorter") { + std::vector result = zip_with(shorter, longer, add); + std::vector correct = {1+1, 2+3}; + + CHECK(result == correct); + } + + SUBCASE("second input is shorter") { + std::vector result = zip_with(longer, shorter, add); + std::vector correct = {1+1, 2+3}; + + CHECK(result == correct); + } + } + + SUBCASE("properly handles empty inputs") { + std::vector nonempty = {1, 2}; + std::vector empty = {}; + + auto throw_err = [](int x1, int x2) -> int { throw std::runtime_error("error"); }; + + SUBCASE("first input is empty") { + std::vector result = zip_with(empty, nonempty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("second input is empty") { + std::vector result = zip_with(nonempty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both inputs are empty") { + std::vector result = zip_with(empty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_with_strict.cc b/lib/utils/test/src/utils/containers/zip_with_strict.cc new file mode 100644 index 0000000000..e86a1f114a --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_with_strict.cc @@ -0,0 +1,53 @@ +#include "utils/containers/zip_with_strict.h" +#include +#include +#include +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_with_strict(std::vector, std::vector, F)") { + SUBCASE("result types and input types are all different") { + std::vector v1 = {1, 3, 4, 3}; + std::vector v2 = {"aa", "cc", "bb", "dd"}; + + std::vector> result = zip_with(v1, v2, [](int x1, std::string const &x2) { return std::make_pair(x1, x2); }); + std::vector> correct = { + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, + }; + + CHECK(result == correct); + } + + SUBCASE("input lengths don't match") { + auto add = [](int x1, int x2) { return x1 + x2; }; + + std::vector shorter = {1, 2}; + std::vector longer = {1, 3, 5, 7}; + + SUBCASE("first input is shorter") { + CHECK_THROWS(zip_with_strict(shorter, longer, add)); + } + + SUBCASE("second input is shorter") { + CHECK_THROWS(zip_with_strict(longer, shorter, add)); + } + } + + SUBCASE("properly handles empty inputs") { + std::vector empty = {}; + + auto throw_err = [](int x1, int x2) -> int { throw std::runtime_error("error"); }; + + std::vector result = zip_with(empty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/tuple.cc b/lib/utils/test/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..1ee7d63a1f --- /dev/null +++ b/lib/utils/test/src/utils/fmt/tuple.cc @@ -0,0 +1,70 @@ +#include "utils/fmt/tuple.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::tuple)") { + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = fmt::to_string(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } + + TEST_CASE("operator<<(ostream &, std::tuple)") { + auto through_ostringstream = [](auto const &t) { + std::ostringstream oss; + oss << t; + return oss.str(); + }; + + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = through_ostringstream(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc index e3471a3031..e820ab8808 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -20,7 +20,7 @@ TEST_SUITE(FF_TEST_SUITE) { */ DiGraph g = DiGraph::create(); - std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector n = repeat(5_n, [&] { return g.add_node(); }); std::vector e = {DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, DirectedEdge{n[1], n[2]}, diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index b15b8a9d7d..ef5cf3c502 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -13,7 +13,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_incoming_edges") { MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 3); + std::vector n = add_nodes(g, 3_n); std::vector edges = add_edges(g, {{n.at(0), n.at(0)}, diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 69b38090d3..20011cb133 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -13,7 +13,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_outgoing_edges") { MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 3); + std::vector n = add_nodes(g, 3_n); std::vector edges = add_edges(g, { diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index 8252622215..78537a4342 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -250,7 +250,7 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiGraph g = MultiDiGraph::create(); SUBCASE("linear graph") { - std::vector n = add_nodes(g, 4); + std::vector n = add_nodes(g, 4_n); std::vector e = add_edges(g, { {n.at(0), n.at(1)}, @@ -266,7 +266,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("2 linear strands with a common terminal node") { - std::vector n = add_nodes(g, 4); + std::vector n = add_nodes(g, 4_n); std::vector e = add_edges(g, {{n.at(0), n.at(1)}, {n.at(0), n.at(2)}, @@ -282,7 +282,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("graph with multiple separate serial strands") { - std::vector n = add_nodes(g, 9); + std::vector n = add_nodes(g, 9_n); std::vector e = add_edges(g, {{n.at(0), n.at(1)}, {n.at(0), n.at(2)}, @@ -310,7 +310,7 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiGraph g = MultiDiGraph::create(); SUBCASE("base case") { - std::vector n = add_nodes(g, 4); + std::vector n = add_nodes(g, 4_n); std::vector e = add_edges( g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); @@ -347,7 +347,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("in larger graph") { - std::vector n = add_nodes(g, 8); + std::vector n = add_nodes(g, 8_n); std::vector e = add_edges(g, { {n.at(0), n.at(2)}, diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc index 10575b069d..77b74fdd20 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -14,7 +14,7 @@ TEST_SUITE(FF_TEST_SUITE) { "UndirectedGraph implementations", T, HashmapUndirectedGraph) { UndirectedGraph g = UndirectedGraph::create(); - std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector n = repeat(5_n, [&] { return g.add_node(); }); std::vector e = {make_undirected_edge(n.at(0), n.at(1)), make_undirected_edge(n.at(0), n.at(2)), make_undirected_edge(n.at(1), n.at(2)), diff --git a/lib/utils/test/src/utils/tuple/visit.cc b/lib/utils/test/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..7024f12e65 --- /dev/null +++ b/lib/utils/test/src/utils/tuple/visit.cc @@ -0,0 +1,40 @@ +#include "utils/tuple/visit.h" +#include +#include "utils/overload.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("visit(std::tuple, Visitor)") { + std::ostringstream oss; + auto visitor = overload { + [&](int const &i) -> void { oss << "int(" << i << "), "; }, + [&](bool const &b) -> void { oss << "bool(" << b << "), "; }, + [&](std::string const &s) -> void { oss << "string(" << s << "), "; }, + }; + + SUBCASE("repeated types") { + std::tuple input = {3, "hello", false, "world"}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = "int(3), string(hello), bool(0), string(world), "; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = ""; + + CHECK(result == correct); + } + } +} From 97227c648e8d3610b352497a5ac4f6000533eedb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 9 Feb 2025 15:45:37 -0800 Subject: [PATCH 16/16] Format --- .../simulate_task_graph_execution.cc | 10 +++--- lib/pcg/src/pcg/machine_view.cc | 6 ++-- lib/utils/include/utils/containers/zip.h | 2 +- lib/utils/include/utils/containers/zip3.h | 6 ++-- .../include/utils/containers/zip3_strict.h | 16 ++++++--- .../include/utils/containers/zip_strict.h | 10 ++++-- lib/utils/include/utils/containers/zip_with.h | 8 +++-- .../utils/containers/zip_with_strict.h | 21 ++++++++--- lib/utils/include/utils/fmt/tuple.h | 18 +++++----- lib/utils/include/utils/tuple.h | 2 +- lib/utils/include/utils/tuple/visit.h | 6 ++-- lib/utils/src/utils/containers/zip.cc | 5 ++- lib/utils/src/utils/containers/zip3.cc | 7 ++-- lib/utils/src/utils/containers/zip3_strict.cc | 6 ++-- lib/utils/src/utils/containers/zip_strict.cc | 4 +-- lib/utils/src/utils/containers/zip_with.cc | 4 +-- .../src/utils/containers/zip_with_strict.cc | 4 +-- lib/utils/src/utils/fmt/tuple.cc | 5 ++- lib/utils/src/utils/tuple/visit.cc | 3 +- .../include/test/utils/doctest/fmt/tuple.h | 2 +- .../common/src/test/utils/doctest/fmt/pair.cc | 3 +- .../src/test/utils/doctest/fmt/tuple.cc | 3 +- lib/utils/test/src/utils/containers/zip.cc | 9 ++--- lib/utils/test/src/utils/containers/zip3.cc | 36 +++++++++++-------- .../test/src/utils/containers/zip3_strict.cc | 24 +++++++------ .../test/src/utils/containers/zip_strict.cc | 9 ++--- .../test/src/utils/containers/zip_with.cc | 26 ++++++++------ .../src/utils/containers/zip_with_strict.cc | 21 ++++++----- lib/utils/test/src/utils/tuple/visit.cc | 17 ++++----- 29 files changed, 171 insertions(+), 122 deletions(-) diff --git a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc index 63509f3534..7b317df31a 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc @@ -29,11 +29,11 @@ TaskGraphExecutionTrace simulate_task_graph_execution( "simulate_task_graph_execution cannot simulate cyclic directed graphs"); } - TaskGraphExecutionState execution_state = - TaskGraphExecutionState{/*ready_tasks=*/set_of(get_initial_nodes(task_graph)), - /*in_progress_tasks=*/{}, - /*finished_tasks=*/{}, - /*current_time=*/0.0}; + TaskGraphExecutionState execution_state = TaskGraphExecutionState{ + /*ready_tasks=*/set_of(get_initial_nodes(task_graph)), + /*in_progress_tasks=*/{}, + /*finished_tasks=*/{}, + /*current_time=*/0.0}; std::unordered_set task_profiles; diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index c958fbb9b8..88110f914a 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -14,8 +14,8 @@ #include "utils/containers/scanl.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" -#include "utils/containers/zip_with_strict.h" #include "utils/containers/zip3_strict.h" +#include "utils/containers/zip_with_strict.h" #include "utils/exception.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/nonnegative_int/num_elements.h" @@ -53,8 +53,8 @@ MachineView machine_view_from_strides_and_machine_spec_dimensions( start, strides)); } - std::vector dimensions = - zip_with_strict(strides, dims, [](stride_t s, MachineSpecificationDimension d) { + std::vector dimensions = zip_with_strict( + strides, dims, [](stride_t s, MachineSpecificationDimension d) { return MachineViewDimension{s, d}; }); return MachineView{start, dimensions}; diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 7bfca5e8b1..2ea049e0b7 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H +#include #include #include -#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/zip3.h b/lib/utils/include/utils/containers/zip3.h index 18fcb28d03..88b79f429d 100644 --- a/lib/utils/include/utils/containers/zip3.h +++ b/lib/utils/include/utils/containers/zip3.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H -#include -#include -#include #include +#include +#include +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/zip3_strict.h b/lib/utils/include/utils/containers/zip3_strict.h index 92620e29ac..40ad31d628 100644 --- a/lib/utils/include/utils/containers/zip3_strict.h +++ b/lib/utils/include/utils/containers/zip3_strict.h @@ -1,18 +1,26 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H +#include "utils/containers/zip3.h" #include "utils/exception.h" #include "utils/fmt/vector.h" -#include "utils/containers/zip3.h" namespace FlexFlow { template std::vector> zip3_strict(std::vector const &as, - std::vector const &bs, - std::vector const &cs) { + std::vector const &bs, + std::vector const &cs) { if (!(as.size() == bs.size() && bs.size() == cs.size())) { - throw mk_runtime_error(fmt::format("zip3_strict requires as, bs, and cs to have the same length, but received as={} (length {}), bs={} (length {}), and cs={} (length {})", as, as.size(), bs, bs.size(), cs, cs.size())); + throw mk_runtime_error(fmt::format( + "zip3_strict requires as, bs, and cs to have the same length, but " + "received as={} (length {}), bs={} (length {}), and cs={} (length {})", + as, + as.size(), + bs, + bs.size(), + cs, + cs.size())); } return zip3(as, bs, cs); diff --git a/lib/utils/include/utils/containers/zip_strict.h b/lib/utils/include/utils/containers/zip_strict.h index 42f32e64d2..64049042d4 100644 --- a/lib/utils/include/utils/containers/zip_strict.h +++ b/lib/utils/include/utils/containers/zip_strict.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H +#include "utils/containers/zip.h" #include "utils/exception.h" #include "utils/fmt/vector.h" -#include "utils/containers/zip.h" namespace FlexFlow { @@ -11,7 +11,13 @@ template std::vector> zip_strict(std::vector const &lhs, std::vector const &rhs) { if (lhs.size() != rhs.size()) { - throw mk_runtime_error(fmt::format("zip_strict requires lhs and rhs to have the same length, but received lhs={} (length {}), rhs={} (length {})", lhs, lhs.size(), rhs, rhs.size())); + throw mk_runtime_error( + fmt::format("zip_strict requires lhs and rhs to have the same length, " + "but received lhs={} (length {}), rhs={} (length {})", + lhs, + lhs.size(), + rhs, + rhs.size())); } return zip(lhs, rhs); diff --git a/lib/utils/include/utils/containers/zip_with.h b/lib/utils/include/utils/containers/zip_with.h index fb10f2a89e..7ae91a7336 100644 --- a/lib/utils/include/utils/containers/zip_with.h +++ b/lib/utils/include/utils/containers/zip_with.h @@ -5,8 +5,12 @@ namespace FlexFlow { -template > -std::vector zip_with(std::vector const &l, std::vector const &r, F &&f) { +template > +std::vector + zip_with(std::vector const &l, std::vector const &r, F &&f) { std::vector result; for (int i = 0; i < l.size() && i < r.size(); i++) { result.push_back(f(l.at(i), r.at(i))); diff --git a/lib/utils/include/utils/containers/zip_with_strict.h b/lib/utils/include/utils/containers/zip_with_strict.h index 357d0e94c6..fd1e2fa7fd 100644 --- a/lib/utils/include/utils/containers/zip_with_strict.h +++ b/lib/utils/include/utils/containers/zip_with_strict.h @@ -1,17 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H -#include "utils/exception.h" -#include #include "utils/containers/zip_with.h" +#include "utils/exception.h" #include "utils/fmt/vector.h" +#include namespace FlexFlow { -template > -std::vector zip_with_strict(std::vector const &lhs, std::vector const &rhs, F &&f) { +template > +std::vector zip_with_strict(std::vector const &lhs, + std::vector const &rhs, + F &&f) { if (lhs.size() != rhs.size()) { - throw mk_runtime_error(fmt::format("zip_with_strict requires inputs to have the same length, but received lhs = {} (length {}) and rhs = {} (length {})", lhs, lhs.size(), rhs, rhs.size())); + throw mk_runtime_error(fmt::format( + "zip_with_strict requires inputs to have the same length, but received " + "lhs = {} (length {}) and rhs = {} (length {})", + lhs, + lhs.size(), + rhs, + rhs.size())); } return zip_with(lhs, rhs, f); diff --git a/lib/utils/include/utils/fmt/tuple.h b/lib/utils/include/utils/fmt/tuple.h index 80054f8e5e..8248cc1cbf 100644 --- a/lib/utils/include/utils/fmt/tuple.h +++ b/lib/utils/include/utils/fmt/tuple.h @@ -2,27 +2,29 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H #include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include "utils/tuple/visit.h" +#include #include #include -#include #include -#include "utils/join_strings.h" -#include "utils/tuple/visit.h" namespace fmt { template -struct formatter, Char> - : formatter { +struct formatter, Char> : formatter { template auto format(std::tuple const &t, FormatContext &ctx) const -> decltype(ctx.out()) { std::vector stringified_elements; - ::FlexFlow::visit_tuple(t, [&](auto const &element) -> void { stringified_elements.push_back(fmt::to_string(element)); }); + ::FlexFlow::visit_tuple(t, [&](auto const &element) -> void { + stringified_elements.push_back(fmt::to_string(element)); + }); - return formatter::format("{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx); + return formatter::format( + "{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx); } }; @@ -30,7 +32,7 @@ struct formatter, Char> namespace FlexFlow { -template +template std::ostream &operator<<(std::ostream &s, std::tuple const &t) { return (s << fmt::to_string(t)); } diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 543d9fb6a8..c1fd774850 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_UTILS_TUPLE_H #include "utils/exception.h" +#include "utils/tuple/visit.h" #include "utils/type_traits_core.h" #include #include #include #include #include -#include "utils/tuple/visit.h" // Adapted from // https://github.com/bitwizeshift/BackportCpp/blob/4f33a7f9b219f169e60d8ed2fd5731a3a23288e4/include/bpstd/tuple.hpp diff --git a/lib/utils/include/utils/tuple/visit.h b/lib/utils/include/utils/tuple/visit.h index 741eac1e88..8c3892980a 100644 --- a/lib/utils/include/utils/tuple/visit.h +++ b/lib/utils/include/utils/tuple/visit.h @@ -1,13 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H -#include #include +#include namespace FlexFlow { template -void visit_tuple_impl(Tuple const &tuple, Visitor &&v, std::index_sequence) { +void visit_tuple_impl(Tuple const &tuple, + Visitor &&v, + std::index_sequence) { (v(std::get(tuple)), ...); } diff --git a/lib/utils/src/utils/containers/zip.cc b/lib/utils/src/utils/containers/zip.cc index 7569640802..80be287ed9 100644 --- a/lib/utils/src/utils/containers/zip.cc +++ b/lib/utils/src/utils/containers/zip.cc @@ -6,8 +6,7 @@ namespace FlexFlow { using L1 = value_type<0>; using R1 = value_type<1>; -template - std::vector> zip(std::vector const &, - std::vector const &); +template std::vector> zip(std::vector const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3.cc b/lib/utils/src/utils/containers/zip3.cc index 6f1d9e46ae..39219f2bbe 100644 --- a/lib/utils/src/utils/containers/zip3.cc +++ b/lib/utils/src/utils/containers/zip3.cc @@ -7,9 +7,8 @@ using A1 = value_type<0>; using B1 = value_type<1>; using C1 = value_type<2>; -template - std::vector> zip3(std::vector const &, - std::vector const &, - std::vector const &); +template std::vector> zip3(std::vector const &, + std::vector const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3_strict.cc b/lib/utils/src/utils/containers/zip3_strict.cc index b3e3047547..72d6d8f0f1 100644 --- a/lib/utils/src/utils/containers/zip3_strict.cc +++ b/lib/utils/src/utils/containers/zip3_strict.cc @@ -7,9 +7,7 @@ using A1 = value_type<0>; using B1 = value_type<1>; using C1 = value_type<2>; -template - std::vector> zip3_strict(std::vector const &, - std::vector const &, - std::vector const &); +template std::vector> zip3_strict( + std::vector const &, std::vector const &, std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_strict.cc b/lib/utils/src/utils/containers/zip_strict.cc index bbc31c708e..90faf520f9 100644 --- a/lib/utils/src/utils/containers/zip_strict.cc +++ b/lib/utils/src/utils/containers/zip_strict.cc @@ -6,7 +6,7 @@ namespace FlexFlow { using L = value_type<0>; using R = value_type<1>; -template - std::vector> zip_strict(std::vector const &, std::vector const &); +template std::vector> zip_strict(std::vector const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with.cc b/lib/utils/src/utils/containers/zip_with.cc index 499d6ac8b2..d58d7eed0f 100644 --- a/lib/utils/src/utils/containers/zip_with.cc +++ b/lib/utils/src/utils/containers/zip_with.cc @@ -8,7 +8,7 @@ using T2 = value_type<1>; using Result = value_type<2>; using F = std::function; -template - std::vector zip_with(std::vector const &, std::vector const &, F &&); +template std::vector + zip_with(std::vector const &, std::vector const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with_strict.cc b/lib/utils/src/utils/containers/zip_with_strict.cc index 349ee9a37c..8dbfea6c2b 100644 --- a/lib/utils/src/utils/containers/zip_with_strict.cc +++ b/lib/utils/src/utils/containers/zip_with_strict.cc @@ -8,7 +8,7 @@ using T2 = value_type<1>; using Result = value_type<2>; using F = std::function; -template - std::vector zip_with_strict(std::vector const &, std::vector const &, F &&); +template std::vector + zip_with_strict(std::vector const &, std::vector const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/tuple.cc b/lib/utils/src/utils/fmt/tuple.cc index 3ba2c53d89..0a8c29cb43 100644 --- a/lib/utils/src/utils/fmt/tuple.cc +++ b/lib/utils/src/utils/fmt/tuple.cc @@ -2,8 +2,7 @@ namespace FlexFlow { - -template - std::ostream &operator<<(std::ostream &s, std::tuple const &); +template std::ostream &operator<<(std::ostream &s, + std::tuple const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/tuple/visit.cc b/lib/utils/src/utils/tuple/visit.cc index f0d218b207..58e9398928 100644 --- a/lib/utils/src/utils/tuple/visit.cc +++ b/lib/utils/src/utils/tuple/visit.cc @@ -9,7 +9,6 @@ using T2 = value_type<1>; using T3 = value_type<2>; using Visitor = std::function const &)>; -template - void visit_tuple(std::tuple const &, Visitor &&); +template void visit_tuple(std::tuple const &, Visitor &&); } // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h index 6a11e7abcb..ed23d597d6 100644 --- a/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h @@ -6,7 +6,7 @@ namespace doctest { -template +template struct StringMaker> { static String convert(std::tuple const &m) { return toString(fmt::to_string(m)); diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc index 4ef1e451ac..eb1d54d242 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -8,7 +8,6 @@ using R = value_type<1>; namespace doctest { -template - struct StringMaker>; +template struct StringMaker>; } // namespace doctest diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc index f832d41e84..8f2f90bfc9 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc @@ -9,7 +9,6 @@ using C = value_type<2>; namespace doctest { -template - struct StringMaker>; +template struct StringMaker>; } // namespace doctest diff --git a/lib/utils/test/src/utils/containers/zip.cc b/lib/utils/test/src/utils/containers/zip.cc index c305e53f69..c29415d920 100644 --- a/lib/utils/test/src/utils/containers/zip.cc +++ b/lib/utils/test/src/utils/containers/zip.cc @@ -1,8 +1,8 @@ -#include #include "utils/containers/zip.h" -#include -#include "test/utils/doctest/fmt/vector.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include using namespace ::FlexFlow; @@ -23,7 +23,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector rhs = {5, 4, 8}; std::vector> result = zip(lhs, rhs); - std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + std::vector> correct = { + {"a", 5}, {"b", 4}, {"b", 8}}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/zip3.cc b/lib/utils/test/src/utils/containers/zip3.cc index f1613105ee..4268c41aaa 100644 --- a/lib/utils/test/src/utils/containers/zip3.cc +++ b/lib/utils/test/src/utils/containers/zip3.cc @@ -1,8 +1,8 @@ -#include #include "utils/containers/zip3.h" -#include -#include "test/utils/doctest/fmt/vector.h" #include "test/utils/doctest/fmt/tuple.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include using namespace ::FlexFlow; @@ -13,8 +13,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4, 5}; std::vector input_c = {3, 4, 3}; - std::vector> result = zip3(input_a, input_b, input_c); - std::vector> correct = {{2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = { + {2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; CHECK(result == correct); } @@ -24,11 +26,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {"a", "d", "d"}; std::vector> input_c = {{1, 2}, {}, {3, 1}}; - std::vector>> result = zip3(input_a, input_b, input_c); + std::vector>> result = + zip3(input_a, input_b, input_c); std::vector>> correct = { - {2, "a", {1, 2}}, - {1, "d", {}}, - {2, "d", {3, 1}}, + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, }; CHECK(result == correct); @@ -39,7 +42,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4, 5}; std::vector input_c = {3, 4}; - std::vector> result = zip3(input_a, input_b, input_c); + std::vector> result = + zip3(input_a, input_b, input_c); std::vector> correct = {{2, 5, 3}}; CHECK(result == correct); @@ -50,7 +54,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4}; std::vector input_c = {3, 4, 3}; - std::vector> result = zip3(input_a, input_b, input_c); + std::vector> result = + zip3(input_a, input_b, input_c); std::vector> correct = {{2, 5, 3}, {1, 4, 4}}; CHECK(result == correct); @@ -61,7 +66,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4, 5}; std::vector input_c = {3, 3}; - std::vector> result = zip3(input_a, input_b, input_c); + std::vector> result = + zip3(input_a, input_b, input_c); std::vector> correct = {{2, 5, 3}, {1, 4, 3}}; CHECK(result == correct); @@ -72,7 +78,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4, 5}; std::vector input_c = {}; - std::vector> result = zip3(input_a, input_b, input_c); + std::vector> result = + zip3(input_a, input_b, input_c); std::vector> correct = {}; CHECK(result == correct); @@ -83,7 +90,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {}; std::vector input_c = {}; - std::vector> result = zip3(input_a, input_b, input_c); + std::vector> result = + zip3(input_a, input_b, input_c); std::vector> correct = {}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/zip3_strict.cc b/lib/utils/test/src/utils/containers/zip3_strict.cc index abfc4576d5..1f69c91e3b 100644 --- a/lib/utils/test/src/utils/containers/zip3_strict.cc +++ b/lib/utils/test/src/utils/containers/zip3_strict.cc @@ -1,8 +1,8 @@ -#include #include "utils/containers/zip3_strict.h" -#include -#include "test/utils/doctest/fmt/vector.h" #include "test/utils/doctest/fmt/tuple.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include using namespace ::FlexFlow; @@ -13,8 +13,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {5, 4, 5}; std::vector input_c = {3, 4, 3}; - std::vector> result = zip3_strict(input_a, input_b, input_c); - std::vector> correct = {{2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + std::vector> result = + zip3_strict(input_a, input_b, input_c); + std::vector> correct = { + {2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; CHECK(result == correct); } @@ -24,11 +26,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {"a", "d", "d"}; std::vector> input_c = {{1, 2}, {}, {3, 1}}; - std::vector>> result = zip3_strict(input_a, input_b, input_c); + std::vector>> result = + zip3_strict(input_a, input_b, input_c); std::vector>> correct = { - {2, "a", {1, 2}}, - {1, "d", {}}, - {2, "d", {3, 1}}, + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, }; CHECK(result == correct); @@ -71,7 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector input_b = {}; std::vector input_c = {}; - std::vector> result = zip3_strict(input_a, input_b, input_c); + std::vector> result = + zip3_strict(input_a, input_b, input_c); std::vector> correct = {}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/zip_strict.cc b/lib/utils/test/src/utils/containers/zip_strict.cc index 0b7d35e0f4..ae0dc6747a 100644 --- a/lib/utils/test/src/utils/containers/zip_strict.cc +++ b/lib/utils/test/src/utils/containers/zip_strict.cc @@ -1,8 +1,8 @@ -#include #include "utils/containers/zip_strict.h" -#include -#include "test/utils/doctest/fmt/vector.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include using namespace ::FlexFlow; @@ -13,7 +13,8 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector rhs = {5, 4, 8}; std::vector> result = zip_strict(lhs, rhs); - std::vector> correct = {{"a", 5}, {"b", 4}, {"b", 8}}; + std::vector> correct = { + {"a", 5}, {"b", 4}, {"b", 8}}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/zip_with.cc b/lib/utils/test/src/utils/containers/zip_with.cc index fe306bbe9e..45cecec84b 100644 --- a/lib/utils/test/src/utils/containers/zip_with.cc +++ b/lib/utils/test/src/utils/containers/zip_with.cc @@ -1,9 +1,9 @@ #include "utils/containers/zip_with.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include -#include "test/utils/doctest/fmt/vector.h" -#include "test/utils/doctest/fmt/pair.h" using namespace ::FlexFlow; @@ -13,18 +13,20 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector v1 = {1, 3, 4, 3}; std::vector v2 = {"aa", "cc", "bb", "dd"}; - std::vector> result = zip_with(v1, v2, [](int x1, std::string const &x2) { return std::make_pair(x1, x2); }); + std::vector> result = + zip_with(v1, v2, [](int x1, std::string const &x2) { + return std::make_pair(x1, x2); + }); std::vector> correct = { - {1, "aa"}, - {3, "cc"}, - {4, "bb"}, - {3, "dd"}, + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, }; CHECK(result == correct); } - SUBCASE("input lengths don't match") { auto add = [](int x1, int x2) { return x1 + x2; }; @@ -33,14 +35,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("first input is shorter") { std::vector result = zip_with(shorter, longer, add); - std::vector correct = {1+1, 2+3}; + std::vector correct = {1 + 1, 2 + 3}; CHECK(result == correct); } SUBCASE("second input is shorter") { std::vector result = zip_with(longer, shorter, add); - std::vector correct = {1+1, 2+3}; + std::vector correct = {1 + 1, 2 + 3}; CHECK(result == correct); } @@ -50,7 +52,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector nonempty = {1, 2}; std::vector empty = {}; - auto throw_err = [](int x1, int x2) -> int { throw std::runtime_error("error"); }; + auto throw_err = [](int x1, int x2) -> int { + throw std::runtime_error("error"); + }; SUBCASE("first input is empty") { std::vector result = zip_with(empty, nonempty, throw_err); diff --git a/lib/utils/test/src/utils/containers/zip_with_strict.cc b/lib/utils/test/src/utils/containers/zip_with_strict.cc index e86a1f114a..0730442e59 100644 --- a/lib/utils/test/src/utils/containers/zip_with_strict.cc +++ b/lib/utils/test/src/utils/containers/zip_with_strict.cc @@ -1,9 +1,9 @@ #include "utils/containers/zip_with_strict.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include -#include "test/utils/doctest/fmt/vector.h" -#include "test/utils/doctest/fmt/pair.h" using namespace ::FlexFlow; @@ -13,12 +13,15 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector v1 = {1, 3, 4, 3}; std::vector v2 = {"aa", "cc", "bb", "dd"}; - std::vector> result = zip_with(v1, v2, [](int x1, std::string const &x2) { return std::make_pair(x1, x2); }); + std::vector> result = + zip_with(v1, v2, [](int x1, std::string const &x2) { + return std::make_pair(x1, x2); + }); std::vector> correct = { - {1, "aa"}, - {3, "cc"}, - {4, "bb"}, - {3, "dd"}, + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, }; CHECK(result == correct); @@ -42,7 +45,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("properly handles empty inputs") { std::vector empty = {}; - auto throw_err = [](int x1, int x2) -> int { throw std::runtime_error("error"); }; + auto throw_err = [](int x1, int x2) -> int { + throw std::runtime_error("error"); + }; std::vector result = zip_with(empty, empty, throw_err); std::vector correct = empty; diff --git a/lib/utils/test/src/utils/tuple/visit.cc b/lib/utils/test/src/utils/tuple/visit.cc index 7024f12e65..ada8b1e786 100644 --- a/lib/utils/test/src/utils/tuple/visit.cc +++ b/lib/utils/test/src/utils/tuple/visit.cc @@ -1,23 +1,24 @@ #include "utils/tuple/visit.h" -#include #include "utils/overload.h" -#include +#include #include +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("visit(std::tuple, Visitor)") { std::ostringstream oss; - auto visitor = overload { - [&](int const &i) -> void { oss << "int(" << i << "), "; }, - [&](bool const &b) -> void { oss << "bool(" << b << "), "; }, - [&](std::string const &s) -> void { oss << "string(" << s << "), "; }, + auto visitor = overload{ + [&](int const &i) -> void { oss << "int(" << i << "), "; }, + [&](bool const &b) -> void { oss << "bool(" << b << "), "; }, + [&](std::string const &s) -> void { oss << "string(" << s << "), "; }, }; SUBCASE("repeated types") { - std::tuple input = {3, "hello", false, "world"}; - + std::tuple input = { + 3, "hello", false, "world"}; + visit_tuple(input, visitor); std::string result = oss.str();