Add branching. why not

This commit is contained in:
Anthony Samms
2025-10-11 18:49:55 -04:00
parent 78b1b31e0c
commit 34dd2adca7
5 changed files with 432 additions and 166 deletions

View File

@@ -33,10 +33,21 @@ class Note:
bpm: float = field(init=False)
gogo_time: bool = field(init=False)
moji: int = field(init=False)
is_branch_start: bool = field(init=False)
branch_params: str = field(init=False)
def __lt__(self, other):
return self.hit_ms < other.hit_ms
def __le__(self, other):
return self.hit_ms <= other.hit_ms
def __gt__(self, other):
return self.hit_ms > other.hit_ms
def __ge__(self, other):
return self.hit_ms >= other.hit_ms
def __eq__(self, other):
return self.hit_ms == other.hit_ms
@@ -112,12 +123,32 @@ class Balloon(Note):
hash_string = str(field_values)
return hash_string.encode('utf-8')
@dataclass
class NoteList:
play_notes: list[Note | Drumroll | Balloon] = field(default_factory=lambda: [])
draw_notes: list[Note | Drumroll | Balloon] = field(default_factory=lambda: [])
bars: list[Note] = field(default_factory=lambda: [])
def __add__(self, other: 'NoteList') -> 'NoteList':
return NoteList(
play_notes=self.play_notes + other.play_notes,
draw_notes=self.draw_notes + other.draw_notes,
bars=self.bars + other.bars
)
def __iadd__(self, other: 'NoteList') -> 'NoteList':
self.play_notes += other.play_notes
self.draw_notes += other.draw_notes
self.bars += other.bars
return self
@dataclass
class CourseData:
level: int = 0
balloon: list[int] = field(default_factory=lambda: [])
scoreinit: list[int] = field(default_factory=lambda: [])
scorediff: int = 0
is_branching: bool = False
@dataclass
class TJAMetadata:
@@ -141,16 +172,16 @@ class TJAEXData:
new: bool = False
def calculate_base_score(play_notes: deque[Note | Drumroll | Balloon]) -> int:
def calculate_base_score(notes: NoteList) -> int:
total_notes = 0
balloon_count = 0
drumroll_msec = 0
for i in range(len(play_notes)):
note = play_notes[i]
if i < len(play_notes)-1:
next_note = play_notes[i+1]
for i in range(len(notes.play_notes)):
note = notes.play_notes[i]
if i < len(notes.play_notes)-1:
next_note = notes.play_notes[i+1]
else:
next_note = play_notes[len(play_notes)-1]
next_note = notes.play_notes[len(notes.play_notes)-1]
if isinstance(note, Drumroll):
drumroll_msec += (next_note.hit_ms - note.hit_ms)
elif isinstance(note, Balloon):
@@ -198,7 +229,9 @@ class TJAParser:
current_diff = None # Track which difficulty we're currently processing
for item in self.data:
if item.startswith("#") or item[0].isdigit():
if item.startswith('#BRANCH') and current_diff is not None:
self.metadata.course_data[current_diff].is_branching = True
elif item.startswith("#") or item[0].isdigit():
continue
elif item.startswith('SUBTITLE'):
region_code = 'en'
@@ -407,9 +440,10 @@ class TJAParser:
play_note_list[-3].moji = se_notes[1][2]
def notes_to_position(self, diff: int):
play_note_list: list[Note | Drumroll | Balloon] = []
draw_note_list: list[Note | Drumroll | Balloon] = []
bar_list: list[Note] = []
master_notes = NoteList()
branch_m: list[NoteList] = []
branch_e: list[NoteList] = []
branch_n: list[NoteList] = []
notes = self.data_to_notes(diff)
balloon = self.metadata.course_data[diff].balloon.copy()
count = 0
@@ -420,19 +454,127 @@ class TJAParser:
y_scroll_modifier = 0
barline_display = True
gogo_time = False
skip_branch = False
curr_note_list = master_notes.play_notes
curr_draw_list = master_notes.draw_notes
curr_bar_list = master_notes.bars
start_branch_ms = 0
start_branch_bpm = bpm
start_branch_time_sig = time_signature
start_branch_x_scroll = x_scroll_modifier
start_branch_y_scroll = y_scroll_modifier
start_branch_barline = barline_display
start_branch_gogo = gogo_time
branch_balloon_count = 0
is_branching = False
for bar in notes:
#Length of the bar is determined by number of notes excluding commands
bar_length = sum(len(part) for part in bar if '#' not in part)
barline_added = False
for part in bar:
if part.startswith('#BRANCHSTART'):
skip_branch = True
start_branch_ms = self.current_ms
start_branch_bpm = bpm
start_branch_time_sig = time_signature
start_branch_x_scroll = x_scroll_modifier
start_branch_y_scroll = y_scroll_modifier
start_branch_barline = barline_display
start_branch_gogo = gogo_time
branch_balloon_count = count
branch_params = part[13:]
if branch_params[0] == 'r':
# Helper function to find and set drumroll branch params
def set_drumroll_branch_params(note_list, bar_list):
for i in range(len(note_list)-1, -1, -1):
if 5 <= note_list[i].type <= 7 or note_list[i].type == 9:
drumroll_ms = note_list[i].hit_ms
for bar_idx in range(len(bar_list)-1, -1, -1):
if bar_list[bar_idx].hit_ms <= drumroll_ms:
bar_list[bar_idx].branch_params = branch_params
return True
break
return False
# Always try to set in master notes
set_drumroll_branch_params(master_notes.play_notes, master_notes.bars)
# If we have existing branches, also apply to them
if branch_m and len(branch_m) > 0:
set_drumroll_branch_params(branch_m[-1].play_notes, branch_m[-1].bars)
if branch_e and len(branch_e) > 0:
set_drumroll_branch_params(branch_e[-1].play_notes, branch_e[-1].bars)
if branch_n and len(branch_n) > 0:
set_drumroll_branch_params(branch_n[-1].play_notes, branch_n[-1].bars)
else:
if len(curr_bar_list) > 1:
curr_bar_list[-2].branch_params = branch_params
elif len(curr_bar_list) > 0:
curr_bar_list[-1].branch_params = branch_params
if branch_m and len(branch_m[-1].bars) > 1:
branch_m[-1].bars[-2].branch_params = branch_params
elif branch_m and len(branch_m[-1].bars) > 0:
branch_m[-1].bars[-1].branch_params = branch_params
if branch_e and len(branch_e[-1].bars) > 1:
branch_e[-1].bars[-2].branch_params = branch_params
elif branch_e and len(branch_e[-1].bars) > 0:
branch_e[-1].bars[-1].branch_params = branch_params
if branch_n and len(branch_n[-1].bars) > 1:
branch_n[-1].bars[-2].branch_params = branch_params
elif branch_n and len(branch_n[-1].bars) > 0:
branch_n[-1].bars[-1].branch_params = branch_params
if branch_m and len(branch_m[-1].bars) > 0:
branch_m[-1].bars[-1].branch_params = branch_params
continue
elif part.startswith('#BRANCHEND'):
curr_note_list = master_notes.play_notes
curr_draw_list = master_notes.draw_notes
curr_bar_list = master_notes.bars
continue
if part == '#M':
skip_branch = False
branch_m.append(NoteList())
curr_note_list = branch_m[-1].play_notes
curr_draw_list = branch_m[-1].draw_notes
curr_bar_list = branch_m[-1].bars
self.current_ms = start_branch_ms
bpm = start_branch_bpm
time_signature = start_branch_time_sig
x_scroll_modifier = start_branch_x_scroll
y_scroll_modifier = start_branch_y_scroll
barline_display = start_branch_barline
gogo_time = start_branch_gogo
count = branch_balloon_count
is_branching = True
continue
if skip_branch:
elif part == '#E':
branch_e.append(NoteList())
curr_note_list = branch_e[-1].play_notes
curr_draw_list = branch_e[-1].draw_notes
curr_bar_list = branch_e[-1].bars
self.current_ms = start_branch_ms
bpm = start_branch_bpm
time_signature = start_branch_time_sig
x_scroll_modifier = start_branch_x_scroll
y_scroll_modifier = start_branch_y_scroll
barline_display = start_branch_barline
gogo_time = start_branch_gogo
count = branch_balloon_count
is_branching = True
continue
elif part == '#N':
branch_n.append(NoteList())
curr_note_list = branch_n[-1].play_notes
curr_draw_list = branch_n[-1].draw_notes
curr_bar_list = branch_n[-1].bars
self.current_ms = start_branch_ms
bpm = start_branch_bpm
time_signature = start_branch_time_sig
x_scroll_modifier = start_branch_x_scroll
y_scroll_modifier = start_branch_y_scroll
barline_display = start_branch_barline
gogo_time = start_branch_gogo
count = branch_balloon_count
is_branching = True
continue
if '#LYRIC' in part:
continue
@@ -445,74 +587,15 @@ class TJAParser:
time_signature = float(part[9:divisor]) / float(part[divisor+1:])
continue
elif '#SCROLL' in part:
# Extract the value after '#SCROLL '
scroll_value = part[7:].strip() # Remove '#SCROLL' and whitespace
# Initialize default values
x_scroll_modifier = 0
y_scroll_modifier = 0
# Handle empty value
if not scroll_value:
continue
# Check if it's a complex number (contains 'i')
scroll_value = part[7:]
if 'i' in scroll_value:
# Handle different imaginary number formats
if scroll_value == 'i':
x_scroll_modifier = 0
y_scroll_modifier = 1
elif scroll_value == '-i':
x_scroll_modifier = 0
y_scroll_modifier = -1
elif scroll_value.endswith('i') or scroll_value.endswith('.i'):
# Remove the 'i' or '.i' suffix
if scroll_value.endswith('.i'):
complex_part = scroll_value[:-2]
else:
complex_part = scroll_value[:-1]
# Look for + or - that separates real and imaginary parts
# Find the rightmost + or - (excluding position 0 for negative numbers)
plus_pos = complex_part.rfind('+')
minus_pos = complex_part.rfind('-')
separator_pos = -1
if plus_pos > 0: # Ignore + at position 0
separator_pos = plus_pos
if minus_pos > 0 and minus_pos > separator_pos: # Ignore - at position 0
separator_pos = minus_pos
if separator_pos > 0:
# Complex number like '1+i', '3+4i', '2-5i', '-1+2i', etc.
real_part = complex_part[:separator_pos]
imag_part = complex_part[separator_pos:]
x_scroll_modifier = float(real_part) if real_part else 0
# Handle imaginary part
if imag_part == '+' or imag_part == '':
y_scroll_modifier = 1
elif imag_part == '-':
y_scroll_modifier = -1
else:
y_scroll_modifier = float(imag_part)
else:
# Pure imaginary like '5i', '-3i', '2.5i'
if complex_part == '' or complex_part == '+':
y_scroll_modifier = 1
elif complex_part == '-':
y_scroll_modifier = -1
else:
y_scroll_modifier = float(complex_part)
x_scroll_modifier = 0
else:
# 'i' is somewhere in the middle - invalid format
continue
normalized = scroll_value.replace('.i', 'j').replace('i', 'j')
c = complex(normalized)
x_scroll_modifier = c.real
y_scroll_modifier = c.imag
else:
# Pure real number
x_scroll_modifier = float(scroll_value)
y_scroll_modifier = 0
y_scroll_modifier = 0.0
continue
elif '#BPMCHANGE' in part:
bpm = float(part[11:])
@@ -555,7 +638,11 @@ class TJAParser:
if barline_added:
bar_line.display = False
bisect.insort(bar_list, bar_line, key=lambda x: x.load_ms)
if is_branching:
bar_line.is_branch_start = True
is_branching = False
bisect.insort(curr_bar_list, bar_line, key=lambda x: x.load_ms)
barline_added = True
#Empty bar is still a bar, otherwise start increment
@@ -571,7 +658,7 @@ class TJAParser:
if item == '0' or (not item.isdigit()):
self.current_ms += increment
continue
if item == '9' and play_note_list and play_note_list[-1].type == 9:
if item == '9' and curr_note_list and curr_note_list[-1].type == 9:
self.current_ms += increment
continue
note = Note()
@@ -600,33 +687,29 @@ class TJAParser:
note = Balloon(note)
note.count = 1 if not balloon else balloon.pop(0)
elif item == '8':
new_pixels_per_ms = play_note_list[-1].pixels_per_frame_x / (1000 / 60)
new_pixels_per_ms = curr_note_list[-1].pixels_per_frame_x / (1000 / 60)
if new_pixels_per_ms == 0:
note.load_ms = note.hit_ms
else:
note.load_ms = note.hit_ms - (self.distance / new_pixels_per_ms)
note.pixels_per_frame_x = play_note_list[-1].pixels_per_frame_x
note.pixels_per_frame_x = curr_note_list[-1].pixels_per_frame_x
self.current_ms += increment
play_note_list.append(note)
bisect.insort(draw_note_list, note, key=lambda x: x.load_ms)
self.get_moji(play_note_list, ms_per_measure)
curr_note_list.append(note)
bisect.insort(curr_draw_list, note, key=lambda x: x.load_ms)
self.get_moji(curr_note_list, ms_per_measure)
index += 1
if len(play_note_list) > 3:
if isinstance(play_note_list[-2], Drumroll) and play_note_list[-1].type != 8:
print(self.file_path, diff)
print(bar)
continue
raise Exception(f"{play_note_list[-2]}")
if hasattr(curr_bar_list[-1], 'branch_params'):
print(curr_note_list[-1])
# https://stackoverflow.com/questions/72899/how-to-sort-a-list-of-dictionaries-by-a-value-of-the-dictionary-in-python
# Sorting by load_ms is necessary for drawing, as some notes appear on the
# screen slower regardless of when they reach the judge circle
# Bars can be sorted like this because they don't need hit detection
return deque(play_note_list), deque(draw_note_list), deque(bar_list)
return master_notes, branch_m, branch_e, branch_n
def hash_note_data(self, play_notes: deque[Note | Drumroll | Balloon], bars: deque[Note]):
def hash_note_data(self, notes: NoteList):
n = hashlib.sha256()
list1 = list(play_notes)
list2 = list(bars)
list1 = notes.play_notes
list2 = notes.bars
merged: list[Note | Drumroll | Balloon] = []
i = 0
j = 0
@@ -644,46 +727,47 @@ class TJAParser:
return n.hexdigest()
def modifier_speed(notes: deque[Note | Balloon | Drumroll], bars, value: float):
notes = notes.copy()
for note in notes:
def modifier_speed(notes: NoteList, value: float):
modded_notes = notes.draw_notes.copy()
modded_bars = notes.bars.copy()
for note in modded_notes:
note.pixels_per_frame_x *= value
note.load_ms = note.hit_ms - (866 / get_pixels_per_ms(note.pixels_per_frame_x))
for bar in bars:
for bar in modded_bars:
bar.pixels_per_frame_x *= value
bar.load_ms = bar.hit_ms - (866 / get_pixels_per_ms(bar.pixels_per_frame_x))
return notes, bars
return modded_notes, modded_bars
def modifier_display(notes: deque[Note | Balloon | Drumroll]):
notes = notes.copy()
for note in notes:
def modifier_display(notes: NoteList):
modded_notes = notes.draw_notes.copy()
for note in modded_notes:
note.display = False
return notes
return modded_notes
def modifier_inverse(notes: deque[Note | Balloon | Drumroll]):
notes = notes.copy()
def modifier_inverse(notes: NoteList):
modded_notes = notes.play_notes.copy()
type_mapping = {1: 2, 2: 1, 3: 4, 4: 3}
for note in notes:
for note in modded_notes:
if note.type in type_mapping:
note.type = type_mapping[note.type]
return notes
return modded_notes
def modifier_random(notes: deque[Note | Balloon | Drumroll], value: int):
def modifier_random(notes: NoteList, value: int):
#value: 1 == kimagure, 2 == detarame
notes = notes.copy()
percentage = int(len(notes) / 5) * value
selected_notes = random.sample(range(len(notes)), percentage)
modded_notes = notes.play_notes.copy()
percentage = int(len(modded_notes) / 5) * value
selected_notes = random.sample(range(len(modded_notes)), percentage)
type_mapping = {1: 2, 2: 1, 3: 4, 4: 3}
for i in selected_notes:
if notes[i].type in type_mapping:
notes[i].type = type_mapping[notes[i].type]
return notes
if modded_notes[i].type in type_mapping:
modded_notes[i].type = type_mapping[modded_notes[i].type]
return modded_notes
def apply_modifiers(notes: deque[Note | Balloon | Drumroll], draw_notes: deque[Note | Balloon | Drumroll], bars: deque[Note]):
def apply_modifiers(notes: NoteList):
if global_data.modifiers.display:
draw_notes = modifier_display(draw_notes)
draw_notes = modifier_display(notes)
if global_data.modifiers.inverse:
notes = modifier_inverse(notes)
notes = modifier_random(notes, global_data.modifiers.random)
draw_notes, bars = modifier_speed(draw_notes, bars, global_data.modifiers.speed)
return notes, draw_notes, bars
play_notes = modifier_inverse(notes)
play_notes = modifier_random(notes, global_data.modifiers.random)
draw_notes, bars = modifier_speed(notes, global_data.modifiers.speed)
return deque(play_notes), deque(draw_notes), deque(bars)