diff --git a/src/base/slow_conv3d_forward.h b/src/base/slow_conv3d_forward.h new file mode 100644 index 00000000..66daa3ea --- /dev/null +++ b/src/base/slow_conv3d_forward.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_BASE_SLOW_CONV3D_FORWARD_H_ +#define INFINI_OPS_BASE_SLOW_CONV3D_FORWARD_H_ + +#include "operator.h" + +namespace infini::ops { + +class SlowConv3dForward : public Operator { + public: + SlowConv3dForward(const Tensor input, const Tensor weight, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, Tensor output) + : input_shape_{input.shape()}, + input_strides_{input.strides()}, + input_type_{input.dtype()}, + weight_shape_{weight.shape()}, + weight_strides_{weight.strides()}, + weight_type_{weight.dtype()}, + output_shape_{output.shape()}, + output_strides_{output.strides()}, + output_type_{output.dtype()}, + kernel_size_{kernel_size}, + stride_{stride}, + padding_{padding}, + device_index_{output.device().index()} {} + + virtual void operator()(const Tensor input, const Tensor weight, + const std::vector kernel_size, + const std::vector stride, + const std::vector padding, + Tensor output) const = 0; + + protected: + Tensor::Shape input_shape_; + + Tensor::Strides input_strides_; + + DataType input_type_; + + Tensor::Shape weight_shape_; + + Tensor::Strides weight_strides_; + + DataType weight_type_; + + Tensor::Shape output_shape_; + + Tensor::Strides output_strides_; + + DataType output_type_; + + std::vector kernel_size_{}; + + std::vector stride_{}; + + std::vector padding_{}; + + int device_index_{0}; +}; + +} // namespace infini::ops + +#endif