개요
이 포스트는 이전 포스트에서 소개한 기타 사운드 분류 로직을 머신러닝 기반 분류기로 업그레이드하는 과정을 다룹니다.
기타 사운드 분류기에 대한 개념적 설명과 개발 동기에 대해서는 이전 포스트에서 상세히 다루었으니, 본 포스트에서는 YAMNet 모델을 플러터 앱에 온디바이스 형태로 탑재하고, 실시간 기타 사운드 분류 로직을 구현하는 과정을 주로 다루겠습니다.
YAMNet이 뭔가요?
YAMNet은 구글에서 공개한 범용 오디오 이벤트 분류 모델입니다.
온디바이스 탑재를 염두애 두어 가볍고 빠르게 만들어졌으며, 16 kHz 모노로 표준화된 입력 파형을 받아, 멜 스펙토그램(64 mel bands)으로 변환 후 분류를 진행하는 구조로 되어있습니다.
AudioSet은 구글에서 공개한 대규모 오디오 데이터셋으로, 100만개 이상의 오디오 클립을 포함하고 있으며, YAMNet은 이 데이터셋과 유튜브 음원 크롤링을 통해 학습되었습니다.
출력은 521 길이의 확률 벡터로 이루어지며, 이 521개의 클래스 중에는 guitar, acoustic guitar, electric guitar, bass guitar, strum, pluck, music 등 기타 사운드로 추측할수 있는 클래스들이 존재하여, 이를 통해 분류 로직을 설계합니다.
실시간성의 빠른 성능, 모바일 탑재 용이성, 그리고 이미 사전 학습된 기타 관련 데이터셋 등등 기존 룰 베이스 로직의 실패에서 얻은 요구사항을 상당수 만족하여, YAMNet을 기반으로 기타 사운드 분류기를 구현하기로 결정했습니다.
모델 설치 및 앱에 탑재
이제 본격적인 구현 과정을 설명드립니다.
캐글 페이지에서 모델 파일과 클래스 라벨 파일을 다운로드 받았다면(허깅 페이스에서는 YAMNet 모델을 찾지 못했습니다), 플러터 프로젝트 내 Assets으로 등록 후 다음의 과정을 진행합니다.
$ flutter pub add tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';
//...
// 모델 로드
Future<void> _initializeModel() async {
try {
_interpreter = await Interpreter.fromAsset("assets/models/yamnet.tflite");
await _loadClassMap();
_isModelLoaded.value = true;
} catch (e) {
_isModelLoaded.value = false;
logger.e("$_tag Failed to load model: $e");
}
}
// 클래스 라벨 로드
Future<void> _loadClassMap() async {
try {
final String csv = await rootBundle.loadString("assets/models/yamnet_class_map.csv");
final List<String> lines = csv.split("\n");
final List<String> names = <String>[];
for (int i = 1; i < lines.length; i++) {
final String line = lines[i].trim();
if (line.isEmpty) continue;
final List<String> parts = line.split(",");
if (parts.length >= 3) {
names.add(parts[2].trim().replaceAll('"', ""));
}
}
_classNames = names;
} catch (e) {
logger.e("$_tag Failed to load class map: $e");
}
}
위 모델 및 클래스 라벨 코드를 실행하면 플러터 앱 내에 모델 파일과 클래스 라벨 파일이 등록되지만, 일부 환경에서 실행하지 못하는 경우가 있습니다.
이 경우, 각 네이티브 설정 파일에서 오디오 관련 권한을 확인해보고, 특히 안드로이드 릴리즈 모드에서 실행이 불가할 경우 proguard 설정을 다음과 같이 변경해 해결할 수 있습니다.
# Keep TensorFlow Lite GPU delegate classes
-keep class org.tensorflow.lite.gpu.** { *; }
-keep class org.tensorflow.lite.nnapi.** { *; }
-keep class org.tensorflow.lite.support.** { *; }
-keep class org.tensorflow.lite.task.** { *; }
# Keep TFLite Java API
-keep class org.tensorflow.lite.** { *; }
# Keep Flutter plugins reflection-based classes
-keep class io.flutter.** { *; }
-keep class io.flutter.plugins.** { *; }
# Keep Play Core splitinstall classes used by Flutter deferred components manager
-keep class com.google.android.play.core.splitinstall.** { *; }
-keep class com.google.android.play.core.splitcompat.** { *; }
-keep class com.google.android.play.core.tasks.** { *; }
-dontwarn com.google.android.play.**
# Keep Kotlin metadata (avoid reflective issues)
-keep class kotlin.Metadata { *; }
-dontwarn org.tensorflow.**
데이터 전처리
활용 예시
이제 모델을 앱에서 사용할수 있게 되었으니, 기본적인 마이크를 통한 YAMNet 모델 테스트를 진행해봅니다.
import "dart:io";
import "package:record/record.dart";
Future<void> runMicOnce(YamnetRunner runner, {int ms = 1000}) async {
final recorder = Record();
if (!await recorder.hasPermission()) {
await recorder.requestPermission();
}
final String tmpPath = "${Directory.systemTemp.path}/yamnet_tmp.wav";
await recorder.start(
encoder: AudioEncoder.wav, // PCM16 WAV
samplingRate: 16000,
numChannels: 1,
path: tmpPath,
);
await Future.delayed(Duration(milliseconds: ms));
final path = await recorder.stop(); // path == tmpPath
if (path == null) return;
final bytes = await File(path).readAsBytes();
final pcm = runner._decodePcm16Mono16k(bytes); // 위 유틸 재사용
final input = runner._makeModelInput(pcm);
final scores = runner._infer(input);
final result = runner._topK(scores);
print(result["topK"]);
}
위 예시는 record 패키지를 통해 마이크 입력을 받아, 1초 동안 녹음 후 모델에 전달하는 예시입니다.
위 코드대로면, 1초간 녹음 -> YAMNet 모델에 전달 -> 결과 출력 순으로 간단하게 모델 사용이 진행됩니다.