Dragon - C++ API
A Computation Graph Virtual Machine Based Deep Learning Framework
op_kernel.h
Go to the documentation of this file.
1 
13 #ifndef DRAGON_UTILS_OP_KERNEL_H_
14 #define DRAGON_UTILS_OP_KERNEL_H_
15 
16 #include "core/context.h"
17 
18 namespace dragon {
19 
20 class Tensor;
21 
22 namespace kernel {
23 
26 template <typename T, class Context>
27 void Dropout(
28  const int count,
29  const float prob,
30  const float scale,
31  const T* x,
32  uint32_t* mask32,
33  uint8_t* mask8,
34  T* y,
35  Context* ctx);
36 
37 template <typename Tx, typename Tm, class Context>
38 void ApplyMask(
39  const int count,
40  const float scale,
41  const Tx* x,
42  const Tm* mask,
43  Tx* y,
44  Context* ctx);
45 
48 template <typename T, class Context>
49 void DropPath(
50  const int rows,
51  const int cols,
52  const float scale,
53  const T* x,
54  const float* mask,
55  T* y,
56  Context* ctx);
57 
60 template <typename T, class Context>
61 void Elu(
62  const int count,
63  const float alpha,
64  const T* x,
65  T* y,
66  Context* ctx);
67 
68 template <typename T, class Context>
69 void EluGrad(
70  const int count,
71  const float alpha,
72  const T* dy,
73  const T* y,
74  T* dx,
75  Context* ctx);
76 
79 template <typename T, class Context>
80 void PRelu(
81  const int count,
82  const int channels,
83  const int dim,
84  const bool channel_shared,
85  const string& data_format,
86  const T* x,
87  const T* w,
88  T* y,
89  Context* ctx);
90 
91 template <typename T, class Context>
92 void PReluGrad(
93  const int count,
94  const int channels,
95  const int dim,
96  const bool channel_shared,
97  const string& data_format,
98  const T* dy,
99  const T* x,
100  const T* w,
101  T* dx,
102  Context* ctx);
103 
104 template <typename T, class Context>
105 void PReluWGrad(
106  const int rows,
107  const int row_offset,
108  const int channels,
109  const int dim,
110  const bool channel_shared,
111  const string& data_format,
112  const T* dy,
113  const T* x,
114  const T* multiplier,
115  T* bcast_dw,
116  T* dw,
117  Context* ctx);
118 
121 template <typename T, class Context>
122 void Relu(
123  const int count,
124  const float slope,
125  const T* x,
126  T* y,
127  Context* ctx);
128 
129 template <typename T, class Context>
130 void ReluGrad(
131  const int count,
132  const float slope,
133  const T* dy,
134  const T* y,
135  T* dx,
136  Context* ctx);
137 
140 template <typename T, class Context>
141 void SElu(
142  const int count,
143  const T* x,
144  T* y,
145  Context* ctx);
146 
147 template <typename T, class Context>
148 void SEluGrad(
149  const int count,
150  const T* dy,
151  const T* y,
152  T* dx,
153  Context* ctx);
154 
157 template <typename T, class Context>
158 void Sigmoid(
159  const int count,
160  const T* x,
161  T* y,
162  Context* ctx);
163 
164 template <typename T, class Context>
165 void SigmoidGrad(
166  const int count,
167  const T* dy,
168  const T* y,
169  T* dx,
170  Context* ctx);
171 
174 template <typename T, class Context>
175 void Softmax(
176  const int outer_dim,
177  const int axis_dim,
178  const int inner_dim,
179  const T* multiplier,
180  const T* x,
181  T* scale,
182  T* y,
183  Context* ctx);
184 
185 template <typename T, class Context>
186 void SoftmaxGrad(
187  const int outer_dim,
188  const int axis_dim,
189  const int inner_dim,
190  const T* multiplier,
191  const T* dy,
192  const T* y,
193  T* scale,
194  T* dx,
195  Context* ctx);
196 
199 template <typename T, class Context>
200 void Tanh(
201  const int count,
202  const T* x,
203  T* y,
204  Context* ctx);
205 
206 template <typename T, class Context>
207 void TanhGrad(
208  const int count,
209  const T* dy,
210  const T* y,
211  T* dx,
212  Context* ctx);
213 
216 template <typename T, class Context>
217 void Affine(
218  const int outer_dim,
219  const int axis_dim,
220  const int inner_dim,
221  const T* x,
222  const T* alpha,
223  const T* beta,
224  T* y,
225  Context* ctx);
226 
227 template <typename T, class Context>
228 void AffineGrad(
229  const int outer_dim,
230  const int axis_dim,
231  const int inner_dim,
232  const T* dy,
233  const T* alpha,
234  T* dx,
235  Context* ctx);
236 
239 template <typename T, class Context>
240 void Clip(
241  const int count,
242  const float low,
243  const float high,
244  const T* x,
245  T* y,
246  Context* ctx);
247 
248 template <typename T, class Context>
249 void ClipGrad(
250  const int count,
251  const float low,
252  const float high,
253  const T* x,
254  const T* dy,
255  T* dx,
256  Context* ctx);
257 
260 template <typename T, class Context>
261 void Maximum(
262  const int count,
263  const T* a,
264  const T* b,
265  T* y,
266  Context* ctx);
267 
268 template <typename T, class Context>
269 void BroadcastMaximum(
270  const int count,
271  const T* a,
272  const T b,
273  T* y,
274  Context* ctx);
275 
276 template <typename T, class Context>
277 void MaximumGrad(
278  const int count,
279  const T* a,
280  const T* b,
281  const T* dy,
282  T* da,
283  T* db,
284  Context* ctx);
285 
286 template <typename T, class Context>
288  const int count,
289  const T* a,
290  const T b,
291  const T* dy,
292  T* da,
293  T* db,
294  Context* ctx);
295 
298 template <typename T, class Context>
299 void Minimum(
300  const int count,
301  const T* a,
302  const T* b,
303  T* y,
304  Context* ctx);
305 
306 template <typename T, class Context>
307 void BroadcastMinimum(
308  const int count,
309  const T* a,
310  const T b,
311  T* y,
312  Context* ctx);
313 
314 template <typename T, class Context>
315 void MinimumGrad(
316  const int count,
317  const T* a,
318  const T* b,
319  const T* dy,
320  T* da,
321  T* db,
322  Context* ctx);
323 
324 template <typename T, class Context>
326  const int count,
327  const T* a,
328  const T b,
329  const T* dy,
330  T* da,
331  T* db,
332  Context* ctx);
333 
336 template <typename Tx, typename Ty, class Context>
337 void Moments(
338  const int ndims,
339  const int* dims,
340  const int naxes,
341  const int* axes,
342  const Tx* x,
343  Ty* mean,
344  Ty* var,
345  Context* ctx);
346 
349 template <typename T, class Context>
350 void Arange(
351  const int count,
352  const int start,
353  const int step,
354  T* y,
355  Context* ctx);
356 
359 template <typename T, class Context>
360 void ArgMax(
361  const int outer_dim,
362  const int inner_dim,
363  const int axis_dim,
364  const int top_k,
365  const T* x,
366  int64_t* indices,
367  T* values,
368  Context* ctx);
369 
370 template <typename T, class Context>
371 void ArgMin(
372  const int outer_dim,
373  const int inner_dim,
374  const int axis_dim,
375  const int top_k,
376  const T* x,
377  int64_t* indices,
378  T* values,
379  Context* ctx);
380 
383 template <typename T, class Context>
384 void ChannelShuffle(
385  const int outer_dim,
386  const int inner_dim,
387  const int axis_dim,
388  const int group,
389  const T* x,
390  T* y,
391  Context* ctx);
392 
395 template <typename T, class Context>
396 void Concat(
397  const int outer_dim,
398  const int inner_dim,
399  const int axis_dim,
400  const int cat_dim,
401  const int cat_ofs,
402  const T* x,
403  T* y,
404  Context* ctx);
405 
408 template <typename T, class Context>
409 void Crop(
410  const int count,
411  const int ndims,
412  const int* x_strides,
413  const int* y_dims,
414  const int* starts,
415  const T* x,
416  T* y,
417  Context* ctx);
418 
419 template <typename T, class Context>
420 void CropGrad(
421  const int count,
422  const int ndims,
423  const int* x_strides,
424  const int* y_dims,
425  const int* starts,
426  const T* dy,
427  T* dx,
428  Context* ctx);
429 
432 template <typename T, class Context>
433 void IndexSelect(
434  const int outer_dim,
435  const int inner_dim,
436  const int axis_dim,
437  const int num_indices,
438  const int64_t* indices,
439  const T* x,
440  T* y,
441  Context* ctx);
442 
443 template <typename T, class Context>
444 void IndexSelectGrad(
445  const int outer_dim,
446  const int inner_dim,
447  const int axis_dim,
448  const int num_indices,
449  const int64_t* indices,
450  const T* dy,
451  T* dx,
452  Context* ctx);
453 
456 template <typename T, class Context>
457 void MaskedSelect(
458  const int count,
459  const uint8_t* mask,
460  const T* x,
461  Tensor* indices,
462  Tensor* scratch,
463  Tensor* y,
464  Context* ctx);
465 
466 template <typename T, class Context>
467 void MaskedSelectGrad(
468  const int count,
469  const int num_indices,
470  const int64_t* indices,
471  const T* dy,
472  T* dx,
473  Context* ctx);
474 
477 template <class Context>
478 void UnravelIndex(
479  const int count,
480  const int ndims,
481  const int* dims,
482  const int64_t* x,
483  int64_t* y,
484  Context* ctx);
485 
488 template <typename T, class Context>
489 void ConstPad(
490  const int count,
491  const int ndims,
492  const int* x_dims,
493  const int* x_strides,
494  const int* y_dims,
495  const int* l_pads,
496  const float value,
497  const T* x,
498  T* y,
499  Context* ctx);
500 
501 template <typename T, class Context>
502 void ReflectPad(
503  const int count,
504  const int ndims,
505  const int* x_dims,
506  const int* x_strides,
507  const int* y_dims,
508  const int* l_pads,
509  const T* x,
510  T* y,
511  Context* ctx);
512 
513 template <typename T, class Context>
514 void EdgePad(
515  const int count,
516  const int ndims,
517  const int* x_dims,
518  const int* x_strides,
519  const int* y_dims,
520  const int* l_pads,
521  const T* x,
522  T* y,
523  Context* ctx);
524 
527 template <typename T, class Context>
528 void OneHot(
529  const int count,
530  const int depth,
531  const int on_value,
532  const T* x,
533  T* y,
534  Context* ctx);
535 
538 template <typename T, class Context>
539 void ReduceSum(
540  const int ndims,
541  const int* dims,
542  const int naxes,
543  const int* axes,
544  const float scale,
545  const T* x,
546  T* y,
547  Context* ctx);
548 
549 template <typename T, class Context>
550 void ReduceSumGrad(
551  const int count,
552  const int ndims,
553  const int* x_dims,
554  const int* y_dims,
555  const int* y_strides,
556  const float scale,
557  const T* dy,
558  T* dx,
559  Context* ctx);
560 
563 template <typename T, class Context>
564 void Repeat(
565  const int outer_dim,
566  const int inner_dim,
567  const int axis_dim,
568  const int repeats,
569  const T* x,
570  T* y,
571  Context* ctx);
572 
573 template <typename T, class Context>
574 void RepeatGrad(
575  const int outer_dim,
576  const int inner_dim,
577  const int axis_dim,
578  const int repeats,
579  const T* dy,
580  T* dx,
581  Context* ctx);
582 
585 template <typename T, class Context>
586 void Slice(
587  const int outer_dim,
588  const int inner_dim,
589  const int axis_dim,
590  const int slice_dim,
591  const int slice_ofs,
592  const T* x,
593  T* y,
594  Context* ctx);
595 
596 template <typename T, class Context>
597 void SliceGrad(
598  const int outer_dim,
599  const int inner_dim,
600  const int axis_dim,
601  const int slice_dim,
602  const int slice_ofs,
603  const T* dy,
604  T* x,
605  Context* ctx);
606 
609 template <typename T, class Context>
610 void Tile(
611  const int count,
612  const int ndims,
613  const int* x_dims,
614  const int* x_strides,
615  const int* y_dims,
616  const T* x,
617  T* y,
618  Context* ctx);
619 
620 template <typename T, class Context>
621 void TileGrad(
622  const int rows,
623  const int cols,
624  const int multiple,
625  const T* dy,
626  T* dx,
627  Context* ctx);
628 
631 template <typename T, class Context>
632 void Transpose(
633  const int count,
634  const int ndims,
635  const int* x_strides,
636  const int* y_dims,
637  const T* x,
638  T* y,
639  Context* ctx);
640 
641 template <typename T, class Context>
642 void TransposeGrad(
643  const int count,
644  const int ndims,
645  const int* x_strides,
646  const int* y_dims,
647  const T* dy,
648  T* dx,
649  Context* ctx);
650 
653 template <typename T, class Context>
654 void Where(
655  const int count,
656  const uint8_t* mask,
657  const T* a,
658  const T* b,
659  T* y,
660  Context* ctx);
661 
662 template <typename T, class Context>
663 void WhereGrad(
664  const int count,
665  const uint8_t* mask,
666  const T* dy,
667  T* da,
668  T* db,
669  Context* ctx);
670 
673 template <typename T, class Context>
674 void Assign(
675  const int count,
676  const int ndims,
677  const int* x_dims,
678  const int* y_strides,
679  const int* starts,
680  const T* x,
681  T* y,
682  Context* ctx);
683 
686 template <typename T, class Context>
687 void NotZero(
688  const int count,
689  const T* x,
690  bool* y,
691  Context* ctx);
692 
693 template <typename T, class Context>
694 void Equal(
695  const int count,
696  const T* a,
697  const T* b,
698  bool* y,
699  Context* ctx);
700 
701 template <typename T, class Context>
702 void NotEqual(
703  const int count,
704  const T* a,
705  const T* b,
706  bool* y,
707  Context* ctx);
708 
709 template <typename T, class Context>
710 void Less(
711  const int count,
712  const T* a,
713  const T* b,
714  bool* y,
715  Context* ctx);
716 
717 template <typename T, class Context>
718 void LessEqual(
719  const int count,
720  const T* a,
721  const T* b,
722  bool* y,
723  Context* ctx);
724 
725 template <typename T, class Context>
726 void Greater(
727  const int count,
728  const T* a,
729  const T* b,
730  bool* y,
731  Context* ctx);
732 
733 template <typename T, class Context>
734 void GreaterEqual(
735  const int count,
736  const T* a,
737  const T* b,
738  bool* y,
739  Context* ctx);
740 
743 template <typename T, class Context>
744 void AbsGrad(
745  const int count,
746  const T* dy,
747  T* dx,
748  Context* ctx);
749 
752 template <typename Tx, typename Ty, class Context>
753 void NLLLoss(
754  const int outer_dim,
755  const int axis_dim,
756  const int inner_dim,
757  const int nignores,
758  const int* ignore,
759  const Tx* log_prob,
760  const Ty* target,
761  Tx* loss,
762  int* flag,
763  Context* ctx);
764 
765 template <typename Tx, typename Ty, class Context>
766 void NLLLossGrad(
767  const int outer_dim,
768  const int axis_dim,
769  const int inner_dim,
770  const int nignores,
771  const int* ignore,
772  const Tx* log_prob,
773  const Ty* target,
774  Tx* dx,
775  int* flag,
776  Context* ctx);
777 
780 template <typename T, class Context>
782  const int count,
783  const T* logit,
784  const T* target,
785  T* loss,
786  int* flag,
787  Context* ctx);
788 
789 template <typename T, class Context>
791  const int count,
792  const T* logit,
793  const T* target,
794  T* dlogit,
795  int* flag,
796  Context* ctx);
797 
800 template <typename Tx, typename Ty, class Context>
801 void SigmoidFocalLoss(
802  const int outer_dim,
803  const int axis_dim,
804  const int inner_dim,
805  const float pos_alpha,
806  const float neg_alpha,
807  const float gamma,
808  const int neg_id,
809  const Tx* logit,
810  const Ty* target,
811  Tx* loss,
812  int* flag,
813  Context* ctx);
814 
815 template <typename Tx, typename Ty, class Context>
817  const int outer_dim,
818  const int axis_dim,
819  const int inner_dim,
820  const float pos_alpha,
821  const float neg_alpha,
822  const float gamma,
823  const int neg_id,
824  const Tx* logit,
825  const Ty* target,
826  Tx* dlogit,
827  int* flag,
828  Context* ctx);
829 
832 template <typename T, class Context>
833 void SmoothL1(
834  const int count,
835  const float beta,
836  const T* x,
837  T* y,
838  Context* ctx);
839 
840 template <typename T, class Context>
841 void SmoothL1Grad(
842  const int count,
843  const float beta,
844  const T* dy,
845  T* dx,
846  Context* ctx);
847 
850 template <typename T, class Context>
852  const int count,
853  const T* prob,
854  const T* targets,
855  T* losses,
856  Context* ctx);
857 
860 template <typename Tx, typename Ty, class Context>
861 void SoftmaxFocalLoss(
862  const int outer_dim,
863  const int axis_dim,
864  const int inner_dim,
865  const float pos_alpha,
866  const float neg_alpha,
867  const float gamma,
868  const int neg_id,
869  const int nignores,
870  const int* ignores,
871  const Tx* prob,
872  const Ty* labels,
873  Tx* losses,
874  int* flags,
875  Context* ctx);
876 
877 template <typename Tx, typename Ty, class Context>
879  const int outer_dim,
880  const int axis_dim,
881  const int inner_dim,
882  const float pos_alpha,
883  const float neg_alpha,
884  const float gamma,
885  const int neg_id,
886  const int nignores,
887  const int* ignores,
888  const Tx* prob,
889  const Ty* labels,
890  Tx* dx,
891  int* flags,
892  Context* ctx);
893 
896 template <typename Tx, typename Ty, class Context>
898  const int outer_dim,
899  const int axis_dim,
900  const int inner_dim,
901  const int nignores,
902  const int* ignore,
903  const Tx* prob,
904  const Ty* target,
905  Tx* loss,
906  int* flag,
907  Context* ctx);
908 
909 template <typename Tx, typename Ty, class Context>
911  const int outer_dim,
912  const int axis_dim,
913  const int inner_dim,
914  const int nignores,
915  const int* ignore,
916  const Tx* prob,
917  const Ty* target,
918  Tx* dx,
919  int* flag,
920  Context* ctx);
921 
924 template <typename Ta, typename Tb, class Context>
925 void TypeA2B(
926  const int count,
927  const Ta* a,
928  Tb* b,
929  Context* ctx);
930 
933 template <typename T, class Context>
934 void GradientTwoSum(
935  const int count,
936  const T* dy1,
937  const T* dy2,
938  T* dx,
939  Context* ctx);
940 
943 template <typename Tx, typename Ty, class Context>
944 void ImageData(
945  const int N,
946  const int C,
947  const int H,
948  const int W,
949  const string& data_format,
950  const float* mean,
951  const float* std,
952  const Tx* x,
953  Ty* y,
954  Context* ctx);
955 
958 template <typename Tx, typename Tp, class Context>
960  const int N,
961  const int C,
962  const int S,
963  const string& data_format,
964  const Tx* x,
965  const Tp* mu,
966  const Tp* rsig,
967  const Tp* gamma,
968  const Tx* dy,
969  Tp* ds,
970  Tp* db,
971  Tx* dx,
972  Tp* dgamma,
973  Tp* dbeta,
974  Context* ctx);
975 
976 template <typename Tx, typename Tp, class Context>
978  const int N,
979  const int C,
980  const int S,
981  const string& data_format,
982  const Tx* x,
983  const Tp* mu,
984  const Tp* rsig,
985  const Tp* gamma,
986  const Tx* dy,
987  Tx* dx,
988  Tp* dgamma,
989  Tp* dbeta,
990  Context* ctx);
991 
994 template <typename Tx, typename Tp, class Context>
995 void GroupNormForward(
996  const int N,
997  const int G,
998  const int D,
999  const int S,
1000  const string& data_format,
1001  const Tx* x,
1002  const Tp* mu,
1003  const Tp* rsig,
1004  const Tp* gamma,
1005  const Tp* beta,
1006  Tp* scale,
1007  Tp* bias,
1008  Tx* y,
1009  Context* ctx);
1010 
1011 template <typename Tx, typename Tp, class Context>
1012 void GroupNormBackward(
1013  const int N,
1014  const int G,
1015  const int D,
1016  const int S,
1017  const string& data_format,
1018  const Tx* x,
1019  const Tp* mu,
1020  const Tp* rsig,
1021  const Tp* gamma,
1022  const Tx* dy,
1023  Tp* ds,
1024  Tp* db,
1025  Tx* dx,
1026  Tp* dgamma,
1027  Tp* dbeta,
1028  Context* ctx);
1029 
1032 template <typename T, class Context>
1033 void LSTMCell(
1034  const int N,
1035  const int C,
1036  const T* cx,
1037  T* actx,
1038  T* c,
1039  T* h,
1040  Context* ctx);
1041 
1042 template <typename T, class Context>
1043 void LSTMCellGrad(
1044  const int N,
1045  const int C,
1046  const T* cx,
1047  const T* actx,
1048  const T* c,
1049  const T* dc,
1050  const T* dh,
1051  T* dcx,
1052  T* dx,
1053  Context* ctx);
1054 
1057 template <typename T, class Context>
1058 void AdamUpdate(
1059  const int count,
1060  const float lr,
1061  const float beta1,
1062  const float beta2,
1063  const float eps,
1064  T* g,
1065  T* m,
1066  T* v,
1067  Context* ctx);
1068 
1071 template <typename T, class Context>
1072 void NesterovUpdate(
1073  const int count,
1074  const float lr,
1075  const float momentum,
1076  T* g,
1077  T* h,
1078  Context* ctx);
1079 
1082 template <typename T, class Context>
1083 void RMSPropUpdate(
1084  const int count,
1085  const float lr,
1086  const float decay,
1087  const float eps,
1088  T* g,
1089  T* h,
1090  Context* ctx);
1091 
1094 template <typename T, class Context>
1095 void SGDUpdate(
1096  const int count,
1097  const float lr,
1098  const float momentum,
1099  T* g,
1100  T* h,
1101  Context* ctx);
1102 
1105 template <typename T, class Context>
1106 void MixedPrecL2Decay(
1107  const int count,
1108  const float alpha,
1109  const T* w,
1110  float* dx,
1111  Context* ctx);
1112 
1113 template <typename T, class Context>
1114 void MixedPrecUpdate(
1115  const int count,
1116  const float* updates,
1117  T* w,
1118  Context* ctx);
1119 
1122 template <typename T, class Context>
1123 void BiasAdd(
1124  const int outer_dim,
1125  const int axis_dim,
1126  const int inner_dim,
1127  const string& data_format,
1128  const T* bias,
1129  const T* multiplier,
1130  T* y,
1131  Context* ctx);
1132 
1135 template <typename T, class Context>
1136 void BilinearResize(
1137  const int N,
1138  const int C,
1139  const int H,
1140  const int W,
1141  const int out_h,
1142  const int out_w,
1143  const string& data_format,
1144  const T* x,
1145  T* y,
1146  Context* ctx);
1147 
1148 template <typename T, class Context>
1149 void BilinearResizeGrad(
1150  const int N,
1151  const int C,
1152  const int H,
1153  const int W,
1154  const int out_h,
1155  const int out_w,
1156  const string& data_format,
1157  const T* dy,
1158  T* dx,
1159  Context* ctx);
1160 
1163 template <typename T, class Context>
1164 void Im2Col2d(
1165  const int C,
1166  const int H,
1167  const int W,
1168  const int out_h,
1169  const int out_w,
1170  const int kernel_h,
1171  const int kernel_w,
1172  const int stride_h,
1173  const int stride_w,
1174  const int pad_h,
1175  const int pad_w,
1176  const int dilation_h,
1177  const int dilation_w,
1178  const string& data_format,
1179  const T* im,
1180  T* col,
1181  Context* ctx);
1182 
1183 template <typename T, class Context>
1184 void Col2Im2d(
1185  const int C,
1186  const int H,
1187  const int W,
1188  const int out_h,
1189  const int out_w,
1190  const int kernel_h,
1191  const int kernel_w,
1192  const int stride_h,
1193  const int stride_w,
1194  const int pad_h,
1195  const int pad_w,
1196  const int dilation_h,
1197  const int dilation_w,
1198  const string& data_format,
1199  const T* col,
1200  T* im,
1201  Context* ctx);
1202 
1205 template <typename T, class Context>
1206 void DepthwiseConv2d(
1207  const int N,
1208  const int C,
1209  const int H,
1210  const int W,
1211  const int out_h,
1212  const int out_w,
1213  const int kernel_h,
1214  const int kernel_w,
1215  const int stride_h,
1216  const int stride_w,
1217  const int pad_h,
1218  const int pad_w,
1219  const int dilation_h,
1220  const int dilation_w,
1221  const string& data_format,
1222  const T* x,
1223  const T* w,
1224  T* y,
1225  Context* ctx);
1226 
1227 template <typename T, class Context>
1228 void DepthwiseConv2dGrad(
1229  const int N,
1230  const int C,
1231  const int H,
1232  const int W,
1233  const int out_h,
1234  const int out_w,
1235  const int kernel_h,
1236  const int kernel_w,
1237  const int stride_h,
1238  const int stride_w,
1239  const int pad_h,
1240  const int pad_w,
1241  const int dilation_h,
1242  const int dilation_w,
1243  const string& data_format,
1244  const T* dy,
1245  const T* d,
1246  T* dx,
1247  Context* ctx);
1248 
1249 template <typename T, class Context>
1251  const int N,
1252  const int C,
1253  const int H,
1254  const int W,
1255  const int out_h,
1256  const int out_w,
1257  const int kernel_h,
1258  const int kernel_w,
1259  const int stride_h,
1260  const int stride_w,
1261  const int pad_h,
1262  const int pad_w,
1263  const int dilation_h,
1264  const int dilation_w,
1265  const string& data_format,
1266  const T* dy,
1267  const T* x,
1268  T* dw,
1269  Context* ctx);
1270 
1273 template <class Context>
1274 void DropBlock2d(
1275  const int N,
1276  const int C,
1277  const int H,
1278  const int W,
1279  const int seed_h,
1280  const int seed_w,
1281  const int block_size,
1282  const float gamma,
1283  const string& data_format,
1284  uint32_t* seed,
1285  int* mask,
1286  Context* ctx);
1287 
1290 template <typename T, class Context>
1291 void NNResize(
1292  const int N,
1293  const int C,
1294  const int H,
1295  const int W,
1296  const int out_h,
1297  const int out_w,
1298  const string& data_format,
1299  const T* x,
1300  T* y,
1301  Context* ctx);
1302 
1303 template <typename T, class Context>
1304 void NNResizeGrad(
1305  const int N,
1306  const int C,
1307  const int H,
1308  const int W,
1309  const int out_h,
1310  const int out_w,
1311  const string& data_format,
1312  const T* dy,
1313  T* dx,
1314  Context* ctx);
1315 
1318 template <typename T, class Context>
1319 void MaxPool2d(
1320  const int N,
1321  const int C,
1322  const int H,
1323  const int W,
1324  const int pool_h,
1325  const int pool_w,
1326  const int kernel_h,
1327  const int kernel_w,
1328  const int stride_h,
1329  const int stride_w,
1330  const int pad_h,
1331  const int pad_w,
1332  const string& data_format,
1333  const T* x,
1334  int* mask,
1335  T* y,
1336  Context* ctx);
1337 
1338 template <typename T, class Context>
1339 void AvgPool2d(
1340  const int N,
1341  const int C,
1342  const int H,
1343  const int W,
1344  const int pool_h,
1345  const int pool_w,
1346  const int kernel_h,
1347  const int kernel_w,
1348  const int stride_h,
1349  const int stride_w,
1350  const int pad_h,
1351  const int pad_w,
1352  const string& data_format,
1353  const T* x,
1354  T* y,
1355  Context* ctx);
1356 
1357 template <typename T, class Context>
1358 void MaxPool2dGrad(
1359  const int N,
1360  const int C,
1361  const int H,
1362  const int W,
1363  const int pool_h,
1364  const int pool_w,
1365  const int kernel_h,
1366  const int kernel_w,
1367  const int stride_h,
1368  const int stride_w,
1369  const int pad_h,
1370  const int pad_w,
1371  const string& data_format,
1372  const T* dy,
1373  const int* mask,
1374  T* dx,
1375  Context* ctx);
1376 
1377 template <typename T, class Context>
1378 void AvgPool2dGrad(
1379  const int N,
1380  const int C,
1381  const int H,
1382  const int W,
1383  const int pool_h,
1384  const int pool_w,
1385  const int kernel_h,
1386  const int kernel_w,
1387  const int stride_h,
1388  const int stride_w,
1389  const int pad_h,
1390  const int pad_w,
1391  const string& data_format,
1392  const T* dy,
1393  T* dx,
1394  Context* ctx);
1395 
1398 template <typename T, class Context>
1399 void ROIPool(
1400  const int C,
1401  const int H,
1402  const int W,
1403  const int pool_h,
1404  const int pool_w,
1405  const int num_rois,
1406  const float spatial_scale,
1407  const T* x,
1408  const float* rois,
1409  int* mask,
1410  T* y,
1411  Context* ctx);
1412 
1413 template <typename T, class Context>
1414 void ROIPoolGrad(
1415  const int N,
1416  const int C,
1417  const int H,
1418  const int W,
1419  const int pool_h,
1420  const int pool_w,
1421  const int num_rois,
1422  const float spatial_scale,
1423  const T* dy,
1424  const T* rois,
1425  const int* mask,
1426  T* dx,
1427  Context* ctx);
1428 
1431 template <typename T, class Context>
1432 void ROIAlign(
1433  const int C,
1434  const int H,
1435  const int W,
1436  const int pool_h,
1437  const int pool_w,
1438  const int num_rois,
1439  const float spatial_scale,
1440  const int sampling_ratio,
1441  const T* x,
1442  const float* rois,
1443  T* y,
1444  Context* ctx);
1445 
1446 template <typename T, class Context>
1447 void ROIAlignGrad(
1448  const int C,
1449  const int H,
1450  const int W,
1451  const int pool_h,
1452  const int pool_w,
1453  const int num_rois,
1454  const float spatial_scale,
1455  const int sampling_ratio,
1456  const T* dy,
1457  const float* rois,
1458  T* dx,
1459  Context* ctx);
1460 
1461 } // namespace kernel
1462 
1463 } // namepsace dragon
1464 
1465 #endif // DRAGON_UTILS_OP_KERNEL_H_
void Where(const int count, const uint8_t *mask, const T *a, const T *b, T *y, Context *ctx)
void EluGrad(const int count, const float alpha, const T *dy, const T *y, T *dx, Context *ctx)
void MinimumGrad(const int count, const T *a, const T *b, const T *dy, T *da, T *db, Context *ctx)
void WhereGrad(const int count, const uint8_t *mask, const T *dy, T *da, T *db, Context *ctx)
void GradientTwoSum(const int count, const T *dy1, const T *dy2, T *dx, Context *ctx)
void DropPath(const int rows, const int cols, const float scale, const T *x, const float *mask, T *y, Context *ctx)
void MaskedSelect(const int count, const uint8_t *mask, const T *x, Tensor *indices, Tensor *scratch, Tensor *y, Context *ctx)
void PRelu(const int count, const int channels, const int dim, const bool channel_shared, const string &data_format, const T *x, const T *w, T *y, Context *ctx)
void SigmoidCrossEntropyGrad(const int count, const T *logit, const T *target, T *dlogit, int *flag, Context *ctx)
void NesterovUpdate(const int count, const float lr, const float momentum, T *g, T *h, Context *ctx)
void Concat(const int outer_dim, const int inner_dim, const int axis_dim, const int cat_dim, const int cat_ofs, const T *x, T *y, Context *ctx)
void TransposeGrad(const int count, const int ndims, const int *x_strides, const int *y_dims, const T *dy, T *dx, Context *ctx)
void BroadcastMinimum(const int count, const T *a, const T b, T *y, Context *ctx)
void TanhGrad(const int count, const T *dy, const T *y, T *dx, Context *ctx)
void TypeA2B(const int count, const Ta *a, Tb *b, Context *ctx)
void ApplyMask(const int count, const float scale, const Tx *x, const Tm *mask, Tx *y, Context *ctx)
void SGDUpdate(const int count, const float lr, const float momentum, T *g, T *h, Context *ctx)
void BatchNormBackwardInference(const int N, const int C, const int S, const string &data_format, const Tx *x, const Tp *mu, const Tp *rsig, const Tp *gamma, const Tx *dy, Tx *dx, Tp *dgamma, Tp *dbeta, Context *ctx)
void SigmoidFocalLossGrad(const int outer_dim, const int axis_dim, const int inner_dim, const float pos_alpha, const float neg_alpha, const float gamma, const int neg_id, const Tx *logit, const Ty *target, Tx *dlogit, int *flag, Context *ctx)
void BilinearResizeGrad(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const string &data_format, const T *dy, T *dx, Context *ctx)
void ROIPoolGrad(const int N, const int C, const int H, const int W, const int pool_h, const int pool_w, const int num_rois, const float spatial_scale, const T *dy, const T *rois, const int *mask, T *dx, Context *ctx)
void Maximum(const int count, const T *a, const T *b, T *y, Context *ctx)
void ArgMin(const int outer_dim, const int inner_dim, const int axis_dim, const int top_k, const T *x, int64_t *indices, T *values, Context *ctx)
void EdgePad(const int count, const int ndims, const int *x_dims, const int *x_strides, const int *y_dims, const int *l_pads, const T *x, T *y, Context *ctx)
void AffineGrad(const int outer_dim, const int axis_dim, const int inner_dim, const T *dy, const T *alpha, T *dx, Context *ctx)
void IndexSelectGrad(const int outer_dim, const int inner_dim, const int axis_dim, const int num_indices, const int64_t *indices, const T *dy, T *dx, Context *ctx)
void SoftmaxGrad(const int outer_dim, const int axis_dim, const int inner_dim, const T *multiplier, const T *dy, const T *y, T *scale, T *dx, Context *ctx)
void Tile(const int count, const int ndims, const int *x_dims, const int *x_strides, const int *y_dims, const T *x, T *y, Context *ctx)
void NNResize(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const string &data_format, const T *x, T *y, Context *ctx)
void Softmax(const int outer_dim, const int axis_dim, const int inner_dim, const T *multiplier, const T *x, T *scale, T *y, Context *ctx)
void MaskedSelectGrad(const int count, const int num_indices, const int64_t *indices, const T *dy, T *dx, Context *ctx)
void OneHot(const int count, const int depth, const int on_value, const T *x, T *y, Context *ctx)
void Elu(const int count, const float alpha, const T *x, T *y, Context *ctx)
void Repeat(const int outer_dim, const int inner_dim, const int axis_dim, const int repeats, const T *x, T *y, Context *ctx)
void Dropout(const int count, const float prob, const float scale, const T *x, uint32_t *mask32, uint8_t *mask8, T *y, Context *ctx)
void ArgMax(const int outer_dim, const int inner_dim, const int axis_dim, const int top_k, const T *x, int64_t *indices, T *values, Context *ctx)
void DepthwiseConv2dWGrad(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const string &data_format, const T *dy, const T *x, T *dw, Context *ctx)
Definition: tensor.h:21
void SigmoidGrad(const int count, const T *dy, const T *y, T *dx, Context *ctx)
void Sigmoid(const int count, const T *x, T *y, Context *ctx)
void Im2Col2d(const int C, const int H, const int W, const int out_h, const int out_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const string &data_format, const T *im, T *col, Context *ctx)
void LSTMCell(const int N, const int C, const T *cx, T *actx, T *c, T *h, Context *ctx)
void SElu(const int count, const T *x, T *y, Context *ctx)
void ReluGrad(const int count, const float slope, const T *dy, const T *y, T *dx, Context *ctx)
void MaxPool2d(const int N, const int C, const int H, const int W, const int pool_h, const int pool_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const string &data_format, const T *x, int *mask, T *y, Context *ctx)
void BilinearResize(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const string &data_format, const T *x, T *y, Context *ctx)
void ClipGrad(const int count, const float low, const float high, const T *x, const T *dy, T *dx, Context *ctx)
void MaximumGrad(const int count, const T *a, const T *b, const T *dy, T *da, T *db, Context *ctx)
void ImageData(const int N, const int C, const int H, const int W, const string &data_format, const float *mean, const float *std, const Tx *x, Ty *y, Context *ctx)
void GroupNormBackward(const int N, const int G, const int D, const int S, const string &data_format, const Tx *x, const Tp *mu, const Tp *rsig, const Tp *gamma, const Tx *dy, Tp *ds, Tp *db, Tx *dx, Tp *dgamma, Tp *dbeta, Context *ctx)
void Tanh(const int count, const T *x, T *y, Context *ctx)
void NLLLossGrad(const int outer_dim, const int axis_dim, const int inner_dim, const int nignores, const int *ignore, const Tx *log_prob, const Ty *target, Tx *dx, int *flag, Context *ctx)
void BiasAdd(const int outer_dim, const int axis_dim, const int inner_dim, const string &data_format, const T *bias, const T *multiplier, T *y, Context *ctx)
void ROIPool(const int C, const int H, const int W, const int pool_h, const int pool_w, const int num_rois, const float spatial_scale, const T *x, const float *rois, int *mask, T *y, Context *ctx)
void Equal(const int count, const T *a, const T *b, bool *y, Context *ctx)
void DropBlock2d(const int N, const int C, const int H, const int W, const int seed_h, const int seed_w, const int block_size, const float gamma, const string &data_format, uint32_t *seed, int *mask, Context *ctx)
void Arange(const int count, const int start, const int step, T *y, Context *ctx)
void SmoothL1(const int count, const float beta, const T *x, T *y, Context *ctx)
void MixedPrecUpdate(const int count, const float *updates, T *w, Context *ctx)
void ReflectPad(const int count, const int ndims, const int *x_dims, const int *x_strides, const int *y_dims, const int *l_pads, const T *x, T *y, Context *ctx)
void Clip(const int count, const float low, const float high, const T *x, T *y, Context *ctx)
void ReduceSumGrad(const int count, const int ndims, const int *x_dims, const int *y_dims, const int *y_strides, const float scale, const T *dy, T *dx, Context *ctx)
void PReluWGrad(const int rows, const int row_offset, const int channels, const int dim, const bool channel_shared, const string &data_format, const T *dy, const T *x, const T *multiplier, T *bcast_dw, T *dw, Context *ctx)
void NotZero(const int count, const T *x, bool *y, Context *ctx)
void Less(const int count, const T *a, const T *b, bool *y, Context *ctx)
void AvgPool2dGrad(const int N, const int C, const int H, const int W, const int pool_h, const int pool_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const string &data_format, const T *dy, T *dx, Context *ctx)
void NLLLoss(const int outer_dim, const int axis_dim, const int inner_dim, const int nignores, const int *ignore, const Tx *log_prob, const Ty *target, Tx *loss, int *flag, Context *ctx)
void Transpose(const int count, const int ndims, const int *x_strides, const int *y_dims, const T *x, T *y, Context *ctx)
void SmoothL1Grad(const int count, const float beta, const T *dy, T *dx, Context *ctx)
void Crop(const int count, const int ndims, const int *x_strides, const int *y_dims, const int *starts, const T *x, T *y, Context *ctx)
void Col2Im2d(const int C, const int H, const int W, const int out_h, const int out_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const string &data_format, const T *col, T *im, Context *ctx)
void SoftmaxFocalLoss(const int outer_dim, const int axis_dim, const int inner_dim, const float pos_alpha, const float neg_alpha, const float gamma, const int neg_id, const int nignores, const int *ignores, const Tx *prob, const Ty *labels, Tx *losses, int *flags, Context *ctx)
void SigmoidFocalLoss(const int outer_dim, const int axis_dim, const int inner_dim, const float pos_alpha, const float neg_alpha, const float gamma, const int neg_id, const Tx *logit, const Ty *target, Tx *loss, int *flag, Context *ctx)
void AdamUpdate(const int count, const float lr, const float beta1, const float beta2, const float eps, T *g, T *m, T *v, Context *ctx)
void MixedPrecL2Decay(const int count, const float alpha, const T *w, float *dx, Context *ctx)
void UnravelIndex(const int count, const int ndims, const int *dims, const int64_t *x, int64_t *y, Context *ctx)
void NotEqual(const int count, const T *a, const T *b, bool *y, Context *ctx)
void IndexSelect(const int outer_dim, const int inner_dim, const int axis_dim, const int num_indices, const int64_t *indices, const T *x, T *y, Context *ctx)
void ROIAlignGrad(const int C, const int H, const int W, const int pool_h, const int pool_w, const int num_rois, const float spatial_scale, const int sampling_ratio, const T *dy, const float *rois, T *dx, Context *ctx)
void PReluGrad(const int count, const int channels, const int dim, const bool channel_shared, const string &data_format, const T *dy, const T *x, const T *w, T *dx, Context *ctx)
void BroadcastMaximum(const int count, const T *a, const T b, T *y, Context *ctx)
void CropGrad(const int count, const int ndims, const int *x_strides, const int *y_dims, const int *starts, const T *dy, T *dx, Context *ctx)
void BatchNormBackwardTraining(const int N, const int C, const int S, const string &data_format, const Tx *x, const Tp *mu, const Tp *rsig, const Tp *gamma, const Tx *dy, Tp *ds, Tp *db, Tx *dx, Tp *dgamma, Tp *dbeta, Context *ctx)
void AvgPool2d(const int N, const int C, const int H, const int W, const int pool_h, const int pool_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const string &data_format, const T *x, T *y, Context *ctx)
void LessEqual(const int count, const T *a, const T *b, bool *y, Context *ctx)
void Slice(const int outer_dim, const int inner_dim, const int axis_dim, const int slice_dim, const int slice_ofs, const T *x, T *y, Context *ctx)
void SEluGrad(const int count, const T *dy, const T *y, T *dx, Context *ctx)
void RepeatGrad(const int outer_dim, const int inner_dim, const int axis_dim, const int repeats, const T *dy, T *dx, Context *ctx)
void BroadcastMaximumGrad(const int count, const T *a, const T b, const T *dy, T *da, T *db, Context *ctx)
void MaxPool2dGrad(const int N, const int C, const int H, const int W, const int pool_h, const int pool_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const string &data_format, const T *dy, const int *mask, T *dx, Context *ctx)
void BroadcastMinimumGrad(const int count, const T *a, const T b, const T *dy, T *da, T *db, Context *ctx)
void ROIAlign(const int C, const int H, const int W, const int pool_h, const int pool_w, const int num_rois, const float spatial_scale, const int sampling_ratio, const T *x, const float *rois, T *y, Context *ctx)
void ReduceSum(const int ndims, const int *dims, const int naxes, const int *axes, const float scale, const T *x, T *y, Context *ctx)
void Moments(const int ndims, const int *dims, const int naxes, const int *axes, const Tx *x, Ty *mean, Ty *var, Context *ctx)
void SliceGrad(const int outer_dim, const int inner_dim, const int axis_dim, const int slice_dim, const int slice_ofs, const T *dy, T *x, Context *ctx)
void RMSPropUpdate(const int count, const float lr, const float decay, const float eps, T *g, T *h, Context *ctx)
void NNResizeGrad(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const string &data_format, const T *dy, T *dx, Context *ctx)
void TileGrad(const int rows, const int cols, const int multiple, const T *dy, T *dx, Context *ctx)
void SoftmaxCrossEntropy(const int count, const T *prob, const T *targets, T *losses, Context *ctx)
void AbsGrad(const int count, const T *dy, T *dx, Context *ctx)
void SoftmaxFocalLossGrad(const int outer_dim, const int axis_dim, const int inner_dim, const float pos_alpha, const float neg_alpha, const float gamma, const int neg_id, const int nignores, const int *ignores, const Tx *prob, const Ty *labels, Tx *dx, int *flags, Context *ctx)
Definition: common.h:41
void Minimum(const int count, const T *a, const T *b, T *y, Context *ctx)
void DepthwiseConv2d(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const string &data_format, const T *x, const T *w, T *y, Context *ctx)
void DepthwiseConv2dGrad(const int N, const int C, const int H, const int W, const int out_h, const int out_w, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const string &data_format, const T *dy, const T *d, T *dx, Context *ctx)
void GroupNormForward(const int N, const int G, const int D, const int S, const string &data_format, const Tx *x, const Tp *mu, const Tp *rsig, const Tp *gamma, const Tp *beta, Tp *scale, Tp *bias, Tx *y, Context *ctx)
void Relu(const int count, const float slope, const T *x, T *y, Context *ctx)
void Affine(const int outer_dim, const int axis_dim, const int inner_dim, const T *x, const T *alpha, const T *beta, T *y, Context *ctx)
void Greater(const int count, const T *a, const T *b, bool *y, Context *ctx)
void Assign(const int count, const int ndims, const int *x_dims, const int *y_strides, const int *starts, const T *x, T *y, Context *ctx)
void SparseSoftmaxCrossEntropy(const int outer_dim, const int axis_dim, const int inner_dim, const int nignores, const int *ignore, const Tx *prob, const Ty *target, Tx *loss, int *flag, Context *ctx)
void LSTMCellGrad(const int N, const int C, const T *cx, const T *actx, const T *c, const T *dc, const T *dh, T *dcx, T *dx, Context *ctx)
void ConstPad(const int count, const int ndims, const int *x_dims, const int *x_strides, const int *y_dims, const int *l_pads, const float value, const T *x, T *y, Context *ctx)
void SparseSoftmaxCrossEntropyGrad(const int outer_dim, const int axis_dim, const int inner_dim, const int nignores, const int *ignore, const Tx *prob, const Ty *target, Tx *dx, int *flag, Context *ctx)
void SigmoidCrossEntropy(const int count, const T *logit, const T *target, T *loss, int *flag, Context *ctx)
void ChannelShuffle(const int outer_dim, const int inner_dim, const int axis_dim, const int group, const T *x, T *y, Context *ctx)
void GreaterEqual(const int count, const T *a, const T *b, bool *y, Context *ctx)