Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix multiple issues
Signed-off-by: zehao-intel <[email protected]>
  • Loading branch information
zehao-intel committed Jul 12, 2024
commit b78e20b0e9f80ef082b2ac38ba8376be5da6f0c5
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Now we support both pb and ckpt formats.
## 2. Benchmark
```shell
bash run_benchmark.sh --input_model=./tensorflow-mask_rcnn_inception_v2-tune.pb --dataset_location=/path/to/dataset/coco_val.record --mode=performance

bash run_benchmark.sh --input_model=./tensorflow-mask_rcnn_inception_v2-tune.pb --dataset_location=/path/to/dataset/coco_val.record --mode=accuracy
```

Details of enabling Intel® Neural Compressor on mask_rcnn_inception_v2 for Tensorflow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,135 @@ def __call__(self, sample):
return sample


class ResizeTFTransform(object):
"""Resize the input image to the given size.
class ResizeWithRatio():
"""Resize image with aspect ratio and pad it to max shape(optional).

If the image is padded, the label will be processed at the same time.
The input image should be np.array.

Args:
size (list or int): Size of the result
interpolation (str, default='bilinear'):Desired interpolation type,
support 'bilinear', 'nearest', 'bicubic'
min_dim (int, default=800):
Resizes the image such that its smaller dimension == min_dim
max_dim (int, default=1365):
Ensures that the image longest side doesn't exceed this value
padding (bool, default=False):
If true, pads image with zeros so its size is max_dim x max_dim

Returns:
tuple of processed image and label
"""

def __init__(self, size, interpolation="bilinear"):
"""Initialize `ResizeTFTransform` class."""
if isinstance(size, int):
self.size = size, size
elif isinstance(size, list):
if len(size) == 1:
self.size = size[0], size[0]
elif len(size) == 2:
self.size = size[0], size[1]
self.interpolation = interpolation
def __init__(self, min_dim=800, max_dim=1365, padding=False, constant_value=0):
"""Initialize `ResizeWithRatio` class."""
self.min_dim = min_dim
self.max_dim = max_dim
self.padding = padding
self.constant_value = constant_value

if self.interpolation not in ["bilinear", "nearest", "bicubic"]:
raise ValueError("Unsupported interpolation type!")
def __call__(self, sample):
"""Resize the image with ratio in sample."""
image, label = sample
height, width = image.shape[:2]
scale = 1
if self.min_dim:
scale = max(1, self.min_dim / min(height, width))
if self.max_dim:
image_max = max(height, width)
if round(image_max * scale) > self.max_dim:
scale = self.max_dim / image_max
if scale != 1:
image = cv2.resize(image, (round(height * scale), round(width * scale)))

bbox, str_label, int_label, image_id = label

if self.padding:
h, w = image.shape[:2]
pad_param = [
[(self.max_dim - h) // 2, self.max_dim - h - (self.max_dim - h) // 2],
[(self.max_dim - w) // 2, self.max_dim - w - (self.max_dim - w) // 2],
[0, 0],
]
if not isinstance(bbox, np.ndarray):
bbox = np.array(bbox)
resized_box = bbox * [height, width, height, width] * scale
moved_box = resized_box + [
(self.max_dim - h) // 2,
(self.max_dim - w) // 2,
(self.max_dim - h) // 2,
(self.max_dim - w) // 2,
]
bbox = moved_box / [self.max_dim, self.max_dim, self.max_dim, self.max_dim]
image = np.pad(image, pad_param, mode="constant", constant_values=self.constant_value)
return image, (bbox, str_label, int_label, image_id)


class TensorflowResizeWithRatio():
"""Resize image with aspect ratio and pad it to max shape(optional).

If the image is padded, the label will be processed at the same time.
The input image should be np.array or tf.Tensor.

Args:
min_dim (int, default=800):
Resizes the image such that its smaller dimension == min_dim
max_dim (int, default=1365):
Ensures that the image longest side doesn't exceed this value
padding (bool, default=False):
If true, pads image with zeros so its size is max_dim x max_dim

Returns:
tuple of processed image and label
"""

def __init__(self, min_dim=800, max_dim=1365, padding=False, constant_value=0):
"""Initialize `TensorflowResizeWithRatio` class."""
self.min_dim = min_dim
self.max_dim = max_dim
self.padding = padding
self.constant_value = constant_value

def __call__(self, sample):
"""Resize the input image in sample to the given size."""
"""Resize the image with ratio in sample."""
image, label = sample
if isinstance(image, tf.Tensor):
image = tf.image.resize(image, self.size, method=self.interpolation)
shape = tf.shape(input=image)
height = tf.cast(shape[0], dtype=tf.float32)
width = tf.cast(shape[1], dtype=tf.float32)
scale = 1
if self.min_dim:
scale = tf.maximum(1.0, tf.cast(self.min_dim / tf.math.minimum(height, width), dtype=tf.float32))
if self.max_dim:
image_max = tf.cast(tf.maximum(height, width), dtype=tf.float32)
scale = tf.cond(
pred=tf.greater(tf.math.round(image_max * scale), self.max_dim),
true_fn=lambda: self.max_dim / image_max,
false_fn=lambda: scale,
)
image = tf.image.resize(image, (tf.math.round(height * scale), tf.math.round(width * scale)))
bbox, str_label, int_label, image_id = label

if self.padding:
shape = tf.shape(input=image)
h = tf.cast(shape[0], dtype=tf.float32)
w = tf.cast(shape[1], dtype=tf.float32)
pad_param = [
[(self.max_dim - h) // 2, self.max_dim - h - (self.max_dim - h) // 2],
[(self.max_dim - w) // 2, self.max_dim - w - (self.max_dim - w) // 2],
[0, 0],
]
resized_box = bbox * [height, width, height, width] * scale
moved_box = resized_box + [
(self.max_dim - h) // 2,
(self.max_dim - w) // 2,
(self.max_dim - h) // 2,
(self.max_dim - w) // 2,
]
bbox = moved_box / [self.max_dim, self.max_dim, self.max_dim, self.max_dim]
image = tf.pad(image, pad_param, constant_values=self.constant_value)
else:
image = cv2.resize(image, self.size, interpolation=interpolation_map[self.interpolation])
return (image, label)
transform = ResizeWithRatio(self.min_dim, self.max_dim, self.padding)
image, (bbox, str_label, int_label, image_id) = transform(sample)
return image, (bbox, str_label, int_label, image_id)


class BaseMetric(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
COCOmAPv2,
COCORecordDataset,
ComposeTransform,
ResizeTFTransform,
TFDataLoader,
LabelBalanceCOCORecordFilter,
TensorflowResizeWithRatio,
)

arg_parser = ArgumentParser(description='Parse args')
Expand Down Expand Up @@ -92,11 +92,12 @@ def eval_func(dataloader):
latency = np.array(latency_list[warmup:]).mean() / args.batch_size
return latency

use_padding = True if args.mode == 'performance' else False
eval_dataset = COCORecordDataset(root=args.dataset_location, filter=None, \
transform=ComposeTransform(transform_list=[TensorflowResizeWithRatio(
min_dim=800, max_dim=1356, padding=False)]))
min_dim=800, max_dim=1356, padding=use_padding)]))
batch_size = 1 if args.mode == 'accuracy' else args.batch_size
eval_dataloader=TFDataLoader(dataset=eval_dataset, batch_size=args.batch_size)
eval_dataloader=TFDataLoader(dataset=eval_dataset, batch_size=batch_size)

latency = eval_func(eval_dataloader)
if args.benchmark and args.mode == 'performance':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _parse_function(proto):
return dataset

def evaluation_func(model, measurer=None):
evaluate_opt_graph.eval_inference(model)
return evaluate_opt_graph.eval_inference(model)

class eval_classifier_optimized_graph:
"""Evaluate image classifier with optimized TensorFlow graph"""
Expand Down Expand Up @@ -294,8 +294,7 @@ def eval_inference(self, infer_graph):
print('Throughput: %.3f records/sec' % throughput)
print('--------------------------------------------------')

if self.args.accuracy:
return accuracy
return accuracy

def run(self):
""" This is neural_compressor function include tuning and benchmark option """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ This example can run on Intel CPUs and GPUs.
pip install neural-compressor
```

### Install Intel Tensorflow
### Install requirements
```shell
pip install intel-tensorflow
pip install -r requirements.txt
```
> Note: Validated TensorFlow [Version](/docs/source/installation_guide.md#validated-software-environment).

Expand Down Expand Up @@ -59,15 +59,18 @@ pip install --upgrade intel-extension-for-tensorflow[cpu]


# Run command
Please set the following environment variables before running quantization or benchmark commands:

* `export nnUNet_preprocessed=<path/to/build>/build/preprocessed_data`
* `export nnUNet_raw_data_base=<path/to/build>/build/raw_data`
* `export RESULTS_FOLDER=<path/to/build>/build/result`

## Quantization

`bash run_quant.sh --input_model=3dunet_dynamic_ndhwc.pb --dataset_location=<path/to/build>/build --output_model=3dunet_dynamic_ndhwc_int8.pb`

## Benchmark

* `export nnUNet_preprocessed=<path/to/build>/build/preprocessed_data`
* `export nnUNet_raw_data_base=<path/to/build>/build/raw_data`
* `export RESULTS_FOLDER=<path/to/build>/build/result`
* `pip install -r requirements.txt`
* `python run_accuracy.py --input-model=<path/to/model_file> --data-location=<path/to/dataset> --calib-preprocess=<path/to/calibrationset> --iters=100 --batch-size=1 --mode=benchmark --bfloat16 0`
`bash run_benchmark.sh --input_model=3dunet_dynamic_ndhwc_int8.pb --dataset_location=<path/to/build>/build --batch_size=100 --mode=benchmark`

`bash run_benchmark.sh --input_model=3dunet_dynamic_ndhwc_int8.pb --dataset_location=<path/to/build>/build --batch_size=1 --mode=accuracy`
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_args():
help="One of three options: 'benchmark'/'accuracy'/'tune'.")
arg_parser.add_argument('-n', "--iters",
help='The number of iteration. shall > warmup num(10)',
type=int, default=20)
type=int, default=100)
arg_parser.add_argument('-e', "--num-inter-threads",
help='The number of inter-thread.',
dest='num_inter_threads', type=int, default=0)
Expand Down Expand Up @@ -209,7 +209,7 @@ def __len__(self):

set_random_seed(9527)
quant_config = StaticQuantConfig()
calib_dataloader=BaseDataloader(dataset=CalibrationDL())
calib_dataloader=BaseDataLoader(dataset=CalibrationDL())
q_model = quantize_model(graph, quant_config, calib_dataloader)
try:
q_model.save(args.output_model)
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
nnunet
tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ optional arguments:

```shell
wget https://storage.googleapis.com/download.magenta.tensorflow.org/models/arbitrary_style_transfer.tar.gz
tar -xvzf arbitrary_style_transfer.tar.gz ./model
tar -xvzf arbitrary_style_transfer.tar.gz
```

### 3. Prepare Dataset
Expand All @@ -70,22 +70,12 @@ There are two folders named style_images and content_images in current folder. P

# Run Command
```shell
python style_tune.py --output_dir=./result --style_images_paths=./style_images --content_images_paths=./content_images --input_model=./model/model.ckpt
python main.py --output_dir=./result --style_images_paths=./style_images --content_images_paths=./content_images --input_model=./model/model.ckpt
```


## Quantization Config

The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.

```
config = PostTrainingQuantConfig(
device="gpu",
backend="itex",
...
)
```

## Quantization
```shell
bash run_quant.sh --dataset_location=style_images/,content_images/ --input_model=./model/model.ckpt --output_model=saved_model
Expand Down Expand Up @@ -119,13 +109,9 @@ Here we set the input tensor and output tensors name into *inputs* and *outputs*

After prepare step is done, we just need add 2 lines to get the quantized model.
```python
from neural_compressor import quantization
from neural_compressor.config import PostTrainingQuantConfig
conf = PostTrainingQuantConfig(inputs=['style_input', 'content_input'],
outputs=['transformer/expand/conv3/conv/Sigmoid'],
calibration_sampling_size=[50, 100])
quantized_model = quantization.fit(args.input_graph, conf=conf, calib_dataloader=dataloader,
eval_dataloader==dataloader)
```
from neural_compressor.tensorflow import StaticQuantConfig, quantize_model

The Intel® Neural Compressor quantizer.fit() function will return a best quantized model during timeout constrain.
quant_config = StaticQuantConfig()
q_model = quantize_model(graph, quant_config, calib_dataloader)
q_model.save(FLAGS.output_model)
```
Loading