63 lines
1.4 KiB
Python
63 lines
1.4 KiB
Python
import sys
|
||
import engine
|
||
|
||
|
||
def print_header():
|
||
print("=" * 55)
|
||
print(" Heat Equation PINN — ∂T/∂t = α ∂²T/∂x²")
|
||
print(" Neumann BC (x=0) + Robin BC (x=L)")
|
||
print("=" * 55)
|
||
|
||
|
||
def _ask_float(prompt, default):
|
||
val = input(prompt).strip()
|
||
try:
|
||
return float(val)
|
||
except ValueError:
|
||
return default
|
||
|
||
|
||
def _ask_int(prompt, default):
|
||
val = input(prompt).strip()
|
||
return int(val) if val.isdigit() else default
|
||
|
||
|
||
def main_menu():
|
||
print("\nInitializing collocation points...")
|
||
data = engine.prepare_data()
|
||
print(f"Ready — device: {data['device']}\n")
|
||
|
||
while True:
|
||
print("\n" + "-" * 30)
|
||
print(" MAIN MENU")
|
||
print("-" * 30)
|
||
print("1. Train New Model")
|
||
print("2. Evaluate Model (L2 vs analytical)")
|
||
print("3. Visualize Temperature Field")
|
||
print("0. Exit")
|
||
print("-" * 30)
|
||
|
||
choice = input("Select an option (0-3): ").strip()
|
||
|
||
if choice == '1':
|
||
epochs = _ask_int("Epochs (default 5000): ", 5000)
|
||
engine.train_model(data, epochs=epochs)
|
||
|
||
elif choice == '2':
|
||
engine.evaluate_model(data)
|
||
|
||
elif choice == '3':
|
||
engine.generate_visualization(data)
|
||
|
||
elif choice == '0':
|
||
print("Exiting.")
|
||
sys.exit(0)
|
||
|
||
else:
|
||
print("Invalid choice.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print_header()
|
||
main_menu()
|