我正在尝试使用Firebase机器学习套件在Android中利用我的模型。
我尝试指定不同的输入,但它没有起作用。 我需要找到一种方法,在Android上使用从Firebase获取的TensorFlow模型进行预测。
目前我只能在Android输入一个值。 如何指定2个输入,以便一个输入是用户ID,另一个输入是电影ID?
private void setupModel() {
FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("Recommender-Model").build();
FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
.requireWifi()
.build();
FirebaseModelManager.getInstance().download(remoteModel, conditions)
.addOnCompleteListener(new OnCompleteListener<Void>() {
@Override
public void onComplete(@NonNull Task<Void> task) {
if (task.isSuccessful()) {
Toast.makeText(getApplicationContext(), "Downloaded", Toast.LENGTH_SHORT).show();
} else {
Toast.makeText(getApplicationContext(), "Download failure!", Toast.LENGTH_SHORT).show();
}
}
});
FirebaseModelInputOutputOptions inputOutputOptions = null;
try {
inputOutputOptions = new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
.build();
} catch (FirebaseMLException e) {
e.printStackTrace();
}
float[][] input = new float[1][1];
input[0][0] = 1f;
FirebaseModelInputs inputs = null;
try {
inputs = new FirebaseModelInputs.Builder()
.add(input)
.build();
} catch (FirebaseMLException e) {
e.printStackTrace();
}
FirebaseModelInterpreterOptions interpreterOptions =
new FirebaseModelInterpreterOptions.Builder(remoteModel).build();
try {
FirebaseModelInterpreter.getInstance(interpreterOptions).run(inputs, inputOutputOptions)
.addOnSuccessListener(
new OnSuccessListener<FirebaseModelOutputs>() {
@Override
public void onSuccess(FirebaseModelOutputs result) {
float[][] predictedRating = result.getOutput(0);
Toast.makeText(getApplicationContext(), "Result Fetched", Toast.LENGTH_SHORT).show();
}
})
.addOnFailureListener(
new OnFailureListener() {
@Override
public void onFailure(@NonNull Exception e) {
Toast.makeText(getApplicationContext(), "Failure", Toast.LENGTH_SHORT).show();
}
});
} catch (FirebaseMLException e) {
e.printStackTrace();
}
}
TensorFlow中的预测函数如下所示:
model = Model(inputs = [u, m], outputs = x)
model.predict([test_user, test_movie], batch_size = 500)