branch fixes

This commit is contained in:
Anthony Samms
2025-11-12 19:12:25 -05:00
parent f15e700c3a
commit 9fe356cde7
2 changed files with 50 additions and 20 deletions

View File

@@ -7,6 +7,7 @@ from collections import deque
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Optional
from libs.global_data import Modifiers from libs.global_data import Modifiers
from libs.utils import get_pixels_per_frame, strip_comments from libs.utils import get_pixels_per_frame, strip_comments
@@ -358,7 +359,7 @@ class TJAParser:
self.metadata.bpm = float(data) self.metadata.bpm = float(data)
elif item.startswith('WAVE'): elif item.startswith('WAVE'):
data = item.split(':')[1] data = item.split(':')[1]
if not data: if not Path(self.file_path.parent / data.strip()).exists():
logger.warning(f"Invalid WAVE value: {data} in TJA file {self.file_path}") logger.warning(f"Invalid WAVE value: {data} in TJA file {self.file_path}")
self.metadata.wave = Path() self.metadata.wave = Path()
else: else:
@@ -627,11 +628,16 @@ class TJAParser:
branch_balloon_count = 0 branch_balloon_count = 0
is_branching = False is_branching = False
prev_note = None prev_note = None
is_section_start = False
section_bar = None
for bar in notes: for bar in notes:
#Length of the bar is determined by number of notes excluding commands #Length of the bar is determined by number of notes excluding commands
bar_length = sum(len(part) for part in bar if '#' not in part) bar_length = sum(len(part) for part in bar if '#' not in part)
barline_added = False barline_added = False
for part in bar: for part in bar:
if part.startswith('#SECTION'):
is_section_start = True
continue
if part.startswith('#BRANCHSTART'): if part.startswith('#BRANCHSTART'):
start_branch_ms = self.current_ms start_branch_ms = self.current_ms
start_branch_bpm = bpm start_branch_bpm = bpm
@@ -667,25 +673,41 @@ class TJAParser:
if branch_n and len(branch_n) > 0: if branch_n and len(branch_n) > 0:
set_drumroll_branch_params(branch_n[-1].play_notes, branch_n[-1].bars) set_drumroll_branch_params(branch_n[-1].play_notes, branch_n[-1].bars)
else: else:
if len(curr_bar_list) > 1: def set_branch_params(bar_list: list[Note], branch_params: str, section_bar: Optional[Note]):
curr_bar_list[-2].branch_params = branch_params if bar_list and len(bar_list) > 1:
elif len(curr_bar_list) > 0: section_index = -2
curr_bar_list[-1].branch_params = branch_params if section_bar and section_bar.hit_ms < self.current_ms:
if section_bar in bar_list:
section_index = bar_list.index(section_bar)
bar_list[section_index].branch_params = branch_params
elif bar_list:
section_index = -1
bar_list[section_index].branch_params = branch_params
elif bar_list == []:
bar_line = Note()
bar_line.pixels_per_frame_x = get_pixels_per_frame(bpm * time_signature * x_scroll_modifier, time_signature*4, self.distance)
bar_line.pixels_per_frame_y = get_pixels_per_frame(bpm * time_signature * y_scroll_modifier, time_signature*4, self.distance)
pixels_per_ms = get_pixels_per_ms(bar_line.pixels_per_frame_x)
if branch_m and len(branch_m[-1].bars) > 1: bar_line.hit_ms = self.current_ms
branch_m[-1].bars[-2].branch_params = branch_params if pixels_per_ms == 0:
elif branch_m and len(branch_m[-1].bars) > 0: bar_line.load_ms = bar_line.hit_ms
branch_m[-1].bars[-1].branch_params = branch_params else:
if branch_e and len(branch_e[-1].bars) > 1: bar_line.load_ms = bar_line.hit_ms - (self.distance / pixels_per_ms)
branch_e[-1].bars[-2].branch_params = branch_params bar_line.type = 0
elif branch_e and len(branch_e[-1].bars) > 0: bar_line.display = False
branch_e[-1].bars[-1].branch_params = branch_params bar_line.gogo_time = gogo_time
if branch_n and len(branch_n[-1].bars) > 1: bar_line.bpm = bpm
branch_n[-1].bars[-2].branch_params = branch_params bar_line.branch_params = branch_params
elif branch_n and len(branch_n[-1].bars) > 0: bar_list.append(bar_line)
branch_n[-1].bars[-1].branch_params = branch_params
if branch_m and len(branch_m[-1].bars) > 0: for bars in [curr_bar_list,
branch_m[-1].bars[-1].branch_params = branch_params branch_m[-1].bars if branch_m else None,
branch_e[-1].bars if branch_e else None,
branch_n[-1].bars if branch_n else None]:
set_branch_params(bars, branch_params, section_bar)
if section_bar:
section_bar = None
continue continue
elif part.startswith('#BRANCHEND'): elif part.startswith('#BRANCHEND'):
curr_note_list = master_notes.play_notes curr_note_list = master_notes.play_notes
@@ -751,6 +773,7 @@ class TJAParser:
scroll_value = part[7:] scroll_value = part[7:]
if 'i' in scroll_value: if 'i' in scroll_value:
normalized = scroll_value.replace('.i', 'j').replace('i', 'j') normalized = scroll_value.replace('.i', 'j').replace('i', 'j')
normalized = normalized.replace(',', '')
c = complex(normalized) c = complex(normalized)
x_scroll_modifier = c.real x_scroll_modifier = c.real
y_scroll_modifier = c.imag y_scroll_modifier = c.imag
@@ -775,6 +798,7 @@ class TJAParser:
continue continue
#Unrecognized commands will be skipped for now #Unrecognized commands will be skipped for now
elif len(part) > 0 and not part[0].isdigit(): elif len(part) > 0 and not part[0].isdigit():
logger.warning(f"Unrecognized command: {part}")
continue continue
ms_per_measure = get_ms_per_measure(bpm, time_signature) ms_per_measure = get_ms_per_measure(bpm, time_signature)
@@ -803,6 +827,10 @@ class TJAParser:
bar_line.is_branch_start = True bar_line.is_branch_start = True
is_branching = False is_branching = False
if is_section_start:
section_bar = bar_line
is_section_start = False
bisect.insort(curr_bar_list, bar_line, key=lambda x: x.load_ms) bisect.insort(curr_bar_list, bar_line, key=lambda x: x.load_ms)
barline_added = True barline_added = True

View File

@@ -477,6 +477,7 @@ class Player:
delattr(self.current_bars[-1], 'branch_params') delattr(self.current_bars[-1], 'branch_params')
e_req = float(e_req) e_req = float(e_req)
m_req = float(m_req) m_req = float(m_req)
logger.info(f'branch condition measures started with conditions {self.branch_condition}, {e_req}, {m_req}, {self.current_bars[-1].hit_ms}')
if not self.is_branch: if not self.is_branch:
self.is_branch = True self.is_branch = True
if self.branch_condition == 'r': if self.branch_condition == 'r':
@@ -874,7 +875,7 @@ class Player:
if current_ms >= end_time: if current_ms >= end_time:
self.is_branch = False self.is_branch = False
if self.branch_condition == 'p': if self.branch_condition == 'p':
self.branch_condition_count = min((self.branch_condition_count/total_notes)*100, 100) self.branch_condition_count = max(min((self.branch_condition_count/total_notes)*100, 100), 0)
if self.branch_condition_count >= e_req and self.branch_condition_count < m_req: if self.branch_condition_count >= e_req and self.branch_condition_count < m_req:
self.merge_branch_section(self.branch_e.pop(0), current_ms) self.merge_branch_section(self.branch_e.pop(0), current_ms)
if self.branch_indicator is not None and self.branch_indicator.difficulty != 'expert': if self.branch_indicator is not None and self.branch_indicator.difficulty != 'expert':
@@ -896,6 +897,7 @@ class Player:
self.branch_indicator.level_down('normal') self.branch_indicator.level_down('normal')
self.branch_m.pop(0) self.branch_m.pop(0)
self.branch_e.pop(0) self.branch_e.pop(0)
logger.info(f"Branch set to {self.branch_indicator.difficulty} based on conditions {self.branch_condition_count}, {e_req, m_req}")
self.branch_condition_count = 0 self.branch_condition_count = 0
def update(self, ms_from_start: float, current_time: float, background: Optional[Background]): def update(self, ms_from_start: float, current_time: float, background: Optional[Background]):