63 lines
1.4 KiB
Python
63 lines
1.4 KiB
Python
|
|
import sys
|
|||
|
|
import engine
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_header():
|
|||
|
|
print("=" * 55)
|
|||
|
|
print(" Heat Equation PINN — ∂T/∂t = α ∂²T/∂x²")
|
|||
|
|
print(" Neumann BC (x=0) + Robin BC (x=L)")
|
|||
|
|
print("=" * 55)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _ask_float(prompt, default):
|
|||
|
|
val = input(prompt).strip()
|
|||
|
|
try:
|
|||
|
|
return float(val)
|
|||
|
|
except ValueError:
|
|||
|
|
return default
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _ask_int(prompt, default):
|
|||
|
|
val = input(prompt).strip()
|
|||
|
|
return int(val) if val.isdigit() else default
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main_menu():
|
|||
|
|
print("\nInitializing collocation points...")
|
|||
|
|
data = engine.prepare_data()
|
|||
|
|
print(f"Ready — device: {data['device']}\n")
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
print("\n" + "-" * 30)
|
|||
|
|
print(" MAIN MENU")
|
|||
|
|
print("-" * 30)
|
|||
|
|
print("1. Train New Model")
|
|||
|
|
print("2. Evaluate Model (L2 vs analytical)")
|
|||
|
|
print("3. Visualize Temperature Field")
|
|||
|
|
print("0. Exit")
|
|||
|
|
print("-" * 30)
|
|||
|
|
|
|||
|
|
choice = input("Select an option (0-3): ").strip()
|
|||
|
|
|
|||
|
|
if choice == '1':
|
|||
|
|
epochs = _ask_int("Epochs (default 5000): ", 5000)
|
|||
|
|
engine.train_model(data, epochs=epochs)
|
|||
|
|
|
|||
|
|
elif choice == '2':
|
|||
|
|
engine.evaluate_model(data)
|
|||
|
|
|
|||
|
|
elif choice == '3':
|
|||
|
|
engine.generate_visualization(data)
|
|||
|
|
|
|||
|
|
elif choice == '0':
|
|||
|
|
print("Exiting.")
|
|||
|
|
sys.exit(0)
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
print("Invalid choice.")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
print_header()
|
|||
|
|
main_menu()
|