ä»åãAIã®éååã«ã¤ãã¦å¦ãã§ããã¾ããGoogleã®QATéååã®è«æ ã«åºã¦ãã gemmlowp ã©ã¤ãã©ãª ã«ã¤ãã¦è¦ã¦ããã¾ãã
ååã¯ãgemmlowp ã® ãµã³ãã«ã³ã¼ãï¼doc/quantization_example.ccï¼ãå®è¡ãã¾ãããé·ãå®è¡ãã°ãåºåããã¦ããã®å
容ã¯éååè¨ç®ã®ãã¥ã¼ããªã¢ã«ã®ããã«ãªã£ã¦ããã®ã§ããã®èª¬æããã¾ããã
ä»åã¯ãå®è¡çµæã®ãã°ã ãã§ã¯ååã«ç解ã§ããªãã£ãé¨åãããããã¬ã使ã£ã¦ãå®è¡çµæã¨ã½ã¼ã¹ã³ã¼ããè¦ãªãããç解ãæ·±ãã¦ããã¾ãã
ããã§ã¯ããã£ã¦ããã¾ãï¼
ã¯ããã«
ãAIã¢ãã«ã®éååãã®è¨äºä¸è¦§ã§ããè¯ãã£ããåèã«ãã¦ãã ããã
AIã¢ãã«ã®éååã®è¨äºä¸è¦§
gemmlowp ã®ãªãã¸ããªã¯ã以ä¸ã§ãã
github.com
ããã§ã¯ãã£ã¦ããã¾ãï¼
è¡åã®ç©ãéååã§è¨ç®ãã
ååã®ããããã§ãã
ãµã³ãã«ã½ã¼ã¹ã¯ãã¹ã±ã¼ã«ã¨ã¼ããã¤ã³ããæ±ãã¦ãLHS 㨠RHS ãéååãã¦ãã¾ããããã®å¾ãéååãã LHS 㨠RHS ã®ç©ãè¨ç®ãã¦ãã¾ããã
ä»åã¯ããã®è¡åç©ã®è¨ç®ãå°ãé£ããã®ã§ããããã¬ã使ã£ã¦é çªã«è¦ã¦ããã¾ãã
è¡åã®ç©ãéååã§è¨ç®ããã½ã¼ã¹ã³ã¼ã
ã¾ãã該å½ã®ã½ã¼ã¹ã³ã¼ãã示ãã¾ããããã¨è¦ã¦ãå
¨ããããã¾ããï¼ç¬ï¼ã
ãªã®ã§ãå°ããã¤è¦ã¦ããã¾ãã
gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage;
quantize_down_stage.result_offset_after_shift = result_offset;
quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier;
quantize_down_stage.result_shift = right_shift;
gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
const auto& output_pipeline =
std::make_tuple(quantize_down_stage, saturating_cast_stage);
auto actual_uint8_result_map = actual_uint8_result.Map();
gemmlowp::GemmContext gemm_context;
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::uint8_t,
gemmlowp::DefaultL8R8BitDepthParams>(
&gemm_context, uint8_lhs.ConstMap(), uint8_rhs.ConstMap(),
&actual_uint8_result_map, lhs_offset, rhs_offset, output_pipeline);
std::cout << "Quantized uint8 result matrix obtained by quantized "
<< "multiplication:\n"
<< actual_uint8_result << std::endl;
ã¾ããæå㯠gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint
ã¨ããæ§é ä½ã®å¤æ°ãå®ç¾©ãã¦ãã¾ãã
æ§é ä½ã®å®ç¾©ã¯ä»¥ä¸ã§ãã
struct OutputStageQuantizeDownInt32ByFixedPoint {
std::int32_t result_fixedpoint_multiplier;
std::int32_t result_shift;
std::int32_t result_offset_after_shift;
};
次ã«ãæ§é ä½ã®ã¡ã³ãã«å¤ãè¨å®ãã¦ãã¾ããããã®å¤ãè¨ç®ãã¦ããé¨åã示ãã¾ãã
real_multiplier
ã¯ãLHSãRHSãè¨ç®çµæã®ã¹ã±ã¼ã«ããè¨ç®ãã¦ãã¾ããè«æã§ã¯ã ã¨æ¸ããã¦ãã¾ããã
real_multiplier
ã使ã£ã¦ãquantized_multiplier
㨠right_shift
ãè¨ç®ãã¦ããããã§ãã
const float real_multiplier = lhs_qparams.scale * rhs_qparams.scale / result_qparams.scale;
std::int32_t quantized_multiplier;
int right_shift;
QuantizeMultiplierSmallerThanOne(real_multiplier, &quantized_multiplier, &right_shift);
ã§ã¯ãQuantizeMultiplierSmallerThanOne()
ãè¦ã¦ããã¾ãã
QuantizeMultiplierSmallerThanOne()
ã¾ããã½ã¼ã¹ã³ã¼ãã示ãã¾ãã
å
é ã® assert()
ã¯ãè«æ㫠㯠ã¨æ¸ããã¦ããé¨åã§ããã
ã¾ããreal_multiplier
ã ã®ç¯å²ã«ãªãã¾ã§ãs
åã ã2åï¼å·¦ã·ããï¼ãã¾ãã
次ã¯ã1 << 31
ã¨ããã¦ãround
ï¼åæ¨äºå
¥ï¼ãã¦ãã¾ããããã¯ãå°æ°é¨ã 31bit ã®åºå®å°æ°ç¹æ°ã«å¤æãã¦ãã¾ããæ®ãã® 1bit ã¯æ´æ°é¨ï¼ãããã¯ã符å·é¨ï¼ã§ãã
æ±ã¾ã£ã q
ã quantized_multiplier
ã§ãs
ã right_shift
ã«ãªãã¾ãããã以å¤ã«ããããããã¦ã¾ããããã§ãã¯ã ã£ããã丸ã誤差ãªã©ãèæ
®ãã調æ´ãªã©ã§ãã
void QuantizeMultiplierSmallerThanOne(float real_multiplier,
std::int32_t* quantized_multiplier,
int* right_shift) {
assert(real_multiplier > 0.f);
assert(real_multiplier < 1.f);
int s = 0;
while (real_multiplier < 0.5f) {
real_multiplier *= 2.0f;
s++;
}
std::int64_t q = static_cast<std::int64_t>(std::round(real_multiplier * (1ll << 31)));
assert(q <= (1ll << 31));
if (q == (1ll << 31)) {
q /= 2;
s--;
}
assert(s >= 0);
assert(q <= std::numeric_limits<std::int32_t>::max());
*quantized_multiplier = static_cast<std::int32_t>(q);
*right_shift = s;
}
å®éã«å¤ããããã¬ã§è¦ãã¨ãreal_multiplier
㯠0.00436593033ãquantized_multiplier
㯠1200097792ãright_shift
㯠7 ã§ããã
ããã¯ãreal_multiplier
ã« ãããããã¨ã¨åãã§ãã
ããã§ãæ§é ä½ quantize_down_stage ã®ã¡ã³ãã®å¤ãè¨ç®ã§ãã¾ããã
gemmlowp::GemmWithOutputPipeline()
ãã¨ã¯ gemmlowp::GemmWithOutputPipeline()
ã§éååã§è¡åç©ãè¨ç®ããã ãã§ãããå®éã«ãããã¬ã§èªã¿é²ãã¦ãã£ãã¨ãããé常ã«è¤éã§ããããå
¨ã¦èª¬æããã®ã¯é£ããã¨æãã¾ããã
ããã§ãè¨ç®ã®æµãã説æãã¾ãã
éååãã LHS 㨠RHS ã¯ä»¥ä¸ã§ãã
Quantized uint8 LHS matrix:
208 236 0 238
3 214 255 29
Quantized uint8 RHS matrix:
152 51 244
60 26 255
0 127 246
127 254 247
ãã®è¨ç®çµæãã以ä¸ã§ãã
Quantized uint8 result matrix obtained by quantized multiplication:
168 115 255
0 66 151
è¨ç®çµæã® 168
ã¯ãè¡åç©ãªã®ã§ã208 * 152 + 236 * 60 + 0 * 0 + 238 * 127
ã®ããã«è¨ç®ããã¾ããè¨ç®ãç¶ããã¨ã31616 + 14160 + 0 + 30226 = 76002
ã¨ãªãã¾ãã
ããã¯ãè«æã® 2.3 ç« ã®å¼ (7) ã® ã«ãªãã¾ãã
å¼ (7) ãè¨ç®ããã°ã168 ã«ãªãã¯ãã§ãããã£ã¦ã¿ã¾ãã
㯠4 ã§ã 㯠113ã 㯠114ã 㯠118ã 㯠152 + 60 + 0 + 127 = 339
ã 㯠208 + 236 + 0 + 238 = 682
ã§ãã
å¼ (7) ã®æ¬å¼§å
ã¯ã4 * 113 * 114 - 113 * 339 - 114 * 682 + 76002 = 51528 - 38307 - 77748 + 76002 = 11475
ã¨ãªãã¾ãã
ãã¨ã¯ã ãããã¦ã ã足ãã¦ããã°ããããã§ããã ã¯å°æ°ã§ãããæ´æ°ã ãã§è¨ç®ãããã®ã§ã代ããã«ã以ä¸ã使ãã¾ãã
quantized_multiplier
㯠1200097792ãright_shift
㯠7 ã§ããã
11475 * 1200097792 = 13,771,122,163,200
ã§ã ã§å²ã£ã¦ã ã§å²ã㨠ã«ãªãã¾ãã
ãã¨ã¯ã ã® ã足ãã¦ã ã¨ãªãã¾ãã
ãªãã ã§å²ããã¨ããã¨ã 㯠31bit ã®åºå®å°æ°ç¹æ°ã ããã§ãããå
ã«æ»ãå¿
è¦ãããããã§ãã
ã¾ãã ã§å²ãçç±ã¯ãreal_multiplier
ã ã®ç¯å²ã«ããããã« right_shift
ã®æ°ã ãå·¦ã·ãããã¦ããã®ã§ãå
ã«æ»ãå¿
è¦ãããããã§ãã
ãã£ã¨ç°¡åã«è¨ãã¨ã ã« ãããã¦ãæ¬å¼§å
ã®çµæï¼11475ï¼ã¨æãç®ãã¦ããã®çµæã ã§å²ã£ãã ãã§ãã
ãªãããããªããããããã¨ãããã®ãã¨ããã¨ãæµ®åå°æ°ç¹æ¼ç®ã®çµæã¨ãªãã¹ãè¿ã¥ãããããã§ããã¤ã¾ããreal_multiplier
ã®å°æ°ç¹ä»¥ä¸ã®å¤ããªãã¹ãæ¼ç®ã«åæ ãããããã§ãã
è¡åç©ãéååã§è¨ç®ããçµæãééååãã
åå ã®å®¿é¡ã§ãæè¨ç®ã§ééååããã¨ããµã³ãã«ã½ã¼ã¹ã®çµæã¨ç°ãªã£ã¦ãã件ãããã¾ããã
ãµã³ãã«ã½ã¼ã¹ãééååããçµæã¯ä»¥ä¸ã§ãæè¨ç®ããã¨ãr = 0.0107 * (168 - 118) = 0.535
ã§ããã
Here is the actual float product (LHS * RHS) matrix obtained by dequantizing the above uint8 result, i.e. as far as we are concerned, the ACTUAL RESULT:
0.533 -0.032 1.46
-1.26 -0.554 0.352
ãããã¬ã§è¦ãã¨ãã¹ã±ã¼ã«ã¯ 0.0106628919
ã ã£ãã®ã§ãããã使ãã¨ãr = 0.0107 * (168 - 118) = 0.533144595
ã¨ãªãã¾ãããµã³ãã«ã½ã¼ã¹ã®çµæã¨åãã«ãªãã¾ããã
éååã§è¡åç©ãè¨ç®ããçµæã¨æµ®åå°æ°ç¹æ¼ç®ã§è¨ç®ããçµæãæ¯è¼ãã
åå ã®å®¿é¡ã§ãæè¨ç®ã§å·®åãè¨ç®ãããããµã³ãã«ã½ã¼ã¹ã®çµæã¨ç°ãªã£ã¦ãã件ãããã¾ããã
ãµã³ãã«ã½ã¼ã¹ã®å·®åã¯ä»¥ä¸ã§ãæè¨ç®ããã¨ã0.533 - 0.534 = -0.001
ã§ããã
Difference between ACTUAL and REFERENCE float results:
-0.000675 0.00764 -0.000674
-0.000674 0.0022 0.00369
ãããã¬ã§è¦ãã¨ãæµ®åå°æ°ç¹æ¼ç®ããçµæã¯ã0.533819675
ã§ãéååã§è¨ç®ããçµæ㯠0.533144595
ãªã®ã§ãå¼ãã¨ã-0.00067508
ã¨ãªããçµæãä¸è´ãã¾ããã
ãããã«
ä»åã¯ãgemmlowp ã©ã¤ãã©ãªã®ãµã³ãã«ã½ã¼ã¹ã®å®è¡ãç´°ããè¦ã¦ããã¾ãããTensorFlow Lite C++ ããã©ã®ããã«éååãã¦ç©åæ¼ç®ãè¡ã£ã¦ãããããããåããã¾ããã
æå¾ã«ãªãã¾ããããã¨ã³ã¸ãã¢ã°ã«ã¼ãã®ã©ã³ãã³ã°ã«åå ä¸ã§ãã
æ°æ¥½ã«ãããã¨ãããããé¡ããããã¾ãð
ä»åã¯ä»¥ä¸ã§ãã
æå¾ã¾ã§ãèªã¿ããã ãããããã¨ããããã¾ããã