pytorch 모바일 배포의 Helloworld 사용
Androidstudio 4.1 설치
이 항목 클론
git clone https://github.com/pytorch/android-demo-app.git
androidstudio를 사용하여android-demo-app의 HelloWordApp을 엽니다.열면androidstudio가 자동으로 의존을 생성합니다. 기다리기만 하면 됩니다.
이 코드는 이미 정부에서 쓴 것이기 때문에
공식 튜토리얼의 코드가 어디에 있는지 열어주세요.
이 문장
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
위치HelloWorldApp\app\build.gradle
안에 있는 모든 코드
apply plugin: 'com.android.application'
repositories {
jcenter()
}
android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.helloworld"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
이 문장
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
다 여기 있어요.HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java
모든 코드
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
Build에서 Build Bundile APK의 Build APK를 선택하시면 됩니다.생성된 apk는
HelloWorldApp\app\build\outputs\apk\debug
이거는 바로 설치할 수 있어요.설치 후 고정된 사진인데 고정된 사진을 검출한 거예요.
이것은 하나의 예입니다. 만약에 당신이 단지 자신의 모델 호출이 이 프로젝트의 수정 모델과 모델 불러오는 데 성공할 수 있는지 테스트하고 싶을 뿐이라면
이 프로젝트 모델은resnet18입니다. 이어서resnet50으로 바꿉니다.
모델 변환 코드는 다음과 같습니다.
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") # build
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()
model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000
resnet.save('model.pt')
if __name__ == '__main__':
pass
이 저장된 모형을 아래 경로의 모형을 덮어씁니다(덮어쓰기 전에 원래 모델을 백업하는 것이 좋습니다. 여기서 원래 모델의 이름을 model_1.pt로 수정하는 것을 선택합니다.)
HelloWorldApp\app\src\main\assets\model.pt
덮어쓰기에 성공하면 다시 패키지 작업을 수행합니다. (Build에서 Build Bundile APK의 Build APK를 선택하면 됩니다.생성된 apk는
HelloWorldApp\app\build\outputs\apk\debug)
파일을 열고 123M의 apk를 발견했습니다. 이전의 apk는 73M이었습니다.설치 및 테스트
완벽하게 열어서 모든 resnet 시리즈가 이 프로젝트를 통해 진화할 수 있다는 거예요.
이는pytorch이동단배치의helloworld에 대한 사용에 관한 글을 소개합니다. 더 많은pytorch이동단배치의helloworld에 대한 내용은 이전의 글을 검색하거나 아래의 관련 글을 계속 훑어보십시오. 앞으로 많은 응원 부탁드립니다!
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
정확도에서 스케일링의 영향데이터셋 스케일링은 데이터 전처리의 주요 단계 중 하나이며, 데이터 변수의 범위를 줄이기 위해 수행됩니다. 이미지와 관련하여 가능한 최소-최대 값 범위는 항상 0-255이며, 이는 255가 최대값임을 의미합니다. 따...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.