Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Upload Pretrained Models for Training
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucaszw committed Aug 5, 2016
1 parent 19f8ecf commit 89bbd5f
Show file tree
Hide file tree
Showing 35 changed files with 1,788 additions and 456 deletions.
12 changes: 12 additions & 0 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ def get_network_from_previous(self, previous_network, use_same_dataset):

return network

@override
def get_network_from_path(self, path):
"""
return network object from a file path
"""
network = caffe_pb2.NetParameter()

with open(path) as infile:
text_format.Merge(infile.read(), network)

return network

@override
def get_network_visualization(self, desc):
"""
Expand Down
6 changes: 6 additions & 0 deletions digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def get_network_from_previous(self, previous_network, use_same_dataset):
"""
raise NotImplementedError('Please implement me')

def get_network_from_path(self, path):
"""
return network object from a file path
"""
raise NotImplementedError('Please implement me')

def get_network_visualization(self, desc):
"""
return visualization of network
Expand Down
9 changes: 9 additions & 0 deletions digits/frameworks/torch_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def get_network_from_previous(self, previous_network, use_same_dataset):
# return the same network description
return previous_network

@override
def get_network_from_path(self,path):
"""
return network object from a file path
"""
with open(path, 'r') as f:
network=f.read()
return network

@override
def validate_network(self, data):
"""
Expand Down
9 changes: 9 additions & 0 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def validate_lr_multistep_values(form, field):
choices = [
('standard', 'Standard network'),
('previous', 'Previous network'),
('pretrained', 'Pretrained network'),
('custom', 'Custom network'),
],
default='standard',
Expand Down Expand Up @@ -259,6 +260,14 @@ def validate_lr_multistep_values(form, field):
],
)

pretrained_networks = wtforms.RadioField('Pretrained Networks',
choices = [],
validators = [
validate_required_iff(method='pretrained'),
selection_exists_in_choices,
],
)

custom_network = utils.forms.TextAreaField('Custom Network',
validators = [
validate_required_iff(method='custom'),
Expand Down
26 changes: 26 additions & 0 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from digits.config import config_value
from digits.dataset import ImageClassificationDatasetJob
from digits.inference import ImageInferenceJob
from digits.pretrained_model.job import PretrainedModelJob
from digits.status import Status
from digits.utils import filesystem as fs
from digits.utils.forms import fill_form_if_cloned, save_form_to_job
Expand Down Expand Up @@ -67,6 +68,7 @@ def new():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -78,6 +80,7 @@ def new():
frameworks = frameworks.get_frameworks(),
previous_network_snapshots = prev_network_snapshots,
previous_networks_fullinfo = get_previous_networks_fulldetails(),
pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(),
multi_gpu = config_value('caffe_root')['multi_gpu'],
)

Expand All @@ -95,6 +98,7 @@ def create():
form.standard_networks.choices = get_standard_networks()
form.standard_networks.default = get_default_standard_network()
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -110,6 +114,7 @@ def create():
frameworks = frameworks.get_frameworks(),
previous_network_snapshots = prev_network_snapshots,
previous_networks_fullinfo = get_previous_networks_fulldetails(),
pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(),
multi_gpu = config_value('caffe_root')['multi_gpu'],
), 400

Expand Down Expand Up @@ -196,6 +201,14 @@ def create():
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break

elif form.method.data == 'pretrained':
pretrained_job = scheduler.get_job(form.pretrained_networks.data)
model_def_path = pretrained_job.get_model_def_path()
weights_path = pretrained_job.get_weights_path()

network = fw.get_network_from_path(model_def_path)
pretrained_model = weights_path

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
pretrained_model = form.custom_network_snapshot.data.strip()
Expand Down Expand Up @@ -700,3 +713,16 @@ def get_previous_network_snapshots():
prev_network_snapshots.append(e)
return prev_network_snapshots

def get_pretrained_networks():
return [(j.id(), j.name()) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_pretrained_networks_fulldetails():
return [(j) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]
27 changes: 26 additions & 1 deletion digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .forms import GenericImageModelForm
from .job import GenericImageModelJob
from digits.pretrained_model.job import PretrainedModelJob
from digits import extensions, frameworks, utils
from digits.config import config_value
from digits.dataset import GenericDatasetJob, GenericImageDatasetJob
Expand All @@ -33,7 +34,7 @@ def new(extension_id=None):
form.dataset.choices = get_datasets(extension_id)
form.standard_networks.choices = []
form.previous_networks.choices = get_previous_networks()

form.pretrained_networks.choices = get_pretrained_networks()
prev_network_snapshots = get_previous_network_snapshots()

## Is there a request to clone a job with ?clone=<job_id>
Expand All @@ -47,6 +48,7 @@ def new(extension_id=None):
frameworks=frameworks.get_frameworks(),
previous_network_snapshots=prev_network_snapshots,
previous_networks_fullinfo=get_previous_networks_fulldetails(),
pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(),
multi_gpu=config_value('caffe_root')['multi_gpu'],
)

Expand All @@ -66,6 +68,7 @@ def create(extension_id=None):
form.dataset.choices = get_datasets(extension_id)
form.standard_networks.choices = []
form.previous_networks.choices = get_previous_networks()
form.pretrained_networks.choices = get_pretrained_networks()

prev_network_snapshots = get_previous_network_snapshots()

Expand All @@ -84,6 +87,7 @@ def create(extension_id=None):
frameworks=frameworks.get_frameworks(),
previous_network_snapshots=prev_network_snapshots,
previous_networks_fullinfo=get_previous_networks_fulldetails(),
pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(),
multi_gpu=config_value('caffe_root')['multi_gpu'],
), 400

Expand Down Expand Up @@ -159,6 +163,13 @@ def create(extension_id=None):
raise werkzeug.exceptions.BadRequest(
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break
elif form.method.data == 'pretrained':
pretrained_job = scheduler.get_job(form.pretrained_networks.data)
model_def_path = pretrained_job.get_model_def_path()
weights_path = pretrained_job.get_weights_path()

network = fw.get_network_from_path(model_def_path)
pretrained_model = weights_path

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
Expand Down Expand Up @@ -638,6 +649,20 @@ def get_previous_network_snapshots():
return prev_network_snapshots


def get_pretrained_networks():
return [(j.id(), j.name()) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_pretrained_networks_fulldetails():
return [(j) for j in sorted(
[j for j in scheduler.jobs.values() if isinstance(j, PretrainedModelJob)],
cmp=lambda x,y: cmp(y.id(), x.id())
)
]

def get_view_extensions():
"""
return all enabled view extensions
Expand Down
35 changes: 22 additions & 13 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import caffe_pb2

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 4
PICKLE_VERSION = 5

# Constants
CAFFE_SOLVER_FILE = 'solver.prototxt'
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, **kwargs):
self.solver = None

self.solver_file = CAFFE_SOLVER_FILE
self.original_file = CAFFE_ORIGINAL_FILE
self.model_file = CAFFE_ORIGINAL_FILE
self.train_val_file = CAFFE_TRAIN_VAL_FILE
self.snapshot_prefix = CAFFE_SNAPSHOT_PREFIX
self.deploy_file = CAFFE_DEPLOY_FILE
Expand Down Expand Up @@ -140,6 +140,13 @@ def __setstate__(self, state):
# So you can't need this upgrade and we can ignore the error.
pass

if state['pickver_task_caffe_train'] <= 4:
if hasattr(self,"original_file"):
self.model_file = self.original_file
del self.original_file
else:
self.model_file = None

self.pickver_task_caffe_train = PICKLE_VERSION

# Make changes to self
Expand Down Expand Up @@ -256,7 +263,7 @@ def save_files_classification(self):
Save solver, train_val and deploy files to disk
"""
# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
with open(self.path(self.model_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

network = cleanedUpClassificationNetwork(self.network, len(self.get_labels()))
Expand Down Expand Up @@ -555,7 +562,7 @@ def save_files_generic(self):
assert train_feature_db_path is not None, 'Training images are required'

# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
with open(self.path(self.model_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

### Split up train_val and deploy layers
Expand Down Expand Up @@ -1084,13 +1091,14 @@ def get_task_stats(self,epoch=-1):
}

# These attributes only available in more recent jobs:
if hasattr(self,"original_file"):
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"network file": self.original_file,
"digits version": self.digits_version
})
if hasattr(self,"model_file"):
if self.model_file is not None:
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"model file": self.model_file,
"digits version": self.digits_version
})

if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})
Expand Down Expand Up @@ -1485,8 +1493,9 @@ def get_model_files(self):
"Network (train/val)": self.train_val_file,
"Network (deploy)": self.deploy_file
}
if hasattr(self,"original_file"):
model_files.update({"Network (original)": self.original_file})
if hasattr(self,"model_file"):
if self.model_file is not None:
model_files.update({"Network (original)": self.model_file})
return model_files

@override
Expand Down
11 changes: 10 additions & 1 deletion digits/model/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from . import images as model_images
from . import ModelJob
from digits.pretrained_model.job import PretrainedModelJob
from digits import frameworks, extensions
from digits.utils import time_filters
from digits.utils.routing import request_wants_json
Expand Down Expand Up @@ -76,13 +77,21 @@ def customize():
elif epoch == -1:
snapshot = job.train_task().pretrained_model
else:

for filename, e in job.train_task().snapshots:
if e == epoch:
snapshot = job.path(filename)
break

if isinstance(job,PretrainedModelJob):
model_def = open(job.get_model_def_path(),'r')
network = model_def.read()
snapshot = job.get_weights_path()
else:
network = job.train_task().get_network_desc()

return json.dumps({
'network': job.train_task().get_network_desc(),
'network': network,
'snapshot': snapshot
})

Expand Down
3 changes: 3 additions & 0 deletions digits/pretrained_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
from .job import PretrainedModelJob
57 changes: 57 additions & 0 deletions digits/pretrained_model/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

from . import tasks
import digits.frameworks
from digits.job import Job
from digits.utils import subclass, override
from digits.pretrained_model.tasks import UploadPretrainedModelTask

@subclass
class PretrainedModelJob(Job):
"""
A Job that uploads a pretrained model
"""

def __init__(self, weights_path, model_def_path, labels_path=None,framework="caffe",**kwargs):
super(PretrainedModelJob, self).__init__(persistent = False, **kwargs)

self.has_labels = labels_path is not None
self.framework = framework
self.tasks = []
self.tasks.append(UploadPretrainedModelTask(
weights_path,
model_def_path,
labels_path,
framework,
job_dir=self.dir()
))

def get_weights_path(self):
if self.framework == "caffe":
return self.dir()+"/model.caffemodel"
else:
return self.dir()+"/_Model.t7"

def get_model_def_path(self):
if self.framework == "caffe":
return self.dir()+"/original.prototxt"
else:
return self.dir()+"/original.lua"

@override
def job_type(self):
return "Pretrained Model"

@override
def __getstate__(self):
fields_to_save = ['_id', '_name', 'username', 'tasks', 'status_history', 'has_labels', 'framework']
full_state = super(PretrainedModelJob, self).__getstate__()
state_to_save = {}
for field in fields_to_save:
state_to_save[field] = full_state[field]
return state_to_save

@override
def __setstate__(self, state):
super(PretrainedModelJob, self).__setstate__(state)
3 changes: 3 additions & 0 deletions digits/pretrained_model/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import
from .upload_pretrained import UploadPretrainedModelTask
Loading

0 comments on commit 89bbd5f

Please sign in to comment.