Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Upload Pre-trained Models for Fine Tuning. #896

Merged
merged 1 commit into from
Aug 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def get_network_from_previous(self, previous_network, use_same_dataset):

return network

@override
def get_network_from_path(self, path):
"""
return network object from a file path
"""
network = caffe_pb2.NetParameter()

with open(path) as infile:
text_format.Merge(infile.read(), network)

return network

@override
def get_network_visualization(self, desc):
"""
Expand Down
6 changes: 6 additions & 0 deletions digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def get_network_from_previous(self, previous_network, use_same_dataset):
"""
raise NotImplementedError('Please implement me')

def get_network_from_path(self, path):
"""
return network object from a file path
"""
raise NotImplementedError('Please implement me')

def get_network_visualization(self, desc):
"""
return visualization of network
Expand Down
9 changes: 9 additions & 0 deletions digits/frameworks/torch_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def get_network_from_previous(self, previous_network, use_same_dataset):
# return the same network description
return previous_network

@override
def get_network_from_path(self,path):
"""
return network object from a file path
"""
with open(path, 'r') as f:
network=f.read()
return network

@override
def validate_network(self, data):
"""
Expand Down
9 changes: 9 additions & 0 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def validate_lr_multistep_values(form, field):
choices = [
('standard', 'Standard network'),
('previous', 'Previous network'),
('pretrained', 'Pretrained network'),
('custom', 'Custom network'),
],
default='standard',
Expand Down Expand Up @@ -259,6 +260,14 @@ def validate_lr_multistep_values(form, field):
],
)

pretrained_networks = wtforms.RadioField('Pretrained Networks',
choices = [],
validators = [
validate_required_iff(method='pretrained'),
selection_exists_in_choices,
],
)

custom_network = utils.forms.TextAreaField('Custom Network',
validators = [
validate_required_iff(method='custom'),
Expand Down
26 changes: 26 additions & 0 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from digits.config import config_value
from digits.dataset import ImageClassificationDatasetJob
from digits.inference import ImageInferenceJob
from digits.pretrained_model.job import PretrainedModelJob
from digits.status import Status
from digits.utils import filesystem as fs
from digits.utils.forms import fill_form_if_cloned, save_form_to_job
Expand Down Expand Up @@ -67,6 +68,7 @@ def new():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -78,6 +80,7 @@ def new():
frameworks = frameworks.get_frameworks(),
previous_network_snapshots = prev_network_snapshots,
previous_networks_fullinfo = get_previous_networks_fulldetails(),
pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(),
multi_gpu = config_value('caffe_root')['multi_gpu'],
)

Expand All @@ -95,6 +98,7 @@ def create():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -110,6 +114,7 @@ def create():
frameworks = frameworks.get_frameworks(),
previous_network_snapshots = prev_network_snapshots,
previous_networks_fullinfo = get_previous_networks_fulldetails(),
pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(),
multi_gpu = config_value('caffe_root')['multi_gpu'],
), 400

Expand Down Expand Up @@ -196,6 +201,14 @@ def create():
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break

elif form.method.data == 'pretrained':
pretrained_job = scheduler.get_job(form.pretrained_networks.data)
model_def_path = pretrained_job.get_model_def_path()
weights_path = pretrained_job.get_weights_path()

network = fw.get_network_from_path(model_def_path)
pretrained_model = weights_path

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
pretrained_model = form.custom_network_snapshot.data.strip()
Expand Down Expand Up @@ -700,3 +713,16 @@ def get_previous_network_snapshots():
prev_network_snapshots.append(e)
return prev_network_snapshots

def get_pretrained_networks():
return [(j.id(), j.name()) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_pretrained_networks_fulldetails():
return [(j) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]
27 changes: 26 additions & 1 deletion digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .forms import GenericImageModelForm
from .job import GenericImageModelJob
from digits.pretrained_model.job import PretrainedModelJob
from digits import extensions, frameworks, utils
from digits.config import config_value
from digits.dataset import GenericDatasetJob, GenericImageDatasetJob
Expand All @@ -33,7 +34,7 @@ def new(extension_id=None):
form.dataset.choices = get_datasets(extension_id)
form.standard_networks.choices = []
form.previous_networks.choices = get_previous_networks()

form.pretrained_networks.choices = get_pretrained_networks()
prev_network_snapshots = get_previous_network_snapshots()

## Is there a request to clone a job with ?clone=<job_id>
Expand All @@ -47,6 +48,7 @@ def new(extension_id=None):
frameworks=frameworks.get_frameworks(),
previous_network_snapshots=prev_network_snapshots,
previous_networks_fullinfo=get_previous_networks_fulldetails(),
pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(),
multi_gpu=config_value('caffe_root')['multi_gpu'],
)

Expand All @@ -66,6 +68,7 @@ def create(extension_id=None):
form.dataset.choices = get_datasets(extension_id)
form.standard_networks.choices = []
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -84,6 +87,7 @@ def create(extension_id=None):
frameworks=frameworks.get_frameworks(),
previous_network_snapshots=prev_network_snapshots,
previous_networks_fullinfo=get_previous_networks_fulldetails(),
pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(),
multi_gpu=config_value('caffe_root')['multi_gpu'],
), 400

Expand Down Expand Up @@ -159,6 +163,13 @@ def create(extension_id=None):
raise werkzeug.exceptions.BadRequest(
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break
elif form.method.data == 'pretrained':
pretrained_job = scheduler.get_job(form.pretrained_networks.data)
model_def_path = pretrained_job.get_model_def_path()
weights_path = pretrained_job.get_weights_path()

network = fw.get_network_from_path(model_def_path)
pretrained_model = weights_path

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
Expand Down Expand Up @@ -638,6 +649,20 @@ def get_previous_network_snapshots():
return prev_network_snapshots


def get_pretrained_networks():
return [(j.id(), j.name()) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_pretrained_networks_fulldetails():
return [(j) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_view_extensions():
"""
return all enabled view extensions
Expand Down
35 changes: 22 additions & 13 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import caffe_pb2

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 4
PICKLE_VERSION = 5

# Constants
CAFFE_SOLVER_FILE = 'solver.prototxt'
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, **kwargs):
self.solver = None

self.solver_file = CAFFE_SOLVER_FILE
self.original_file = CAFFE_ORIGINAL_FILE
self.model_file = CAFFE_ORIGINAL_FILE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use network_file instead of model_file here? I try to be consistent in calling the .prototxt file a "network description" and not calling it a model unless it has weights attached to it.

The nomenclature I'd like to migrate to (but don't fully support yet) is:

  • "network" for a .prototxt file
  • "model" or "trained model" for a .prototxt and a corresponding .caffemodel file
  • "training" for a group of models

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nevermind. We do the same for Torch. Rats.

However, you will want to be more careful with the upgrade path here.

  • Bump the pickle version
  • In __setstate__, upgrade from original_file to model_file appropriately

self.train_val_file = CAFFE_TRAIN_VAL_FILE
self.snapshot_prefix = CAFFE_SNAPSHOT_PREFIX
self.deploy_file = CAFFE_DEPLOY_FILE
Expand Down Expand Up @@ -140,6 +140,13 @@ def __setstate__(self, state):
# So you can't need this upgrade and we can ignore the error.
pass

if state['pickver_task_caffe_train'] <= 4:
if hasattr(self,"original_file"):
self.model_file = self.original_file
del self.original_file
else:
self.model_file = None

self.pickver_task_caffe_train = PICKLE_VERSION

# Make changes to self
Expand Down Expand Up @@ -256,7 +263,7 @@ def save_files_classification(self):
Save solver, train_val and deploy files to disk
"""
# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
with open(self.path(self.model_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

network = cleanedUpClassificationNetwork(self.network, len(self.get_labels()))
Expand Down Expand Up @@ -555,7 +562,7 @@ def save_files_generic(self):
assert train_feature_db_path is not None, 'Training images are required'

# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
with open(self.path(self.model_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

### Split up train_val and deploy layers
Expand Down Expand Up @@ -1084,13 +1091,14 @@ def get_task_stats(self,epoch=-1):
}

# These attributes only available in more recent jobs:
if hasattr(self,"original_file"):
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"network file": self.original_file,
"digits version": self.digits_version
})
if hasattr(self,"model_file"):
if self.model_file is not None:
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"model file": self.model_file,
"digits version": self.digits_version
})

if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})
Expand Down Expand Up @@ -1485,8 +1493,9 @@ def get_model_files(self):
"Network (train/val)": self.train_val_file,
"Network (deploy)": self.deploy_file
}
if hasattr(self,"original_file"):
model_files.update({"Network (original)": self.original_file})
if hasattr(self,"model_file"):
if self.model_file is not None:
model_files.update({"Network (original)": self.model_file})
return model_files

@override
Expand Down
11 changes: 10 additions & 1 deletion digits/model/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from . import images as model_images
from . import ModelJob
from digits.pretrained_model.job import PretrainedModelJob
from digits import frameworks, extensions
from digits.utils import time_filters
from digits.utils.routing import request_wants_json
Expand Down Expand Up @@ -76,13 +77,21 @@ def customize():
elif epoch == -1:
snapshot = job.train_task().pretrained_model
else:

for filename, e in job.train_task().snapshots:
if e == epoch:
snapshot = job.path(filename)
break

if isinstance(job,PretrainedModelJob):
model_def = open(job.get_model_def_path(),'r')
network = model_def.read()
snapshot = job.get_weights_path()
else:
network = job.train_task().get_network_desc()

return json.dumps({
'network': job.train_task().get_network_desc(),
'network': network,
'snapshot': snapshot
})

Expand Down
3 changes: 3 additions & 0 deletions digits/pretrained_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
from .job import PretrainedModelJob
57 changes: 57 additions & 0 deletions digits/pretrained_model/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from . import tasks
import digits.frameworks
from digits.job import Job
from digits.utils import subclass, override
from digits.pretrained_model.tasks import UploadPretrainedModelTask

@subclass
class PretrainedModelJob(Job):
"""
A Job that uploads a pretrained model
"""

def __init__(self, weights_path, model_def_path, labels_path=None,framework="caffe",**kwargs):
super(PretrainedModelJob, self).__init__(persistent = False, **kwargs)

self.has_labels = labels_path is not None
self.framework = framework
self.tasks = []
self.tasks.append(UploadPretrainedModelTask(
weights_path,
model_def_path,
labels_path,
framework,
job_dir=self.dir()
))

def get_weights_path(self):
if self.framework == "caffe":
return self.dir()+"/model.caffemodel"
else:
return self.dir()+"/_Model.t7"

def get_model_def_path(self):
if self.framework == "caffe":
return self.dir()+"/original.prototxt"
else:
return self.dir()+"/original.lua"

@override
def job_type(self):
return "Pretrained Model"

@override
def __getstate__(self):
fields_to_save = ['_id', '_name', 'username', 'tasks', 'status_history', 'has_labels', 'framework']
full_state = super(PretrainedModelJob, self).__getstate__()
state_to_save = {}
for field in fields_to_save:
state_to_save[field] = full_state[field]
return state_to_save

@override
def __setstate__(self, state):
super(PretrainedModelJob, self).__setstate__(state)
3 changes: 3 additions & 0 deletions digits/pretrained_model/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
from .upload_pretrained import UploadPretrainedModelTask
Loading