LeNet
Contents
LeNet#
import torch
from torch.nn import functional as F
from torch import nn, fx
from torch.ao.quantization.observer import HistogramObserver
from torch_book.data.simple_vision import load_data_fashion_mnist
from torch_book.tools import train, try_gpu
from observer import histogram_observer
class LeNet(nn.Module):
def __init__(self, activation=nn.Sigmoid):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
self.sigmoid1 = activation()
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.sigmoid2 = activation()
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(16 * 5 * 5, 120)
self.sigmoid3 = activation()
self.linear2 = nn.Linear(120, 84)
self.sigmoid4 = activation()
self.linear3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.sigmoid1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.sigmoid2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.sigmoid3(x)
x = self.linear2(x)
x = self.sigmoid4(x)
x = self.linear3(x)
return x
# batch_size = 256
# lr, num_epochs = 0.9, 320
# train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
# train(net, train_iter, test_iter, num_epochs, lr, try_gpu())
模块变换#
# m = LeNet()
# mod = fx.symbolic_trace(m)
# graph = fx.Graph()
# new_mod = fx.GraphModule(m, graph)
# for node in mod.graph.nodes:
# if node.op == 'call_module' and "sigmoid" in node.target:
# break
# mod.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
def sigmoid_linear(x):
return x * (x>0)
m = LeNet()
mod = fx.symbolic_trace(m)
# 遍历 Graph 中全部节点
for node in mod.graph.nodes:
# 如果匹配目标
if node.op == "call_module":
if "sigmoid" in node.target:
# 设置插入点,添加新节点,用新节点替换所有 `node` 的用法
with mod.graph.inserting_after(node):
new_node = mod.graph.call_function(sigmoid_linear, node.args, node.kwargs)
node.replace_all_uses_with(new_node)
# 移除 graph 中旧的节点
mod.graph.erase_node(node)
# 不用忘记 recompile!
new_code = mod.recompile()
new_mod = fx.GraphModule(m, mod.graph)
print(mod)
LeNet(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(sigmoid1): Sigmoid()
(pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(sigmoid2): Sigmoid()
(pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear1): Linear(in_features=400, out_features=120, bias=True)
(sigmoid3): Sigmoid()
(linear2): Linear(in_features=120, out_features=84, bias=True)
(sigmoid4): Sigmoid()
(linear3): Linear(in_features=84, out_features=10, bias=True)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
sigmoid_linear = __main___sigmoid_linear(conv1); conv1 = None
pool1 = self.pool1(sigmoid_linear); sigmoid_linear = None
conv2 = self.conv2(pool1); pool1 = None
sigmoid_linear_1 = __main___sigmoid_linear(conv2); conv2 = None
pool2 = self.pool2(sigmoid_linear_1); sigmoid_linear_1 = None
flatten = self.flatten(pool2); pool2 = None
linear1 = self.linear1(flatten); flatten = None
sigmoid_linear_2 = __main___sigmoid_linear(linear1); linear1 = None
linear2 = self.linear2(sigmoid_linear_2); sigmoid_linear_2 = None
sigmoid_linear_3 = __main___sigmoid_linear(linear2); linear2 = None
linear3 = self.linear3(sigmoid_linear_3); sigmoid_linear_3 = None
return linear3
def sigmoid_linear(x):
return x * (x>0)
env = {}
# xs = torch.tensor([-6, -5, -3, -2, -1, 0, 1, 2, 3, 5, 6])
# ys = torch.sigmoid(xs)
new_graph = fx.Graph()
decomposition_rules = {"sigmoid": sigmoid_linear}
m = LeNet()
mod = fx.symbolic_trace(m)
graph = mod.graph
tracer = fx.proxy.GraphAppendingTracer(graph)
for node in graph.nodes:
if node.op == 'call_module' and "sigmoid" in node.target:
print(node)
# 通过使用代理包装参数,可以分派到适当的分解规则,
# 并通过符号跟踪隐式地将其添加到 Graph 中。
proxy_args = [fx.Proxy(env[x.name], tracer)
if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules["sigmoid"](*proxy_args)
# 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
# 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
new_node = output_proxy.node
env[node.target] = new_node
# break
else:
# 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
new_mod = fx.GraphModule(m, new_graph)
print(new_mod.code)