SightX: Teaching the Model to Learn - The Training Loop
Day 12 & 13
The architecture is built. The preprocessing pipeline is done. The model knows how to accept an image and produce five output logits. What it does not know yet is what those logits should mean. That is what the training loop is for. This phase is where the model stops being a static structure and starts being something that learns.
It is also the phase with the most moving parts, the most ways to silently fail, and the most satisfying moment when the loss curve finally starts going down.
The Dataset Class: Teaching PyTorch to Read Your Data
Before the training loop can run, PyTorch needs to know how to load your data. That is not automatic. You must write a Dataset class that tells it exactly how to find an image, open it, and pair it with the correct label. It is essentially a contract given an index, returning a sample.
The EyePACS dataset gives you a CSV file mapping image filenames to DR grades, and a folder of JPEG retinal scans. The Dataset class reads the CSV with pandas, constructs the file path for each image, opens it with PIL, and returns the transformed image tensor alongside its integer label. Three methods. One responsibility.
The Data Loader wraps this class and handles everything else: batching images together, shuffling the order between epochs, and feeding data into the training loop efficiently. You define how to load one sample. PyTorch figures out how to load thirty-two at once.
Class Imbalance: The 73% Trap
The EyePACS dataset has a problem that is not obvious until you look at the label distribution: roughly 73% of images are Grade 0, meaning no diabetic retinopathy at all. The remaining 27% is split across four increasingly severe grades.
If you train without accounting for this, the model will take the path of least resistance. It will learn to predict Grade 0 for everything and sit comfortably at 73% accuracy while being completely useless for the cases that matter. A model that misses every Grade 3 and Grade 4 case is not a medical tool. It is a liability.
The fix is weighted loss. You calculate the inverse frequency of each grade in the training set and pass those weights into the loss function. Now a wrong prediction on a rare Grade 4 case costs the model significantly more than a wrong prediction on a common Grade O. The model is forced to pay attention to the minority classes instead of ignoring them.
This is one of the most important decisions in the entire training setup. Getting the architecture right matters. Getting the class weights right might matter more.
The Training Loop: What Actually Happens Each Epoch
An epoch is one full pass through the training data. Each epoch, the model sees every image in the dataset, makes a prediction, measures how wrong it was, and updates its weights to be slightly less wrong next time. Repeat this twenty times and the model has seen every image twenty times, each time learning from its mistakes.
Within each epoch, the data arrives in batches of 32 images. For each batch, there are four steps. The optimizer clears the gradients from the previous batch so they do not accumulate incorrectly. The batch passes through the model, producing 32 sets of five logits. The loss function compares those predictions to the correct labels and produces a single number representing how wrong the model was. Then backpropagation computes how much each of the 16 million trainable weights contributed to that wrongness.
The optimizer then uses those gradients to nudge every weight in the direction that reduces the error. That nudge is tiny, the learning rate controls how small. Too large and the weights overshoot. Too small and training takes forever. 0.0001 is a reasonable starting point.
After each epoch, the model switches to evaluation mode and runs through the validation set without updating any weights. This is the honest performance check. Training accuracy tells you how well the model memorized the training data. Validation accuracy tells you how well it generalized to images it has never seen before. The gap between them is where overfitting lives.
The Learning Rate Scheduler: Knowing When to Slow Down
Early in training, large weight updates are fine. The model is far from optimal and needs to move quickly. But as training progresses and the weights settle closer to a good solution, large updates become dangerous. They can overshoot the optimal point and destabilize what was already working.
The StepLR scheduler handles this automatically. Every seven epochs, the learning rate is multiplied by 0.1, dropping it by a factor of ten. The model starts training aggressively and finishes with fine-grained adjustments. It is the difference between roughing out a shape with a chisel and finishing it with sandpaper.
Checkpoint Saving: Only Keeping What Matters
Not every epoch produces a better model. Sometimes validation accuracy plateaus. Sometimes it dips. The checkpoint logic only saves when the validation accuracy beats the previous best. This means the file on disk always represents the peak of what the model achieved, not just the state after the final epoch.
What gets saved is the model’s state dictionary, the numerical values of every weight and bias in the network. The architecture is defined in code. The learned knowledge lives in that file. When the inference server loads the model, it reconstructs the architecture and fills it with those saved weights.
What This Phase Delivered:
A complete, runnable training pipeline. The Dataset class loads EyePACS images correctly. Class imbalance is handled with weighted loss so the model cannot cheat by ignoring rare grades. The training loop runs for 20 epochs with Adam optimization, a decaying learning rate, and automatic checkpoint saving. The validation loop provides an honest accuracy measure after every epoch.
The model has not been trained yet. That comes once the full dataset is staged and the environment is confirmed stable on M4. But the loop is ready. Every component is in place.
Next phase:
Wrapping the trained model in a Fast API server so the Node.js backend can send it an image and get a DR grade back. The model learns here. It gets deployed next.
Comments
Post a Comment