diff --git a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc index 974a70ddf5..7b317df31a 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/simulate_task_graph_execution.cc @@ -29,11 +29,11 @@ TaskGraphExecutionTrace simulate_task_graph_execution( "simulate_task_graph_execution cannot simulate cyclic directed graphs"); } - TaskGraphExecutionState execution_state = - TaskGraphExecutionState{/*ready_tasks=*/set_of(get_sources(task_graph)), - /*in_progress_tasks=*/{}, - /*finished_tasks=*/{}, - /*current_time=*/0.0}; + TaskGraphExecutionState execution_state = TaskGraphExecutionState{ + /*ready_tasks=*/set_of(get_initial_nodes(task_graph)), + /*in_progress_tasks=*/{}, + /*finished_tasks=*/{}, + /*current_time=*/0.0}; std::unordered_set task_profiles; diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index fe319dc63c..88110f914a 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -14,7 +14,8 @@ #include "utils/containers/scanl.h" #include "utils/containers/sum.h" #include "utils/containers/transform.h" -#include "utils/containers/zip.h" +#include "utils/containers/zip3_strict.h" +#include "utils/containers/zip_with_strict.h" #include "utils/exception.h" #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/nonnegative_int/num_elements.h" @@ -52,9 +53,9 @@ MachineView machine_view_from_strides_and_machine_spec_dimensions( start, strides)); } - std::vector dimensions = - transform(zip(strides, dims), [&](auto const &p) { - return MachineViewDimension{p.first, p.second}; + std::vector dimensions = zip_with_strict( + strides, dims, [](stride_t s, MachineSpecificationDimension d) { + return MachineViewDimension{s, d}; }); return MachineView{start, dimensions}; } @@ -109,7 +110,7 @@ std::optional get_machine_space_coordinate( nonnegative_int index = start_idx; for (auto [coeff, coord_point, stride] : - zip(coeffs, coord_points, strides)) { + zip3(coeffs, coord_points, strides)) { index += coeff * coord_point * stride; } return index; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 4cc0500fa2..376cb0c19a 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -117,7 +117,7 @@ std::unordered_set std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { - std::unordered_set raw_sources = get_sources(pcg.raw_graph); + std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); return transform(raw_sources, [](Node const &n) { return parallel_layer_guid_t{n}; }); } diff --git a/lib/utils/include/utils/commutative_pair.h b/lib/utils/include/utils/commutative_pair.h index 12cc16f90e..1a4009fedb 100644 --- a/lib/utils/include/utils/commutative_pair.h +++ b/lib/utils/include/utils/commutative_pair.h @@ -13,7 +13,7 @@ template struct commutative_pair { public: commutative_pair() = delete; - commutative_pair(T const &x, T const &y) : first(x), second(y) {} + explicit commutative_pair(T const &x, T const &y) : first(x), second(y) {} bool operator==(commutative_pair const &other) const { return this->tie() == other.tie() || this->rtie() == other.tie(); diff --git a/lib/utils/include/utils/containers/find.h b/lib/utils/include/utils/containers/find.h index eed5f8453c..7b103fed16 100644 --- a/lib/utils/include/utils/containers/find.h +++ b/lib/utils/include/utils/containers/find.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FIND_H #include +#include namespace FlexFlow { @@ -11,6 +12,12 @@ typename Container::const_iterator return std::find(c.cbegin(), c.cend(), e); } +template +typename std::unordered_set::const_iterator + find(std::unordered_set const &c, V const &e) { + return c.find(e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 0f6dbed1d3..2ea049e0b7 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H -#include +#include #include #include @@ -17,17 +17,6 @@ std::vector> zip(std::vector const &l, return result; } -template -std::vector> zip(std::vector const &a, - std::vector const &b, - std::vector const &c) { - std::vector> result; - for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { - result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/zip3.h b/lib/utils/include/utils/containers/zip3.h new file mode 100644 index 0000000000..88b79f429d --- /dev/null +++ b/lib/utils/include/utils/containers/zip3.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_H + +#include +#include +#include +#include + +namespace FlexFlow { + +template +std::vector> zip3(std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector> result; + for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { + result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip3_strict.h b/lib/utils/include/utils/containers/zip3_strict.h new file mode 100644 index 0000000000..40ad31d628 --- /dev/null +++ b/lib/utils/include/utils/containers/zip3_strict.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP3_STRICT_H + +#include "utils/containers/zip3.h" +#include "utils/exception.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +template +std::vector> zip3_strict(std::vector const &as, + std::vector const &bs, + std::vector const &cs) { + if (!(as.size() == bs.size() && bs.size() == cs.size())) { + throw mk_runtime_error(fmt::format( + "zip3_strict requires as, bs, and cs to have the same length, but " + "received as={} (length {}), bs={} (length {}), and cs={} (length {})", + as, + as.size(), + bs, + bs.size(), + cs, + cs.size())); + } + + return zip3(as, bs, cs); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_strict.h b/lib/utils/include/utils/containers/zip_strict.h new file mode 100644 index 0000000000..64049042d4 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_strict.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_STRICT_H + +#include "utils/containers/zip.h" +#include "utils/exception.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +template +std::vector> zip_strict(std::vector const &lhs, + std::vector const &rhs) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error( + fmt::format("zip_strict requires lhs and rhs to have the same length, " + "but received lhs={} (length {}), rhs={} (length {})", + lhs, + lhs.size(), + rhs, + rhs.size())); + } + + return zip(lhs, rhs); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with.h b/lib/utils/include/utils/containers/zip_with.h new file mode 100644 index 0000000000..7ae91a7336 --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_H + +#include + +namespace FlexFlow { + +template > +std::vector + zip_with(std::vector const &l, std::vector const &r, F &&f) { + std::vector result; + for (int i = 0; i < l.size() && i < r.size(); i++) { + result.push_back(f(l.at(i), r.at(i))); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_with_strict.h b/lib/utils/include/utils/containers/zip_with_strict.h new file mode 100644 index 0000000000..fd1e2fa7fd --- /dev/null +++ b/lib/utils/include/utils/containers/zip_with_strict.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_WITH_STRICT_H + +#include "utils/containers/zip_with.h" +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include + +namespace FlexFlow { + +template > +std::vector zip_with_strict(std::vector const &lhs, + std::vector const &rhs, + F &&f) { + if (lhs.size() != rhs.size()) { + throw mk_runtime_error(fmt::format( + "zip_with_strict requires inputs to have the same length, but received " + "lhs = {} (length {}) and rhs = {} (length {})", + lhs, + lhs.size(), + rhs, + rhs.size())); + } + + return zip_with(lhs, rhs, f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/tuple.h b/lib/utils/include/utils/fmt/tuple.h new file mode 100644 index 0000000000..8248cc1cbf --- /dev/null +++ b/lib/utils/include/utils/fmt/tuple.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_TUPLE_H + +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include "utils/tuple/visit.h" +#include +#include +#include +#include + +namespace fmt { + +template +struct formatter, Char> : formatter { + + template + auto format(std::tuple const &t, FormatContext &ctx) const + -> decltype(ctx.out()) { + + std::vector stringified_elements; + ::FlexFlow::visit_tuple(t, [&](auto const &element) -> void { + stringified_elements.push_back(fmt::to_string(element)); + }); + + return formatter::format( + "{" + ::FlexFlow::join_strings(stringified_elements, ", ") + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::tuple const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 25b0103f9c..5cf0c88015 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,9 +15,10 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected -- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. +- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. +- `DataflowGraph`: used to model computation graphs. See the [DataflowGraph](#dataflowgraph) section for a detailed explanation. Examples of the different graph variants are shown below. @@ -37,7 +38,7 @@ flowchart TD D --- B ``` -Example of `DirectedGraph`: +Example of `DiGraph`: ```mermaid flowchart TD A(" ") @@ -58,98 +59,36 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A("A") - B("B") - C("C") - D("D") - E("E") - F("F") - - A -->|"(■, ★)"| B - B -->|"(●, ★)"| C - C -->|"(♥, ▲)"| D - D -->|"(●, ■)"| A - B -->|"(★, ●)"| E - E -->|"(■, ■)"| B - D -->|"(●, ●)"| A - A -->|"(●, ■)"| E - D -->|"(■, ●)"| D - E -->|"(■, ■)"| E -``` -or visualized a different way, -```mermaid -flowchart TD - Acirc("●") - Asqua("■") - Bcirc("●") - Bstar("★") - Bsqua("■") - Chear("♥") - Cstar("★") - Dsqua("■") - Dcirc("●") - Dtria("▲") - Ecirc("●") - Esqua("■") - Fplaceholder(" ") - - style Fplaceholder fill:#0000,stroke:#0000 - - subgraph "A" - Acirc - Asqua - end - - subgraph "B" - Bsqua - Bcirc - Bstar - end - - subgraph "C" - Chear - Cstar - end - - subgraph "D" - Dsqua - Dcirc - Dtria - end - - subgraph "E" - Ecirc - Esqua - end - - subgraph "F" - Fplaceholder - end - - Asqua --> Bstar - Bcirc --> Cstar - Chear --> Dtria - Dcirc --> Asqua - Bstar --> Ecirc - Esqua --> Bsqua - Dcirc --> Acirc - Acirc --> Esqua - Dsqua --> Dcirc - Esqua --> Esqua + A + B + C + D + E + F + + A --> B + B --> C + C --> D + D --> A + B --> E + E --> B + D --> A + A --> E + D --> D + E --> E ``` -Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. +Note that the node names are completely arbitrary: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. -To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). +To add a node to an `UndirectedGraph g`, simply call `g.add_node()`, which will return a `Node` object. +For semantics closer to `networkx`'s method of adding nodes, `g.add_node_unsafe(my_node)` can be used. This is useful when constructing a modified copy of an existing graph (given that it maintains node bijection), though it is not generally recommended. +The interface for node addition is identical for `DiGraph` and `MultiDiGraph`. To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. -`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. -Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. -`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -158,11 +97,13 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. +They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). +You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -170,41 +111,135 @@ Both `Graph` and `GraphView` types follow normal value semantics. This may seem wasteful (oftentimes graphs are large objects that are passed around via reference to avoid making additional copies), but the `Graph` and `GraphView` types internally implement copy-on-write optimizations to only perform the minimum number of actual copies while maintaining immutability and lifetime safety (if you allocate a `DiGraph` use for example `get_subgraph` to get a `DiGraphView` representing a part of this graph, modifications to the underlying `DiGraph` will not be mirrored in the `DiGraphView` and the `DiGraphView` will remain valid even after the base `DiGraph` leaves scope. At this point, however, we still have not discussed how to create a graph. -The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: +The user-facing graph interface is intentionally separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. +For example, to construct a `DiGraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open, Upward, Downward +### DataflowGraph + +The primary abstraction for representing computation graphs / task graphs is the `DataflowGraph` interface (along with its variants, `OpenDataflowGraph`, `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`). +At a high level, nodes represent multivariate functions (from tuples of inputs to tuple of outputs), while edges represent value uses of such functions. -`Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). +`DataflowGraph` is similar to `MultiDiGraph`, but with the following important differences: + - The edges entering, exiting a given nodes have a well-defined order. + - The outputs of a given node also have a well-defined order. + - `DataflowGraph`s are directed acyclic graphs. This is enforced by the interface used to construct them, since a node can only be added to the graph after all of its predecessor nodes have already been added. -![Open graphs inheritance diagram](docs/open.svg) +The main components of `DataflowGraph` are as follows: +- `DataflowInput`: used to denote an entry in the ordered sequence of incoming dependencies (arguments) of a given node (operator). +- `DataflowOutput`: used to denote an entry in the ordered sequence of outgoing results (value uses) from a given node (operator). +- `DataflowEdge`: wrapper around a `DataflowInput`, `DataflowOutput` pair between 2 nodes. +- `NodeAddedResult`: returned upon adding a new node. Contains the newly generated `Node` and the vector of `DataflowOutput`s for the given node. -Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) +`DataflowGraph`s are constructed as follows: +```cpp + auto g = DataflowGraph::create(); + + // Node with no inputs and 2 outputs + NodeAddedResult n1_result = g.add_node({}, 2); + Node n1 = n1_result.node; + DataflowOutput n1_o1 = n1_result.outputs[0]; + DataflowOutput n1_o2 = n1_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n2_result = g.add_node({n1_o1, n1_o2}, 1); + Node n2 = n2_result.node; + DataflowOutput n2_o1 = n2_result.outputs[0]; + + // Node with 1 input and 2 outputs + NodeAddedResult n3_result = g.add_node({n1_o2}, 1); + Node n3 = n3_result.node; + DataflowOutput n3_o1 = n3_result.outputs[0]; + DataflowOutput n3_o2 = n3_result.outputs[1]; + + // Node with 2 inputs and 1 output + NodeAddedResult n4_result = g.add_node({n2_o1, n3_o1}, 1); + Node n4 = n4_result.node; + DataflowOutput n4_o1 = n4_result.outputs[0]; +``` -### Labelled Graphs +which generates the following graph + +```mermaid +flowchart TD + subgraph Node1[ ] + direction TB + N1Process[n1] + n1_o1((n1_o1)) + n1_o2((n1_o2)) + N1Process --> n1_o1 + N1Process --> n1_o2 + end + + subgraph Node2[ ] + direction TB + n2_i1((n2_i1)) + n2_i2((n2_i2)) + N2Process[n2] + n2_o1((o1)) + n2_i1 --> N2Process + n2_i2 --> N2Process + N2Process --> n2_o1 + end + + subgraph Node3[ ] + direction TB + n3_i1((n3_i1)) + N3Process[n3] + n3_o1((n3_o1)) + n3_o2((n3_o2)) + n3_i1 --> N3Process + N3Process --> n3_o1 + N3Process --> n3_o2 + end + + subgraph Node4[ ] + direction TB + n4_i1((n4_i1)) + n4_i2((n4_i2)) + N4Process[n4] + n4_o1((n4_o1)) + n4_i1 --> N4Process + n4_i2 --> N4Process + N4Process --> n4_o1 + end + + n1_o1 --> n2_i1 + n1_o2 --> n2_i2 + n1_o2 --> n3_i1 + n2_o1 --> n4_i1 + n3_o1 --> n4_i2 +``` + + +### Open Dataflow Variant + +`Open` should be interpreted in the topological sense: that is, a graph that contains some edges where one of the edge's 2 nodes is not present in the graph itself. +This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. +`DataflowGraphInput` is used to represent the open (incoming) inputs to the graph. Note that, unlike `DataFlowInput`, `DataflowGraphInput`s are unordered (given that they are inputs to possibly several different nodes within the graph). + +### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). -While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. -Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. +Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. +- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s. + +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) +Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram](docs/labelled.svg) - - ## Internals @@ -236,12 +271,7 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. -To create graphs within the library, we thus use the following syntax: -`BaseGraph obj = BaseGraph::create();` - -Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` - ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). -In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 3f170b5652..ca59f997c7 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -138,12 +138,6 @@ std::unordered_set get_neighbors(DiGraphView const &, Node const &); // std::unordered_set get_neighbors(MultiDiGraphView const &, Node const // &); -// return the set of nodes without incoming edges -std::unordered_set get_sources(DiGraphView const &); - -// return the set of nodes without outgoing edges -std::unordered_set get_sinks(DiGraphView const &); - // std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g); // std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g); // std::unordered_set get_open_sources(OpenMultiDiGraphView const &g); diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 58c28aaff6..043187208c 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -43,8 +43,6 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms.h b/lib/utils/include/utils/graph/digraph/algorithms.h index 370f181c37..67cfba13ff 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms.h +++ b/lib/utils/include/utils/graph/digraph/algorithms.h @@ -6,8 +6,16 @@ namespace FlexFlow { std::unordered_set get_edges(DiGraphView const &); -std::unordered_set get_sources(DiGraphView const &); -std::unordered_set get_sinks(DiGraphView const &); + +/** + * @brief Returns the set of nodes in the graph with no incoming edges. + */ +std::unordered_set get_initial_nodes(DiGraphView const &graph); + +/** + * @brief Returns the set of nodes in the graph with no outgoing edges. + */ +std::unordered_set get_terminal_nodes(DiGraphView const &graph); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 1e4d09d3ae..96e8864bc1 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,7 +5,22 @@ namespace FlexFlow { +/** + * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) + * + * @note By definition, the root node dominates every node and every node + * dominates itself. + * + */ std::unordered_set get_dominators(DiGraphView const &, Node const &); + +/** + * @brief Returns the intersection of the dominators of the given set of nodes. + * @note This is conceptually equivalent to merging the given set of nodes and + * then finding the set of dominators of the new merged node (where merged means + * that all edges belonging to the set of nodes now pass through a single + * unified node). + */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index e36b90d4bf..3d320b1c06 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 54f84f8d2c..0380751c55 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h index df5662804a..471a12a44b 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_incoming_edges.h @@ -8,6 +8,10 @@ namespace FlexFlow { std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &nodes); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h index 6bc73533e7..bd8c364f7e 100644 --- a/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h +++ b/lib/utils/include/utils/graph/multidigraph/algorithms/get_outgoing_edges.h @@ -2,12 +2,17 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_MULTIDIGRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H #include "utils/graph/multidigraph/multidigraph_view.h" +#include namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &, Node const &); +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 69080b9348..692ee33783 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index bddefdacb3..1d94d1a65e 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,8 +31,6 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index fce3177ef1..8d904e05f2 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,8 +22,6 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml new file mode 100644 index 0000000000..ca43a987e2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_parallel_reduction.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ExtendedParallelReduction" +features = [ + "eq", + "hash", + "fmt", +] + +docstring = """\ +@brief An ExtendedParallelReduction is a unordered collection of +`MultiDiEdge`s such that they share a common source and destination node. +""" + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml new file mode 100644 index 0000000000..ed999a22df --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/extended_series_reduction.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "ExtendedSeriesReduction" + +docstring = """\ +@details An `ExtendedSeriesReduction` is an ordered collection of +`MultiDiEdges` such that: +- The destination node of the nth edge is the same as the source node of the + (n+1)th edge. +- Such a node (intermediate node) has exactly two edges: one incoming (nth + edge) and one outgoing ((n+1)th edge). +""" + +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/multidigraph/multidiedge.dtg.h", + "" +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "edges" +type = "std::vector<::FlexFlow::MultiDiEdge>" diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h index f2a006d899..7e9aa5606d 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -4,14 +4,19 @@ #include "utils/graph/digraph/digraph.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/optional.h" -#include -#include namespace FlexFlow { std::optional get_series_parallel_decomposition(DiGraphView const &); +/** + * @brief Unoptimized version of get_series_parallel_decomposition, used for + * reference. + */ +std::optional + get_series_parallel_decomposition_unoptimized(DiGraphView const &g); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 3fc1347ee5..598548bec1 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -2,18 +2,38 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include +#include namespace FlexFlow { ParallelReduction make_parallel_reduction(MultiDiEdge const &, MultiDiEdge const &); + std::optional find_parallel_reduction(MultiDiGraphView const &); +/** + * @brief Finds all ExtendedParallelReduction for a given MultiDiGraph + */ +std::unordered_set + find_all_extended_parallel_reductions(MultiDiGraphView const &); + MultiDiEdge apply_parallel_reduction(MultiDiGraph &, ParallelReduction const &); +/** + * @brief Applies a given ExtendedParallelReduction in place to a given + * MultiDiGraph + * @details The reduction removes all but one `MultiDiEdge`, so that the source, + * destination nodes associated with the reduction become connected by a single + * edge. + */ +MultiDiEdge + apply_extended_parallel_reduction(MultiDiGraph &, + ExtendedParallelReduction const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index 52d2cb7236..b3fc201ca5 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -17,6 +17,25 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); +bool is_empty(Node const &node); +bool is_empty(SeriesSplit const &serial); +bool is_empty(ParallelSplit const ¶llel); +bool is_empty(SeriesParallelDecomposition const &sp); + +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, + Node const &node); + +// duplicate nodes within `sp` are counted multiple times +size_t num_nodes(SeriesParallelDecomposition const &sp); + +SeriesParallelDecomposition series_composition( + std::vector const &sp_compositions); +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h index a7d53fecfc..9d11e2bdfb 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -3,7 +3,9 @@ #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/series_reduction.dtg.h" +#include "utils/hash/vector.h" namespace FlexFlow { @@ -14,8 +16,44 @@ Node get_center_node(MultiDiGraphView const &, SeriesReduction const &); SeriesReduction make_series_reduction(MultiDiEdge const &, MultiDiEdge const &); std::optional find_series_reduction(MultiDiGraphView const &); +/** + * @brief Finds all the ExtendedSeriesReduction structures in a given graph. + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * We have that [(A,B), (B,D), (D,E)] and [(A,C), (C,E)] both constitute + * `ExtendedSeriesReduction`. + */ +std::unordered_set + find_all_extended_series_reductions(MultiDiGraphView const &g); + MultiDiEdge apply_series_reduction(MultiDiGraph &, SeriesReduction const &); +/** + * @brief Applies a given ExtendedSeriesReduction in-place to a given graph. + * + * For example, in the following graph: + * + * A -> B -> D -> E + * \ / + * -> C -> + * + * Given the ExtendedSeriesReduction [(A,B), (B,D), (D,E)], the intermediate + *nodes B, D, will be deleted, and the resulting graph will be: + * + * A ----> E + * \ / + * -> C -> + * + **/ +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h b/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h new file mode 100644 index 0000000000..e6c834f4fb --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/make_undirected_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_MAKE_UNDIRECTED_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_MAKE_UNDIRECTED_EDGE_H + +#include "utils/graph/undirected/undirected_edge.dtg.h" + +namespace FlexFlow { + +UndirectedEdge make_undirected_edge(Node const &, Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 33d50192cb..d051413faa 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,33 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); +#include "utils/graph/undirected/undirected_edge.dtg.h" - bool operator==(UndirectedEdge const &) const; - bool operator!=(UndirectedEdge const &) const; - bool operator<(UndirectedEdge const &) const; - -public: - Node smaller; - Node bigger; -}; +namespace FlexFlow { -bool is_connected_to(UndirectedEdge const &, Node const &); +bool is_connected_to(UndirectedEdge const &e, Node const &n); } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::UndirectedEdge> { - size_t operator()(::FlexFlow::UndirectedEdge const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml new file mode 100644 index 0000000000..0ad8232339 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 69975991ce..09b6495699 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; - friend struct GraphInternal; - private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index c2df96abc0..90dd5dd5d8 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; - friend struct GraphInternal; - private: IUndirectedGraphView const &get_ptr() const; }; diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index aaa1e033f4..5e0109ed5b 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -41,104 +41,11 @@ struct DiSubgraphView : public IDiGraphView { std::unordered_set subgraph_nodes; }; -struct JoinedNodeView { -public: - JoinedNodeView() = delete; - explicit JoinedNodeView(GraphView const &lhs, GraphView const &rhs); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::pair, std::unordered_set> - trace_nodes(std::unordered_set const &) const; - - Node at_join_key(JoinNodeKey const &) const; - JoinNodeKey at_node(Node const &) const; - -private: - bidict mapping; - NodeSource node_source; -}; - -struct JoinedUndirectedGraphView : public IUndirectedGraphView { -public: - JoinedUndirectedGraphView() = delete; - explicit JoinedUndirectedGraphView(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs); - - std::unordered_set - query_edges(UndirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedUndirectedGraphView *clone() const override; - -private: - UndirectedEdge fix_lhs_edge(UndirectedEdge const &) const; - UndirectedEdge fix_rhs_edge(UndirectedEdge const &) const; - -private: - UndirectedGraphView lhs; - UndirectedGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct JoinedDigraphView : virtual public IDiGraphView { -public: - JoinedDigraphView() = delete; - explicit JoinedDigraphView(DiGraphView const &lhs, DiGraphView const &rhs); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedNodeView const &joined_nodes_view() const; +UndirectedGraphView view_subgraph(UndirectedGraphView const &, + std::unordered_set const &); - JoinedDigraphView *clone() const override; - -private: - DirectedEdge fix_lhs_edge(DirectedEdge const &) const; - DirectedEdge fix_rhs_edge(DirectedEdge const &) const; - -private: - DiGraphView lhs; - DiGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct AddDirectedEdgesView : public IDiGraphView { -public: - AddDirectedEdgesView() = delete; - - explicit AddDirectedEdgesView(DiGraphView const &g, - std::unordered_set const &edges); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - AddDirectedEdgesView *clone() const override; - -private: - DiGraphView g; - std::unordered_set edges; -}; - -struct SingleSourceNodeView : public IDiGraphView { -public: - SingleSourceNodeView() = delete; - - explicit SingleSourceNodeView(DiGraphView const &g) : g(g) {} - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - SingleSourceNodeView *clone() const override; - -private: - DiGraphView g; - std::optional singleton_src; - std::optional joined_view; - std::unique_ptr added_edges_view; -}; +DiGraphView view_subgraph(DiGraphView const &, + std::unordered_set const &); UndirectedEdge to_undirected_edge(DirectedEdge const &); std::unordered_set @@ -176,31 +83,6 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView { UndirectedGraphView g; }; -std::unordered_map - flatten_contraction(std::unordered_map const &); - -template -Impl materialize_view(View const &g) { - Impl result; - for (Node const &n : get_nodes(g)) { - result.add_node_unsafe(n); - } - for (auto const &e : get_edges(g)) { - result.add_edge(e); - } - return result; -} - -template -Impl materialize_undirected_graph_view(IUndirectedGraphView const &g) { - return materialize_view(g); -} - -template -Impl materialize_digraph_view(IDiGraphView const &g) { - return materialize_view(g); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 0296e365a3..c1fd774850 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_TUPLE_H #include "utils/exception.h" +#include "utils/tuple/visit.h" #include "utils/type_traits_core.h" #include #include @@ -32,21 +33,6 @@ struct index_of : index_of_impl {}; } // namespace TupleUtils -template -void visit_tuple_impl(Visitor &v, std::tuple const &tup) { - v(Idx, std::get(tup)); - if (Idx >= std::tuple_size::value) { - return; - } else { - visit_tuple_impl<(Idx + 1)>(v, tup); - } -} - -template -void visit_tuple(Visitor &v, std::tuple const &tup) { - visit_tuple_impl<0>(v, tup); -} - struct tuple_get_visitor { tuple_get_visitor() = delete; tuple_get_visitor(int requested_idx, std::any &result) diff --git a/lib/utils/include/utils/tuple/visit.h b/lib/utils/include/utils/tuple/visit.h new file mode 100644 index 0000000000..8c3892980a --- /dev/null +++ b/lib/utils/include/utils/tuple/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_TUPLE_VISIT_H + +#include +#include + +namespace FlexFlow { + +template +void visit_tuple_impl(Tuple const &tuple, + Visitor &&v, + std::index_sequence) { + (v(std::get(tuple)), ...); +} + +template +void visit_tuple(std::tuple const &tuple, Visitor &&v) { + visit_tuple_impl(tuple, v, std::index_sequence_for{}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/zip.cc b/lib/utils/src/utils/containers/zip.cc index a02c9c8a35..80be287ed9 100644 --- a/lib/utils/src/utils/containers/zip.cc +++ b/lib/utils/src/utils/containers/zip.cc @@ -1 +1,12 @@ #include "utils/containers/zip.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L1 = value_type<0>; +using R1 = value_type<1>; + +template std::vector> zip(std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3.cc b/lib/utils/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..39219f2bbe --- /dev/null +++ b/lib/utils/src/utils/containers/zip3.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip3.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A1 = value_type<0>; +using B1 = value_type<1>; +using C1 = value_type<2>; + +template std::vector> zip3(std::vector const &, + std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip3_strict.cc b/lib/utils/src/utils/containers/zip3_strict.cc new file mode 100644 index 0000000000..72d6d8f0f1 --- /dev/null +++ b/lib/utils/src/utils/containers/zip3_strict.cc @@ -0,0 +1,13 @@ +#include "utils/containers/zip3_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using A1 = value_type<0>; +using B1 = value_type<1>; +using C1 = value_type<2>; + +template std::vector> zip3_strict( + std::vector const &, std::vector const &, std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_strict.cc b/lib/utils/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..90faf520f9 --- /dev/null +++ b/lib/utils/src/utils/containers/zip_strict.cc @@ -0,0 +1,12 @@ +#include "utils/containers/zip_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; + +template std::vector> zip_strict(std::vector const &, + std::vector const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with.cc b/lib/utils/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..d58d7eed0f --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template std::vector + zip_with(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_with_strict.cc b/lib/utils/src/utils/containers/zip_with_strict.cc new file mode 100644 index 0000000000..8dbfea6c2b --- /dev/null +++ b/lib/utils/src/utils/containers/zip_with_strict.cc @@ -0,0 +1,14 @@ +#include "utils/containers/zip_with_strict.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using Result = value_type<2>; +using F = std::function; + +template std::vector + zip_with_strict(std::vector const &, std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/tuple.cc b/lib/utils/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..0a8c29cb43 --- /dev/null +++ b/lib/utils/src/utils/fmt/tuple.cc @@ -0,0 +1,8 @@ +#include "utils/fmt/tuple.h" + +namespace FlexFlow { + +template std::ostream &operator<<(std::ostream &s, + std::tuple const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 6ed41daf43..d7cd979f14 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,7 +184,8 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = + UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; return contains(g.query_edges(q), e); } @@ -212,7 +213,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.smaller, e.bigger}; + return {e.endpoints.min(), e.endpoints.max()}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { @@ -480,15 +481,6 @@ DiGraphView get_subgraph(DiGraphView const &g, // return MultiDiGraphView::create(lhs, rhs); // } -DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) { - return DiGraphView::create(lhs, rhs); -} - -UndirectedGraphView join(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs) { - return UndirectedGraphView::create(lhs, rhs); -} - UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms.cc b/lib/utils/src/utils/graph/digraph/algorithms.cc index 8cd685e5c6..84798b2f62 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms.cc @@ -15,11 +15,11 @@ std::unordered_set get_edges(DiGraphView const &g) { return g.query_edges(directed_edge_query_all()); } -std::unordered_set get_sinks(DiGraphView const &g) { - return get_sources(flipped(g)); +std::unordered_set get_terminal_nodes(DiGraphView const &g) { + return get_initial_nodes(flipped(g)); } -std::unordered_set get_sources(DiGraphView const &g) { +std::unordered_set get_initial_nodes(DiGraphView const &g) { std::unordered_set all_nodes = get_nodes(g); std::unordered_set with_incoming_edge = transform(get_edges(g), [](DirectedEdge const &e) { return e.dst; }); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 92bd1e32ca..9a2f9cb019 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -71,8 +71,8 @@ std::optional extend(already_in_a_tail, tail); } - assert(already_in_a_head == set_minus(get_nodes(g), get_sinks(g))); - assert(already_in_a_tail == set_minus(get_nodes(g), get_sources(g))); + assert(already_in_a_head == set_minus(get_nodes(g), get_terminal_nodes(g))); + assert(already_in_a_tail == set_minus(get_nodes(g), get_initial_nodes(g))); return result; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index ccd2808603..bf428ed26b 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -6,7 +6,7 @@ namespace FlexFlow { bool is_complete_bipartite_digraph(DiGraphView const &g) { - return is_complete_bipartite_digraph(g, get_sources(g)); + return is_complete_bipartite_digraph(g, get_initial_nodes(g)); } bool is_complete_bipartite_digraph(DiGraphView const &g, diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 3dd9de73f0..1d909150cc 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -15,11 +15,11 @@ namespace FlexFlow { std::unordered_map> get_dominators_map(DiGraphView const &g) { - std::unordered_set sources = get_sources(g); + std::unordered_set initial_nodes = get_initial_nodes(g); std::queue queue; - for (Node src : get_sources(g)) { + for (Node src : get_initial_nodes(g)) { queue.push(src); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index 41fe3b67d5..fea799b3e9 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -9,7 +9,7 @@ namespace FlexFlow { static std::vector get_unchecked_topological_ordering(DiGraphView const &g) { - auto dfs_view = unchecked_dfs(g, get_sources(g)); + auto dfs_view = unchecked_dfs(g, get_initial_nodes(g)); std::vector order; std::unordered_set seen; std::unordered_map> predecessors = diff --git a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index 77f04f2efd..ccf943c4d3 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -42,11 +42,11 @@ std::optional return get_component_containing_node_in_tail(cbc_decomposition, n).value(); }; - std::unordered_set sources = get_sources(view); - std::unordered_set sinks = get_sinks(view); + std::unordered_set initial_nodes = get_initial_nodes(view); + std::unordered_set terminal_nodes = get_terminal_nodes(view); auto src_for_node = [&](Node const &v) -> Node { - if (contains(sources, v)) { + if (contains(initial_nodes, v)) { return alpha; } else { return component_nodes.at_l(t(v)); @@ -54,7 +54,7 @@ std::optional }; auto dst_for_node = [&](Node const &v) -> Node { - if (contains(sinks, v)) { + if (contains(terminal_nodes, v)) { return omega; } else { return component_nodes.at_l(h(v)); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index dd660f193d..018d07163d 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -9,11 +9,11 @@ std::optional is_acyclic(DiGraphView const &g) { if (num_nodes(g) == 0) { return std::nullopt; } - std::unordered_set sources = get_sources(g); - if (sources.size() == 0) { + std::unordered_set initial_nodes = get_initial_nodes(g); + if (initial_nodes.size() == 0) { return false; } - auto dfs_view = unchecked_dfs(g, sources); + auto dfs_view = unchecked_dfs(g, initial_nodes); std::unordered_set seen; for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); it++) { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 61c4f80763..6713fafe41 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -1,7 +1,9 @@ #include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/commutative_pair.h" #include "utils/containers/contains_key.h" #include "utils/containers/keys.h" #include "utils/exception.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" namespace FlexFlow { @@ -18,35 +20,41 @@ void HashmapUndirectedGraph::add_node_unsafe(Node const &node) { } void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { + for (Node const &neighbor : this->adjacency.at(n)) { + this->adjacency.at(neighbor).erase(n); + } this->adjacency.erase(n); } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.bigger)); + if (!contains_key(this->adjacency, e.endpoints.max())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.max())); } - if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.smaller)); + if (!contains_key(this->adjacency, e.endpoints.min())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.min())); } - this->adjacency.at(e.bigger).insert(e.smaller); - this->adjacency.at(e.smaller).insert(e.bigger); + this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); + this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.bigger); - m.erase(e.smaller); - m.erase(e.bigger); + std::unordered_set &max_map = this->adjacency.at(e.endpoints.max()); + max_map.erase(e.endpoints.min()); + std::unordered_set &min_map = this->adjacency.at(e.endpoints.min()); + min_map.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : src_kv.second) { - result.insert({src_kv.first, dst}); + for (auto const &dst : apply_query(query.nodes, src_kv.second)) { + result.insert(make_undirected_edge(src_kv.first, dst)); } } return result; diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc index 6f6722f635..cb44f4636d 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -27,8 +27,8 @@ void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { } void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { - assert(contains(this->nodes, e.bigger)); - assert(contains(this->nodes, e.smaller)); + assert(contains(this->nodes, e.endpoints.min())); + assert(contains(this->nodes, e.endpoints.max())); this->edges.insert(e); } diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index be39dc158f..7a5ba695f9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -1,4 +1,10 @@ #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidiedge_query.dtg.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/query_set.h" namespace FlexFlow { @@ -7,4 +13,19 @@ std::unordered_set get_incoming_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{query_set::matchall(), {n}}); } +std::unordered_map> + get_incoming_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { + std::unordered_map> result = + group_by(g.query_edges(MultiDiEdgeQuery{query_set::matchall(), + query_set{ns}}), + [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }); + + for (Node const &n : ns) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index f98c599614..d183b44137 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -1,5 +1,8 @@ #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" - +#include "utils/containers/group_by.h" +#include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/node/algorithms.h" +#include namespace FlexFlow { std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, @@ -7,4 +10,19 @@ std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, return g.query_edges(MultiDiEdgeQuery{{n}, query_set::matchall()}); } +std::unordered_map> + get_outgoing_edges(MultiDiGraphView const &g, + std::unordered_set const &ns) { + std::unordered_map> result = + group_by(g.query_edges(MultiDiEdgeQuery{query_set{ns}, + query_set::matchall()}), + [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }); + + for (Node const &n : ns) { + result[n]; + } + + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index fa17678943..375aaa3762 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -34,14 +34,14 @@ static std::optional { std::unordered_set already_mapped_src_nodes = left_entries(sink_node_mapping); - std::unordered_set src_g_sink_nodes = get_sinks(src_g); + std::unordered_set src_g_sink_nodes = get_terminal_nodes(src_g); assert(already_mapped_src_nodes == src_g_sink_nodes); } { std::unordered_set already_mapped_dst_nodes = right_entries(sink_node_mapping); - std::unordered_set dst_g_sink_nodes = get_sinks(dst_g); + std::unordered_set dst_g_sink_nodes = get_terminal_nodes(dst_g); assert(already_mapped_dst_nodes == dst_g_sink_nodes); } @@ -201,8 +201,8 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = vector_of(get_sinks(src)); - std::unordered_set dst_sink_nodes = get_sinks(dst); + std::vector src_sink_nodes = vector_of(get_terminal_nodes(src)); + std::unordered_set dst_sink_nodes = get_terminal_nodes(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { return {}; diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index cd29af59a0..33bdd74787 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -2,14 +2,19 @@ #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" #include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/series_reduction.h" @@ -20,6 +25,101 @@ std::optional DiGraphView transitively_reduced = transitive_reduction(g); + InverseLineGraphResult inverse_line_graph_result = ({ + std::optional maybe_line_graph = + get_inverse_line_graph(transitively_reduced); + if (!maybe_line_graph.has_value()) { + return std::nullopt; + } + maybe_line_graph.value(); + }); + + MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( + inverse_line_graph_result.graph); + + std::unordered_map + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return SeriesParallelDecomposition{n}; }); + + auto perform_extended_parallel_reduction = + [&](ExtendedParallelReduction const ¶llel_reduction) { + MultiDiEdge merged = + apply_extended_parallel_reduction(ttsp, parallel_reduction); + + SeriesParallelDecomposition new_tree = parallel_composition(transform( + unordered_multiset_of(parallel_reduction.edges), + [&](MultiDiEdge const &e) { return ttsp_edge_to_sp_tree.at(e); })); + + for (MultiDiEdge const &e : parallel_reduction.edges) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + return new_tree; + }; + + auto perform_extended_series_reduction = + [&](ExtendedSeriesReduction const &series_reduction) { + MultiDiEdge merged = + apply_extended_series_reduction(ttsp, series_reduction); + + SeriesParallelDecomposition new_tree = series_composition( + transform(series_reduction.edges, [&](MultiDiEdge const &e) { + return ttsp_edge_to_sp_tree.at(e); + })); + + for (MultiDiEdge const &e : series_reduction.edges) { + ttsp_edge_to_sp_tree.erase(e); + } + ttsp_edge_to_sp_tree.insert({merged, new_tree}); + + return new_tree; + }; + + while (true) { + bool reduction_has_happened = false; + + std::unordered_set parallel_reductions = + find_all_extended_parallel_reductions(ttsp); + + if (!parallel_reductions.empty()) { + for (ExtendedParallelReduction parallel_reduction : parallel_reductions) { + perform_extended_parallel_reduction(parallel_reduction); + } + reduction_has_happened = true; + } + + std::unordered_set series_reductions = + find_all_extended_series_reductions(ttsp); + if (!series_reductions.empty()) { + for (ExtendedSeriesReduction series_reduction : series_reductions) { + perform_extended_series_reduction(series_reduction); + } + reduction_has_happened = true; + } + + if (reduction_has_happened) { + continue; + } + + if (get_nodes(ttsp).size() != 2 || get_edges(ttsp).size() != 1) { + return std::nullopt; + } + + MultiDiEdge e = get_only(get_edges(ttsp)); + if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { + return ttsp_edge_to_sp_tree.at(e); + } + } +} + +std::optional + get_series_parallel_decomposition_unoptimized(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); + InverseLineGraphResult inverse_line_graph_result = ({ std::optional maybe_line_graph = get_inverse_line_graph(transitively_reduced); diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 12a6630bf0..cf03db0e8a 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,33 +1,79 @@ #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/commutative_pair.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/get_one_of.h" +#include "utils/containers/group_by.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_parallel_reduction.dtg.h" +#include "utils/hash/unordered_set.h" +#include +#include namespace FlexFlow { ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, MultiDiEdge const &e2) { - return ParallelReduction{{e1, e2}}; + return ParallelReduction{commutative_pair{e1, e2}}; } std::optional find_parallel_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 != e2 && g.get_multidiedge_src(e1) == g.get_multidiedge_src(e2) && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); - } + + std::unordered_map seen; + for (MultiDiEdge const &edge : get_edges(g)) { + DirectedEdge diedge = get_directed_edge(g, edge); + if (contains_key(seen, diedge)) { + return make_parallel_reduction(seen.at(diedge), edge); } + seen.emplace(diedge, edge); } - return std::nullopt; } +std::unordered_set + find_all_extended_parallel_reductions(MultiDiGraphView const &g) { + std::unordered_map> + reduction_groups; + for (MultiDiEdge const &edge : get_edges(g)) { + reduction_groups[get_directed_edge(g, edge)].insert(edge); + } + + std::unordered_set> reductions = filter( + unordered_set_of(values(reduction_groups)), + [](std::unordered_set const &s) { return s.size() > 1; }); + + return transform(reductions, + [&](std::unordered_set const &edges) { + return ExtendedParallelReduction{edges}; + }); +} + MultiDiEdge apply_parallel_reduction(MultiDiGraph &g, ParallelReduction const &r) { g.remove_edge(r.edges.max()); return r.edges.min(); } +MultiDiEdge apply_extended_parallel_reduction( + MultiDiGraph &g, ExtendedParallelReduction const &reduction) { + + MultiDiEdge keep_edge = get_one_of(reduction.edges); + + for (MultiDiEdge const ¶llel_edge : reduction.edges) { + if (parallel_edge != keep_edge) { + g.remove_edge(parallel_edge); + } + } + + return keep_edge; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index b7a84b871a..937fc1254e 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,12 +1,17 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" +#include "utils/containers/sum.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -74,4 +79,60 @@ std::unordered_multiset get_nodes(Node const &node) { return {node}; } +bool is_empty(Node const &node) { + return false; +} + +bool is_empty(SeriesSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(ParallelSplit const ¶llel) { + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty(widen(child)); + }); +} + +bool is_empty(SeriesParallelDecomposition const &sp) { + return sp.visit([](auto const &t) { return is_empty(t); }); +} + +SeriesParallelDecomposition series_composition( + std::vector const &sp_compositions) { + std::vector> composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.push_back(sp_comp.get()); + } + } + return SeriesParallelDecomposition{SeriesSplit{composition}}; +} + +SeriesParallelDecomposition parallel_composition( + std::unordered_multiset const + &sp_compositions) { + std::unordered_multiset< + std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> + composition{}; + for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition = multiset_union(composition, + sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.insert(sp_comp.get()); + } + } + return SeriesParallelDecomposition(ParallelSplit{composition}); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc index 7300c93fb0..5b9b592444 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,8 +1,23 @@ #include "utils/graph/series_parallel/series_reduction.h" +#include "utils/containers/contains.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/get_only.h" #include "utils/containers/require_same.h" +#include "utils/containers/subvec.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" #include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/graph/multidigraph/multidigraph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/extended_series_reduction.dtg.h" +#include "utils/hash/unordered_set.h" +#include namespace FlexFlow { @@ -26,31 +41,49 @@ SeriesReduction make_series_reduction(MultiDiEdge const &e1, std::optional find_series_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); + for (Node const &node : get_nodes(g)) { + if (get_incoming_edges(g, node).size() == 1 && + get_outgoing_edges(g, node).size() == 1) { + return make_series_reduction(get_only(get_incoming_edges(g, node)), + get_only(get_outgoing_edges(g, node))); + } + } + return std::nullopt; +} - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 == e2) { - continue; - } - Node e1_dst = g.get_multidiedge_dst(e1); - Node e2_src = g.get_multidiedge_src(e2); - if (e1_dst != e2_src) { - continue; - } +std::unordered_set + find_all_extended_series_reductions(MultiDiGraphView const &g) { - std::unordered_set outgoing = get_outgoing_edges(g, e1_dst); - std::unordered_set incoming = get_incoming_edges(g, e1_dst); + auto incoming_edges_map = get_incoming_edges(g, get_nodes(g)); + auto outgoing_edges_map = get_outgoing_edges(g, get_nodes(g)); - if (outgoing.size() > 1 || incoming.size() > 1) { - continue; - } + std::unordered_map> strands; + std::unordered_map node_to_head_of_strand; + + for (Node const &n : get_topological_ordering(g)) { + if ((incoming_edges_map.at(n).size() == 1) && + (outgoing_edges_map.at(n).size() == 1)) { + + MultiDiEdge incoming = get_only(incoming_edges_map.at(n)); + MultiDiEdge outgoing = get_only(outgoing_edges_map.at(n)); + Node pre = g.get_multidiedge_src(incoming); - return SeriesReduction{e1, e2}; + if (contains_key(node_to_head_of_strand, pre)) { + Node head = node_to_head_of_strand.at(pre); + node_to_head_of_strand.emplace(n, head); + strands.at(head).push_back(outgoing); + + } else { + node_to_head_of_strand.emplace(n, n); + strands[n].push_back(incoming); + strands[n].push_back(outgoing); + } } } - return std::nullopt; + return transform(unordered_set_of(values(strands)), [&](auto const &edges) { + return ExtendedSeriesReduction{edges}; + }); } MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { @@ -62,4 +95,21 @@ MultiDiEdge apply_series_reduction(MultiDiGraph &g, SeriesReduction const &r) { return g.add_edge(pre_node, post_node); } +MultiDiEdge + apply_extended_series_reduction(MultiDiGraph &g, + ExtendedSeriesReduction const &reduction) { + + Node first = g.get_multidiedge_src(reduction.edges.at(0)); + Node last = g.get_multidiedge_dst(reduction.edges.back()); + + std::vector internal_nodes; + for (MultiDiEdge const &e : subvec(reduction.edges, std::nullopt, -1)) { + internal_nodes.push_back(g.get_multidiedge_dst(e)); + } + + for (Node const &n : internal_nodes) { + g.remove_node(n); + } + return g.add_edge(first, last); +} } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc index 3c05b9d5d5..726fda8af7 100644 --- a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -10,7 +10,7 @@ std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, std::unordered_set result = set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { - return std::unordered_set{e.bigger, e.smaller}; + return std::unordered_set{e.endpoints.max(), e.endpoints.max()}; })); result.erase(n); return result; diff --git a/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc b/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc new file mode 100644 index 0000000000..1c1eb4ae07 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/make_undirected_edge.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" +#include "utils/commutative_pair.h" + +namespace FlexFlow { + +UndirectedEdge make_undirected_edge(Node const &n1, Node const &n2) { + return UndirectedEdge{commutative_pair{n1, n2}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 0a575e115c..4cfc6aaaa8 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,40 +1,11 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" +#include namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -static std::tuple tie(UndirectedEdge const &e) { - return std::tie(e.smaller, e.bigger); -} - -bool UndirectedEdge::operator==(UndirectedEdge const &other) const { - return tie(*this) == tie(other); -} - -bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { - return tie(*this) != tie(other); -} - -bool UndirectedEdge::operator<(UndirectedEdge const &other) const { - return tie(*this) < tie(other); -} - bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; + return e.endpoints.min() == n || e.endpoints.max() == n; } } // namespace FlexFlow - -namespace std { - -using namespace FlexFlow; - -size_t hash::operator()(UndirectedEdge const &e) const { - std::tuple members = ::FlexFlow::tie(e); - return std::hash{}(members); -} - -} // namespace std diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 3cccf1c6eb..e9e948aa40 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -7,7 +7,8 @@ UndirectedEdgeQuery undirected_edge_query_all() { } bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { - return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); + return includes(q.nodes, e.endpoints.max()) && + includes(q.nodes, e.endpoints.min()); } UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 9b5353de9f..74234033b3 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,14 +1,12 @@ #include "utils/graph/views/views.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" -#include "utils/disjoint_set.h" -#include "utils/exception.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" +#include "utils/graph/query_set.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.h" - namespace FlexFlow { UndirectedSubgraphView::UndirectedSubgraphView( @@ -65,150 +63,8 @@ DiGraphView view_subgraph(DiGraphView const &g, return DiGraphView::create(g, subgraph_nodes); } -JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { - for (Node const &n : get_nodes(lhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::LEFT}, - this->node_source.new_node()); - } - for (Node const &n : get_nodes(rhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::RIGHT}, - this->node_source.new_node()); - } -} - -std::unordered_set - JoinedNodeView::query_nodes(NodeQuery const &query) const { - // TODO @lockshaw this is going to be reimplemented in 984, so don't bother - // fixing it for now - NOT_IMPLEMENTED(); -} - -std::pair, std::unordered_set> - JoinedNodeView::trace_nodes(std::unordered_set const &nodes) const { - std::unordered_set left_nodes, right_nodes; - - for (Node const &n : nodes) { - JoinNodeKey k = this->at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - return {left_nodes, right_nodes}; -} - -Node JoinedNodeView::at_join_key(JoinNodeKey const &k) const { - return this->mapping.at_l(k); -} - -JoinNodeKey JoinedNodeView::at_node(Node const &n) const { - return this->mapping.at_r(n); -} - -JoinedUndirectedGraphView::JoinedUndirectedGraphView( - UndirectedGraphView const &lhs, UndirectedGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -std::unordered_set - JoinedUndirectedGraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set JoinedUndirectedGraphView::query_edges( - UndirectedEdgeQuery const &query) const { - std::unordered_set nodes = this->query_nodes(NodeQuery{query.nodes}); - std::unordered_set left_nodes, right_nodes; - for (Node const &n : nodes) { - JoinNodeKey k = this->joined_nodes.at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - std::unordered_set result; - for (UndirectedEdge const &e : - this->lhs.query_edges(UndirectedEdgeQuery{left_nodes})) { - result.insert(this->fix_lhs_edge(e)); - } - for (UndirectedEdge const &e : - this->rhs.query_edges(UndirectedEdgeQuery{right_nodes})) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return { - this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key( - JoinNodeKey{e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key( - JoinNodeKey{e.bigger, LRDirection::RIGHT})}; -} - -JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, - DiGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -JoinedDigraphView *JoinedDigraphView::clone() const { - return new JoinedDigraphView(lhs, rhs); -} - -std::unordered_set - JoinedDigraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set - JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { - - std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); - std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); - auto traced_srcs = this->joined_nodes.trace_nodes(srcs); - auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - DirectedEdgeQuery left_query = - DirectedEdgeQuery{traced_srcs.first, traced_dsts.first}; - DirectedEdgeQuery right_query = - DirectedEdgeQuery{traced_srcs.second, traced_dsts.second}; - - std::unordered_set result; - for (DirectedEdge const &e : this->lhs.query_edges(left_query)) { - result.insert(this->fix_lhs_edge(e)); - } - for (DirectedEdge const &e : this->rhs.query_edges(right_query)) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -DirectedEdge JoinedDigraphView::fix_lhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT})}; -} - -DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; -} - UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return {e.src, e.dst}; + return make_undirected_edge(e.src, e.dst); } std::unordered_set to_undirected_edges( @@ -218,8 +74,9 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, - DirectedEdge{e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.endpoints.min(), e.endpoints.max()}, + DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; } std::unordered_set to_directed_edges( @@ -258,8 +115,7 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + g.query_edges(UndirectedEdgeQuery{query_union(q.srcs, q.dsts)}); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); @@ -272,8 +128,4 @@ std::unordered_set return g.query_nodes(q); } -JoinedUndirectedGraphView *JoinedUndirectedGraphView::clone() const { - return new JoinedUndirectedGraphView(lhs, rhs); -} - } // namespace FlexFlow diff --git a/lib/utils/src/utils/tuple/visit.cc b/lib/utils/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..58e9398928 --- /dev/null +++ b/lib/utils/src/utils/tuple/visit.cc @@ -0,0 +1,14 @@ +#include "utils/tuple/visit.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T1 = value_type<0>; +using T2 = value_type<1>; +using T3 = value_type<2>; +using Visitor = std::function const &)>; + +template void visit_tuple(std::tuple const &, Visitor &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h new file mode 100644 index 0000000000..ed23d597d6 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/tuple.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_TUPLE_H + +#include "utils/fmt/tuple.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::tuple const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc index 106fb1c900..eb1d54d242 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -1 +1,13 @@ #include "test/utils/doctest/fmt/pair.h" +#include "utils/archetypes/value_type.h" + +using ::FlexFlow::value_type; + +using L = value_type<0>; +using R = value_type<1>; + +namespace doctest { + +template struct StringMaker>; + +} // namespace doctest diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc new file mode 100644 index 0000000000..8f2f90bfc9 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/tuple.cc @@ -0,0 +1,14 @@ +#include "test/utils/doctest/fmt/tuple.h" +#include "utils/archetypes/value_type.h" + +using ::FlexFlow::value_type; + +using A = value_type<0>; +using B = value_type<1>; +using C = value_type<2>; + +namespace doctest { + +template struct StringMaker>; + +} // namespace doctest diff --git a/lib/utils/test/src/utils/commutative_pair.cc b/lib/utils/test/src/utils/commutative_pair.cc index 2b91c8b843..af015e2b8c 100644 --- a/lib/utils/test/src/utils/commutative_pair.cc +++ b/lib/utils/test/src/utils/commutative_pair.cc @@ -7,9 +7,9 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("commutative_pair") { - commutative_pair x = {1, 2}; - commutative_pair y = {2, 1}; - commutative_pair z = {1, 1}; + commutative_pair x = commutative_pair{1, 2}; + commutative_pair y = commutative_pair{2, 1}; + commutative_pair z = commutative_pair{1, 1}; SUBCASE("max and min") { SUBCASE("max") { diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc index 6e0a84c7ab..9d686ab814 100644 --- a/lib/utils/test/src/utils/containers/contains.cc +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -1,13 +1,22 @@ #include "utils/containers/contains.h" #include +#include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); + SUBCASE("std::vector") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK_FALSE(contains(v, 6)); + } + + SUBCASE("std::unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + CHECK(contains(s, 3)); + CHECK_FALSE(contains(s, 6)); + } } } diff --git a/lib/utils/test/src/utils/containers/zip.cc b/lib/utils/test/src/utils/containers/zip.cc new file mode 100644 index 0000000000..c29415d920 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip.cc @@ -0,0 +1,82 @@ +#include "utils/containers/zip.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip(std::vector, std::vector)") { + SUBCASE("L and R types are the same") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}, {2, 8}}; + + CHECK(result == correct); + } + + SUBCASE("L and R types are different") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = { + {"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("left is longer than right") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {5, 4}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}, {1, 4}}; + + CHECK(result == correct); + } + + SUBCASE("right is longer than left") { + std::vector lhs = {2}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {{2, 5}}; + + CHECK(result == correct); + } + + SUBCASE("left is empty") { + std::vector lhs = {}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("right is empty") { + std::vector lhs = {2, 1, 2}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("both are empty") { + std::vector lhs = {}; + std::vector rhs = {}; + + std::vector> result = zip(lhs, rhs); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip3.cc b/lib/utils/test/src/utils/containers/zip3.cc new file mode 100644 index 0000000000..4268c41aaa --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip3.cc @@ -0,0 +1,100 @@ +#include "utils/containers/zip3.h" +#include "test/utils/doctest/fmt/tuple.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip3(std::vector, std::vector, std::vector)") { + SUBCASE("types are same") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = { + {2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("types are different") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {"a", "d", "d"}; + std::vector> input_c = {{1, 2}, {}, {3, 1}}; + + std::vector>> result = + zip3(input_a, input_b, input_c); + std::vector>> correct = { + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, + }; + + CHECK(result == correct); + } + + SUBCASE("A list is shortest") { + std::vector input_a = {2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("B list is shortest") { + std::vector input_a = {2, 1, 2, 4}; + std::vector input_b = {5, 4}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 4}}; + + CHECK(result == correct); + } + + SUBCASE("C list is shortest") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 3}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = {{2, 5, 3}, {1, 4, 3}}; + + CHECK(result == correct); + } + + SUBCASE("one list is empty") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("all lists are empty") { + std::vector input_a = {}; + std::vector input_b = {}; + std::vector input_c = {}; + + std::vector> result = + zip3(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip3_strict.cc b/lib/utils/test/src/utils/containers/zip3_strict.cc new file mode 100644 index 0000000000..1f69c91e3b --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip3_strict.cc @@ -0,0 +1,84 @@ +#include "utils/containers/zip3_strict.h" +#include "test/utils/doctest/fmt/tuple.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip3_strict(std::vector, std::vector, std::vector)") { + SUBCASE("types are same") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4, 3}; + + std::vector> result = + zip3_strict(input_a, input_b, input_c); + std::vector> correct = { + {2, 5, 3}, {1, 4, 4}, {2, 5, 3}}; + + CHECK(result == correct); + } + + SUBCASE("types are different") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {"a", "d", "d"}; + std::vector> input_c = {{1, 2}, {}, {3, 1}}; + + std::vector>> result = + zip3_strict(input_a, input_b, input_c); + std::vector>> correct = { + {2, "a", {1, 2}}, + {1, "d", {}}, + {2, "d", {3, 1}}, + }; + + CHECK(result == correct); + } + + SUBCASE("A list is shortest") { + std::vector input_a = {2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 4}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("B list is shortest") { + std::vector input_a = {2, 1, 2, 4}; + std::vector input_b = {5, 4}; + std::vector input_c = {3, 4, 3}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("C list is shortest") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {3, 3}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("one list is empty") { + std::vector input_a = {2, 1, 2}; + std::vector input_b = {5, 4, 5}; + std::vector input_c = {}; + + CHECK_THROWS(zip3_strict(input_a, input_b, input_c)); + } + + SUBCASE("all lists are empty") { + std::vector input_a = {}; + std::vector input_b = {}; + std::vector input_c = {}; + + std::vector> result = + zip3_strict(input_a, input_b, input_c); + std::vector> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_strict.cc b/lib/utils/test/src/utils/containers/zip_strict.cc new file mode 100644 index 0000000000..ae0dc6747a --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_strict.cc @@ -0,0 +1,29 @@ +#include "utils/containers/zip_strict.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_strict(std::vector, std::vector)") { + SUBCASE("input lengths are the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4, 8}; + + std::vector> result = zip_strict(lhs, rhs); + std::vector> correct = { + {"a", 5}, {"b", 4}, {"b", 8}}; + + CHECK(result == correct); + } + + SUBCASE("input lengths are not the same") { + std::vector lhs = {"a", "b", "b"}; + std::vector rhs = {5, 4}; + + CHECK_THROWS(zip_strict(lhs, rhs)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_with.cc b/lib/utils/test/src/utils/containers/zip_with.cc new file mode 100644 index 0000000000..45cecec84b --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_with.cc @@ -0,0 +1,81 @@ +#include "utils/containers/zip_with.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_with(std::vector, std::vector, F)") { + SUBCASE("result types and input types are all different") { + std::vector v1 = {1, 3, 4, 3}; + std::vector v2 = {"aa", "cc", "bb", "dd"}; + + std::vector> result = + zip_with(v1, v2, [](int x1, std::string const &x2) { + return std::make_pair(x1, x2); + }); + std::vector> correct = { + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, + }; + + CHECK(result == correct); + } + + SUBCASE("input lengths don't match") { + auto add = [](int x1, int x2) { return x1 + x2; }; + + std::vector shorter = {1, 2}; + std::vector longer = {1, 3, 5, 7}; + + SUBCASE("first input is shorter") { + std::vector result = zip_with(shorter, longer, add); + std::vector correct = {1 + 1, 2 + 3}; + + CHECK(result == correct); + } + + SUBCASE("second input is shorter") { + std::vector result = zip_with(longer, shorter, add); + std::vector correct = {1 + 1, 2 + 3}; + + CHECK(result == correct); + } + } + + SUBCASE("properly handles empty inputs") { + std::vector nonempty = {1, 2}; + std::vector empty = {}; + + auto throw_err = [](int x1, int x2) -> int { + throw std::runtime_error("error"); + }; + + SUBCASE("first input is empty") { + std::vector result = zip_with(empty, nonempty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("second input is empty") { + std::vector result = zip_with(nonempty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both inputs are empty") { + std::vector result = zip_with(empty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/containers/zip_with_strict.cc b/lib/utils/test/src/utils/containers/zip_with_strict.cc new file mode 100644 index 0000000000..0730442e59 --- /dev/null +++ b/lib/utils/test/src/utils/containers/zip_with_strict.cc @@ -0,0 +1,58 @@ +#include "utils/containers/zip_with_strict.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip_with_strict(std::vector, std::vector, F)") { + SUBCASE("result types and input types are all different") { + std::vector v1 = {1, 3, 4, 3}; + std::vector v2 = {"aa", "cc", "bb", "dd"}; + + std::vector> result = + zip_with(v1, v2, [](int x1, std::string const &x2) { + return std::make_pair(x1, x2); + }); + std::vector> correct = { + {1, "aa"}, + {3, "cc"}, + {4, "bb"}, + {3, "dd"}, + }; + + CHECK(result == correct); + } + + SUBCASE("input lengths don't match") { + auto add = [](int x1, int x2) { return x1 + x2; }; + + std::vector shorter = {1, 2}; + std::vector longer = {1, 3, 5, 7}; + + SUBCASE("first input is shorter") { + CHECK_THROWS(zip_with_strict(shorter, longer, add)); + } + + SUBCASE("second input is shorter") { + CHECK_THROWS(zip_with_strict(longer, shorter, add)); + } + } + + SUBCASE("properly handles empty inputs") { + std::vector empty = {}; + + auto throw_err = [](int x1, int x2) -> int { + throw std::runtime_error("error"); + }; + + std::vector result = zip_with(empty, empty, throw_err); + std::vector correct = empty; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/tuple.cc b/lib/utils/test/src/utils/fmt/tuple.cc new file mode 100644 index 0000000000..1ee7d63a1f --- /dev/null +++ b/lib/utils/test/src/utils/fmt/tuple.cc @@ -0,0 +1,70 @@ +#include "utils/fmt/tuple.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt::to_string(std::tuple)") { + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = fmt::to_string(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = fmt::to_string(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } + + TEST_CASE("operator<<(ostream &, std::tuple)") { + auto through_ostringstream = [](auto const &t) { + std::ostringstream oss; + oss << t; + return oss.str(); + }; + + SUBCASE("types are different") { + std::tuple input = {3, false, "hello"}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, false, hello}"; + + CHECK(result == correct); + } + + SUBCASE("types are the same") { + std::tuple input = {3, 5}; + + std::string result = through_ostringstream(input); + std::string correct = "{3, 5}"; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + std::string result = through_ostringstream(input); + std::string correct = "{}"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc new file mode 100644 index 0000000000..f17f0cb106 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -0,0 +1,126 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_edges(DiGraphView)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(3), n.at(1)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n.at(0), n.at(3)}); + std::unordered_set correct = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + TEST_CASE("get_terminal_nodes(DiGraphView)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(2), n.at(3)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a terminal node") { + g.add_edge(DirectedEdge{n.at(3), n.at(2)}); + std::unordered_set correct = {n.at(2)}; + std::unordered_set result = get_terminal_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_terminal_nodes(g); + std::unordered_set correct = {n.at(3)}; + CHECK(result == correct); + } + } + + TEST_CASE("get_initial_nodes(DiGraphView)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("Base") { + std::unordered_set correct = {n.at(0)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set correct = {}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n.at(0), n.at(1)}); + std::unordered_set correct = {n.at(0), n.at(1)}; + std::unordered_set result = get_initial_nodes(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n.at(2), n.at(0)}); + std::unordered_set result = get_initial_nodes(g); + std::unordered_set correct = {}; + CHECK(result.empty()); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc new file mode 100644 index 0000000000..e820ab8808 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -0,0 +1,87 @@ +#include "utils/graph/digraph/digraph.h" +#include "utils/containers/repeat.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { + /* + graph TD + + n0 --> n1 + n0 --> n2 + n1 --> n2 + n2 --> n4 + n1 --> n3 + */ + + DiGraph g = DiGraph::create(); + std::vector n = repeat(5_n, [&] { return g.add_node(); }); + std::vector e = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}; + for (DirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + SUBCASE("query_edges") { + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges(DirectedEdgeQuery{ + query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } + } + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[1], n[2], n[3], n[4]}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[3], e[4]}); + + g.remove_node_unsafe(n[1]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[2], n[3], n[4]}); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[3]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[1], e[2], e[3], e[4]}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + g.remove_edge(e[1]); + g.remove_edge(e[3]); + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[4]}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc new file mode 100644 index 0000000000..ee7ead009e --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -0,0 +1,68 @@ +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/digraph/directed_edge_query.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("directed_edge_query_all") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; + + DirectedEdgeQuery result = directed_edge_query_all(); + + CHECK(matches_edge(result, e1)); + CHECK(matches_edge(result, e2)); + } + + TEST_CASE("matches_edge") { + Node n1{0}, n2{1}, n3{2}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; + + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{n1}, query_set{n2}}; + + CHECK(matches_edge(query, e1)); + CHECK_FALSE(matches_edge(query, e2)); + + DirectedEdge flipped_edge = DirectedEdge{n2, n1}; + CHECK_FALSE(matches_edge(query, flipped_edge)); + } + + TEST_CASE("query_intersection") { + Node n1{0}, n2{1}, n3{2}, n4{3}; + DirectedEdge e1 = DirectedEdge{n1, n2}; + DirectedEdge e2 = DirectedEdge{n2, n3}; + DirectedEdge e3 = DirectedEdge{n3, n4}; + + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n2, n3}}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{query_set{n2, n3}, query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n2}, query_set{n3}}; + + CHECK(result == expected); + } + + SUBCASE("intersection with matchall") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n1, n2}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n3, n4}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery expected = + DirectedEdgeQuery{query_set{n1, n2}, query_set{n3, n4}}; + + CHECK(result == expected); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc new file mode 100644 index 0000000000..17bea2210f --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -0,0 +1,69 @@ +#include "utils/graph/digraph/algorithms/get_dominators.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dominators") { + DiGraph g = DiGraph::create(); + SUBCASE("acyclic graph") { + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("get_dominators(DiGraph, Node)") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } + + SUBCASE("get_dominators(DiGraph, std::unordered_set)") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } + } + + SUBCASE("graph with cycles") { + // example from + // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 + + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(1)}, + }); + + SUBCASE("node 1") { + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + CHECK(result == correct); + } + + SUBCASE("node 3") { + std::unordered_set result = get_dominators(g, n.at(3)); + std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc new file mode 100644 index 0000000000..5adc0cc4df --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -0,0 +1,36 @@ +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/containers/index_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc new file mode 100644 index 0000000000..f778cfbd22 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -0,0 +1,135 @@ +#include "utils/graph/traversal.h" +#include "utils/containers/contains.h" +#include "utils/fmt/vector.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/hash/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unchecked_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + + SUBCASE("linear path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("diamond path") { + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(3), n.at(2), n.at(3)}, + {n.at(0), n.at(2), n.at(3), n.at(1), n.at(3)}}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } + } + + TEST_CASE("get_bfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}); + + SUBCASE("branching path") { + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(2), n.at(3), n.at(4), n.at(5)}, + {n.at(0), n.at(2), n.at(1), n.at(3), n.at(4), n.at(5)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } + + SUBCASE("isolated node") { + std::vector correct = {n.at(5)}; + std::vector result = get_bfs_ordering(g, {n.at(5)}); + CHECK(correct == result); + } + + SUBCASE("graph with cycle") { + g = DiGraph::create(); + n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(2), n.at(1)}}); + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(2)}, {n.at(0), n.at(2), n.at(1)}}; + std::vector result = get_bfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } + } + + TEST_CASE("get_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + SUBCASE("simple path") { + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("start from non-initial node") { + std::vector correct = {n.at(1), n.at(2), n.at(3)}; + std::vector result = get_unchecked_dfs_ordering(g, {n.at(1)}); + CHECK(correct == result); + } + + SUBCASE("with cycle") { + g.add_edge(DirectedEdge{n.at(3), n.at(1)}); + std::vector correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("branching") { + g.add_edge(DirectedEdge{n.at(1), n.at(3)}); + std::unordered_set> corrects = { + {n.at(0), n.at(1), n.at(2), n.at(3)}, + {n.at(0), n.at(1), n.at(3), n.at(2)}}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(contains(corrects, result)); + } + + SUBCASE("disconnected") { + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(0), n.at(1), n.at(2)}; + std::vector result = get_dfs_ordering(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("isolated node") { + g.remove_edge(DirectedEdge{n.at(2), n.at(3)}); + std::vector correct = {n.at(3)}; + std::vector result = get_dfs_ordering(g, {n.at(3)}); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..ef5cf3c502 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -0,0 +1,52 @@ +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_edges") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3_n); + + std::vector edges = add_edges(g, + {{n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}, + {n.at(2), n.at(0)}}); + + SUBCASE("get_incoming_edges(MultiDiGraphView, Node)") { + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + + SUBCASE("get_incoming_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(2)}; + std::unordered_map> result = + get_incoming_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(3), edges.at(4)}}, {n.at(2), {}}}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..20011cb133 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,56 @@ +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3_n); + + std::vector edges = add_edges(g, + { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(0)}, + }); + + SUBCASE("get_outgoing_edges(MultiDiGraphView, Node)") { + + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2), edges.at(3)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + + SUBCASE("get_outgoing_edges(MultiDiGraphView, std::unordered_set)") { + + std::unordered_set ns = {n.at(0), n.at(1)}; + std::unordered_map> result = + get_outgoing_edges(g, ns); + + std::unordered_map> correct = { + {n.at(0), {edges.at(0), edges.at(1), edges.at(2), edges.at(3)}}, + {n.at(1), {edges.at(4)}}}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index 4bb57aeb0d..78537a4342 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,9 +1,12 @@ #include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/node/algorithms.h" #include @@ -195,8 +198,8 @@ TEST_SUITE(FF_TEST_SUITE) { {n.at(1), n.at(2)}, {n.at(2), n.at(3)}, {n.at(2), n.at(3)}, - {n.at(3), n.at(4)}, //* - {n.at(4), n.at(5)}, //* + {n.at(3), n.at(4)}, + {n.at(4), n.at(5)}, {n.at(5), n.at(6)}, {n.at(5), n.at(7)}, }); @@ -242,4 +245,157 @@ TEST_SUITE(FF_TEST_SUITE) { } } } + + TEST_CASE("find_all_extended_series_reductions") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("linear graph") { + std::vector n = add_nodes(g, 4_n); + std::vector e = add_edges(g, + { + {n.at(0), n.at(1)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(3)}, + }); + + std::unordered_set result = + find_all_extended_series_reductions(g); + std::unordered_set correct = { + ExtendedSeriesReduction{{e.at(0), e.at(1), e.at(2)}}}; + CHECK(result == correct); + } + + SUBCASE("2 linear strands with a common terminal node") { + std::vector n = add_nodes(g, 4_n); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(3)}, + {n.at(2), n.at(3)}}); + + std::unordered_set result = + find_all_extended_series_reductions(g); + std::unordered_set correct = { + ExtendedSeriesReduction{{e.at(0), e.at(2)}}, + ExtendedSeriesReduction{{e.at(1), e.at(3)}}}; + CHECK(result == correct); + } + + SUBCASE("graph with multiple separate serial strands") { + std::vector n = add_nodes(g, 9_n); + std::vector e = add_edges(g, + {{n.at(0), n.at(1)}, + {n.at(0), n.at(2)}, + {n.at(1), n.at(4)}, + {n.at(2), n.at(3)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(6)}, + {n.at(3), n.at(5)}, + {n.at(4), n.at(7)}, + {n.at(5), n.at(7)}, + {n.at(6), n.at(8)}, + {n.at(7), n.at(8)}}); + + std::unordered_set result = + find_all_extended_series_reductions(g); + std::unordered_set correct = { + ExtendedSeriesReduction{{e.at(0), e.at(2), e.at(7)}}, + ExtendedSeriesReduction{{e.at(3), e.at(6)}}, + ExtendedSeriesReduction{{e.at(5), e.at(9)}}}; + CHECK(result == correct); + } + } + + TEST_CASE("apply_extended_series_reduction") { + MultiDiGraph g = MultiDiGraph::create(); + + SUBCASE("base case") { + std::vector n = add_nodes(g, 4_n); + std::vector e = add_edges( + g, {{n.at(0), n.at(1)}, {n.at(1), n.at(2)}, {n.at(2), n.at(3)}}); + + ExtendedSeriesReduction reduction = + ExtendedSeriesReduction{{e.at(0), e.at(1), e.at(2)}}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = {n.at(0), n.at(3)}; + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = {returned_edge}; + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(0); + CHECK(returned_edge_src == correct_src); + } + + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(3); + CHECK(returned_edge_dst == correct_dst); + } + } + } + + SUBCASE("in larger graph") { + std::vector n = add_nodes(g, 8_n); + std::vector e = add_edges(g, + { + {n.at(0), n.at(2)}, + {n.at(1), n.at(2)}, + {n.at(2), n.at(5)}, + {n.at(2), n.at(3)}, + {n.at(3), n.at(4)}, + {n.at(4), n.at(5)}, + {n.at(5), n.at(6)}, + {n.at(5), n.at(7)}, + }); + + ExtendedSeriesReduction reduction = + ExtendedSeriesReduction{{e.at(3), e.at(4), e.at(5)}}; + + MultiDiEdge returned_edge = apply_extended_series_reduction(g, reduction); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(g); + std::unordered_set correct_nodes = + set_minus(unordered_set_of(n), {n.at(4), n.at(3)}); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(g); + std::unordered_set correct_edges = [&] { + std::unordered_set new_edges = unordered_set_of(e); + new_edges = set_minus(new_edges, {e.at(3), e.at(4), e.at(5)}); + new_edges.insert(returned_edge); + return new_edges; + }(); + CHECK(result_edges == correct_edges); + } + + SUBCASE("returned edge") { + SUBCASE("src") { + Node returned_edge_src = g.get_multidiedge_src(returned_edge); + Node correct_src = n.at(2); + CHECK(returned_edge_src == correct_src); + } + + SUBCASE("dst") { + Node returned_edge_dst = g.get_multidiedge_dst(returned_edge); + Node correct_dst = n.at(5); + CHECK(returned_edge_dst == correct_dst); + } + } + } + } } diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc new file mode 100644 index 0000000000..20b3eaa74a --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -0,0 +1,95 @@ +#include "utils/graph/undirected/algorithms/get_connected_components.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" +#include "utils/graph/undirected/undirected_graph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); + + std::unordered_set> correct = { + {n.at(0)}, + {n.at(1)}, + {n.at(2)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("1 component") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(3)), + make_undirected_edge(n.at(3), n.at(0)), + }); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2), n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("2 components") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(2), n.at(1)), + }); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(0), n.at(2)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(3), n.at(4)), + }); + + std::unordered_set> correct = { + {n.at(0), n.at(1), n.at(2)}, + {n.at(3), n.at(4)}, + {n.at(5)}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("empty graph") { + std::unordered_set> correct = {}; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc new file mode 100644 index 0000000000..77b74fdd20 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -0,0 +1,85 @@ +#include "utils/commutative_pair.h" +#include "utils/containers/repeat.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" +#include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/undirected/undirected_graph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = repeat(5_n, [&] { return g.add_node(); }); + std::vector e = {make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(0), n.at(2)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(4)), + make_undirected_edge(n.at(1), n.at(3))}; + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{ + n.at(0), n.at(1), n.at(2), n.at(3), n.at(4)}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n.at(0), n.at(2)}}}) == + std::unordered_set{n.at(0), n.at(2)}); + } + + SUBCASE("query_edges") { + + std::unordered_set queried_edges = + g.query_edges(undirected_edge_query_all()); + std::unordered_set expected = { + e.at(0), e.at(1), e.at(2), e.at(3), e.at(4)}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + UndirectedEdgeQuery{query_set{{n.at(0), n.at(1)}}}); + expected = std::unordered_set{e.at(0)}; + CHECK(queried_edges == expected); + } + + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n.at(0)); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n.at(1), n.at(2), n.at(3), n.at(4)}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(2), e.at(3), e.at(4)}); + + g.remove_node_unsafe(n.at(1)); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n.at(2), n.at(3), n.at(4)}); + + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(3)}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e.at(0)); + + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{ + e.at(1), e.at(2), e.at(3), e.at(4)}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{ + n.at(0), n.at(1), n.at(2), n.at(3), n.at(4)}); + + g.remove_edge(e.at(1)); + g.remove_edge(e.at(3)); + CHECK(g.query_edges(undirected_edge_query_all()) == + std::unordered_set{e.at(2), e.at(4)}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc new file mode 100644 index 0000000000..183bab4cc2 --- /dev/null +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -0,0 +1,154 @@ +#include "utils/graph/views/views.h" +#include "utils/containers/set_union.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/algorithms/make_undirected_edge.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UndirectedSubgraphView") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {make_undirected_edge(n.at(0), n.at(3)), + make_undirected_edge(n.at(1), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(1), n.at(3)), + make_undirected_edge(n.at(2), n.at(3)), + make_undirected_edge(n.at(2), n.at(4))}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + UndirectedGraphView view = view_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + make_undirected_edge(n.at(0), n.at(3)), + make_undirected_edge(n.at(1), n.at(1)), + make_undirected_edge(n.at(1), n.at(3)), + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("DiSubgraphView") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + DiGraphView view = view_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewDiGraphAsUndirectedGraph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}); + + UndirectedGraphView view = as_undirected(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(0))}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewUndirectedGraphAsDiGraph") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {make_undirected_edge(n.at(0), n.at(0)), + make_undirected_edge(n.at(0), n.at(1)), + make_undirected_edge(n.at(1), n.at(2)), + make_undirected_edge(n.at(2), n.at(0))}); + + DiGraphView view = as_digraph(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(0)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } +} diff --git a/lib/utils/test/src/utils/tuple/visit.cc b/lib/utils/test/src/utils/tuple/visit.cc new file mode 100644 index 0000000000..ada8b1e786 --- /dev/null +++ b/lib/utils/test/src/utils/tuple/visit.cc @@ -0,0 +1,41 @@ +#include "utils/tuple/visit.h" +#include "utils/overload.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("visit(std::tuple, Visitor)") { + std::ostringstream oss; + auto visitor = overload{ + [&](int const &i) -> void { oss << "int(" << i << "), "; }, + [&](bool const &b) -> void { oss << "bool(" << b << "), "; }, + [&](std::string const &s) -> void { oss << "string(" << s << "), "; }, + }; + + SUBCASE("repeated types") { + std::tuple input = { + 3, "hello", false, "world"}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = "int(3), string(hello), bool(0), string(world), "; + + CHECK(result == correct); + } + + SUBCASE("empty tuple") { + std::tuple<> input = {}; + + visit_tuple(input, visitor); + + std::string result = oss.str(); + std::string correct = ""; + + CHECK(result == correct); + } + } +}