Added additional check for url and path in discriminator model params
This commit is contained in:
committed by
Julien Chaumond
parent
f10b925015
commit
f42816e7fc
@@ -317,8 +317,11 @@ def get_classifier(
|
|||||||
).to(device)
|
).to(device)
|
||||||
if "url" in params:
|
if "url" in params:
|
||||||
resolved_archive_file = cached_path(params["url"])
|
resolved_archive_file = cached_path(params["url"])
|
||||||
else:
|
elif "path" in params:
|
||||||
resolved_archive_file = params["path"]
|
resolved_archive_file = params["path"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Either url or path have to be specified "
|
||||||
|
"in the discriminator model parameters")
|
||||||
classifier.load_state_dict(
|
classifier.load_state_dict(
|
||||||
torch.load(resolved_archive_file, map_location=device))
|
torch.load(resolved_archive_file, map_location=device))
|
||||||
classifier.eval()
|
classifier.eval()
|
||||||
|
|||||||
Reference in New Issue
Block a user