Ir2Torch and Torch2Ir#
The Ir2Torch
module is responsible for converting an intermediate representation (IR) back into a PyTorch model. This is useful for scenarios where you need to reconstruct a PyTorch model from its IR, which might have been generated by the Torch2Ir
module.
Concepts#
Intermediate Representation (IR)#
The IR is a structured format that represents a PyTorch model in a way that is independent of the PyTorch framework. It captures the model’s architecture, including layers, connections, and attributes.
Extending Ir2Torch#
To extend Ir2Torch
with support for new modules, you need to register custom handlers that define how to convert specific IR nodes back into PyTorch modules.
Step-by-Step Guide#
Define a Custom Linear Handler Function Create a function that takes an
Implementation
object and returns a corresponding PyTorch module.import torch.nn as nn from elasticai.creator.torch2ir import Implementation from elasticai.creator.ir2torch import get_default_converter as get_ir2torch my_ir2torch = get_ir2torch() @my_ir2torch.register_override("linear") def custom_linear_handler(impl: Implementation) -> nn.Module: return nn.Linear( in_features=impl.data["in_features"], out_features=impl.data["out_features"], bias=impl.data["bias"] )
Convert the IR Use the
Ir2Torch
instance to convert the IR back into a PyTorch model.ir = { "root": { "name": "root", "type": "module", "nodes": { "input_1": {"name": "input_1", "type": "input", "implementation": "input"}, "_0": {"name": "_0", "type": "custom_linear", "implementation": "0"}, "output": {"name": "output", "type": "output", "implementation": "output"}, }, "edges": { ("input_1", "_0"): {"src": "input_1", "dst": "_0"}, ("_0", "output"): {"src": "_0", "dst": "output"}, }, }, "0": { "name": "0", "type": "custom_linear", "in_features": 1, "out_features": 2, "bias": True, "edges": {}, "nodes": {}, }, } reconstructed_model = ir2torch_converter.convert(ir)
By following these steps, you can extend Ir2Torch
to support custom modules and ensure that your IR can be accurately converted back into a PyTorch model.
Important
The reconstructed model will not contain the state of an original
pytorch model. Often the generated module hierarchy will be compatible
with the state dict, that you can obtain by calling torch.nn.Module.state_dict()
and you can load that dict after rebuilding the model with torch.nn.Module.load_state(state)
.
In some cases, e.g., if the Ir graph is altered or if the original module contained
and used more than a single reference to the same submodule, loading the state dict
directly will fail. In these cases you have to modify the state dict, such that
the paths to parameters from the original model match the paths to the parameters
in the rebuilt model.
lin = nn.Linear(1, 1)
seq = Sequential(lin, lin)
In the example above the module seq
contains two references to the same module.
When you read the state dict you will obtain a set of parameters for each of these
duplicate submodules with names like "0.weight"
and "1.weight"
.
During the generation of our Ir the duplicate instance will be ignored and the graph
will just contain two calls to the lin
module.
Thus, you will have to remove the parameters starting with "1."
from the state dict
before loading.
Extending Torch2Ir#
Extending Torch2Ir works very similar, but with a few differences:
There is currently no support for overriding handlers.
You register new handlers with the
register
orregister_handlers
method.