fix func signature (#10271)
This commit is contained in:
@@ -146,7 +146,7 @@ if is_torch_available():
|
|||||||
self.double_output = double_output
|
self.double_output = double_output
|
||||||
self.config = None
|
self.config = None
|
||||||
|
|
||||||
def forward(self, input_x=None, labels=None, **kwargs):
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
y = input_x * self.a + self.b
|
y = input_x * self.a + self.b
|
||||||
if labels is None:
|
if labels is None:
|
||||||
return (y, y) if self.double_output else (y,)
|
return (y, y) if self.double_output else (y,)
|
||||||
@@ -160,7 +160,7 @@ if is_torch_available():
|
|||||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||||
self.config = None
|
self.config = None
|
||||||
|
|
||||||
def forward(self, input_x=None, labels=None, **kwargs):
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
y = input_x * self.a + self.b
|
y = input_x * self.a + self.b
|
||||||
result = {"output": y}
|
result = {"output": y}
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -177,7 +177,7 @@ if is_torch_available():
|
|||||||
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
self.b = torch.nn.Parameter(torch.tensor(config.b).float())
|
||||||
self.double_output = config.double_output
|
self.double_output = config.double_output
|
||||||
|
|
||||||
def forward(self, input_x=None, labels=None, **kwargs):
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
y = input_x * self.a + self.b
|
y = input_x * self.a + self.b
|
||||||
if labels is None:
|
if labels is None:
|
||||||
return (y, y) if self.double_output else (y,)
|
return (y, y) if self.double_output else (y,)
|
||||||
|
|||||||
Reference in New Issue
Block a user