TVM Relay IR计算图可视化
本文主要介绍如何将Relay IR的计算图(computational graph)/数据流图(dataflow graph)进行可视化输出。
参照TVM #3259的Pull Request,将下列代码复制到python/tvm/relay/visualize.py
中,注意代码做了一定的适应性修改。
from.expr_functorimportExprFunctorfrom.importexpras_expr
importnetworkxasnx
classVisualizeExpr(ExprFunctor):
def__init__(self):
super().__init__()
self.graph=nx.DiGraph()
self.counter=0
defviz(self,expr):
assertisinstance(expr,_expr.Function)
forparaminexpr.params:
self.visit(param)
returnself.visit(expr.body)
defvisit_constant(self,const):# overload this!
pass
defvisit_var(self,var):
name=var.name_hint
self.graph.add_node(name)
self.graph.nodes[name]['style']='filled'
self.graph.nodes[name]['fillcolor']='mistyrose'
returnvar.name_hint
defvisit_tuple_getitem(self,get_item):
tuple=self.visit(get_item.tuple_value)
# self.graph.nodes[tuple]
index=get_item.index
# import pdb; pdb.set_trace()
returntuple
defvisit_call(self,call):
parents=[]
forargincall.args:
parents.append(self.visit(arg))
# assert isinstance(call.op, _expr.Op)
name="{}({})".format(call.op.name,self.counter)
self.counter+=1
self.graph.add_node(name)
self.graph.nodes[name]['style']='filled'
self.graph.nodes[name]['fillcolor']='turquoise'
self.graph.nodes[name]['shape']='diamond'
edges=[]
fori,parentinenumerate(parents):
edges.append((parent,name,{'label':'arg{}'.format(i)}))
self.graph.add_edges_from(edges)
returnname
defvisualize(expr,mydir="relay_ir.png"):
viz_expr=VisualizeExpr()
viz_expr.viz(expr)
graph=viz_expr.graph
dotg=nx.nx_pydot.to_pydot(graph)
dotg.write_png(mydir)
注意传入的参数需要时一个ExprFunctor
实例,因此原文给出的测试实例调用relay.testing.renet.getworkload()
得到模型并输出对v0.6版本并不可行。
下面复用上次GCN的例子,来生成计算图。
fromtvm.relay.visualizeimportvisualizefunc=relay.Function(relay.analysis.free_vars(output),output)
visualize(func)
执行上述代码之前需要先安装pydot和graphviz
pip install pydotapt-get install graphviz
最后会生成对应的relay_ir.png
图片,如下。
实际上VisualizeExpr
就是一个计算图的遍历器(以visit
对结点进行访问),因此只要重载对应的结点函数,就可以实现对应的功能。
以上是 TVM Relay IR计算图可视化 的全部内容, 来源链接: utcz.com/a/128586.html