Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing time signature encoding/decoding #76

Merged
merged 5 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,17 @@ def _create_midi_events(self, midi: MidiFile) -> List[Event]:
"""
events = []

# First adds time signature tokens if specified
if self.config.use_time_signatures:
for time_signature_change in midi.time_signature_changes:
events.append(
Event(
type="TimeSig",
value=f"{time_signature_change.numerator}/{time_signature_change.denominator}",
time=time_signature_change.time,
)
)

# Adds tempo events if specified
if self.config.use_tempos:
for tempo_change in midi.tempo_changes:
Expand All @@ -861,17 +872,6 @@ def _create_midi_events(self, midi: MidiFile) -> List[Event]:
)
)

# Add time signature tokens if specified
if self.config.use_time_signatures:
for time_signature_change in midi.time_signature_changes:
events.append(
Event(
type="TimeSig",
value=f"{time_signature_change.numerator}/{time_signature_change.denominator}",
time=time_signature_change.time,
)
)

return events

def _add_time_events(self, events: List[Event]) -> List[Event]:
Expand Down
135 changes: 105 additions & 30 deletions miditok/tokenizations/cp_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CPWord(MIDITokenizer):
* (+ Optional) Chord: chords occurring with position tokens
* (+ Optional) Rest: rest acting as a TimeShift token
* (+ Optional) Tempo: occurring with position tokens
* (+ Optional) TimeSig: occurring with bar tokens

The output hidden states of the model will then be fed to several output layers
(one per token type). This means that the training requires to add multiple losses.
Expand Down Expand Up @@ -68,10 +69,16 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
# Add time events
all_events = []
current_bar = -1
bar_at_last_ts_change = 0
previous_tick = -1
previous_note_end = 0
tick_at_last_ts_change = tick_at_current_bar = 0
current_time_sig = TIME_SIGNATURE
current_tempo = TEMPO
if self.config.log_tempos:
# pick the closest to the default value
current_tempo = float(self.tempos[(np.abs(self.tempos - TEMPO)).argmin()])
else:
current_tempo = TEMPO
current_program = None
ticks_per_bar = self._compute_ticks_per_bar(
TimeSignature(*current_time_sig, 0), time_division
Expand All @@ -85,16 +92,31 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
TimeSignature(*current_time_sig, event.time), time_division
)
break
elif event.type in ["Pitch", "Velocity", "Duration", "PitchBend", "Pedal"]:
elif event.type in [
"Pitch",
"Velocity",
"Duration",
"PitchBend",
"Pedal",
]:
break
# Then look for a Tempo token, if any is given at tick 0, to update current_tempo
if self.config.use_tempos:
for event in events:
if event.type == "Tempo":
current_tempo = event.value
break
elif event.type in [
"Pitch",
"Velocity",
"Duration",
"PitchBend",
"Pedal",
]:
break
# Add the time events
for e, event in enumerate(events):
if event.type == "TimeSig":
current_time_sig = list(map(int, event.value.split("/")))
ticks_per_bar = self._compute_ticks_per_bar(
TimeSignature(*current_time_sig, event.time), time_division
)
elif event.type == "Tempo":
if event.type == "Tempo":
current_tempo = event.value
elif event.type == "Program":
current_program = event.value
Expand All @@ -108,6 +130,7 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
rest_values = self._ticks_to_duration_tokens(
event.time - previous_tick, rest=True
)
# Add Rest events and increment previous_tick
for dur_value, dur_ticks in zip(*rest_values):
all_events.append(
self.__create_cp_token(
Expand All @@ -117,27 +140,60 @@ def _add_time_events(self, events: List[Event]) -> List[List[Event]]:
)
)
previous_tick += dur_ticks
current_bar = previous_tick // ticks_per_bar
# We update current_bar and tick_at_current_bar here without creating Bar tokens
real_current_bar = (
bar_at_last_ts_change
+ (previous_tick - tick_at_last_ts_change) // ticks_per_bar
)
if real_current_bar > current_bar:
tick_at_current_bar += (
real_current_bar - current_bar
) * ticks_per_bar
current_bar = real_current_bar

# Bar
nb_new_bars = event.time // ticks_per_bar - current_bar
for i in range(nb_new_bars):
if self.config.use_time_signatures:
time_sig_arg = f"{current_time_sig[0]}/{current_time_sig[1]}"
else:
time_sig_arg = None
all_events.append(
self.__create_cp_token(
(current_bar + i + 1) * ticks_per_bar,
bar=True,
desc="Bar",
time_signature=time_sig_arg,
nb_new_bars = (
bar_at_last_ts_change
+ (event.time - tick_at_last_ts_change) // ticks_per_bar
- current_bar
)
if nb_new_bars >= 1:
for i in range(nb_new_bars):
# Update time signature time variables before adding the last bar
if self.config.use_time_signatures:
if event.type == "TimeSig" and i + 1 == nb_new_bars:
current_time_sig = list(
map(int, event.value.split("/"))
)
bar_at_last_ts_change += (
event.time - tick_at_last_ts_change
) // ticks_per_bar
tick_at_last_ts_change = event.time
ticks_per_bar = self._compute_ticks_per_bar(
TimeSignature(*current_time_sig, event.time),
time_division,
)
time_sig_arg = (
f"{current_time_sig[0]}/{current_time_sig[1]}"
)
else:
time_sig_arg = None
all_events.append(
self.__create_cp_token(
(current_bar + i + 1) * ticks_per_bar,
bar=True,
desc="Bar",
time_signature=time_sig_arg,
)
)
current_bar += nb_new_bars
tick_at_current_bar = (
tick_at_last_ts_change
+ (current_bar - bar_at_last_ts_change) * ticks_per_bar
)
current_bar += nb_new_bars

# Position
pos_index = int((event.time % ticks_per_bar) / ticks_per_sample)
pos_index = int((event.time - tick_at_current_bar) / ticks_per_sample)
all_events.append(
self.__create_cp_token(
event.time,
Expand Down Expand Up @@ -295,15 +351,17 @@ def tokens_to_midi(
time_signature_changes[0], time_division
) # init

current_tick = 0
current_tick = tick_at_last_ts_change = tick_at_current_bar = 0
current_bar = -1
bar_at_last_ts_change = 0
current_program = 0
previous_note_end = 0
for si, seq in enumerate(tokens):
# Set track / sequence program if needed
if not self.one_token_stream:
current_tick = 0
current_tick = tick_at_last_ts_change = tick_at_current_bar = 0
current_bar = -1
bar_at_last_ts_change = 0
ticks_per_bar = self._compute_ticks_per_bar(
time_signature_changes[0], time_division
)
Expand All @@ -316,7 +374,10 @@ def tokens_to_midi(
token_family = compound_token[0].split("_")[1]
if token_family == "Note":
pad_range_idx = 6 if self.config.use_programs else 5
if any(tok.split("_")[1] == "None" for tok in compound_token[2:pad_range_idx]):
if any(
tok.split("_")[1] == "None"
for tok in compound_token[2:pad_range_idx]
):
continue
pitch = int(compound_token[2].split("_")[1])
vel = int(compound_token[3].split("_")[1])
Expand All @@ -342,7 +403,9 @@ def tokens_to_midi(
bar_pos = compound_token[1].split("_")[0]
if bar_pos == "Bar":
current_bar += 1
current_tick = current_bar * ticks_per_bar
if current_bar > 0:
current_tick = tick_at_current_bar + ticks_per_bar
tick_at_current_bar = current_tick
# Add new TS only if different from the last one
if self.config.use_time_signatures and si == 0:
num, den = self._parse_token_time_signature(
Expand All @@ -352,11 +415,15 @@ def tokens_to_midi(
)
if (
num != time_signature_changes[-1].numerator
and den != time_signature_changes[-1].denominator
or den != time_signature_changes[-1].denominator
):
time_sig = TimeSignature(num, den, current_tick)
if si == 0:
time_signature_changes.append(time_sig)
tick_at_last_ts_change = (
tick_at_current_bar # == current_tick
)
bar_at_last_ts_change = current_bar
ticks_per_bar = self._compute_ticks_per_bar(
time_sig, time_division
)
Expand All @@ -365,7 +432,7 @@ def tokens_to_midi(
# in case this Position token comes before any Bar token
current_bar = 0
current_tick = (
current_bar * ticks_per_bar
tick_at_current_bar
+ int(compound_token[1].split("_")[1]) * ticks_per_sample
)
# Add new tempo change only if different from the last one
Expand All @@ -392,7 +459,15 @@ def tokens_to_midi(
compound_token[self.vocab_types_idx["Rest"]].split("_")[1],
time_division,
)
current_bar = current_tick // ticks_per_bar
real_current_bar = (
bar_at_last_ts_change
+ (current_tick - tick_at_last_ts_change) // ticks_per_bar
)
if real_current_bar > current_bar:
tick_at_current_bar += (
real_current_bar - current_bar
) * ticks_per_bar
current_bar = real_current_bar

if len(tempo_changes) > 1:
del tempo_changes[0] # delete mocked tempo change
Expand Down
9 changes: 4 additions & 5 deletions miditok/tokenizations/midi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ def check_inst(prog: int):
elif si == 0 and tok_type == "TimeSig":
num, den = self._parse_token_time_signature(tok_val)
current_time_signature = time_signature_changes[-1]
if (
si == 0
and num != current_time_signature.numerator
and den != current_time_signature.denominator
if si == 0 and (
num != current_time_signature.numerator
or den != current_time_signature.denominator
):
time_signature_changes.append(
TimeSignature(num, den, current_tick)
Expand Down Expand Up @@ -399,7 +398,7 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
if self.config.use_rests:
dic["TimeSig"].append("Rest") # only for first token
if self.config.use_tempos:
dic["Tempo"].append("TimeSig")
dic["TimeSig"].append("Tempo")

if self.config.use_sustain_pedals:
dic["TimeShift"].append("Pedal")
Expand Down
Loading