Unanswered
Hi! Is There Something Happening With The
Hey AgitatedDove14 does this work for you?
` from argparse import ArgumentParser
from tensorflow.keras import utils as np_utils
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow as tf
from clearml import Task
class Linear(tf.keras.Model):
def init(self, in_shape=(784,), num_classes=10):
super().init()
self.linear = Dense(num_classes, input_shape=in_shape, activation="softmax")
def call(self, inputs, training=None, mask=None):
return self.linear(inputs)
def main():
parser = ArgumentParser()
parser.add_argument("--output-uri", type=str, required=False)
args = parser.parse_args()
# the data, shuffled and split between train and test sets
nb_classes = 10
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784).astype("float32") / 255.0
X_test = X_test.reshape(10000, 784).astype("float32") / 255.0
print(X_train.shape[0], "train samples")
print(X_test.shape[0], "test samples")
# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Linear()
model.compile(
loss="categorical_crossentropy", optimizer=Adam(), metrics=["accuracy"]
)
model_checkpoint = ModelCheckpoint(
"best_model.hdf5", save_best_only=True, save_weights_only=True
)
# Connecting ClearML
task = Task.init(
project_name="examples", task_name="Upload problem", output_uri=args.output_uri
)
history = model.fit(
X_train,
Y_train,
batch_size=128,
epochs=5,
callbacks=[model_checkpoint],
verbose=1,
validation_data=(X_test, Y_test),
)
if name == "main":
main() `
156 Views
0
Answers
3 years ago
one year ago
Tags