如何为TensorFlow Lite模型指定Firebase模型输入,以便在Android上进行物品推荐预测

3

我正在尝试使用Firebase机器学习套件在Android中利用我的模型。

我尝试指定不同的输入,但它没有起作用。 我需要找到一种方法,在Android上使用从Firebase获取的TensorFlow模型进行预测。

目前我只能在Android输入一个值。 如何指定2个输入,以便一个输入是用户ID,另一个输入是电影ID?

private void setupModel() {
    FirebaseCustomRemoteModel remoteModel = new FirebaseCustomRemoteModel.Builder("Recommender-Model").build();
    FirebaseModelDownloadConditions conditions = new FirebaseModelDownloadConditions.Builder()
            .requireWifi()
            .build();
    FirebaseModelManager.getInstance().download(remoteModel, conditions)
            .addOnCompleteListener(new OnCompleteListener<Void>() {
                @Override
                public void onComplete(@NonNull Task<Void> task) {
                    if (task.isSuccessful()) {
                        Toast.makeText(getApplicationContext(), "Downloaded", Toast.LENGTH_SHORT).show();
                    } else {
                        Toast.makeText(getApplicationContext(), "Download failure!", Toast.LENGTH_SHORT).show();
                    }
                }
            });

    FirebaseModelInputOutputOptions inputOutputOptions = null;
    try {
        inputOutputOptions = new FirebaseModelInputOutputOptions.Builder()
                .setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
                .setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 1})
                .build();
    } catch (FirebaseMLException e) {
        e.printStackTrace();
    }

    float[][] input = new float[1][1];
    input[0][0] = 1f;

    FirebaseModelInputs inputs = null;
    try {
        inputs = new FirebaseModelInputs.Builder()
                .add(input)
                .build();
    } catch (FirebaseMLException e) {
        e.printStackTrace();
    }


    FirebaseModelInterpreterOptions interpreterOptions =
            new FirebaseModelInterpreterOptions.Builder(remoteModel).build();

    try {
        FirebaseModelInterpreter.getInstance(interpreterOptions).run(inputs, inputOutputOptions)
                .addOnSuccessListener(
                        new OnSuccessListener<FirebaseModelOutputs>() {
                            @Override
                            public void onSuccess(FirebaseModelOutputs result) {
                                float[][] predictedRating = result.getOutput(0);
                                Toast.makeText(getApplicationContext(), "Result Fetched", Toast.LENGTH_SHORT).show();
                            }
                        })
                .addOnFailureListener(
                        new OnFailureListener() {
                            @Override
                            public void onFailure(@NonNull Exception e) {
                                Toast.makeText(getApplicationContext(), "Failure", Toast.LENGTH_SHORT).show();
                            }
                        });
    } catch (FirebaseMLException e) {
        e.printStackTrace();
    }
}

TensorFlow中的预测函数如下所示:

model = Model(inputs = [u, m], outputs = x)
model.predict([test_user, test_movie], batch_size = 500)
1个回答

0

你应该能够使用:

    FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
        .add(u)  // add() as many input arrays as your model requires
        .add(m)
        .build();

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接