Skip to content

Commit c277800

Browse files
Finish from_v1 for LabelledKwargDataflowGraph.
1 parent 25b0041 commit c277800

6 files changed

Lines changed: 58 additions & 67 deletions

File tree

lib/pcg/include/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.dtg.h"
66
#include "utils/bidict/algorithms/bidict_from_enumerating.h"
77
#include "utils/containers/enumerate.h"
8+
#include "utils/containers/generate_map.h"
89
#include "utils/containers/sorted.h"
910
#include "utils/containers/transform.h"
1011
#include "utils/containers/unordered_set_of.h"
@@ -47,15 +48,17 @@ V1KwargDataflowGraph<SlotName>
4748
}
4849

4950
template <typename SlotName>
50-
KwargDataflowGraphView<SlotName>
51-
from_v1(V1KwargDataflowGraph<SlotName> const &v1_g) {
52-
std::unordered_set<Node> graph_nodes =
53-
unordered_set_of(transform(v1_g.nodes, [](nonnegative_int n) {
51+
std::pair<KwargDataflowGraphView<SlotName>,
52+
std::unordered_map<nonnegative_int, Node>>
53+
from_v1(V1KwargDataflowGraph<SlotName> const &v1) {
54+
std::unordered_map<nonnegative_int, Node> node_map =
55+
generate_map(v1.nodes, [](nonnegative_int n) {
5456
return Node{n.size_t_from_nonnegative_int()};
55-
}));
57+
});
58+
std::unordered_set<Node> node_set = unordered_set_of(values(node_map));
5659

57-
std::unordered_set<OpenKwargDataflowEdge<int, SlotName>> graph_edges =
58-
transform(v1_g.edges, [](V1GraphEdge<SlotName> const &e) {
60+
std::unordered_set<OpenKwargDataflowEdge<int, SlotName>> edges =
61+
transform(v1.edges, [](V1GraphEdge<SlotName> const &e) {
5962
Node srcNode = Node{e.srcNode.size_t_from_nonnegative_int()};
6063
Node dstNode = Node{e.dstNode.size_t_from_nonnegative_int()};
6164
return OpenKwargDataflowEdge<int, SlotName>{KwargDataflowEdge<SlotName>{
@@ -66,12 +69,13 @@ KwargDataflowGraphView<SlotName>
6669

6770
OpenKwargDataflowGraphData<int, SlotName> graph_data =
6871
OpenKwargDataflowGraphData<int, SlotName>{
69-
/*nodes=*/graph_nodes,
70-
/*edges=*/graph_edges,
72+
/*nodes=*/node_set,
73+
/*edges=*/edges,
7174
/*inputs=*/{},
7275
/*outputs=*/{},
7376
};
74-
return view_from_open_kwarg_dataflow_graph_data(graph_data);
77+
return std::pair{view_from_open_kwarg_dataflow_graph_data(graph_data),
78+
node_map};
7579
}
7680

7781
} // namespace FlexFlow

lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.h"
55
#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.dtg.h"
66
#include "utils/bidict/algorithms/bidict_from_enumerating.h"
7+
#include "utils/containers/map_keys.h"
78
#include "utils/containers/map_values.h"
89
#include "utils/containers/transform.h"
910
#include "utils/containers/unordered_map_from_pairs.h"
@@ -14,9 +15,10 @@
1415
#include "utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h"
1516
#include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h"
1617
#include "utils/graph/kwarg_dataflow_graph/kwarg_node_added_result.dtg.h"
17-
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h"
18+
#include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/kwarg_dataflow_graph_view_with_labelling.h"
1819
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph_view.h"
1920
#include "utils/graph/node/algorithms.h"
21+
#include "utils/nonnegative_int/nonnegative_int.h"
2022

2123
namespace FlexFlow {
2224

@@ -59,53 +61,19 @@ V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> to_v1(
5961
}
6062

6163
template <typename NodeLabel, typename OutputLabel, typename SlotName>
62-
LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> from_v1(
63-
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const &v1) {
64-
// Build incoming-edge map
65-
std::unordered_map<nonnegative_int, std::vector<V1GraphEdge<SlotName>>>
66-
incoming;
67-
for (nonnegative_int const &n : v1.graph.nodes) {
68-
incoming[n] = {};
69-
}
70-
for (V1GraphEdge<SlotName> const &e : v1.graph.edges) {
71-
incoming[e.dstNode].push_back(e);
72-
}
73-
74-
// Build a DiGraph with V1 indices as Node raw_uids to get topological order
75-
DiGraph dg = DiGraph::create<AdjacencyDiGraph>();
76-
for (nonnegative_int const &n : v1.graph.nodes) {
77-
dg.add_node_unsafe(Node{n.size_t_from_nonnegative_int()});
78-
}
79-
for (V1GraphEdge<SlotName> const &e : v1.graph.edges) {
80-
dg.add_edge(DirectedEdge{Node{e.srcNode.size_t_from_nonnegative_int()},
81-
Node{e.dstNode.size_t_from_nonnegative_int()}});
82-
}
83-
84-
auto g = LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName>::
85-
template create<UnorderedSetLabelledOpenKwargDataflowGraph<NodeLabel,
86-
OutputLabel,
87-
int,
88-
SlotName>>();
89-
90-
std::unordered_map<nonnegative_int, Node> node_map;
91-
for (Node const &topo_node : get_topological_ordering(dg)) {
92-
nonnegative_int v1_idx{topo_node.raw_uid};
93-
94-
std::unordered_map<SlotName, KwargDataflowOutput<SlotName>> inputs =
95-
unordered_map_from_pairs(
96-
transform(incoming.at(v1_idx), [&](V1GraphEdge<SlotName> const &e) {
97-
return std::pair{e.dstSlot,
98-
KwargDataflowOutput<SlotName>{
99-
node_map.at(e.srcNode), e.srcSlot}};
100-
}));
101-
102-
KwargNodeAddedResult<SlotName> result = g.add_node(
103-
v1.node_labels.at(v1_idx), inputs, v1.output_labels.at(v1_idx));
104-
105-
node_map.insert(std::pair{v1_idx, result.node});
106-
}
107-
108-
return g;
64+
std::pair<LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>,
65+
std::unordered_map<nonnegative_int, Node>>
66+
from_v1(V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const
67+
&v1) {
68+
auto [graph_view, node_map] = from_v1(v1.graph);
69+
70+
std::unordered_map<Node, NodeLabel> node_labels = map_keys(
71+
v1.node_labels, [&](nonnegative_int n) { return node_map.at(n); });
72+
std::unordered_map<KwargDataflowOutput<SlotName>, OutputLabel> value_labels;
73+
74+
return std::pair{kwarg_dataflow_graph_view_with_labelling(
75+
graph_view, node_labels, value_labels),
76+
node_map};
10977
}
11078

11179
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/graphs/v1_kwarg_dataflow_graph.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ template V1KwargDataflowGraph<SlotName>
1212
to_v1(KwargDataflowGraphView<SlotName> const &,
1313
std::unordered_map<Node, nonnegative_int> const &);
1414

15-
template KwargDataflowGraphView<SlotName>
15+
template std::pair<KwargDataflowGraphView<SlotName>,
16+
std::unordered_map<nonnegative_int, Node>>
1617
from_v1(V1KwargDataflowGraph<SlotName> const &);
1718

1819
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ template std::pair<
1818
template V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> to_v1(
1919
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName> const &);
2020

21-
template LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> from_v1(
22-
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const &);
21+
template std::pair<
22+
LabelledKwargDataflowGraphView<NodeLabel, OutputLabel, SlotName>,
23+
std::unordered_map<nonnegative_int, Node>>
24+
from_v1(
25+
V1LabelledKwargDataflowGraph<NodeLabel, OutputLabel, SlotName> const &);
2326

2427
} // namespace FlexFlow

lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "pcg/file_format/v1/v1_computation_graph.h"
22
#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h"
33
#include "utils/bidict/algorithms/transform_values.h"
4+
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h"
45

56
namespace FlexFlow {
67

@@ -11,9 +12,14 @@ V1ComputationGraph to_v1(ComputationGraph const &g) {
1112
}
1213

1314
ComputationGraph from_v1(V1ComputationGraph const &v1) {
14-
return ComputationGraph{
15-
from_v1(v1.raw_graph),
16-
};
15+
LabelledKwargDataflowGraph<LayerAttrs, TensorAttrs, TensorSlotName>
16+
raw_graph =
17+
LabelledKwargDataflowGraph<LayerAttrs, TensorAttrs, TensorSlotName>::
18+
create_copy_of<LabelledKwargDataflowGraph<LayerAttrs,
19+
TensorAttrs,
20+
TensorSlotName>>(
21+
from_v1(v1.raw_graph).first);
22+
return ComputationGraph{raw_graph};
1723
}
1824

1925
std::pair<V1ComputationGraph, bidict<nonnegative_int, layer_guid_t>>

lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "pcg/file_format/v1/v1_parallel_computation_graph.h"
22
#include "pcg/file_format/v1/graphs/v1_labelled_kwarg_dataflow_graph.h"
3+
#include "utils/graph/labelled_kwarg_dataflow_graph/labelled_kwarg_dataflow_graph.h"
34

45
namespace FlexFlow {
56

@@ -11,9 +12,17 @@ V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) {
1112
}
1213

1314
ParallelComputationGraph from_v1(V1ParallelComputationGraph const &v1) {
14-
return ParallelComputationGraph{
15-
from_v1(v1.raw_graph),
16-
};
15+
LabelledKwargDataflowGraph<ParallelLayerAttrs,
16+
ParallelTensorAttrs,
17+
TensorSlotName>
18+
raw_graph = LabelledKwargDataflowGraph<ParallelLayerAttrs,
19+
ParallelTensorAttrs,
20+
TensorSlotName>::
21+
create_copy_of<LabelledKwargDataflowGraph<ParallelLayerAttrs,
22+
ParallelTensorAttrs,
23+
TensorSlotName>>(
24+
from_v1(v1.raw_graph).first);
25+
return ParallelComputationGraph{raw_graph};
1726
}
1827

1928
} // namespace FlexFlow

0 commit comments

Comments
 (0)