avfilter/dnn: get the data type of network output from dnn execution result
authorGuo, Yejun <yejun.guo@intel.com>
Mon, 21 Oct 2019 12:38:10 +0000 (20:38 +0800)
committerPedro Arthur <bygrandao@gmail.com>
Wed, 30 Oct 2019 14:00:41 +0000 (11:00 -0300)
so,  we can make a filter more general to accept different network
models, by adding a data type convertion after getting data from network.

After we add dt field into struct DNNData, it becomes the same as
DNNInputData, so merge them with one struct: DNNData.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
libavfilter/dnn/dnn_backend_native.c
libavfilter/dnn/dnn_backend_native_layer_conv2d.c
libavfilter/dnn/dnn_backend_native_layer_depth2space.c
libavfilter/dnn/dnn_backend_native_layer_pad.c
libavfilter/dnn/dnn_backend_tf.c
libavfilter/dnn_interface.h
libavfilter/vf_derain.c
libavfilter/vf_sr.c

index ff280b5..add1db4 100644 (file)
@@ -28,7 +28,7 @@
 #include "dnn_backend_native_layer_conv2d.h"
 #include "dnn_backend_native_layers.h"
 
-static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
     DnnOperand *oprd = NULL;
@@ -263,6 +263,7 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output
         outputs[i].height = oprd->dims[1];
         outputs[i].width = oprd->dims[2];
         outputs[i].channels = oprd->dims[3];
+        outputs[i].dt = oprd->data_type;
     }
 
     return DNN_SUCCESS;
index 6ec0fa7..7b29697 100644 (file)
@@ -106,6 +106,7 @@ int dnn_execute_layer_conv2d(DnnOperand *operands, const int32_t *input_operand_
     output_operand->dims[1] = height - pad_size * 2;
     output_operand->dims[2] = width - pad_size * 2;
     output_operand->dims[3] = conv_params->output_num;
+    output_operand->data_type = operands[input_operand_index].data_type;
     output_operand->length = calculate_operand_data_length(output_operand);
     output_operand->data = av_realloc(output_operand->data, output_operand->length);
     if (!output_operand->data)
index 174676e..7dab19d 100644 (file)
@@ -69,6 +69,7 @@ int dnn_execute_layer_depth2space(DnnOperand *operands, const int32_t *input_ope
     output_operand->dims[1] = height * block_size;
     output_operand->dims[2] = width * block_size;
     output_operand->dims[3] = new_channels;
+    output_operand->data_type = operands[input_operand_index].data_type;
     output_operand->length = calculate_operand_data_length(output_operand);
     output_operand->data = av_realloc(output_operand->data, output_operand->length);
     if (!output_operand->data)
index 8fa35de..8e5959b 100644 (file)
@@ -105,6 +105,7 @@ int dnn_execute_layer_pad(DnnOperand *operands, const int32_t *input_operand_ind
     output_operand->dims[1] = new_height;
     output_operand->dims[2] = new_width;
     output_operand->dims[3] = new_channel;
+    output_operand->data_type = operands[input_operand_index].data_type;
     output_operand->length = calculate_operand_data_length(output_operand);
     output_operand->data = av_realloc(output_operand->data, output_operand->length);
     if (!output_operand->data)
index c8dff51..ed91d05 100644 (file)
@@ -83,7 +83,7 @@ static TF_Buffer *read_graph(const char *model_filename)
     return graph_buf;
 }
 
-static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
+static TF_Tensor *allocate_input_tensor(const DNNData *input)
 {
     TF_DataType dt;
     size_t size;
@@ -105,7 +105,7 @@ static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
                              input_dims[1] * input_dims[2] * input_dims[3] * size);
 }
 
-static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     TFModel *tf_model = (TFModel *)model;
     TF_SessionOptions *sess_opts;
@@ -603,6 +603,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, u
         outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
         outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
         outputs[i].data = TF_TensorData(tf_model->output_tensors[i]);
+        outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
     }
 
     return DNN_SUCCESS;
index 057005f..fdefcb7 100644 (file)
@@ -34,15 +34,10 @@ typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
 
 typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
 
-typedef struct DNNInputData{
+typedef struct DNNData{
     void *data;
     DNNDataType dt;
     int width, height, channels;
-} DNNInputData;
-
-typedef struct DNNData{
-    float *data;
-    int width, height, channels;
 } DNNData;
 
 typedef struct DNNModel{
@@ -50,7 +45,7 @@ typedef struct DNNModel{
     void *model;
     // Sets model input and output.
     // Should be called at least once before model execution.
-    DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output);
+    DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
index b7bba09..89f9d5a 100644 (file)
@@ -39,7 +39,7 @@ typedef struct DRContext {
     DNNBackendType     backend_type;
     DNNModule         *dnn_module;
     DNNModel          *model;
-    DNNInputData       input;
+    DNNData            input;
     DNNData            output;
 } DRContext;
 
@@ -137,7 +137,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
             int t = i * out->width * 3 + j;
 
             int t_in =  (i + pad_size) * in->width * 3 + j + pad_size * 3;
-            out->data[0][k] = CLIP((int)((((float *)dr_context->input.data)[t_in] - dr_context->output.data[t]) * 255), 0, 255);
+            out->data[0][k] = CLIP((int)((((float *)dr_context->input.data)[t_in] - ((float *)dr_context->output.data)[t]) * 255), 0, 255);
         }
     }
 
index 0433246..fff19ea 100644 (file)
@@ -41,7 +41,7 @@ typedef struct SRContext {
     DNNBackendType backend_type;
     DNNModule *dnn_module;
     DNNModel *model;
-    DNNInputData input;
+    DNNData input;
     DNNData output;
     int scale_factor;
     struct SwsContext *sws_contexts[3];