-
Notifications
You must be signed in to change notification settings - Fork 38
/
modeltest.js
79 lines (61 loc) · 2.29 KB
/
modeltest.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
const tf = require('@tensorflow/tfjs-node');
const fs = require('fs');
let mirNetModel;
let modelInfo;
const imageSize = 512;
async function loadModel() {
try {
// Warm up the model
if (!mirNetModel) {
modelInfo = await tf.node.getMetaGraphsFromSavedModel('./model');
console.log(await modelInfo);
mirNetModel = await tf.node.loadSavedModel(
'./model'
);
return await mirNetModel;
}
} catch (error) {
console.log(error);
}
}
const predict = async () => {
try {
await loadModel();
console.log("Inside predict");
let image = fs.readFileSync("input.PNG");
// image = new Uint8Array(image);
// Decode the image into a tensor.
let imageTensor = await tf.node.decodePng(image, 3);
imageTensor = tf.image.resizeBilinear(imageTensor, size = [imageSize, imageSize]);
imageTensor = tf.cast(imageTensor, "float32");
imageTensor = tf.div(imageTensor, tf.scalar(255.0));
let input = imageTensor.expandDims(0);
// Feed the image tensor into the model for inference.
const startTime = tf.util.now();
let outputTensor = await mirNetModel.predict(input);
const endTime = tf.util.now();
console.log(endTime - startTime);
console.log("After Predict");
outputTensor = tf.reshape(outputTensor, [512, 512, 3]);
// outputTensor = outputTensor.squeeze();
// outputTensor = new Uint8Array(outputTensor);
// let factor = tf.onesLike(outputTensor);
// factor = tf.mul(factor, tf.min(outputTensor));
// outputTensor = tf.add(outputTensor, factor);
// const mulFactor = tf.max(outputTensor) / 255.0;
// outputTensor = tf.mul(outputTensor, mulFactor);
// outputTensor = tf.mul(outputTensor, factor);
outputTensor = tf.mul(outputTensor, tf.scalar(255.0));
outputTensor = tf.clipByValue(outputTensor, 0, 255);
outputTensor = await tf.node.encodePng(outputTensor);
fs.writeFileSync("output.Png", outputTensor);
} catch (error) {
console.log(error);
}
};
(async () => {
await predict();
console.log("DONE");
})()
// fs.writeFileSync("./uploads/NEW-1.png", a);
// fs.readFileSync()