242
242
< div class ="pytorch-left-menu-search ">
243
243
244
244
< div class ="version ">
245
- < a href ='https://pytorch.org/docs/versions.html '> main (2.2.0a0+git5a96a42 ) ▼</ a >
245
+ < a href ='https://pytorch.org/docs/versions.html '> main (2.2.0a0+gite7f12b1 ) ▼</ a >
246
246
</ div >
247
247
248
248
@@ -529,6 +529,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
529
529
< span class ="s1 "> 'set_warn_always'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'is_warn_always_enabled'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymInt'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymFloat'</ span > < span class ="p "> ,</ span >
530
530
< span class ="s1 "> 'SymBool'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_not'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'unravel_index'</ span > < span class ="p "> ,</ span >
531
531
< span class ="s1 "> 'sym_int'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_float'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_max'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_min'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_ite'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'compile'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'vmap'</ span > < span class ="p "> ,</ span >
532
+ < span class ="s1 "> 'sym_sqrt'</ span > < span class ="p "> ,</ span >
532
533
< span class ="s1 "> 'export'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'autocast'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'cond'</ span > < span class ="p "> ,</ span >
533
534
< span class ="p "> ]</ span >
534
535
@@ -887,8 +888,15 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
887
888
< span class ="sd "> Args:</ span >
888
889
< span class ="sd "> a (SymBool or bool): Object to negate</ span >
889
890
< span class ="sd "> """</ span >
891
+ < span class ="kn "> import</ span > < span class ="nn "> sympy</ span >
892
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
893
+
894
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
895
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_not</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
890
896
< span class ="k "> if</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s1 "> '__sym_not__'</ span > < span class ="p "> ):</ span >
891
897
< span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_not__</ span > < span class ="p "> ()</ span >
898
+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> sympy</ span > < span class ="o "> .</ span > < span class ="n "> Basic</ span > < span class ="p "> ):</ span >
899
+ < span class ="k "> return</ span > < span class ="o "> ~</ span > < span class ="n "> a</ span > < span class ="c1 "> # type: ignore[operator]</ span >
892
900
< span class ="k "> return</ span > < span class ="ow "> not</ span > < span class ="n "> a</ span > </ div >
893
901
894
902
< div class ="viewcode-block " id ="sym_float "> < a class ="viewcode-back " href ="../generated/torch.sym_float.html#torch.sym_float "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_float</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
@@ -897,6 +905,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
897
905
< span class ="sd "> Args:</ span >
898
906
< span class ="sd "> a (SymInt, SymFloat, or object): Object to cast</ span >
899
907
< span class ="sd "> """</ span >
908
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
909
+
910
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
911
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_float</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
900
912
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> ):</ span >
901
913
< span class ="k "> return</ span > < span class ="n "> a</ span >
902
914
< span class ="k "> elif</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s1 "> '__sym_float__'</ span > < span class ="p "> ):</ span >
@@ -910,6 +922,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
910
922
< span class ="sd "> Args:</ span >
911
923
< span class ="sd "> a (SymInt, SymFloat, or object): Object to cast</ span >
912
924
< span class ="sd "> """</ span >
925
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
926
+
927
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
928
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_int</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
913
929
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymInt</ span > < span class ="p "> ):</ span >
914
930
< span class ="k "> return</ span > < span class ="n "> a</ span >
915
931
< span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> ):</ span >
@@ -918,6 +934,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
918
934
919
935
< div class ="viewcode-block " id ="sym_max "> < a class ="viewcode-back " href ="../generated/torch.sym_max.html#torch.sym_max "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_max</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ):</ span >
920
936
< span class ="w "> </ span > < span class ="sd "> """ SymInt-aware utility for max()."""</ span >
937
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
938
+
939
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )):</ span >
940
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_max</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ),</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
921
941
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
922
942
< span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_max__</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
923
943
< span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
@@ -929,13 +949,31 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
929
949
930
950
< div class ="viewcode-block " id ="sym_min "> < a class ="viewcode-back " href ="../generated/torch.sym_min.html#torch.sym_min "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_min</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ):</ span >
931
951
< span class ="w "> </ span > < span class ="sd "> """ SymInt-aware utility for max()."""</ span >
952
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
953
+
954
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )):</ span >
955
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_min</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ),</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
932
956
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
933
957
< span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_min__</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
934
958
< span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
935
959
< span class ="k "> return</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> __sym_min__</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
936
960
< span class ="k "> return</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> min</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span > < span class ="c1 "> # type: ignore[operator]</ span > </ div >
937
961
962
+ < span class ="c1 "> # Drop in replacement for math.sqrt</ span >
963
+ < span class ="k "> def</ span > < span class ="nf "> sym_sqrt</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
964
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
965
+
966
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
967
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_sqrt</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
968
+ < span class ="k "> if</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s2 "> "__sym_sqrt__"</ span > < span class ="p "> ):</ span >
969
+ < span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_sqrt__</ span > < span class ="p "> ()</ span >
970
+ < span class ="k "> return</ span > < span class ="n "> math</ span > < span class ="o "> .</ span > < span class ="n "> sqrt</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
971
+
938
972
< div class ="viewcode-block " id ="sym_ite "> < a class ="viewcode-back " href ="../generated/torch.sym_ite.html#torch.sym_ite "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_ite</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> ):</ span >
973
+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
974
+
975
+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )):</ span >
976
+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_ite</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> ),</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
939
977
< span class ="k "> assert</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymBool</ span > < span class ="p "> ,</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ))</ span > < span class ="ow "> and</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
940
978
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> SymBool</ span > < span class ="p "> ):</ span >
941
979
< span class ="k "> return</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> __sym_ite__</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
0 commit comments