Shortcuts

Source code for torch.fx.node

# Nodes represent a definition of a value in our graph of operators.
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
import types
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair

if TYPE_CHECKING:
    from .graph import Graph

BaseArgumentTypes = Union[str, int, float, bool, torch.dtype, torch.Tensor, torch.device]
base_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined]

Target = Union[Callable[..., Any], str]

Argument = Optional[Union[
    Tuple[Any, ...],  # actually Argument, but mypy can't represent recursive types
    List[Any],  # actually Argument
    Dict[str, Any],  # actually Argument
    slice,  # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
    'Node',
    BaseArgumentTypes
]]

_side_effectful_functions: Set[Callable] = {torch._assert}

# this is fixed on master, WAR for 1.5
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
    name = orig_method.__name__
    module = orig_method.__module__
    if module is not None:
        return module
    for guess in [torch, torch.nn.functional]:
        if getattr(guess, name, None) is orig_method:
            return guess.__name__
    raise RuntimeError(f'cannot find module for {orig_method}')

# Borrowed from CPython typing module
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
def _type_repr(obj):
    """Return the repr() of an object, special-casing types (internal helper).
    If obj is a type, we return a shorter version than the default
    type.__repr__, based on the module and qualified name, which is
    typically enough to uniquely identify a type.  For everything
    else, we fall back on repr(obj).
    """
    # HACK: In Python 3.6, type aliases from ``typing`` are instances of ``type``, but in
    # later Python versions, type aliases are not instances of ``type``!! We want
    # all type aliases to fall through to ``repr``, so if we have a type that is
    # in the module typing, don't go down this path.
    if isinstance(obj, type) and obj.__module__ != 'typing':
        if obj.__module__ == 'builtins':
            return obj.__qualname__
        return f'{obj.__module__}.{obj.__qualname__}'
    if obj is ...:
        return('...')
    if isinstance(obj, types.FunctionType):
        return obj.__name__
    return repr(obj)

def _get_qualified_name(func: Callable[..., Any]) -> str:
    # things like getattr just appear in builtins
    if getattr(builtins, func.__name__, None) is func:
        return func.__name__
    name = func.__name__
    module = _find_module_of_method(func)
    module = module.replace('torch._ops', 'torch.ops')  # WAR for bug in how torch.ops assigns module
    return f'{module}.{name}'

def _format_arg(arg) -> str:
    if isinstance(arg, list):
        items = ', '.join(_format_arg(a) for a in arg)
        return f'[{items}]'
    elif isinstance(arg, tuple):
        items = ', '.join(_format_arg(a) for a in arg)
        maybe_comma = ',' if len(arg) == 1 else ''
        return f'({items}{maybe_comma})'
    elif isinstance(arg, dict):
        items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items())
        return f'{{{items_str}}}'

    if isinstance(arg, Node):
        return '%' + str(arg)
    else:
        return str(arg)

[docs]class Node: """ ``Node`` is the data structure that represents individual operations within a ``Graph``. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include nodes that specify function inputs and outputs). Each ``Node`` has a function specified by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to the function parameters (e.g. ``x``) in the graph printout. - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. ``args`` and ``kwargs`` are don't-care - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, following the Python calling convention - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument* - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], type : Optional[Any] = None) -> None: self.graph = graph self.name = name # unique name of value being created assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr if op in ['call_method', 'call_module']: assert isinstance(target, str) self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. self._input_nodes : Dict[Node, None] = {} self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] # All of the nodes that use the value produced by this Node # Note one user may correspond to several uses, e.g. the node fo ``x + x`` # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care self.users : Dict['Node', None] = {} # Type expression representing the output value of this node. # This should contain the same class of Type objects that would appear # as type annotations for function inputs/outputs. # # For placeholder nodes, this value will be used to type-annotate the # generated function parameters. # For the return node, this value will be used to type-annotate the # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. self.type : Optional[Any] = type self._prev = self self._next = self self._erased = False # If set, use this fn to print this node self._repr_fn : Optional[Callable[[Node], str]] = None self._stack_trace : Optional[str] = None # Dictionary to store metadata passes need to do their # transformations. This metadata is preserved across node copies self.meta : Dict[str, Any] = {} @property def next(self) -> 'Node': """ Returns the next ``Node`` in the linked list of Nodes. Returns: The next ``Node`` in the linked list of Nodes. """ return self._next @property def prev(self) -> 'Node': """ Returns the previous ``Node`` in the linked list of Nodes. Returns: The previous ``Node`` in the linked list of Nodes. """ return self._prev
[docs] def prepend(self, x: 'Node') -> None: """ Insert x before this node in the list of nodes in the graph. Example:: Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax Args: x (Node): The node to put before this node. Must be a member of the same graph. """ assert self.graph == x.graph, "Attempting to move a Node into a different Graph" x._remove_from_list() p = self._prev p._next, x._prev = x, p x._next, self._prev = self, x
[docs] def append(self, x: 'Node') -> None: """ Insert x after this node in the list of nodes in the graph. Equvalent to ``self.next.prepend(x)`` Args: x (Node): The node to put after this node. Must be a member of the same graph. """ self._next.prepend(x)
def _remove_from_list(self): p, n = self._prev, self._next p._next, n._prev = n, p @property def args(self) -> Tuple[Argument, ...]: """ The tuple of arguments to this ``Node``. The interpretation of arguments depends on the node's opcode. See the :class:`Node` docstring for more information. Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment. """ return self._args @args.setter def args(self, a : Tuple[Argument, ...]): """ Set the tuple of arguments to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.args = new_args` self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] @property def kwargs(self) -> Dict[str, Argument]: """ The dict of keyword arguments to this ``Node``. The interpretation of arguments depends on the node's opcode. See the :class:`Node` docstring for more information. Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment. """ return self._kwargs @kwargs.setter def kwargs(self, k : Dict[str, Argument]): """ Set the dict of kwargs to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] @property def all_input_nodes(self) -> List['Node']: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that are Nodes. Returns: List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this ``Node``, in that order. """ return list(self._input_nodes.keys()) @property def stack_trace(self) -> Optional[str]: """ Return the Python stack trace that was recorded during tracing, if any. This property is usually populated by `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, set `record_stack_traces = True` on the `Tracer` instance. """ return self._stack_trace @stack_trace.setter def stack_trace(self, trace : Optional[str]): self._stack_trace = trace def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']): """ This API is internal. Do *not* call it directly. """ self._args = new_args self._kwargs = new_kwargs for old_use in self._input_nodes.keys(): old_use.users.pop(self) self._input_nodes = {} map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) for new_use in self._input_nodes.keys(): new_use.users.setdefault(self) def __repr__(self) -> str: if self._repr_fn: return self._repr_fn(self) return self.name def _pretty_print_target(self, target): """ Make target printouts more user-friendly. 1) builtins will be printed as `builtins.xyz` 2) operators will be printed as `operator.xyz` 3) other callables will be printed with qualfied name, e.g. torch.add """ if isinstance(target, str): return target if hasattr(target, '__module__'): if not hasattr(target, '__name__'): # Just to be defensive, if we don't have `__name__`, get the # qualname. Not sure if this happens for any members of `operator` # or `builtins`. This fallback path is not as good, since e.g. # things in `operator` have `_operator` as their __module__. return _get_qualified_name(target) if target.__module__ == 'builtins': return f'builtins.{target.__name__}' elif target.__module__ == '_operator': return f'operator.{target.__name__}' return _get_qualified_name(target)
[docs] def format_node(self, placeholder_names: List[str] = None, maybe_return_typename: List[str] = None) -> Optional[str]: """ Return a descriptive string representation of ``self``. This method can be used with no arguments as a debugging utility. This function is also used internally in the ``__str__`` method of ``Graph``. Together, the strings in ``placeholder_names`` and ``maybe_return_typename`` make up the signature of the autogenerated ``forward`` function in this Graph's surrounding GraphModule. ``placeholder_names`` and ``maybe_return_typename`` should not be used otherwise. Args: placeholder_names: A list that will store formatted strings representing the placeholders in the generated ``forward`` function. Internal use only. maybe_return_typename: A single-element list that will store a formatted string representing the output of the generated ``forward`` function. Internal use only. Returns: str: If 1) we're using ``format_node`` as an internal helper in the ``__str__`` method of ``Graph``, and 2) ``self`` is a placeholder Node, return ``None``. Otherwise, return a descriptive string representation of the current Node. """ if self.op == 'placeholder': assert isinstance(self.target, str) arg_str = self.target arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' if placeholder_names: placeholder_names.append(arg_str) return None maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' elif self.op == 'get_attr': maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ f'{self.op}[target={self._pretty_print_target(self.target)}]' elif self.op == 'output': if self.type and maybe_return_typename: maybe_return_typename[0] = f' -> {_type_repr(self.type)}' return f'return {self.args[0]}' else: maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \ f'{self.op}[target={self._pretty_print_target(self.target)}](' \ f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'
[docs] def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. Args: replace_with (Node): The node to replace all uses of ``self`` with. Returns: The list of Nodes on which this change was made. """ to_process = list(self.users) for use_node in to_process: def maybe_replace_node(n : Node) -> Node: if n == self: return replace_with else: return n new_args = map_arg(use_node.args, maybe_replace_node) new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) assert isinstance(new_args, tuple) assert isinstance(new_kwargs, dict) use_node.__update_args_kwargs(new_args, new_kwargs) assert len(self.users) == 0 return to_process
[docs] def is_impure(self): """ Returns whether this op is impure, i.e. if its op is a placeholder or output, or if a call_function or call_module which is impure. Returns: bool: If the op is impure or not. """ if self.op in {"placeholder", "output"}: return True # Check if an impure function. if self.op == "call_function": return self.target in _side_effectful_functions # Check if an impure module. if self.op == "call_module": assert ( self.graph.owning_module is not None ), "self.graph.owning_module not set for purity check" target_mod = self.graph.owning_module.get_submodule(self.target) assert ( target_mod is not None ), f"Did not find expected submodule target {self.target}" return getattr(target_mod, "_is_impure", False) return False
[docs] def normalized_arguments( self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, kwarg_types : Optional[Dict[str, Any]] = None, normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to Python targets. This means that `args/kwargs` will be matched up to the module/functional's signature and return exclusively kwargs in positional order if `normalize_to_only_use_kwargs` is true. Also populates default values. Does not support positional-only parameters or varargs parameters. Supports module calls. May require `arg_types` and `kwarg_types` in order to disambiguate overloads. Args: root (torch.nn.Module): Module upon which to resolve module targets. arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. Returns: Returns NamedTuple ArgsKwargsPair, or `None` if not successful. """ if self.op == 'call_function': assert callable(self.target) return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] elif self.op == 'call_module': assert isinstance(self.target, str) return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] return None
[docs] def replace_input_with(self, old_input: 'Node', new_input: 'Node'): """ Loop through input nodes of ``self``, and if `old_input` is one of those, replace `old_input` node with new input node `new_input`. Args: old_input (Node): The old input node to be replaced. new_input (Node): The new input node to replace `old_input`. """ def maybe_replace_node(n : Node) -> Node: return new_input if n == old_input else n new_args = map_arg(self.args, maybe_replace_node) new_kwargs = map_arg(self.kwargs, maybe_replace_node) assert isinstance(new_args, tuple) assert isinstance(new_kwargs, dict) self.__update_args_kwargs(new_args, new_kwargs)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: """ Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ if isinstance(a, tuple): return tuple(map_aggregate(elem, fn) for elem in a) elif isinstance(a, list): return immutable_list(map_aggregate(elem, fn) for elem in a) elif isinstance(a, dict): return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) elif isinstance(a, slice): return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) else: return fn(a)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources