JavaScript 库
楷模和层数是重要的组成部分机器学习。
对于不同的机器学习任务,您必须将不同类型的层组合到一个模型中,该模型可以使用数据进行训练以预测未来值。
TensorFlow.js 支持不同类型楷模以及不同类型的层数。
TensorFlow模型是一个神经网络与一个或多个层数。
Tensorflow 项目具有以下典型工作流程:
假设您知道一个定义直线的函数:
Y = 1.2X + 5
然后您可以使用 JavaScript 公式计算任何 y 值:
y = 1.2 * x + 5;
为了演示 Tensorflow.js,我们可以训练 Tensorflow.js 模型来根据 X 输入预测 Y 值。
TensorFlow 模型不知道该函数。
// Create Training Data
const xs = tf.tensor([0, 1, 2, 3, 4]);
const ys = xs.mul(1.2).add(5);
// Define a Linear Regression Model
const model = tf.sequential();
model.add(tf.layers.dense({units:1, inputShape:[1]}));
// Specify Loss and Optimizer
model.compile({loss:'meanSquaredError', optimizer:'sgd'});
// Train the Model
model.fit(xs, ys, {epochs:500}).then(() => {myFunction()});
// Use the Model
function myFunction() {
const xArr = [];
const yArr = [];
for (let x = 0; x <= 10; x++) {
xArr.push(x);
let result = model.predict(tf.tensor([Number(x)]));
result.data().then(y => {
yArr.push(Number(y));
if (x == 10) {plot(xArr, yArr)};
});
}
}
创建具有 5 个 x 值的张量 (xs):
const xs = tf.tensor([0, 1, 2, 3, 4]);
创建具有 5 个正确 y 答案的张量 (ys)(将 xs 乘以 1.2 再加上 5):
const ys = xs.mul(1.2).add(5);
创建顺序模式:.
const model = tf.sequential();
在顺序模型中,一层的输出是下一层的输入。
向模型添加一层致密层。
该层只有一个单位(张量),形状为 1(一维):
model.add(tf.layers.dense({units:1, inputShape:[1]}));
在密集层中,每个节点都连接到前一层中的每个节点。
使用meanSquaredError作为损失函数并使用sgd(随机梯度下降)作为优化器函数来编译模型:
model.compile({loss:'meanSquaredError', optimizer:'sgd'});
使用 500 次重复(epoch)训练模型(使用 xs 和 ys):
model.fit(xs, ys, {epochs:500}).then(() => {myFunction()});
模型训练完成后,您可以将其用于许多不同的目的。
此示例在给定 10 个 x 值的情况下预测 10 个 y 值,并调用函数在图表中绘制预测结果:
function myFunction() {
const xArr = [];
const yArr = [];
for (let x = 0; x <= 10; x++) {
let result = model.predict(tf.tensor([Number(x)]));
result.data().then(y => {
xArr.push(x);
yArr.push(Number(y));
if (x == 10) {display(xArr, yArr)};
});
}
}
此示例在给定 10 个 x 值的情况下预测 10 个 y 值,并调用函数来显示这些值:
function myFunction() {
const xArr = [];
const yArr = [];
for (let x = 0; x <= 10; x++) {
let result = model.predict(tf.tensor([Number(x)]));
result.data().then(y => {
xArr.push(x);
yArr.push(Number(y));
if (x == 10) {display(xArr, yArr)};
});
}
}
截取页面反馈部分,让我们更快修复内容!也可以直接跳过填写反馈内容!