Integrating physical units into high-performance AI-driven scientific computing
개요
특성
- 물리 단위를 코드에 네이티브로 통합하여, 연산 시점에 차원 검증을 수행하기 위해 기저 단위(dims) 직접 정의하고, 연산 시점의 차원 검증 (
__add__/__sub__에서 단위 맞춤) → 단위 오류를 컴파일 전(실행 직후)에 잡아내며, 조합 연산 (__mul__, __truediv__, __pow__) → 물리식 그대로 구현하여, 복잡한 물리 모델에도 적용 가능
from collections import Counter
# ——— 1. 단위(Unit) 클래스 정의 ———
class Unit:
def __init__(self, dims=None):
# dims: {'L':1, 'T':-1} 같은 형태로 길이(L), 시간(T), 질량(M) 등 기저 단위 지수 표현
self.dims = Counter(dims or {})
def __mul__(self, other):
return Unit(self.dims + other.dims)
def __truediv__(self, other):
return Unit(self.dims - other.dims)
def __pow__(self, power):
return Unit({k: v * power for k, v in self.dims.items()})
def __eq__(self, other):
# 0이 아닌 지수만 비교
return {k: v for k, v in self.dims.items() if v != 0} == \
{k: v for k, v in other.dims.items() if v != 0}
def __repr__(self):
if not self.dims:
return "1"
parts = []
for base, exp in sorted(self.dims.items()):
if exp:
parts.append(f"{base}" + (f"^{exp}" if exp != 1 else ""))
return "·".join(parts)
# ——— 2. 물리량(Quantity) 클래스 정의 ———
class Quantity:
def __init__(self, value, unit: Unit):
self.value = value
self.unit = unit
def __add__(self, other):
if self.unit != other.unit:
raise ValueError(f"단위 불일치: {self.unit} vs {other.unit}")
return Quantity(self.value + other.value, self.unit)
def __sub__(self, other):
if self.unit != other.unit:
raise ValueError(f"단위 불일치: {self.unit} vs {other.unit}")
return Quantity(self.value - other.value, self.unit)
def __mul__(self, other):
return Quantity(self.value * other.value, self.unit * other.unit)
def __truediv__(self, other):
return Quantity(self.value / other.value, self.unit / other.unit)
def __pow__(self, power):
return Quantity(self.value ** power, self.unit ** power)
def __repr__(self):
return f"{self.value} [{self.unit}]"
# ——— 3. 기저 단위 정의 ———
L = Unit({'L': 1}) # 길이 (Length)
T = Unit({'T': 1}) # 시간 (Time)
M = Unit({'M': 1}) # 질량 (Mass)
# ——— 4. 예제: 속도와 운동 에너지 계산 ———
if __name__ == "__main__":
# 100 미터, 9.58 초(우사인 볼트 기록), 80 kg
distance = Quantity(100, L)
time = Quantity(9.58, T)
mass = Quantity(80, M)
# 속도 = 거리 / 시간
velocity = distance / time
print("속도:", velocity)
# -> 약 10.438 [L·T^-1] (m/s)
# 운동 에너지 E = 0.5 * m * v^2
energy = Quantity(0.5, Unit()) * mass * (velocity ** 2)
print("운동 에너지:", energy)
# -> 단위: [M·L^2·T^-2] 즉 줄(J)
# 단위 불일치 예시: 거리 + 시간 시도하면 오류 발생
try:
bad = distance + time
except ValueError as e:
print("오류:", e)
# 출력 결과 예시:
# 속도: 10.438413361169102 [L·T^-1]
# 운동 에너지: 4361.325157412191 [M·L^2·T^-2]
# 오류: 단위 불일치: L vs T