pytorch打印网络结构的实例

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):

""" Produces Graphviz representation of PyTorch autograd graph

Blue nodes are the Variables that require grad, orange are Tensors

saved for backward in torch.autograd.Function

Args:

var: output Variable

params: dict of (name, Variable) to add names to node that

require grad (TODO: make optional)

"""

if params is not None:

assert isinstance(params.values()[0], Variable)

param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',

shape='box',

align='left',

fontsize='12',

ranksep='0.1',

height='0.2')

dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

seen = set()

def size_to_str(size):

return '('+(', ').join(['%d' % v for v in size])+')'

def add_nodes(var):

if var not in seen:

if torch.is_tensor(var):

dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')

elif hasattr(var, 'variable'):

u = var.variable

name = param_map[id(u)] if params is not None else ''

node_name = '%s\n %s' % (name, size_to_str(u.size()))

dot.node(str(id(var)), node_name, fillcolor='lightblue')

else:

dot.node(str(id(var)), str(type(var).__name__))

seen.add(var)

if hasattr(var, 'next_functions'):

for u in var.next_functions:

if u[0] is not None:

dot.edge(str(id(u[0])), str(id(var)))

add_nodes(u[0])

if hasattr(var, 'saved_tensors'):

for t in var.saved_tensors:

dot.edge(str(id(t)), str(id(var)))

add_nodes(t)

add_nodes(var.grad_fn)

return dot

(3)打印网络结构:

import torch

from torch.autograd import Variable

import torch.nn as nn

from graphviz import Digraph

class CNN(nn.module):

def __init__(self):

******

def forward(self,x):

******

return out

*****************************

def make_dot(): #复制上面的代码

*****************************

if __name__ == '__main__':

net = CNN()

x = Variable(torch.randn(1, 1, 1024,1024))

y = net(x)

g = make_dot(y)

g.view()

params = list(net.parameters())

k = 0

for i in params:

l = 1

print("该层的结构:" + str(list(i.size())))

for j in i.size():

l *= j

print("该层参数和:" + str(l))

k = k + l

print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 pytorch打印网络结构的实例 的全部内容, 来源链接: utcz.com/z/328395.html

回到顶部