import json
import numpy as np
import sapien
from sapien import internal_renderer as R
from .plugin import Plugin
[docs]
class Path:
def __init__(self, entity):
self.entity = entity
self.poses = []
[docs]
class PathWindow(Plugin):
def __init__(self):
self.reset()
[docs]
def reset(self):
self._paths = []
self._current_path_index = -1
self.ui_window = None
self.file_chooser_load = None
self.file_chooser_save = None
self._selected_point_index = -1
self._curve = None
self._curve_time = 0
self.curve_segments = 128
@property
def transform_window(self):
from .transform_window import TransformWindow
return next(p for p in self.viewer.plugins if isinstance(p, TransformWindow))
@property
def scene(self):
return self.viewer.scene
@property
def paths(self):
return [p.entity.name for p in self._paths]
@paths.setter
def paths(self, _):
pass
[docs]
def add_path(self, _):
entity = self.viewer.selected_entity
if entity is None:
return
self._paths.append(Path(entity))
self.current_path_index = len(self._paths) - 1
self.add_point(None)
[docs]
def remove_path(self, _):
if self._current_path_index < 0 or self._current_path_index >= len(self._paths):
return
del self._paths[self._current_path_index]
# force refresh
if self._current_path_index >= len(self._paths):
self.current_path_index = len(self._paths) - 1
else:
self.current_path_index = self._current_path_index
@property
def current_path_index(self):
return self._current_path_index
@current_path_index.setter
def current_path_index(self, p):
self._current_path_index = p
self.show_curve()
@property
def current_path(self) -> Path:
if self._current_path_index < 0 or self._current_path_index >= len(self._paths):
return None
return self._paths[self.current_path_index]
[docs]
def add_point(self, _):
path = self.current_path
if path is None:
print("no path selected")
return
if self.viewer.selected_entity != self.current_path.entity:
print("selected entity is not path entity")
return
tw = self.transform_window
if not tw.enabled:
tw.enabled = True
if path.poses:
if np.linalg.norm(path.poses[-1].p - tw._gizmo_pose.p) < 1e-4:
print("path points too close!!!")
return
path.poses.append(tw._gizmo_pose)
self.show_curve()
self.selected_point_index = len(path.poses) - 1
[docs]
def insert_point(self, _):
path = self.current_path
if path is None:
print("no path selected")
return
if self.viewer.selected_entity != self.current_path.entity:
print("selected entity is not path entity")
return
tw = self.transform_window
if not tw.enabled:
tw.enabled = True
if path.poses:
if (
np.linalg.norm(
path.poses[self.selected_point_index].p - tw._gizmo_pose.p
)
< 1e-4
):
print("path points too close!!!")
return
path.poses.insert(self.selected_point_index + 1, tw._gizmo_pose)
self.show_curve()
self.selected_point_index = self.selected_point_index + 1
[docs]
def set_point(self, _):
path = self.current_path
if path is None:
print("no path selected")
return
if self.viewer.selected_entity != self.current_path.entity:
print("selected entity is not path entity")
return
tw = self.transform_window
if not tw.enabled:
tw.enabled = True
path.poses[self.selected_point_index] = tw._gizmo_pose
self.show_curve()
[docs]
def del_point(self, _):
path = self.current_path
if path is None:
print("no path selected")
return
if self.viewer.selected_entity != self.current_path.entity:
print("selected entity is not path entity")
return
del path.poses[self.selected_point_index]
self.selected_point_index = 0
self.show_curve()
@property
def selected_point_index(self):
return self._selected_point_index
@selected_point_index.setter
def selected_point_index(self, index):
self._selected_point_index = index
self.transform_window.gizmo_matrix = self.current_path.poses[
index
].to_transformation_matrix()
@property
def points(self):
if not self.current_path:
return []
return [f"Point {i}" for i in range(len(self.current_path.poses))]
[docs]
def get_curve(self, knots=128):
from scipy.interpolate import splev, splprep
from scipy.spatial.transform import Rotation, Slerp
if self.current_path is None:
return None, None, None
points = np.array([pose.p for pose in self.current_path.poses])
scipy_quats = np.array(
[pose.q[[1, 2, 3, 0]] for pose in self.current_path.poses]
)
if len(points) <= 1:
return None, None, None
tck, u = splprep(
[points[:, 0], points[:, 1], points[:, 2]], s=0, k=min(3, len(points) - 1)
)
ts = np.linspace(0, 1, knots)
points = np.stack(splev(ts, tck), axis=-1)
rots = Slerp(u, Rotation.from_quat(scipy_quats))(ts)
quats = rots.as_quat()[:, [3, 0, 1, 2]]
return ts, points, quats
[docs]
def eval_curve(self, ts):
from scipy.interpolate import splev, splprep
from scipy.spatial.transform import Rotation, Slerp
if self.current_path is None:
return None, None
points = np.array([pose.p for pose in self.current_path.poses])
scipy_quats = np.array(
[pose.q[[1, 2, 3, 0]] for pose in self.current_path.poses]
)
if len(points) <= 1:
return None, None
tck, u = splprep(
[points[:, 0], points[:, 1], points[:, 2]], s=0, k=min(3, len(points) - 1)
)
points = np.stack(splev(ts, tck), axis=-1)
rots = Slerp(u, Rotation.from_quat(scipy_quats))(ts)
quats = rots.as_quat()[:, [3, 0, 1, 2]]
return points, quats
[docs]
def show_curve(self, _=None):
if not self.viewer.render_scene:
self._curve = None
return
self.hide_curve(_)
ts, points, _ = self.get_curve(self.curve_segments)
if points is None:
return
segs = []
for p0, p1 in zip(points[:-1], points[1:]):
segs += list(p0)
segs += list(p1)
colors = [0, 0, 1, 1] * int(len(segs) / 3)
self._curve = self.viewer.render_scene.add_line_set(
self.viewer.renderer_context.create_line_set(segs, colors)
)
[docs]
def hide_curve(self, _):
if not self.viewer.render_scene:
return
if self._curve is None:
return
self.viewer.render_scene.remove_node(self._curve)
self._curve = None
@property
def curve_time(self):
return self._curve_time
@curve_time.setter
def curve_time(self, t):
self._curve_time = t
ps, qs = self.eval_curve([t])
if ps is None:
return
self.transform_window.gizmo_matrix = sapien.Pose(
ps[0], qs[0]
).to_transformation_matrix()
[docs]
def save(self, _=None):
self.file_chooser_save.open()
[docs]
def save_confirm(self, _, name, path):
data = {
"version": 0,
"name": self.current_path.entity.name,
"trajectory": [
[float(x) for x in p.p] + [float(x) for x in p.q]
for p in self.current_path.poses
],
}
assert name.endswith(".json")
with open(name, "w") as f:
json.dump(data, f)
[docs]
def load(self, _=None):
self.file_chooser_load.open()
[docs]
def load_confirm(self, _, name, path):
assert name.endswith(".json")
with open(name, "r") as f:
data = json.load(f)
assert data["version"] == 0
if data["name"] != self.current_path.entity.name:
print("name does not match!")
self.current_path.poses = [
sapien.Pose(p[:3], p[3:]) for p in data["trajectory"]
]
self.show_curve()
[docs]
def build(self):
if self.scene is None:
self.ui_window = None
return
if self.ui_window is None:
self.file_chooser_save = R.UIFileChooser().Label("Save")
self.file_chooser_load = R.UIFileChooser().Label("Load")
self.ui_window = (
R.UIWindow()
.Label("Path")
.append(
R.UISameLine().append(
R.UIOptions()
.Id("paths")
.Style("select")
.BindItems(self, "paths")
.BindIndex(self, "current_path_index"),
R.UIButton().Label("+").Width(40).Callback(self.add_path),
R.UIButton().Label("-").Width(40).Callback(self.remove_path),
),
R.UIOptions()
.Style("radio")
.BindItems(self, "points")
.BindIndex(self, "selected_point_index"),
R.UIConditional()
.Bind(lambda: self.current_path is not None)
.append(
R.UISameLine().append(
R.UIButton().Label("Add Point").Callback(self.add_point),
R.UIButton()
.Label("Insert Point")
.Callback(self.insert_point),
R.UIButton().Label("Set Point").Callback(self.set_point),
R.UIButton().Label("Del Point").Callback(self.del_point),
),
R.UISliderFloat()
.Label("t")
.Min(0)
.Max(1)
.Bind(self, "curve_time"),
R.UIInputInt().Label("Segments").Bind(self, "curve_segments"),
R.UISameLine().append(
R.UIButton().Label("Save").Callback(self.save),
self.file_chooser_save.Filter(".json").Callback(
self.save_confirm
),
R.UIButton().Label("Load").Callback(self.load),
self.file_chooser_load.Filter(".json").Callback(
self.load_confirm
),
),
),
)
)
[docs]
def get_ui_windows(self):
self.build()
if self.ui_window:
return [self.ui_window]
return []
[docs]
def close(self):
self.reset()