Skip to content

train & predict

Function

train

Trains the selected model using the provided training data loader. The training process adapts based on the model type: for logistic regression, the number of epochs is managed internally, while for other models, the number of epochs can be specified.

predict

Generates predictions using the trained model on the provided data loader. This function abstracts the prediction process for different model types, ensuring a consistent interface.

Parameters

Name Type Description
train_loader DataLoader The data loader containing the training dataset.
epochs int, optional Number of training epochs (used for non-logistic models). Default is 10.
data_loader DataLoader The data loader containing the data for which predictions are to be made.

Return type

  • train: Depends on the underlying model's train method. Typically returns training metrics or status.
  • predict: Model-specific prediction output (e.g., numpy array, tensor, or list of predictions).

Returns

  • train: Returns the result of the model's training process, which may include training loss, accuracy, or a status indicator.
  • predict: Returns the predictions generated by the model for the input data.

Attributes Set

  • self.model_type: Sets the type of model being used (e.g., "cnn", "simple_cnn", "logistic").
  • self.model: The instantiated model object corresponding to the selected type.
  • self.device: The device (e.g., 'cpu' or 'cuda') on which the model will operate.

Example

from crisgi import CRISGI
import pickle
from torchvision import transforms
from torch.utils.data import DataLoader
from src.cnn.CNNModel import CNNModel
from src.simplecnn.SimpleCNNModel import SimpleCNNModel
from src.logistic.LogisticModel import LogisticModel
from src.util import ImageDataset
import src.plotting_sniee_time as pl

crisgi_obj = pickle.load(open("data/GSE30550_H3N2_crisgi_obj.pk", 'rb'))

# Set up the training data loader
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = ImageDataset(
    image_dir='./out1',
    label_csv='./out1/labels.csv',
    transform=transform,
    return_label=True
)

loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Set the model type and initialize the model
# 3L-CNN
crisgi_obj.set_model_type(model_type='cnn',ae_path="data/model/3L-CNN/GSE30550_H3N2_ae_model.pth", mlp_path="data/model/3L-CNN/GSE30550_H3N2_mlp_model.pth")
# 1L-CNN
crisgi_obj.set_model_type(model_type='simple_cnn')
# Logistic Regression
crisgi_obj.set_model_type(model_type='logistic', model_path="data/model/logistic/GSE30550_H3N2_log_model.pth",device = 'cpu')

# Train the model
crisgi_obj.train(loader, epochs=20)

# Generate predictions using the trained model
predictions = crisgi_obj.predict(loader)

# Print the predictions
print(predictions)