LeNet#

import torch
from torch.nn import functional as F
from torch import nn, fx
from torch.ao.quantization.observer import HistogramObserver
from torch_book.data.simple_vision import load_data_fashion_mnist
from torch_book.tools import train, try_gpu
from observer import histogram_observer
class LeNet(nn.Module):
    def __init__(self, activation=nn.Sigmoid):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.sigmoid1 = activation()
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.sigmoid2 = activation()
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(16 * 5 * 5, 120)
        self.sigmoid3 = activation()
        self.linear2 = nn.Linear(120, 84)
        self.sigmoid4 = activation()
        self.linear3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.sigmoid1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.sigmoid2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.sigmoid3(x)
        x = self.linear2(x)
        x = self.sigmoid4(x)
        x = self.linear3(x)
        return x
# batch_size = 256
# lr, num_epochs = 0.9, 320

# train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
# train(net, train_iter, test_iter, num_epochs, lr, try_gpu())

模块变换#

# m = LeNet()
# mod = fx.symbolic_trace(m)
# graph = fx.Graph()
# new_mod = fx.GraphModule(m, graph)
# for node in mod.graph.nodes:
#     if node.op == 'call_module' and "sigmoid" in node.target:
#         break
# mod.graph.lint() # 做一些检查,以确保 Graph 是格式良好的。
def sigmoid_linear(x):
    return x * (x>0)

m = LeNet()
mod = fx.symbolic_trace(m)
# 遍历 Graph 中全部节点
for node in mod.graph.nodes:
    # 如果匹配目标
    if node.op == "call_module":
        if "sigmoid" in node.target:
            # 设置插入点,添加新节点,用新节点替换所有 `node` 的用法
            with mod.graph.inserting_after(node):
                new_node = mod.graph.call_function(sigmoid_linear, node.args, node.kwargs)
                node.replace_all_uses_with(new_node)
            # 移除 graph 中旧的节点
            mod.graph.erase_node(node)
# 不用忘记 recompile!
new_code = mod.recompile()
new_mod = fx.GraphModule(m, mod.graph)
print(mod)
LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (sigmoid1): Sigmoid()
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (sigmoid2): Sigmoid()
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=400, out_features=120, bias=True)
  (sigmoid3): Sigmoid()
  (linear2): Linear(in_features=120, out_features=84, bias=True)
  (sigmoid4): Sigmoid()
  (linear3): Linear(in_features=84, out_features=10, bias=True)
)



def forward(self, x):
    conv1 = self.conv1(x);  x = None
    sigmoid_linear = __main___sigmoid_linear(conv1);  conv1 = None
    pool1 = self.pool1(sigmoid_linear);  sigmoid_linear = None
    conv2 = self.conv2(pool1);  pool1 = None
    sigmoid_linear_1 = __main___sigmoid_linear(conv2);  conv2 = None
    pool2 = self.pool2(sigmoid_linear_1);  sigmoid_linear_1 = None
    flatten = self.flatten(pool2);  pool2 = None
    linear1 = self.linear1(flatten);  flatten = None
    sigmoid_linear_2 = __main___sigmoid_linear(linear1);  linear1 = None
    linear2 = self.linear2(sigmoid_linear_2);  sigmoid_linear_2 = None
    sigmoid_linear_3 = __main___sigmoid_linear(linear2);  linear2 = None
    linear3 = self.linear3(sigmoid_linear_3);  sigmoid_linear_3 = None
    return linear3
    
def sigmoid_linear(x):
    return x * (x>0)

env = {}
# xs = torch.tensor([-6, -5, -3, -2, -1, 0, 1, 2, 3, 5, 6])
# ys = torch.sigmoid(xs)
new_graph = fx.Graph()
decomposition_rules = {"sigmoid": sigmoid_linear}

m = LeNet()
mod = fx.symbolic_trace(m)
graph = mod.graph
tracer = fx.proxy.GraphAppendingTracer(graph)

for node in graph.nodes:
    if node.op == 'call_module' and "sigmoid" in node.target:
        print(node)
        # 通过使用代理包装参数,可以分派到适当的分解规则,
        # 并通过符号跟踪隐式地将其添加到 Graph 中。
        proxy_args = [fx.Proxy(env[x.name], tracer) 
                      if isinstance(x, fx.Node) else x for x in node.args]
        output_proxy = decomposition_rules["sigmoid"](*proxy_args)
            
        # 对 `Proxy` 的运算总是产生新的 `Proxy`,分解规则的返回值也不例外。
        # 需要从 `Proxy` 中提取底层的 `Node`,以便在此变换的后续迭代中使用它。
        new_node = output_proxy.node
        env[node.target] = new_node
        # break
    else:
        # 默认情况:没有此节点的分解规则,所以只需要将它复制到新的 Graph 中。
        new_node = new_graph.node_copy(node, lambda x: env[x.name])
        env[node.name] = new_node
        
new_mod = fx.GraphModule(m, new_graph)
print(new_mod.code)