Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse and generate style for SVG #1717

Merged
merged 8 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions manimlib/mobject/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ def set_points_by_ends(self, start, end, buff=0, path_arc=0):
self.insert_tip_anchor()
return self

def init_colors(self):
super().init_colors()
def init_colors(self, override=True):
super().init_colors(override)
self.create_tip_with_stroke_width()

def get_arc_length(self):
Expand Down
4 changes: 2 additions & 2 deletions manimlib/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def init_uniforms(self):
"reflectiveness": self.reflectiveness,
}

def init_colors(self):
self.set_color(self.color, self.opacity)
def init_colors(self, override=True):
self.set_color(self.color, self.opacity, override)

def init_points(self):
# Typically implemented in subclass, unlpess purposefully left blank
Expand Down
211 changes: 142 additions & 69 deletions manimlib/mobject/svg/svg_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import os
import hashlib

from colour import web2hex
from xml.dom import minidom

from manimlib.constants import DEFAULT_STROKE_WIDTH
from manimlib.constants import ORIGIN, UP, DOWN, LEFT, RIGHT, IN
from manimlib.constants import BLACK
from manimlib.constants import WHITE
from manimlib.constants import DEGREES, PI

from manimlib.mobject.geometry import Circle
Expand All @@ -25,14 +24,82 @@
from manimlib.logger import log


def string_to_numbers(num_string):
num_string = num_string.replace("-", ",-")
num_string = num_string.replace("e,-", "e-")
return [
float(s)
for s in re.split("[ ,]", num_string)
if s != ""
]
DEFAULT_STYLE = {
"fill": "black",
"stroke": "none",
"fill-opacity": "1",
"stroke-opacity": "1",
"stroke-width": 0,
}


def cascade_element_style(element, inherited):
style = inherited.copy()

for attr in DEFAULT_STYLE:
if element.hasAttribute(attr):
style[attr] = element.getAttribute(attr)

if element.hasAttribute("style"):
for style_spec in element.getAttribute("style").split(";"):
style_spec = style_spec.strip()
try:
key, value = style_spec.split(":")
except ValueError as e:
if not style_spec.strip():
pass
else:
raise e
else:
style[key.strip()] = value.strip()

return style


def parse_color(color):
color = color.strip()

if color[0:3] == "rgb":
splits = color[4:-1].strip().split(",")
if splits[0].strip()[-1] == "%":
parsed_rgbs = [float(i.strip()[:-1]) / 100.0 for i in splits]
else:
parsed_rgbs = [int(i) / 255.0 for i in splits]
return rgb_to_hex(parsed_rgbs)

else:
return web2hex(color)


def fill_default_values(style, default_style):
default = DEFAULT_STYLE.copy()
default.update(default_style)
for attr in default:
if attr not in style:
style[attr] = default[attr]


def parse_style(style, default_style):
manim_style = {}
fill_default_values(style, default_style)

manim_style["fill_opacity"] = float(style["fill-opacity"])
manim_style["stroke_opacity"] = float(style["stroke-opacity"])
manim_style["stroke_width"] = float(style["stroke-width"])

if style["fill"] == "none":
manim_style["fill_opacity"] = 0
else:
manim_style["fill_color"] = parse_color(style["fill"])

if style["stroke"] == "none":
manim_style["stroke_width"] = 0
if "fill_color" in manim_style:
manim_style["stroke_color"] = manim_style["fill_color"]
else:
manim_style["stroke_color"] = parse_color(style["stroke"])

return manim_style


class SVGMobject(VMobject):
Expand All @@ -43,7 +110,6 @@ class SVGMobject(VMobject):
# Must be filled in in a subclass, or when called
"file_name": None,
"unpack_groups": True, # if False, creates a hierarchy of VGroups
# TODO, style components should be read in, not defaulted
"stroke_width": DEFAULT_STROKE_WIDTH,
"fill_opacity": 1.0,
"path_string_config": {}
Expand All @@ -66,6 +132,9 @@ def move_into_position(self):
self.set_height(self.height)
if self.width is not None:
self.set_width(self.width)

def init_colors(self, override=False):
super().init_colors(override=False)

def init_points(self):
doc = minidom.parse(self.file_path)
Expand All @@ -74,69 +143,84 @@ def init_points(self):
for child in doc.childNodes:
if not isinstance(child, minidom.Element): continue
if child.tagName != 'svg': continue
mobjects = self.get_mobjects_from(child)
mobjects = self.get_mobjects_from(child, dict())
if self.unpack_groups:
self.add(*mobjects)
else:
self.add(*mobjects[0].submobjects)
doc.unlink()

def get_mobjects_from(self, element):
def get_mobjects_from(self, element, style):
result = []
if not isinstance(element, minidom.Element):
return result
style = cascade_element_style(element, style)

if element.tagName == 'defs':
self.update_ref_to_element(element)
self.update_ref_to_element(element, style)
elif element.tagName == 'style':
pass # TODO, handle style
elif element.tagName in ['g', 'svg', 'symbol']:
result += it.chain(*(
self.get_mobjects_from(child)
self.get_mobjects_from(child, style)
for child in element.childNodes
))
elif element.tagName == 'path':
result.append(self.path_string_to_mobject(
element.getAttribute('d')
element.getAttribute('d'), style
))
elif element.tagName == 'use':
result += self.use_to_mobjects(element)
result += self.use_to_mobjects(element, style)
elif element.tagName == 'rect':
result.append(self.rect_to_mobject(element))
result.append(self.rect_to_mobject(element, style))
elif element.tagName == 'circle':
result.append(self.circle_to_mobject(element))
result.append(self.circle_to_mobject(element, style))
elif element.tagName == 'ellipse':
result.append(self.ellipse_to_mobject(element))
result.append(self.ellipse_to_mobject(element, style))
elif element.tagName in ['polygon', 'polyline']:
result.append(self.polygon_to_mobject(element))
result.append(self.polygon_to_mobject(element, style))
else:
log.warning(f"Unsupported element type: {element.tagName}")
pass # TODO
result = [m for m in result if m is not None]
result = [m.insert_n_curves(0) for m in result if m is not None]
self.handle_transforms(element, VGroup(*result))
if len(result) > 1 and not self.unpack_groups:
result = [VGroup(*result)]

return result

def g_to_mobjects(self, g_element):
mob = VGroup(*self.get_mobjects_from(g_element))
self.handle_transforms(g_element, mob)
return mob.submobjects

def path_string_to_mobject(self, path_string):

def generate_default_style(self):
style = {
"fill-opacity": self.fill_opacity,
"stroke-width": self.stroke_width,
"stroke-opacity": self.stroke_opacity,
}
if self.color:
style["fill"] = style["stroke"] = self.color
if self.fill_color:
style["fill"] = self.fill_color
if self.stroke_color:
style["stroke"] = self.stroke_color
return style

def path_string_to_mobject(self, path_string, style):
return VMobjectFromSVGPathstring(
path_string,
**self.path_string_config,
**parse_style(style, self.generate_default_style()),
)

def use_to_mobjects(self, use_element):
def use_to_mobjects(self, use_element, local_style):
# Remove initial "#" character
ref = use_element.getAttribute("xlink:href")[1:]
if ref not in self.ref_to_element:
log.warning(f"{ref} not recognized")
return VGroup()
def_element, def_style = self.ref_to_element[ref]
style = local_style.copy()
style.update(def_style)
return self.get_mobjects_from(
self.ref_to_element[ref]
def_element, style
)

def attribute_to_float(self, attr):
Expand All @@ -146,57 +230,46 @@ def attribute_to_float(self, attr):
])
return float(stripped_attr)

def polygon_to_mobject(self, polygon_element):
def polygon_to_mobject(self, polygon_element, style):
path_string = polygon_element.getAttribute("points")
for digit in string.digits:
path_string = path_string.replace(f" {digit}", f"L {digit}")
path_string = path_string.replace("L", "M", 1)
return self.path_string_to_mobject(path_string)
return self.path_string_to_mobject(path_string, style)

def circle_to_mobject(self, circle_element):
x, y, r = [
def circle_to_mobject(self, circle_element, style):
x, y, r = (
self.attribute_to_float(
circle_element.getAttribute(key)
)
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "r")
]
return Circle(radius=r).shift(x * RIGHT + y * DOWN)
)
return Circle(
radius=r,
**parse_style(style, self.generate_default_style())
).shift(x * RIGHT + y * DOWN)

def ellipse_to_mobject(self, circle_element):
x, y, rx, ry = [
def ellipse_to_mobject(self, circle_element, style):
x, y, rx, ry = (
self.attribute_to_float(
circle_element.getAttribute(key)
)
if circle_element.hasAttribute(key)
else 0.0
for key in ("cx", "cy", "rx", "ry")
]
result = Circle()
)
result = Circle(**parse_style(style, self.generate_default_style()))
result.stretch(rx, 0)
result.stretch(ry, 1)
result.shift(x * RIGHT + y * DOWN)
return result

def rect_to_mobject(self, rect_element):
fill_color = rect_element.getAttribute("fill")
stroke_color = rect_element.getAttribute("stroke")
def rect_to_mobject(self, rect_element, style):
stroke_width = rect_element.getAttribute("stroke-width")
corner_radius = rect_element.getAttribute("rx")

# input preprocessing
fill_opacity = 1
if fill_color in ["", "none", "#FFF", "#FFFFFF"] or Color(fill_color) == Color(WHITE):
fill_opacity = 0
fill_color = BLACK # shdn't be necessary but avoids error msgs
if fill_color in ["#000", "#000000"]:
fill_color = WHITE
if stroke_color in ["", "none", "#FFF", "#FFFFFF"] or Color(stroke_color) == Color(WHITE):
stroke_width = 0
stroke_color = BLACK
if stroke_color in ["#000", "#000000"]:
stroke_color = WHITE
if stroke_width in ["", "none", "0"]:
stroke_width = 0

Expand All @@ -205,6 +278,9 @@ def rect_to_mobject(self, rect_element):

corner_radius = float(corner_radius)

parsed_style = parse_style(style, self.generate_default_style())
parsed_style["stroke_width"] = stroke_width

if corner_radius == 0:
mob = Rectangle(
width=self.attribute_to_float(
Expand All @@ -213,10 +289,7 @@ def rect_to_mobject(self, rect_element):
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=fill_opacity
**parsed_style,
)
else:
mob = RoundedRectangle(
Expand All @@ -226,11 +299,8 @@ def rect_to_mobject(self, rect_element):
height=self.attribute_to_float(
rect_element.getAttribute("height")
),
stroke_width=stroke_width,
stroke_color=stroke_color,
fill_color=fill_color,
fill_opacity=fill_opacity,
corner_radius=corner_radius
corner_radius=corner_radius,
**parsed_style
)

mob.shift(mob.get_center() - mob.get_corner(UP + LEFT))
Expand Down Expand Up @@ -345,8 +415,8 @@ def get_all_childNodes_have_id(self, element):
all_childNodes_have_id.append(self.get_all_childNodes_have_id(e))
return self.flatten([e for e in all_childNodes_have_id if e])

def update_ref_to_element(self, defs):
new_refs = dict([(e.getAttribute('id'), e) for e in self.get_all_childNodes_have_id(defs)])
def update_ref_to_element(self, defs, style):
new_refs = dict([(e.getAttribute('id'), (e, style)) for e in self.get_all_childNodes_have_id(defs)])
self.ref_to_element.update(new_refs)


Expand Down Expand Up @@ -404,14 +474,15 @@ def handle_commands(self):
upper_command = command.upper()
if upper_command == "Z":
func() # `close_path` takes no arguments
relative_point = self.get_last_point()
continue

number_types = np.array(list(number_types_str))
n_numbers = len(number_types_str)
number_list = _PathStringParser(coord_string, number_types_str).args
number_groups = np.array(number_list).reshape((-1, n_numbers))

for numbers in number_groups:
for ind, numbers in enumerate(number_groups):
if command.islower():
# Treat it as a relative command
numbers[number_types == "x"] += relative_point[0]
Expand All @@ -427,10 +498,12 @@ def handle_commands(self):
args = list(np.hstack((
numbers.reshape((-1, 2)), np.zeros((n_numbers // 2, 1))
)))
if upper_command == "M" and ind != 0:
# M x1 y1 x2 y2 is equal to M x1 y1 L x2 y2
func, _ = self.command_to_function("L")
func(*args)
relative_point = self.get_last_point()


def add_elliptical_arc_to(self, rx, ry, x_axis_rotation, large_arc_flag, sweep_flag, point):
def close_to_zero(a, threshold=1e-5):
return abs(a) < threshold
Expand Down
Loading