File size: 38,402 Bytes
db4ff34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
import re
from dataclasses import dataclass
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple

from torchgen import local

from torchgen.api import cpp
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
from torchgen.model import (
    BaseTy,
    BaseType,
    FunctionSchema,
    ListType,
    NativeFunction,
    NativeFunctionsViewGroup,
    SchemaKind,
    Type,
)
from torchgen.utils import IDENT_REGEX


# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.
@dataclass(frozen=True)
class SavedAttribute:
    # The NamedCType holds the updated name and cpp type of the attribute
    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
    nctype: NamedCType

    # The expression to read the derived property at save time, e.g.:
    # `other.scalar_type()`.
    expr: str


# Represents a backward formula that calculates derivatives for one
# or more tensors.
@dataclass(frozen=True)
class Derivative:
    # The formula string (legit C++ expression).
    # Note that expressions against input arguments have been replaced with the
    # corresponding saved attributes.
    # E.g.:
    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
    formula: str

    # The formula string before input argument replacement
    original_formula: str

    # Names of the arguments for which this formula calculates derivatives.
    var_names: Tuple[str, ...]

    # Saved inputs that are referenced by the formula.
    saved_inputs: Tuple[SavedAttribute, ...]

    # Saved outputs that are referenced by the formula.
    saved_outputs: Tuple[SavedAttribute, ...]

    # Gradients that are referenced by name in the formula.
    named_gradients: Set[str]


# Represents a forward formula that calculates forward derivatives
# for one tensor.
@dataclass(frozen=True)
class ForwardDerivative:
    # The formula string (legit C++ expression).
    # Note that special keywords such as "linear" or "element_wise" have been
    # replaced by the automatically generated formula.
    formula: str

    # Name of the output arguments for which this formula calculates forward
    # derivatives
    var_names: Tuple[str, ...]

    # Type of the output arguments for which this formula calculates forward
    # derivatives
    var_types: Tuple[Type, ...]

    # Inputs for which the forward derivatives are required for this formula
    required_inputs_fw_grad: Optional[Tuple[str, ...]]

    # Inputs for which the primal is required for this formula
    required_inputs_primal: Optional[Tuple[str, ...]]

    # Flag to specify if this formula requires the original value of self
    # This is only used by inplace operations
    required_original_self_value: bool

    # If this formula is specified in derivatives.yaml or if we are re-using the
    # out of place formula for inplace
    is_reusing_outplace_formula: bool


# Represents differentiability info for a NativeFunction.
@dataclass(frozen=True)
class DifferentiabilityInfo:
    # The base name read from derivatives.yaml.
    name: str

    # The matching native function.
    #
    # There can be multiple NativeFunction having the same base name:
    #  - different overloads with different types of input arguments;
    #  - in-place/out/functional variants of the same function;
    #
    # We first use the schema string (under the 'name' key) in derivatives.yaml
    # to find the NativeFunction having the same schema string.
    # Then we find the in-place/out/functional variants of the matching function.
    # Among these variants, we choose the one having the same name as the
    # derivatives.yaml entry. If there is no exact match, then we choose the
    # in-place variant.
    # TODO: maybe the logic to search for all variants is no longer necessary?
    func: NativeFunction

    # The name of the generated autograd function.
    # It's set only if we will calculate a derivative, i.e.
    # 'args_with_derivatives' is not empty.
    op: Optional[str]

    # The derivatives formulae for this function.
    # Note that the length of this sequence is the number of differentiable inputs
    derivatives: Sequence[Derivative]

    # The forward derivatives formulae for this function.
    # Note that the length of this sequence is the number of differentiable outputs
    forward_derivatives: Sequence[ForwardDerivative]

    # The union of 'saved_inputs' of all 'derivatives'.
    all_saved_inputs: Sequence[SavedAttribute]

    # The union of 'saved_outputs' of all 'derivatives'.
    all_saved_outputs: Sequence[SavedAttribute]

    # All named gradients that are available for use, in the same
    # order as in the grads vector.
    available_named_gradients: Sequence[str]

    # The named gradients that are used in any of the derivatives.
    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
    used_named_gradients: Set[str]

    # The function's input arguments for which it calculates derivatives.
    # It's the union of 'var_names' of all 'derivatives', sorted by the
    # argument order in the function schema.
    args_with_derivatives: Sequence[Binding]

    # Names of arguments whose derivative formula is 'non_differentiable'.
    non_differentiable_arg_names: Sequence[str]

    # Raw data read from derivatives.yaml.
    output_differentiability: Optional[List[bool]]

    # output_differentiability in derivatives.yaml can be a list of
    # conditions that express if the output is differentiable. In this case,
    # the number of conditions must match the number of outputs
    # (NB: we only support one condition right now).
    # output_differentiability gets populated with True for each condition,
    # while output_differentiability_conditions gets populated with the conditions
    output_differentiability_conditions: Optional[List[str]]

    @property
    def has_derivatives(self) -> bool:
        return len(self.args_with_derivatives) > 0

    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
    # but with a new operator name.
    # This is used when generating "copy" variants of view ops,
    # which are able to use the exact same derivative formula as the original view op
    # See Note [Codegen'd {view}_copy Operators]
    def create_view_copy_from_view_derivative(
        self, g: NativeFunctionsViewGroup
    ) -> Optional["DifferentiabilityInfo"]:
        if g.view_copy is None:
            return None
        f = g.view_copy

        name_split_by_period = self.name.split(".", maxsplit=2)
        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
            name_split_by_period[1:]
        )
        view_copy_op_name = None if self.op is None else f"{self.op}_copy"

        return DifferentiabilityInfo(
            # Use the "_copy" version of name/func/op
            name=view_copy_name,
            func=f,
            op=view_copy_op_name,
            # But keep all derivative info the same
            derivatives=self.derivatives,
            forward_derivatives=self.forward_derivatives,
            all_saved_inputs=self.all_saved_inputs,
            all_saved_outputs=self.all_saved_outputs,
            available_named_gradients=self.available_named_gradients,
            used_named_gradients=self.used_named_gradients,
            args_with_derivatives=self.args_with_derivatives,
            non_differentiable_arg_names=self.non_differentiable_arg_names,
            output_differentiability=self.output_differentiability,
            output_differentiability_conditions=self.output_differentiability_conditions,
        )


def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
    if info is None:
        return False
    for derivative in info.derivatives:
        formula = derivative.formula
        if re.search(IDENT_REGEX.format(ident), formula):
            return True
    return False


def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
    return uses_ident(info, "retain_variables")


def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
    return uses_ident(info, "grad")


# Represents a differentiable `Argument`.
# How is it different from the `Argument` type?
# - It's processed Arguments which are differentiable and only used in the
#   context of the autograd codegen;
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
@dataclass(frozen=True)
class DifferentiableInput:
    name: str
    type: Type

    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
    cpp_type: str


# Represents a differentiable `Return`.
# How it it different from the `Return` type?
# - The name in `Return` is optional. Here it is always populated using the same
#   `cpp.return_names()` method.
#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
# - It's processed Returns which are differentiable, in compliance with the
#   `output_differentiability` field defined in derivatives.yaml (if specified),
#   and are only used in the context of the autograd codegen;
@dataclass(frozen=True)
class DifferentiableOutput:
    name: str
    type: Type

    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
    cpp_type: str


@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
    func: NativeFunction
    info: Optional[Dict[str, DifferentiabilityInfo]]
    fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]


# TODO: Update comment below since it is out of date.
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
    """How are we going to call the underlying implementation of a
    declaration?  There are two strategies:
        - use_derived: we want to call the implementation on CPUDoubleType
          (or a similar, derived Type instance).  Because these derived
          instances deal in Tensors, not Variables (it's a completely different
          object, so it doesn't dispatch back to VariableType), code on
          this dispatch path needs to wrap/unwrap tensors.  If the
          derived implementation takes and returns tensors, the
          implementation is usually differentiable (although we also use
          the derived dispatch path for non-differentiable functions
          that we still want to dispatch on the derived Type instance;
          e.g., size())
        - use_type: we want to call the implementation on Type, because
          it is implemented concretely, and the functions it invokes will
          get dispatched back to VariableType (which will ensure that they
          are differentiable.)
    """
    # fn is derived as long as any of its per-key differentiability infos
    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
    # and ADInplaceOrViewType. We want to generate these functions as long as a
    # derivative is defined for ANY dispatch key.
    if fn.func.is_abstract or (
        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
    ):
        # If the function is abstract (not implemented on at::Type), we must
        # call the implementation on the derived type with unpacked tensors.

        # If the function has a derivative specified and is concrete, we could
        # call either implementation. We prefer the calling the derived
        # type's implementation with unpacked tensors because it is more
        # performant in some cases: any internal calls to other ATen functions
        # won't have the history tracked.

        # If the function has a type dispatched argument (i.e. is a factory),
        # we prefer calling the derived type's implementation both because it is
        # more performant and to ensure factory functions return tensors with _version
        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
        # to understand.

        return "use_derived"
    else:
        # If the function is concrete (we don't have to override it) and we
        # didn't declare it in derivatives.yaml, we'll assume that it is
        # actually implemented out of differentiable functions. (This
        # assumption might not hold, but then you'll see gradcheck fail.)
        return "use_type"


def is_foreach_func(f: NativeFunction) -> bool:
    return f.func.name.name.base.startswith("_foreach_")


# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
_foreach_with_inplace_ref = {"_foreach_zero_"}
_foreach_with_tensor_overload = {
    "_foreach_add.Tensor",
    "_foreach_mul.Tensor",
    "_foreach_div.Tensor",
}


# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
# reference to generate derivatives.
def is_reference_for_foreach(
    f: NativeFunction,
    function_schema: FunctionSchema,
) -> bool:
    return (
        f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
        and (
            not function_schema.name.name.inplace
            or str(f.func.name) in _foreach_with_inplace_ref
        )
        and all(
            ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
            for arg, ref_arg in zip(
                f.func.arguments.flat_non_out,
                function_schema.arguments.flat_non_out,
            )
        )
    )


# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(
    foreach_function: NativeFunction,
    functional_info_by_signature: Dict[
        FunctionSchema, Dict[str, DifferentiabilityInfo]
    ],
    non_functional_info_by_signature: Dict[
        FunctionSchema, Dict[str, DifferentiabilityInfo]
    ],
    dispatch_key: str = "Default",
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
    """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.

    The second return value indicates whether the info is generated in this function.
    """
    ref_diff_info: Optional[DifferentiabilityInfo] = None

    for function_schema, diff_info in functional_info_by_signature.items():
        if not is_reference_for_foreach(foreach_function, function_schema):
            continue
        ref_diff_info = diff_info[dispatch_key]
        if ref_diff_info is not None:
            break
    # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
    # while the info of `zero_` is in non_functional_info_by_signature
    if (
        ref_diff_info is None
        and foreach_function.func.kind() == SchemaKind.inplace
        and str(foreach_function.func.name) in _foreach_with_inplace_ref
    ):
        for function_schema, diff_info in non_functional_info_by_signature.items():
            if not is_reference_for_foreach(foreach_function, function_schema):
                continue
            ref_diff_info = diff_info[dispatch_key]
            if ref_diff_info is not None:
                break
    if ref_diff_info is None:
        return None, False

    # non out-place uses the existing Derivative.
    if foreach_function.func.kind() == SchemaKind.inplace:
        return ref_diff_info, False

    map_refarg2foreacharg, map_name2arg = {}, {}
    for i, (arg, ref_arg) in enumerate(
        zip(
            foreach_function.func.arguments.flat_non_out,
            function_schema.arguments.flat_non_out,
        )
    ):
        map_refarg2foreacharg[ref_arg.name] = arg.name
        map_name2arg[arg.name] = arg

    all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
    modified_derivative_formulas = []
    for i, derivative in enumerate(ref_diff_info.derivatives):
        modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
            "result", "result[i]"
        )
        saved_inputs, saved_outputs = [], []
        # note(crcrpar): This context seems necessary to call `cpp.argument_type`
        with local.parametrize(
            use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
            use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
        ):
            for ref_input in derivative.saved_inputs:
                ref_input_jit_name = ref_input.expr.split(".")[0]
                mapped_name = map_refarg2foreacharg[ref_input_jit_name]
                if isinstance(map_name2arg[mapped_name].type, ListType):
                    mapped_expr = mapped_name + "[i]"
                else:
                    mapped_expr = mapped_name
                new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
                modified_formula = modified_formula.replace(
                    cast(str, ref_input.nctype.name), new_expr
                )

                nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
                canonical_nctype = NamedCType(
                    nctype.name, nctype.type.remove_const_ref()
                )
                saved_inputs.append(
                    SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
                )
            for ref_output in derivative.saved_outputs:
                if ref_output.nctype.name == "result":
                    saved_outputs.append(
                        SavedAttribute(
                            nctype=NamedCType(
                                name="result", type=BaseCType(tensorListT)
                            ),
                            expr="result",
                        )
                    )
                else:
                    raise RuntimeError("")
        var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
        all_var_names.extend(var_names)
        all_saved_inputs.extend(saved_inputs)
        all_saved_outputs.extend(saved_outputs)
        modified_derivative = Derivative(
            formula=modified_formula,
            original_formula=derivative.formula,
            var_names=tuple(var_names),
            saved_inputs=tuple(saved_inputs),
            saved_outputs=tuple(saved_outputs),
            named_gradients=set(),
        )
        modified_derivative_formulas.append(modified_derivative)

    with local.parametrize(
        use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
        use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
    ):
        args_with_derivatives = [
            Binding(
                name=arg.name,
                nctype=cpp.argument_type(arg, binds=arg.name),
                argument=arg,
                default=None,
            )
            for arg in foreach_function.func.arguments.flat_non_out
            if arg.name in all_var_names
        ]

    forward_derivatives: List[ForwardDerivative] = []
    fw_derivative: ForwardDerivative
    for fw_derivative in ref_diff_info.forward_derivatives:
        var_names: List[str] = list(fw_derivative.var_names)  # type: ignore[no-redef]
        var_types: List[Type] = list(fw_derivative.var_types)
        required_inputs_fw_grad: List[str] = []
        required_inputs_primal: List[str] = []
        if fw_derivative.required_inputs_fw_grad is not None:
            required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
        if fw_derivative.required_inputs_primal:
            required_inputs_primal = list(fw_derivative.required_inputs_primal)
        modified_formula = fw_derivative.formula

        # Foreach's result is TensorList
        if "result" in modified_formula:
            modified_formula = fw_derivative.formula.replace("result", "result[i]")

        for foreach_arg, ref_arg in zip(
            foreach_function.func.arguments.flat_non_out,
            ref_diff_info.func.func.arguments.flat_non_out,
        ):
            # Modify reference forward formula
            if (
                isinstance(foreach_arg.type, ListType)
                and not foreach_arg.type.is_tensor_like()
            ):
                # Assuming ScalarList
                modified_formula = modified_formula.replace(
                    ref_arg.name, foreach_arg.name + "[i]"
                )
            elif foreach_arg.type.is_tensor_like():
                # Assuming TensorList / Tensor
                # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
                assert isinstance(foreach_arg.type, ListType) or (
                    foreach_arg.type == BaseType(BaseTy.Tensor)
                    and str(foreach_function.func.name) in _foreach_with_tensor_overload
                ), f"{foreach_function.func.name}, {foreach_arg.type}"
                for suffix in ("_p", "_t"):
                    curr_expr = ref_arg.name + suffix
                    if curr_expr in modified_formula:
                        new_expr = foreach_arg.name + suffix
                        modified_formula = modified_formula.replace(curr_expr, new_expr)
            else:
                # Assuming Scalar
                if foreach_arg.name != ref_arg.name:
                    modified_formula = modified_formula.replace(
                        ref_arg.name, foreach_arg.name
                    )

            # note(crcrpar): there should exist a cooler way...
            for i, name in enumerate(var_names):
                if name == ref_arg.name:
                    var_names[i] = foreach_arg.name
                    var_types[i] = foreach_arg.type
            for i, name in enumerate(required_inputs_fw_grad):
                if name == ref_arg.name:
                    required_inputs_fw_grad[i] = foreach_arg.name
            for i, name in enumerate(required_inputs_primal):
                if name == ref_arg.name:
                    required_inputs_primal[i] = foreach_arg.name
        forward_derivatives.append(
            ForwardDerivative(
                formula=modified_formula,
                var_names=tuple(var_names),
                var_types=tuple(var_types),
                required_inputs_fw_grad=tuple(required_inputs_fw_grad),
                required_inputs_primal=tuple(required_inputs_primal),
                required_original_self_value=fw_derivative.required_original_self_value,
                is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
            )
        )

    return (
        DifferentiabilityInfo(
            name=foreach_function.func.name.name.base,
            func=foreach_function,
            op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
            derivatives=modified_derivative_formulas,
            forward_derivatives=forward_derivatives,
            all_saved_inputs=tuple(set(all_saved_inputs)),
            all_saved_outputs=tuple(set(all_saved_outputs)),
            available_named_gradients=(),
            used_named_gradients=set(),
            args_with_derivatives=args_with_derivatives,
            non_differentiable_arg_names=[],
            output_differentiability=None,
            output_differentiability_conditions=None,
        ),
        True,
    )


def match_differentiability_info(
    native_functions: List[NativeFunction],
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
) -> List[NativeFunctionWithDifferentiabilityInfo]:
    """Sets the "derivative" key on declarations to matching autograd function
    In-place functions will use the out-of-place derivative definition if there
    is no in-place specific derivative.
    """

    functional_info_by_signature = {
        schema.signature(strip_default=True): info_dict
        for schema, info_dict in differentiability_infos.items()
        if schema.kind() == SchemaKind.functional
    }
    non_functional_info_by_signature = {
        schema.signature(strip_default=True): info_dict
        for schema, info_dict in differentiability_infos.items()
        if schema.kind() != SchemaKind.functional
    }

    def find_info(
        f: NativeFunction,
    ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
        # Don't bother matching info to generated out= variants
        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
            return None, False

        # (1) Check for an exact match
        if f.func in differentiability_infos:
            return differentiability_infos[f.func], True

        # (2) If no exact match, check if the out-of-place variant
        # of this operator has a match.
        # i.e mul() for mul_() or mul_out()
        # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
        # native functions instead of the out-place counterparts.
        f_sig = f.func.signature(strip_default=True)
        if f_sig in functional_info_by_signature and not is_foreach_func(f):
            return functional_info_by_signature[f_sig], False

        # (3) Some operators have a derivative explicitly defined for the mutable
        # variant, but get a code-generated out-of-place variant which does *not*
        # come with a derivative formula.
        # For the generated out-of-place variant, use the mutable variant's formula
        # if it exists.
        if "generated" in f.tags and f_sig in non_functional_info_by_signature:
            info_dict = non_functional_info_by_signature[f_sig]
            # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
            assert not any(
                any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
                for info in info_dict.values()
            ), f"""\
Attempted to convert a derivative formula for a mutable operator
 to be used by automatically by its functional variant ("{str(f.func)}").
 this is not currently supported (we'd need to fix up the formula in the codegen)."""
            return info_dict, False

        # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
        if is_foreach_func(f):
            assert f.func not in differentiability_infos
            diff_info, is_generated = gen_foreach_derivativeinfo(
                f,
                functional_info_by_signature,
                non_functional_info_by_signature,
            )
            if diff_info is None:
                return None, False
            # TODO(crcrpar): Avoid hard coding "Default" ideally.
            diff_info_dict = {"Default": diff_info}
            if is_generated:
                differentiability_infos[f.func] = diff_info_dict
                functional_info_by_signature[f.func] = diff_info_dict
            return diff_info_dict, is_generated

        return None, False

    result: List[NativeFunctionWithDifferentiabilityInfo] = []
    for f in native_functions:
        info_dict, is_exact_match = find_info(f)

        # Currently, the '.strides()' to 'strides_or_error' replacement does not support
        # 'self' derivatives of an inplace function, so we must check for this case.
        if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
            for info in info_dict.values():
                for derivative in info.derivatives:
                    if "self" in derivative.var_names:
                        for saved_input in derivative.saved_inputs:
                            assert "strides_or_error" not in saved_input.expr, (
                                "Calling '.strides()' in the 'self' derivative formula of an "
                                f"in-place function is not supported: {f.func}"
                            )

        if not info_dict:
            result.append(
                NativeFunctionWithDifferentiabilityInfo(
                    func=f, info=None, fw_derivatives=None
                )
            )
            continue

        fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
        for key, info in info_dict.items():
            if not info.forward_derivatives:
                fw_derivative_dict[key] = []
                continue

            forward_derivatives = info.forward_derivatives

            # For functions that have a single def for out-of-place and inplace (like abs())
            if f.func.kind() == SchemaKind.inplace:
                # For inplace functions there is a little bit of work to do:
                #  1) Validate the formula and make sure the input that is modified in not used:
                #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then
                #      we make sure that the original value of the input that is being modified inplace (self_p) is
                #      not used in the formula. Note that the formula can use "original_self_p" here and that would
                #      trigger a clone of the original input.
                #    - If we are re-using the out of place formula (is_exact_match == False) then we replace every
                #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be
                #      populated by cloned version of the original input (either the clone done by the backward AD
                #      logic if self is also used in a backward formula or a special clone that we add).
                #  2) At this point, there cannot be a self_p in the formula.
                #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
                #     simply called self (as it is modified inplace).
                #  4) Update the required primals data in case it used to contain "result" but should now contain
                #     "self"
                #  5) If it is not an exact match, the user formula is not modifying the existing forward grad
                #     inplace as it should. So add some code that makes sure that we do so if the forward grad
                #     already exists.

                assert (
                    len(info.forward_derivatives) == 1
                )  # Only single output inplace should exist
                fw_info = info.forward_derivatives[0]
                formula = fw_info.formula

                def replace_self_with_original_self(formula: str, postfix: str) -> str:
                    def repl(m: Match[str]) -> str:
                        return f"{m.group(1)}original_self{postfix}{m.group(2)}"

                    return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)

                if re.search(IDENT_REGEX.format("self_p"), formula):
                    if is_exact_match:
                        # For manually defined formulas, don't allow the original value to be used
                        raise RuntimeError(
                            f'The formula for "{f.func.name}" is using the original value of self '
                            "that is being modified inplace. This would lead to wrong forward gradients. "
                            'Please use "result" in the formula only.'
                        )
                    else:
                        # When the original formula is out of place, we save a clone of the primal
                        # value to be able to access this value if needed
                        # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
                        formula = replace_self_with_original_self(formula, "_p")
                        formula = replace_self_with_original_self(formula, "_t")

                # replace "result" from the formula by "self_p"
                def repl(m: Match[str]) -> str:
                    return f"{m.group(1)}self_p{m.group(2)}"

                formula = re.sub(IDENT_REGEX.format("result"), repl, formula)

                required_primals = fw_info.required_inputs_primal
                if re.search(IDENT_REGEX.format("self_p"), formula):
                    required_primals = (
                        required_primals + ("self",) if required_primals else ("self",)
                    )

                if not is_exact_match:
                    # NOTE [In-place forward AD formula Optimization]
                    #
                    # This optimization transforms the formula to directly do inplace, i.e.
                    # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
                    #
                    # 1) the formula satisfies the pattern: "self_t.op(*args)"
                    # 2) "op" in (1) needs to be the same as the op the derivative is for
                    #
                    # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
                    # If there is a need, we can relax (2) to allow any op that has an in-place variant
                    is_single_method_on_self_t = False
                    directly_do_inplace = False
                    op_name: Optional[str] = None
                    between_parens: Optional[str] = None
                    match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
                    if match:
                        op_name, between_parens = match.group(1), match.group(2)

                        # We want to...
                        #   Match: self_t.op1(other_p.op2(arg))
                        #   Avoid: self_t.op1(args) + self_t.op2(args)
                        #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
                        def check_parens_nest_level_gt_zero(s: str) -> bool:
                            level = 1
                            for ch in s:
                                if ch == ")":
                                    level -= 1
                                    if level == 0:
                                        return False
                                if ch == "(":
                                    level += 1
                            return True

                        is_single_method_on_self_t = check_parens_nest_level_gt_zero(
                            between_parens
                        )
                        directly_do_inplace = (
                            is_single_method_on_self_t and op_name == info.name
                        )

                    if directly_do_inplace:
                        assert op_name is not None
                        assert between_parens is not None
                        formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
                    else:
                        # Make sure that the forward grad is modified inplace when the original formula
                        # is out of place
                        formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"

                required_original_self_value = bool(
                    re.search(IDENT_REGEX.format("original_self_p"), formula)
                ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))

                forward_derivatives = [
                    ForwardDerivative(
                        formula=formula,
                        var_names=("self",),
                        var_types=fw_info.var_types,
                        required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
                        required_inputs_primal=required_primals,
                        required_original_self_value=required_original_self_value,
                        is_reusing_outplace_formula=not is_exact_match,
                    ),
                ]

            fw_derivative_dict[key] = forward_derivatives

        result.append(
            NativeFunctionWithDifferentiabilityInfo(
                func=f, info=info_dict, fw_derivatives=fw_derivative_dict
            )
        )

    return result


def is_differentiable(
    name: str, type: Type, info: Optional[DifferentiabilityInfo]
) -> bool:
    return type.is_tensor_like() and (
        info is None or name not in info.non_differentiable_arg_names
    )


def gen_differentiable_outputs(
    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[DifferentiableOutput]:
    f = fn.func
    info = fn.info[key] if fn.info else None
    outputs: List[DifferentiableOutput] = [
        DifferentiableOutput(
            name=name,
            type=ret.type,
            cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
        )
        for name, ret in zip(cpp.return_names(f), f.func.returns)
    ]
    output_differentiability = info.output_differentiability if info else None
    if output_differentiability is not None:
        if len(output_differentiability) != len(outputs):
            raise RuntimeError(
                f"The length of output_differentiability ({len(output_differentiability)}), "
                f"does not match the number of outputs ({len(outputs)})."
            )
        differentiable_outputs: List[DifferentiableOutput] = []
        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
            raise RuntimeError(
                "output_differentiability=False for inplace operation (version_counter won't get updated)"
            )
        for differentiable, output in zip(output_differentiability, outputs):
            if differentiable:
                differentiable_outputs.append(output)
        return differentiable_outputs
    candidate_differentiable_outputs = list(
        filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
    )
    if uses_single_grad(info):
        return candidate_differentiable_outputs[:1]
    else:
        return candidate_differentiable_outputs