Skip to content

Instantly share code, notes, and snippets.

@argman
Created February 25, 2021 09:48
Show Gist options
  • Save argman/f378b5020e2ec7e2b621dcd1bf2f61f7 to your computer and use it in GitHub Desktop.
Save argman/f378b5020e2ec7e2b621dcd1bf2f61f7 to your computer and use it in GitHub Desktop.
onnx c++ code to extract a sub-graph
#include "onnx_util.h"
void Extractor::prepare() {
for(int i=0; i < mModel.graph().initializer_size(); ++i) {
auto initializerI = mModel.graph().initializer(i);
spdlog::info("Initializer {} : name={}", i, initializerI.name());
wmap[initializerI.name()] = initializerI;
}
for(int i=0; i < mModel.graph().value_info_size(); ++i) {
auto valueInfoI = mModel.graph().value_info(i);
vimap[valueInfoI.name()] = valueInfoI;
}
}
OnnxValueInfoProtoMap Extractor::CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract) {
std::set<std::string> originalIONames;
for(auto item : originalIO) {
originalIONames.insert(item.first);
}
StringSet sIONamesToExtract;
for(size_t i=0; i < ioNamesToExtract.size(); i++) {
sIONamesToExtract.insert(ioNamesToExtract[i]);
}
StringSet ioNamesToKeep;
StringSet newIONamesToAdd;
std::set_intersection(sIONamesToExtract.begin(), sIONamesToExtract.end(),
originalIONames.begin(), originalIONames.end(),
std::inserter(ioNamesToKeep, newIONamesToAdd.begin()));
// not in graph input or output, maybe activations
std::set_difference(sIONamesToExtract.begin(), sIONamesToExtract.end(),
originalIONames.begin(), originalIONames.end(),
std::inserter(newIONamesToAdd, newIONamesToAdd.begin()));
OnnxValueInfoProtoMap newIOValueInfo;
for(auto name : ioNamesToKeep) {
newIOValueInfo[name] = originalIO[name];
}
for(auto name : newIONamesToAdd) {
if(vimap.count(name) == 0) {
spdlog::error("Cannot find io name {} in activations", name);
continue;
} else {
newIOValueInfo[name] = vimap[name];
}
}
return newIOValueInfo;
}
void Extractor::DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes) {
if(CT::indexOfStringItems(nodeOutputName, graphInputNames) != -1) {
return;
}
for(int i=0; i < mModel.graph().node_size(); i++) {
auto nodeI = mModel.graph().node(i);
bool nodeInReachableNodes = false;
for(int j=0; j < reachableNodes.size(); ++j) {
if(reachableNodes[j].name() == nodeI.name()) {
nodeInReachableNodes = true;
break;
}
}
if(nodeInReachableNodes) {
continue;
}
bool findInNodeOutput = false;
for(int j=0; j < nodeI.output_size(); j++) {
if(nodeI.output(j) == nodeOutputName) {
findInNodeOutput = true;
break;
}
}
if(!findInNodeOutput) {
continue;
}
reachableNodes.push_back(nodeI);
for(int j=0; j < nodeI.input_size(); j++) {
DFSSearchReachableNodes(nodeI.input(j), graphInputNames, reachableNodes);
}
}
}
std::vector<onnx::NodeProto> Extractor::CollectReachableNodes(StringVector inputNames, StringVector outputNames) {
std::vector<onnx::NodeProto> reachableNodes;
for(auto name : outputNames) {
DFSSearchReachableNodes(name, inputNames, reachableNodes);
}
// needs to be topology sorted
std::vector<onnx::NodeProto> nodes;
for(int i=0; i < mModel.graph().node_size(); i++) {
for(auto node : reachableNodes) {
if(node.name() == mModel.graph().node(i).name()) {
nodes.push_back(mModel.graph().node(i));
}
}
}
return nodes;
}
std::vector<onnx::TensorProto> Extractor::CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes) {
StringVector allTensorNames;
for(auto node : nodes) {
for(int i=0; i < node.input_size(); i++) {
allTensorNames.push_back(node.input(i));
}
for(int i=0; i < node.output_size(); i++) {
allTensorNames.push_back(node.output(i));
}
}
std::vector<onnx::TensorProto> initializer;
for(auto w : wmap) {
for(auto name : allTensorNames) {
if(name == w.first) {
initializer.push_back(w.second);
}
}
}
return initializer;
}
std::vector<onnx::ValueInfoProto> Extractor::CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes) {
StringVector allTensorNames;
for(auto node : nodes) {
for(int i=0; i < node.input_size(); i++) {
allTensorNames.push_back(node.input(i));
}
for(int i=0; i < node.output_size(); i++) {
allTensorNames.push_back(node.output(i));
}
}
std::vector<onnx::ValueInfoProto> valueInfo;
for(auto v : vimap) {
for(auto name : allTensorNames) {
if(name == v.first) {
valueInfo.push_back(v.second);
}
}
}
return valueInfo;
}
onnx::ModelProto Extractor::ExtractModel(std::vector<std::string> inputNames, std::vector<std::string> outputNames) {
prepare();
OnnxValueInfoProtoMap originalInput;
for(int i=0; i < mModel.graph().input_size(); i++) {
originalInput[mModel.graph().input(i).name()] = mModel.graph().input(i);
}
OnnxValueInfoProtoMap originalOutput;
for(int i=0; i < mModel.graph().output_size(); i++) {
originalOutput[mModel.graph().output(i).name()] = mModel.graph().output(i);
}
OnnxValueInfoProtoMap newInputs = CollectNewIOCore(originalInput, inputNames);
spdlog::info("Collect {} new input tensors", newInputs.size());
OnnxValueInfoProtoMap newOutputs = CollectNewIOCore(originalOutput, outputNames);
spdlog::info("Collect {} new ouptut tensors", newOutputs.size());
std::vector<onnx::NodeProto> newNodes = CollectReachableNodes(inputNames, outputNames);
spdlog::info("Collect {} reachable nodes in graph", newNodes.size());
std::vector<onnx::TensorProto> newInitializer = CollectReachableInitializer(newNodes);
spdlog::info("Collect {} reachable initializers in graph", newInitializer.size());
std::vector<onnx::ValueInfoProto> newValueInfo = CollectReachableValueInfo(newNodes);
spdlog::info("Collect {} reachable value info in graph", newValueInfo.size());
auto newGraph = onnx::GraphProto();
for(auto node : newNodes) {
auto addedNode = newGraph.add_node();
addedNode->CopyFrom(node);
}
newGraph.set_name("ExtractedGraph");
for(auto input : newInputs) {
auto addedInput = newGraph.add_input();
addedInput->CopyFrom(input.second);
}
for(auto output : newOutputs) {
auto addedOutput = newGraph.add_output();
addedOutput->CopyFrom(output.second);
}
for(auto initializer : newInitializer) {
auto addedInitializer = newGraph.add_initializer();
addedInitializer->CopyFrom(initializer);
}
for(auto valueInfo : newValueInfo) {
auto addedValueInfo = newGraph.add_value_info();
addedValueInfo->CopyFrom(valueInfo);
}
auto newModel = onnx::ModelProto();
newModel.set_ir_version(mModel.ir_version());
newModel.mutable_graph()->CopyFrom(newGraph);
for(int i=0; i < mModel.opset_import_size(); i++) {
auto addedOpset = newModel.add_opset_import();
addedOpset->CopyFrom(mModel.opset_import(i));
}
newModel.set_producer_version(mModel.producer_name() + "SubModel");
return newModel;
}
#ifndef __onnx_util_H__
#define __onnx_util_H__
#include <onnx/onnx_pb.h>
#include <vector>
#include <map>
#include <set>
#include <algorithm>
#include "string_util.h"
#include "spdlog/spdlog.h"
typedef std::vector<std::string> StringVector;
typedef std::set<std::string> StringSet;
typedef std::map<std::string, onnx::ValueInfoProto> OnnxValueInfoProtoMap;
// Extract sub-graph from an onnx model
// https://github.com/onnx/onnx/blob/1cbd549193661c7e9c870c4e9595144d0c6971c5/onnx/utils.py
class Extractor {
public:
Extractor(onnx::ModelProto model): mModel(model) {}
void prepare();
OnnxValueInfoProtoMap CollectNewIOCore(OnnxValueInfoProtoMap originalIO, StringVector ioNamesToExtract);
void DFSSearchReachableNodes(std::string nodeOutputName, StringVector graphInputNames, std::vector<onnx::NodeProto>& reachableNodes);
std::vector<onnx::NodeProto> CollectReachableNodes(StringVector inputNames, StringVector outputNames);
std::vector<onnx::TensorProto> CollectReachableInitializer(std::vector<onnx::NodeProto>& nodes);
std::vector<onnx::ValueInfoProto> CollectReachableValueInfo(std::vector<onnx::NodeProto>& nodes);
onnx::ModelProto ExtractModel(StringVector inputNames, StringVector outputNames);
private:
onnx::ModelProto mModel;
onnx::GraphProto* graph;
std::map<std::string, onnx::TensorProto> wmap;
OnnxValueInfoProtoMap vimap;
};
#endif // !__onnx_util_H__
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment