Magicode logo
Magicode
1
15 min read

raspberry pi pico でTensorFlow Lite for Microcontrollersのmagic wandに挑戦した

はじめに

以前 の記事で、raspberry pi pico(以下、pico)上でTensorflow Lite for Microcontrollersを動かしてみた。
今回は、加速度センサを用いたmagic wandというサンプルを動かすのに挑戦した。

結論を言うとうまくジェスチャを検出までできず、再学習に挑戦することになった。

環境

  • PC
    • windows 10
  • pico
  • USBケーブル
    • 家に転がってたものを
  • 加速度センサー(MMA8452Q)
  • 330Ωの抵抗 x 2
    • 動作確認では手元になかったので使ってないが、ドキュメントを見た感じ合った方がよさそう。
  • ブレッドボードやジャンパー
    • VDD/GND用 2 or 4本
    • SCL/SDA用 2本
  • 開発環境
    • Visual Studio Code
  • 動作確認環境
    • TeraTerm

magic wandとは?

こちらのyoutube動画がわかりやすいが
WING(wの文字を書く), RING(円を描く), SLOPE(∠←こんな感じの記号を描く)の3種類のジェスチャ検知を行う推論器のサンプルコード

サンプル実行方法

githubにサンプルコードがあがっているが、具体的にどうするか書いていない。このリポジトリはpico向けなので、tensorflow lite for microcontorollers公式のリポジトリも確認したが、詳細は書いていなかった。
どの加速度センサ使っているかで変わるだろうからなと思い調べていたら
こちらの記事がまとめてくださっていたので参考にした。
なお、参考リンクはM5 Stack fireを使っているので、差分は自分で調整する必要がある。

このサンプルコードでは、accelerometer_handler.cppで加速度センサーの値読み取り、output_handler.cppで出力処理を実施する構成になっていて、
デモ用にaccelerometer_handler.cppではダミーデータの挿入、output_handler.cppではシリアルへの出力となっている。
このため、自身が使う加速度センサの処理をaccelerometer_handler.cppに下記、検出結果をどう処理するかをoutput_handler.cppに書けばよい。

picoへの移植

移植の流れ

今回は動作を確認したいだけなので、出力のoutput_handler.cppはいじらないこととする。
その場合、下記の3つを実施する。

  1. USBシリアル出力に変更(usb給電と合わせて出力を得たかったから)
  2. 加速度センサ(MMA8452Q)の値を読み取る処理を追加
  3. 推論器の入力形式に合わせる

なおaccelerometer_handler.cppのコードは記事の末尾に挙げる。

1. USBシリアル出力に変更

以前の投稿でもあったようにCMakeLists.txtでusbを有効にする。

pico_enable_stdio_usb(magic_wand 1)
pico_enable_stdio_uart(magic_wand 0)

今回はこれだけではだめで、初期化時にstdio_init_all();を呼ぶ必要がある。 今回は、main_function.cpp内のsetup()内で実行することにした。

#include "pico/stdlib.h"
// ...

void setup() {
  stdio_init_all();
  //...
}

// ...

こうすることによりPCとUSB経由でシリアル接続できるようになる。

ちなみに以前の投稿に書いた通り、build allだとエラーになるので、ターゲットを指定してビルドすること。

2. 加速度センサ(MMA8452Q)の値を読み取る処理を追加

picoでMMA8452Qを使う方法はこちらに書いたので、詳細はそちらにまかせる。

SetupAccelerometer()内で初期化、ReadAccelerometer()内でセンサ値の読み取りを行うように切り分けた。
なお、errorハンドリングは省いた。
詳細は末尾のコードを参照。

3. 推論器の入力形式に合わせる

ReadAccelerometer()の第二引数にfloat* 型のinputという引数が渡されていて、このinputに推論器に入力するデータを詰めることになる。
参考リンクには下記のように書かれており、加速度はmg単位のよう。
picoのexampleではmma8451_convert_accel()内で9.81かけている個所があるので、そこをコメントアウトし、1000倍することでmgになる。

ReadAccelerometer()はX、Y、Z方向の加速度を25Hzの頻度で計測し、結果を、X、Y、Zそれぞれmg単位(計測結果の1000倍)で時系列で内部のリングバッファに記録していきます。リングバッファの位置はbegin_indexが保持しています。 そして、リングバッファから配列inputへデータを最新のものからlength個コピーします。

また、x,y,zの座標軸の定義に関しては、同様に参考リンクによると

つまり、M5Stackのディスプレイを手前に向けて立てたときx, y, zが(0, 0, 1)、左へ90度傾けたとき(0, 1, 0)、ディスプレイを上に向けて机に置いた状態のとき(1, 0, 0)となるようにする必要があります。

とあり、
下図のように配置した場合に、usbポートを下にして持つ想定とすると (MMA8452Q.z, MMA8452Q.x, MMA8452Q.y) の順にすると、推論器が想定している 座標系と一致しそう。

リングバッファの処理は、参考リンクのコードを流用した。ただ、動作確認では同じジェスチャが何度も認識するのは困らないので、内部バッファ初期化のための処理は外した。また、40ms経過の判断もさぼってsleepを入れた。

結果

ジェスチャ認識せず。すごくたまにwを検知したので、推論器へのデータ挿入が致命的に間違っているわけではないと思う。
constant.hで閾値の調整ができるが、色々変えてもダメだった。(ちなみに参考リンクの時と構造が大きく変わっている。)

// These control the sensitivity of the detection algorithm. If you're seeing
// too many false positives or not enough true positives, you can try tweaking
// these thresholds. Often, increasing the size of the training set will give
// more robust results though, so consider retraining if you are seeing poor
// predictions.
constexpr float kDetectionThreshold = 0.8f;
constexpr int kPredictionHistoryLength = 5;
constexpr int kPredictionSuppressionDuration = 25;

終わりに

ひとまず自分の環境で動かせるところまでもっていけたが、うまく検出できなくて残念だ。
座標系が違うのかもしれないが、自分で学習もしてみたいので、自分の環境で学習して挑戦してみようと思う。
学習方法について調べた次の記事

accelerometer_hander.h

// reset_bufferを追加
extern bool ReadAccelerometer(tflite::ErrorReporter* error_reporter,
                              float* input, int length, bool reset_buffer);

accelerometer_handler.cpp

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "accelerometer_handler.h"
const int data_len = 600;
float save_data[data_len] = {0.0};
int begin_index = 0;
bool pending_initial_data = true;

// MMA8452Q用のコード。
// raspberry pi picoのexapmleから流用
// 本当はクラスで分離したいが、動作確認を優先してまとめている。

/**
 * Copyright (c) 2020 Raspberry Pi (Trading) Ltd.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <stdio.h>
#include <string.h>
#include "pico/stdlib.h"
#include "pico/binary_info.h"
#include "hardware/i2c.h"
/* Example code to talk to a MMA8451 triple-axis accelerometer.
   
   This reads and writes to registers on the board. 

   Connections on Raspberry Pi Pico board, other boards may vary.

   GPIO PICO_DEFAULT_I2C_SDA_PIN (On Pico this is GP4 (physical pin 6)) -> SDA on MMA8451 board
   GPIO PICO_DEFAULT_I2C_SCK_PIN (On Pico this is GP5 (physcial pin 7)) -> SCL on MMA8451 board
   VSYS (physical pin 39) -> VDD on MMA8451 board
   GND (physical pin 38)  -> GND on MMA8451 board

*/

const uint8_t ADDRESS = 0x1D;

//hardware registers

const uint8_t REG_X_MSB = 0x01;
const uint8_t REG_X_LSB = 0x02;
const uint8_t REG_Y_MSB = 0x03;
const uint8_t REG_Y_LSB = 0x04;
const uint8_t REG_Z_MSB = 0x05;
const uint8_t REG_Z_LSB = 0x06;
const uint8_t REG_DATA_CFG = 0x0E;
const uint8_t REG_CTRL_REG1 = 0x2A;

// Set the range and precision for the data 
const uint8_t range_config = 0x01; // 0x00 for ±2g, 0x01 for ±4g, 0x02 for ±8g
const float count = 2048; // 4096 for ±2g, 2048 for ±4g, 1024 for ±8g

uint8_t buf[2];

float mma8451_convert_accel(uint16_t raw_accel) {
    float acceleration;
    // Acceleration is read as a multiple of g (gravitational acceleration on the Earth's surface)
    // Check if acceleration < 0 and convert to decimal accordingly
    if ((raw_accel & 0x2000) == 0x2000) {
        raw_accel &= 0x1FFF;
        acceleration = (-8192 + (float) raw_accel) / count;
    } else {
        acceleration = (float) raw_accel / count;
    }
    //acceleration *= 9.81f;
    return acceleration;
}

#ifdef i2c_default
void mma8451_set_state(uint8_t state) {
    buf[0] = REG_CTRL_REG1;
    buf[1] = state; // Set RST bit to 1
    i2c_write_blocking(i2c_default, ADDRESS, buf, 2, false);
}
#endif

TfLiteStatus mma8452_setup(){
#if !defined(i2c_default) || !defined(PICO_DEFAULT_I2C_SDA_PIN) || !defined(PICO_DEFAULT_I2C_SCL_PIN)
#warning i2c/mma8451_i2c example requires a board with I2C pins
  puts("Default I2C pins were not defined");
  return kTfLiteError;
#else
  printf("Hello, MMA8451! Reading raw data from registers...\n");

  // This example will use I2C0 on the default SDA and SCL pins (4, 5 on a Pico)
  i2c_init(i2c_default, 400 * 1000);
  gpio_set_function(PICO_DEFAULT_I2C_SDA_PIN, GPIO_FUNC_I2C);
  gpio_set_function(PICO_DEFAULT_I2C_SCL_PIN, GPIO_FUNC_I2C);
  gpio_pull_up(PICO_DEFAULT_I2C_SDA_PIN);
  gpio_pull_up(PICO_DEFAULT_I2C_SCL_PIN);
  // Make the I2C pins available to picotool
  bi_decl(bi_2pins_with_func(PICO_DEFAULT_I2C_SDA_PIN, PICO_DEFAULT_I2C_SCL_PIN, GPIO_FUNC_I2C));

  // Enable standby mode
  mma8451_set_state(0x00);

  // Edit configuration while in standby mode
  buf[0] = REG_DATA_CFG;
  buf[1] = range_config;
  i2c_write_blocking(i2c_default, ADDRESS, buf, 2, false);

  // Enable active mode
  mma8451_set_state(0x01);

  return kTfLiteOk;
#endif
}

void update(){
  float x_acceleration;
  float y_acceleration;
  float z_acceleration;

  // Start reading acceleration registers for 2 bytes
  i2c_write_blocking(i2c_default, ADDRESS, &REG_X_MSB, 1, true);
  i2c_read_blocking(i2c_default, ADDRESS, buf, 2, false);
  float x_raw = buf[0] << 6 | buf[1] >> 2;
  x_acceleration = mma8451_convert_accel(x_raw);

  i2c_write_blocking(i2c_default, ADDRESS, &REG_Y_MSB, 1, true);
  i2c_read_blocking(i2c_default, ADDRESS, buf, 2, false);
  float y_raw = buf[0] << 6 | buf[1] >> 2;
  y_acceleration = mma8451_convert_accel(y_raw);

  i2c_write_blocking(i2c_default, ADDRESS, &REG_Z_MSB, 1, true);
  i2c_read_blocking(i2c_default, ADDRESS, buf, 2, false);
  float z_raw = buf[0] << 6 | buf[1] >> 2;
  z_acceleration = mma8451_convert_accel(z_raw);

  save_data[begin_index++] = z_acceleration * 1000;
  save_data[begin_index++] = x_acceleration * 1000;
  save_data[begin_index++] = y_acceleration * 1000;

  if (begin_index >= data_len) {
    begin_index = 0;
  }
}

// ここからaccelerometer_handlerの処理

TfLiteStatus SetupAccelerometer(tflite::ErrorReporter* error_reporter) {
  return mma8452_setup();
}

bool ReadAccelerometer(tflite::ErrorReporter* error_reporter, float* input,
                       int length, bool reset_buffer) {
  // Clear the buffer if required, e.g. after a successful prediction
  if (reset_buffer) {
    memset(save_data, 0, data_len * sizeof(float));
    begin_index = 0;
    pending_initial_data = true;
    // Wait 10ms after a reset to avoid hang
    sleep_ms(10);
  }
  
  update(); 

  if (pending_initial_data && begin_index >= length) {
    pending_initial_data = false;
  }

  if (pending_initial_data) {
    return false;
  }
  
  for (int i = 0; i < length; ++i) {
    int ring_array_index = begin_index + i - length;
    if (ring_array_index < 0) {
      ring_array_index += data_len;
    }
    input[i] = save_data[ring_array_index];
  }
  return true;
}

main_function.cpp should_clear_bufferを追加している。

// ...

namespace {
// ...
  
bool should_clear_buffer = false;
}

// ... 

void loop() {
  // Attempt to read new data from the accelerometer.
  bool got_data =
      ReadAccelerometer(error_reporter, model_input->data.f, input_length, should_clear_buffer);
      
  // Don't try to clear the buffer again
  should_clear_buffer = false;
  // If there was no new data, wait until next time.
  if (!got_data){
    sleep_ms(40);
    return;
  }

  // Run inference, and report any error.
  TfLiteStatus invoke_status = interpreter->Invoke();
  if (invoke_status != kTfLiteOk) {
    TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on index: %d\n",
                         begin_index);
    return;
  }
  // Analyze the results to obtain a prediction
  int gesture_index = PredictGesture(interpreter->output(0)->data.f);
  should_clear_buffer = gesture_index < 3;
  // Produce an output
  HandleOutput(error_reporter, gesture_index);
}

思ったより長かったのでgistあたりにあげればよかった

Discussion

コメントにはログインが必要です。