diff --git a/model.py b/model.py index 47bfb5e..5f003a2 100644 --- a/model.py +++ b/model.py @@ -14,8 +14,7 @@ class HeatPINN(nn.Module): self.net = nn.Sequential(*layers) def forward(self, xt): - # Normalize t to [0,1] to avoid Tanh saturation for large t values (up to T_END) - x = xt[:, :1] + x = xt[:, :1] / config.L t = xt[:, 1:] / config.T_END return config.T_AMB + (config.Q_VAL * config.L / config.K) * self.net(torch.cat([x, t], dim=1))