Added additional check for url and path in discriminator model params

This commit is contained in:
w4nderlust
2019-11-29 20:00:43 -08:00
committed by Julien Chaumond
parent f10b925015
commit f42816e7fc

View File

@@ -317,8 +317,11 @@ def get_classifier(
).to(device)
if "url" in params:
resolved_archive_file = cached_path(params["url"])
else:
elif "path" in params:
resolved_archive_file = params["path"]
else:
raise ValueError("Either url or path have to be specified "
"in the discriminator model parameters")
classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device))
classifier.eval()