Added additional check for url and path in discriminator model params

This commit is contained in:
w4nderlust
2019-11-29 20:00:43 -08:00
committed by Julien Chaumond
parent f10b925015
commit f42816e7fc

View File

@@ -317,8 +317,11 @@ def get_classifier(
).to(device) ).to(device)
if "url" in params: if "url" in params:
resolved_archive_file = cached_path(params["url"]) resolved_archive_file = cached_path(params["url"])
else: elif "path" in params:
resolved_archive_file = params["path"] resolved_archive_file = params["path"]
else:
raise ValueError("Either url or path have to be specified "
"in the discriminator model parameters")
classifier.load_state_dict( classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device)) torch.load(resolved_archive_file, map_location=device))
classifier.eval() classifier.eval()