#!/usr/bin/env python3 import os import sys import re from ffmpeg import FFmpeg from gradio_client import Client API_URL = "https://sanchit-gandhi-whisper-jax.hf.space/" TMP_AUDIO_FILE = "tmp_audio.aac" SUB_EXT = "ja.srt" def parse_time(time_str): """Convert time string in the format mm:ss.sss to hh:mm:ss,sss.""" minutes, seconds = map(float, time_str.split(":")) hours = int(minutes // 60) minutes = int(minutes % 60) seconds, milliseconds = divmod(seconds * 1000, 1000) return f"{hours:02}:{minutes:02}:{int(seconds):02},{int(milliseconds):03}" def time_to_seconds(time_str): """Convert hh:mm:ss,sss or mm:ss,sss to seconds.""" parts = list(map(float, re.split("[:.,]", time_str))) if len(parts) == 3: minutes, seconds, milliseconds = parts hours = 0 elif len(parts) == 4: hours, minutes, seconds, milliseconds = parts else: raise ValueError(f"Invalid time format: {time_str}") return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000 def seconds_to_time(seconds): """Convert seconds to hh:mm:ss,sss.""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds, milliseconds = divmod(seconds % 60 * 1000, 1000) return f"{hours:02}:{minutes:02}:{int(seconds):02},{int(milliseconds):03}" def apply_shift(start, end, shift): """Apply a time shift to the start and end times.""" start_seconds = max(0, time_to_seconds(start) - shift) # end_seconds = max(start_seconds, time_to_seconds(end) - shift) end_seconds = time_to_seconds(end) return seconds_to_time(start_seconds), seconds_to_time(end_seconds) def calculate_shift(text, duration): """Calculate the shift based on the subtitle length and duration.""" # Japanese speech rate: 11.76 characters per second speech_rate = 11.76 text_length = len(text) shift = max(0.5, text_length / speech_rate) return shift def convert_to_srt(input_text): """Convert the custom timestamp format to SRT format with an adjusted shift.""" pattern = re.compile(r"\[(\d+:\d+\.\d+) -> (\d+:\d+\.\d+)\] (.+)") srt_output = [] srt_index = 1 previous_end = 0 for match in pattern.finditer(input_text): start_time, end_time, text = match.groups() original_start = parse_time(start_time) original_end = parse_time(end_time) duration = time_to_seconds(original_end) - time_to_seconds(original_start) shift = calculate_shift(text, duration) adjusted_start, adjusted_end = apply_shift(original_start, original_end, shift) # Ensure current start does not overlap with previous end adjusted_start_seconds = max(previous_end, time_to_seconds(adjusted_start)) adjusted_end_seconds = max( adjusted_start_seconds, time_to_seconds(adjusted_end) ) srt_output.append(f"{srt_index}") srt_output.append( f"{seconds_to_time(adjusted_start_seconds)} --> {seconds_to_time(adjusted_end_seconds)}" ) srt_output.append(text) srt_output.append("") previous_end = adjusted_end_seconds srt_index += 1 return "\n".join(srt_output) client = Client(API_URL) def transcribe_audio(audio_path, task="transcribe", return_timestamps=False): """Function to transcribe an audio file using the Whisper JAX endpoint.""" if task not in ["transcribe", "translate"]: raise ValueError("task should be one of 'transcribe' or 'translate'.") text, runtime = client.predict( audio_path, task, return_timestamps, api_name="/predict_1", ) return text def main(): if len(sys.argv) < 2: print("Usage: Pass path to video file as first argument") return video_path = sys.argv[1] ffmpeg = FFmpeg().input(video_path).option("vn").option("y").output(TMP_AUDIO_FILE) ffmpeg.execute() transcript = transcribe_audio(TMP_AUDIO_FILE, return_timestamps=True) os.remove(TMP_AUDIO_FILE) srt_text = convert_to_srt(transcript) sub_filename = os.path.splitext(video_path)[0] + "." + SUB_EXT with open(sub_filename, "w", encoding="utf-8") as file: file.write(srt_text) print(f"\nSaved {sub_filename}") main()