Ir2Torch and Torch2Ir#

The Ir2Torch module is responsible for converting an intermediate representation (IR) back into a PyTorch model. This is useful for scenarios where you need to reconstruct a PyTorch model from its IR, which might have been generated by the Torch2Ir module.

Concepts#

Intermediate Representation (IR)#

The IR is a structured format that represents a PyTorch model in a way that is independent of the PyTorch framework. It captures the model’s architecture, including layers, connections, and attributes.

Extending Ir2Torch#

To extend Ir2Torch with support for new modules, you need to register custom handlers that define how to convert specific IR nodes back into PyTorch modules.

Step-by-Step Guide#

  1. Define a Custom Linear Handler Function Create a function that takes an Implementation object and returns a corresponding PyTorch module.

    import torch.nn as nn
    from elasticai.creator.torch2ir import Implementation
    from elasticai.creator.ir2torch import get_default_converter as get_ir2torch
    
    my_ir2torch = get_ir2torch()
    
    @my_ir2torch.register_override("linear")
    def custom_linear_handler(impl: Implementation) -> nn.Module:
        return nn.Linear(
            in_features=impl.data["in_features"],
            out_features=impl.data["out_features"],
            bias=impl.data["bias"]
        )
    
  2. Convert the IR Use the Ir2Torch instance to convert the IR back into a PyTorch model.

    ir = {
        "root": {
            "name": "root",
            "type": "module",
            "nodes": {
                "input_1": {"name": "input_1", "type": "input", "implementation": "input"},
                "_0": {"name": "_0", "type": "custom_linear", "implementation": "0"},
                "output": {"name": "output", "type": "output", "implementation": "output"},
            },
            "edges": {
                ("input_1", "_0"): {"src": "input_1", "dst": "_0"},
                ("_0", "output"): {"src": "_0", "dst": "output"},
            },
        },
        "0": {
            "name": "0",
            "type": "custom_linear",
            "in_features": 1,
            "out_features": 2,
            "bias": True,
            "edges": {},
            "nodes": {},
        },
    }
    
    reconstructed_model = ir2torch_converter.convert(ir)
    

By following these steps, you can extend Ir2Torch to support custom modules and ensure that your IR can be accurately converted back into a PyTorch model.

Important

The reconstructed model will not contain the state of an original pytorch model. Often the generated module hierarchy will be compatible with the state dict, that you can obtain by calling torch.nn.Module.state_dict() and you can load that dict after rebuilding the model with torch.nn.Module.load_state(state). In some cases, e.g., if the Ir graph is altered or if the original module contained and used more than a single reference to the same submodule, loading the state dict directly will fail. In these cases you have to modify the state dict, such that the paths to parameters from the original model match the paths to the parameters in the rebuilt model.

lin = nn.Linear(1, 1)
seq = Sequential(lin, lin)

In the example above the module seq contains two references to the same module. When you read the state dict you will obtain a set of parameters for each of these duplicate submodules with names like "0.weight" and "1.weight". During the generation of our Ir the duplicate instance will be ignored and the graph will just contain two calls to the lin module. Thus, you will have to remove the parameters starting with "1." from the state dict before loading.

Extending Torch2Ir#

Extending Torch2Ir works very similar, but with a few differences:

  • There is currently no support for overriding handlers.

  • You register new handlers with the register or register_handlers method.