import sys import engine def print_header(): print("=" * 55) print(" Heat Equation PINN — ∂T/∂t = α ∂²T/∂x²") print(" Robin BC (x=0, x=L) + point source @ X_SRC") print("=" * 55) def _ask_float(prompt, default): val = input(prompt).strip() try: return float(val) except ValueError: return default def _ask_int(prompt, default): val = input(prompt).strip() return int(val) if val.isdigit() else default def main_menu(): print("\nInitializing collocation points...") data = engine.prepare_data() print(f"Ready — device: {data['device']}\n") while True: print("\n" + "-" * 30) print(" MAIN MENU") print("-" * 30) print("1. Train New Model") print("2. Evaluate Model (L2 vs analytical)") print("3. Visualize Temperature Field") print("0. Exit") print("-" * 30) choice = input("Select an option (0-3): ").strip() if choice == '1': epochs = _ask_int("Epochs (default 5000): ", 5000) engine.train_model(data, epochs=epochs) elif choice == '2': engine.evaluate_model(data) elif choice == '3': engine.generate_visualization(data) elif choice == '0': print("Exiting.") sys.exit(0) else: print("Invalid choice.") if __name__ == "__main__": print_header() main_menu()