| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir.ir import * |
| from mlir.dialects import quant |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| return f |
| |
| |
| # CHECK-LABEL: TEST: test_type_hierarchy |
| @run |
| def test_type_hierarchy(): |
| with Context(): |
| i8 = IntegerType.get_signless(8) |
| any = Type.parse("!quant.any<i8<-8:7>:f32>") |
| uniform = Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") |
| per_axis = Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") |
| calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") |
| |
| assert not quant.QuantizedType.isinstance(i8) |
| assert quant.QuantizedType.isinstance(any) |
| assert quant.QuantizedType.isinstance(uniform) |
| assert quant.QuantizedType.isinstance(per_axis) |
| assert quant.QuantizedType.isinstance(calibrated) |
| |
| assert quant.AnyQuantizedType.isinstance(any) |
| assert quant.UniformQuantizedType.isinstance(uniform) |
| assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) |
| assert quant.CalibratedQuantizedType.isinstance(calibrated) |
| |
| assert not quant.AnyQuantizedType.isinstance(uniform) |
| assert not quant.UniformQuantizedType.isinstance(per_axis) |
| |
| |
| # CHECK-LABEL: TEST: test_any_quantized_type |
| @run |
| def test_any_quantized_type(): |
| with Context(): |
| i8 = IntegerType.get_signless(8) |
| f32 = F32Type.get() |
| any = quant.AnyQuantizedType.get( |
| quant.QuantizedType.FLAG_SIGNED, i8, f32, -8, 7 |
| ) |
| |
| # CHECK: flags: 1 |
| print(f"flags: {any.flags}") |
| # CHECK: signed: True |
| print(f"signed: {any.is_signed}") |
| # CHECK: storage type: i8 |
| print(f"storage type: {any.storage_type}") |
| # CHECK: expressed type: f32 |
| print(f"expressed type: {any.expressed_type}") |
| # CHECK: storage min: -8 |
| print(f"storage min: {any.storage_type_min}") |
| # CHECK: storage max: 7 |
| print(f"storage max: {any.storage_type_max}") |
| # CHECK: storage width: 8 |
| print(f"storage width: {any.storage_type_integral_width}") |
| # CHECK: quantized element type: !quant.any<i8<-8:7>:f32> |
| print(f"quantized element type: {any.quantized_element_type}") |
| # CHECK: !quant.any<i8<-8:7>:f32> |
| print(any) |
| assert any == Type.parse("!quant.any<i8<-8:7>:f32>") |
| |
| |
| # CHECK-LABEL: TEST: test_uniform_type |
| @run |
| def test_uniform_type(): |
| with Context(): |
| i8 = IntegerType.get_signless(8) |
| f32 = F32Type.get() |
| uniform = quant.UniformQuantizedType.get( |
| quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7 |
| ) |
| |
| # CHECK: scale: 0.99872 |
| print(f"scale: {uniform.scale}") |
| # CHECK: zero point: 127 |
| print(f"zero point: {uniform.zero_point}") |
| # CHECK: fixed point: False |
| print(f"fixed point: {uniform.is_fixed_point}") |
| # CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127> |
| print(uniform) |
| assert uniform == Type.parse("!quant.uniform<i8<-8:7>:f32, 0.99872:127>") |
| |
| |
| # CHECK-LABEL: TEST: test_uniform_per_axis_type |
| @run |
| def test_uniform_per_axis_type(): |
| with Context(): |
| i8 = IntegerType.get_signless(8) |
| f32 = F32Type.get() |
| per_axis = quant.UniformQuantizedPerAxisType.get( |
| quant.QuantizedType.FLAG_SIGNED, |
| i8, |
| f32, |
| [200, 0.99872], |
| [0, 120], |
| quantized_dimension=1, |
| storage_type_min=quant.QuantizedType.default_minimum_for_integer( |
| is_signed=True, integral_width=8 |
| ), |
| storage_type_max=quant.QuantizedType.default_maximum_for_integer( |
| is_signed=True, integral_width=8 |
| ), |
| ) |
| |
| # CHECK: scales: None |
| print(f"scales: {per_axis.scales}") |
| # CHECK: zero_points: None |
| print(f"zero_points: {per_axis.zero_points}") |
| # CHECK: quantized dim: 1 |
| print(f"quantized dim: {per_axis.quantized_dimension}") |
| # CHECK: fixed point: False |
| print(f"fixed point: {per_axis.is_fixed_point}") |
| # CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}> |
| print(per_axis) |
| assert per_axis == Type.parse("!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>") |
| |
| |
| # CHECK-LABEL: TEST: test_calibrated_type |
| @run |
| def test_calibrated_type(): |
| with Context(): |
| f32 = F32Type.get() |
| calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) |
| |
| # CHECK: min: -0.998 |
| print(f"min: {calibrated.min}") |
| # CHECK: max: 1.2321 |
| print(f"max: {calibrated.max}") |
| # CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>> |
| print(calibrated) |
| assert calibrated == Type.parse("!quant.calibrated<f32<-0.998:1.2321>>") |