Spaces:
Running
Running
| import random | |
| import gradio as gr | |
| import numpy as np | |
| import rtmidi | |
| import MIDI | |
| import base64 | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| from midi_synthesizer import MidiSynthesizer | |
| MAX_SEED = np.iinfo(np.int32).max | |
| class MIDIManager: | |
| def __init__(self): | |
| self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
| self.loaded_midi = {} # Store uploaded MIDI files | |
| self.modified_files = [] # Track generated files | |
| self.is_playing = False | |
| self.midi_in = rtmidi.MidiIn() | |
| self.midi_in.open_port(0) if self.midi_in.get_ports() else None | |
| self.midi_in.set_callback(self.midi_callback) | |
| self.live_notes = [] | |
| def midi_callback(self, event, data=None): | |
| message, _ = event | |
| if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On | |
| note, velocity = message[1], message[2] | |
| if velocity > 0: | |
| self.live_notes.append((note, velocity, 0)) # Time placeholder | |
| def load_midi(self, file_path): | |
| midi = MIDI.load(file_path) | |
| midi_id = f"midi_{len(self.loaded_midi)}" | |
| self.loaded_midi[midi_id] = midi | |
| return midi_id | |
| def extract_notes(self, midi): | |
| notes = [] | |
| for track in midi.tracks: | |
| for event in track.events: | |
| if event.type == 'note_on' and event.velocity > 0: | |
| notes.append((event.note, event.velocity, event.time)) | |
| return notes | |
| def generate_variation(self, midi_id, length_factor=2, variation=0.3): | |
| if midi_id not in self.loaded_midi: | |
| return None | |
| notes = self.extract_notes(self.loaded_midi[midi_id]) | |
| new_notes = [] | |
| for _ in range(int(length_factor)): | |
| for note, vel, time in notes: | |
| if random.random() < variation: | |
| new_note = min(127, max(0, note + random.randint(-2, 2))) | |
| new_vel = min(127, max(0, vel + random.randint(-10, 10))) | |
| new_notes.append((new_note, new_vel, time)) | |
| else: | |
| new_notes.append((note, vel, time)) | |
| new_midi = MIDI.MIDIFile(1) | |
| new_midi.addTrack() | |
| for note, vel, time in new_notes: | |
| new_midi.addNote(0, 0, note, time, 100, vel) | |
| output = io.BytesIO() | |
| new_midi.writeFile(output) | |
| midi_data = base64.b64encode(output.getvalue()).decode('utf-8') | |
| self.modified_files.append(midi_data) | |
| return midi_data | |
| def apply_synth_effect(self, midi_data, effect, intensity): | |
| midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
| if effect == "tempo": | |
| factor = 1 + (intensity - 0.5) * 0.4 | |
| for track in midi.tracks: | |
| for event in track.events: | |
| event.time = int(event.time * factor) | |
| output = io.BytesIO() | |
| midi.writeFile(output) | |
| midi_data = base64.b64encode(output.getvalue()).decode('utf-8') | |
| self.modified_files.append(midi_data) | |
| return midi_data | |
| def play_with_loop(self, midi_data): | |
| self.is_playing = True | |
| midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
| while self.is_playing: | |
| self.synthesizer.play_midi(midi_file) | |
| return "Stopped" | |
| def stop_playback(self): | |
| self.is_playing = False | |
| return "Stopping..." | |
| def save_live_midi(self): | |
| if not self.live_notes: | |
| return None | |
| midi = MIDI.MIDIFile(1) | |
| midi.addTrack() | |
| time_cum = 0 | |
| for note, vel, _ in self.live_notes: | |
| midi.addNote(0, 0, note, time_cum, 100, vel) | |
| time_cum += 100 # Simple timing | |
| output = io.BytesIO() | |
| midi.writeFile(output) | |
| midi_data = base64.b64encode(output.getvalue()).decode('utf-8') | |
| self.modified_files.append(midi_data) | |
| self.live_notes = [] # Reset after saving | |
| return midi_data | |
| midi_manager = MIDIManager() | |
| def create_download_list(): | |
| html = "<h3>Downloads</h3><ul>" | |
| for i, data in enumerate(midi_manager.modified_files): | |
| html += f'<li><a href="data:audio/midi;base64,{data}" download="midi_{i}.mid">MIDI {i}</a></li>' | |
| html += "</ul>" | |
| return html | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>") | |
| with gr.Tabs(): | |
| # Tab 1: Load MIDI Files | |
| with gr.Tab("Load MIDI"): | |
| midi_files = gr.File(label="Upload MIDI Files", file_count="multiple") | |
| midi_list = gr.State({}) | |
| file_display = gr.HTML(value="No files loaded") | |
| def load_files(files): | |
| midi_list_val = {} | |
| html = "<h3>Loaded Files</h3>" | |
| for file in files or []: | |
| midi_id = midi_manager.load_midi(file.name) | |
| midi_list_val[midi_id] = file.name | |
| html += f"<div>{file.name}</div>" | |
| return midi_list_val, html | |
| midi_files.change(load_files, inputs=[midi_files], outputs=[midi_list, file_display]) | |
| # Tab 2: Generate & Perform | |
| with gr.Tab("Generate & Perform"): | |
| midi_select = gr.Dropdown(label="Select MIDI", choices=[]) | |
| length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor") | |
| variation = gr.Slider(0, 1, value=0.3, label="Variation") | |
| generate_btn = gr.Button("Generate") | |
| effect = gr.Radio(["tempo"], label="Effect", value="tempo") | |
| intensity = gr.Slider(0, 1, value=0.5, label="Intensity") | |
| apply_btn = gr.Button("Apply Effect") | |
| play_btn = gr.Button("Play Loop") | |
| stop_btn = gr.Button("Stop") | |
| output = gr.Audio(label="Preview", type="bytes") | |
| status = gr.Textbox(label="Status", value="Ready") | |
| def update_dropdown(midi_list): | |
| return gr.update(choices=list(midi_list.keys())) | |
| midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select]) | |
| def generate(midi_id, length, var): | |
| if not midi_id: | |
| return None, "Select a MIDI file" | |
| midi_data = midi_manager.generate_variation(midi_id, length, var) | |
| return io.BytesIO(base64.b64decode(midi_data)), "Generated" | |
| def apply_effect(midi_data, fx, inten): | |
| if not midi_data: | |
| return None, "Generate a MIDI first" | |
| new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten) | |
| return io.BytesIO(base64.b64decode(new_data)), "Effect Applied" | |
| generate_btn.click(generate, inputs=[midi_select, length_factor, variation], | |
| outputs=[output, status]) | |
| apply_btn.click(apply_effect, inputs=[output, effect, intensity], | |
| outputs=[output, status]) | |
| play_btn.click(midi_manager.play_with_loop, inputs=[output], outputs=[status]) | |
| stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status]) | |
| # Tab 3: MIDI Input | |
| with gr.Tab("MIDI Input"): | |
| gr.Markdown("Play your MIDI keyboard to record notes") | |
| save_btn = gr.Button("Save Live MIDI") | |
| live_output = gr.Audio(label="Live MIDI", type="bytes") | |
| def save_live(): | |
| midi_data = midi_manager.save_live_midi() | |
| return io.BytesIO(base64.b64decode(midi_data)) if midi_data else None | |
| save_btn.click(save_live, inputs=None, outputs=[live_output]) | |
| # Tab 4: Downloads | |
| with gr.Tab("Downloads"): | |
| downloads = gr.HTML(value="No files yet") | |
| def update_downloads(*args): | |
| return create_download_list() | |
| gr.on(triggers=[generate_btn.click, apply_btn.click, save_btn.click], | |
| fn=update_downloads, inputs=None, outputs=[downloads]) | |
| gr.Markdown(""" | |
| <div style='text-align: center; margin-top: 20px;'> | |
| <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br> | |
| <strong>Hugging Face</strong><br> | |
| <a href='https://huggingface.co/models'>Models</a> | | |
| <a href='https://huggingface.co/datasets'>Datasets</a> | | |
| <a href='https://huggingface.co/spaces'>Spaces</a> | | |
| <a href='https://huggingface.co/posts'>Posts</a> | | |
| <a href='https://huggingface.co/docs'>Docs</a> | | |
| <a href='https://huggingface.co/enterprise'>Enterprise</a> | | |
| <a href='https://huggingface.co/pricing'>Pricing</a> | |
| </div> | |
| """) | |
| app.queue().launch(inbrowser=True) |