pytorch 模型可视化的例子

如下所示:

一. visualize.py

from graphviz import Digraph

import torch

from torch.autograd import Variable

def make_dot(var, params=None):

""" Produces Graphviz representation of PyTorch autograd graph

Blue nodes are the Variables that require grad, orange are Tensors

saved for backward in torch.autograd.Function

Args:

var: output Variable

params: dict of (name, Variable) to add names to node that

require grad (TODO: make optional)

"""

if params is not None:

assert isinstance(params.values()[0], Variable)

param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',

shape='box',

align='left',

fontsize='12',

ranksep='0.1',

height='0.2')

dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

seen = set()

def size_to_str(size):

return '('+(', ').join(['%d' % v for v in size])+')'

def add_nodes(var):

if var not in seen:

if torch.is_tensor(var):

dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')

elif hasattr(var, 'variable'):

u = var.variable

name = param_map[id(u)] if params is not None else ''

node_name = '%s\n %s' % (name, size_to_str(u.size()))

dot.node(str(id(var)), node_name, fillcolor='lightblue')

else:

dot.node(str(id(var)), str(type(var).__name__))

seen.add(var)

if hasattr(var, 'next_functions'):

for u in var.next_functions:

if u[0] is not None:

dot.edge(str(id(u[0])), str(id(var)))

add_nodes(u[0])

if hasattr(var, 'saved_tensors'):

for t in var.saved_tensors:

dot.edge(str(id(t)), str(id(var)))

add_nodes(t)

add_nodes(var.grad_fn)

return dot

二. 使用步骤

import torch

from torch.autograd import Variable

from models import *

from visualize import make_dot

x = Variable(torch.rand(1, 3, 256, 256))

model = GeneratorUNet()

y = model(x)

g = make_dot(y)

g.view()

三. 效果展示

以上这篇pytorch 模型可视化的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 pytorch 模型可视化的例子 的全部内容, 来源链接: utcz.com/z/350789.html

回到顶部