Skip to content

Commit 0c18b7c

Browse files
committed
fused moe layer, edit grok
1 parent 580fd6b commit 0c18b7c

30 files changed

+4061
-256
lines changed

analyze_server_log.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""Analyze server logs to extract token count information.
2+
3+
Usage:
4+
# Analyze from stdin
5+
cat server.log | python analyze_server_log.py
6+
7+
# Or analyze from file
8+
python analyze_server_log.py server.log
9+
"""
10+
11+
import re
12+
import sys
13+
from collections import defaultdict
14+
15+
16+
def parse_prefill_line(line):
17+
"""Parse prefill batch log line."""
18+
# Example: Prefill batch. #new-seq: 8, #new-token: 65536, #cached-token: 0, token usage: 0.21, #running-req: 8, #queue-req: 8,
19+
match = re.search(
20+
r"#new-seq:\s*(\d+).*#new-token:\s*(\d+).*#cached-token:\s*(\d+).*token usage:\s*([\d.]+).*#running-req:\s*(\d+).*#queue-req:\s*(\d+)",
21+
line,
22+
)
23+
if match:
24+
return {
25+
"type": "prefill",
26+
"new_seq": int(match.group(1)),
27+
"new_token": int(match.group(2)),
28+
"cached_token": int(match.group(3)),
29+
"token_usage": float(match.group(4)),
30+
"running_req": int(match.group(5)),
31+
"queue_req": int(match.group(6)),
32+
}
33+
return None
34+
35+
36+
def parse_decode_line(line):
37+
"""Parse decode batch log line."""
38+
# Example: Decode batch. #running-req: 8, #token: 286848, token usage: 0.96, gen throughput (token/s): 232.27, #queue-req: 0,
39+
match = re.search(
40+
r"#running-req:\s*(\d+).*#token:\s*(\d+).*token usage:\s*([\d.]+).*gen throughput.*:\s*([\d.]+).*#queue-req:\s*(\d+)",
41+
line,
42+
)
43+
if match:
44+
return {
45+
"type": "decode",
46+
"running_req": int(match.group(1)),
47+
"token": int(match.group(2)),
48+
"token_usage": float(match.group(3)),
49+
"gen_throughput": float(match.group(4)),
50+
"queue_req": int(match.group(5)),
51+
}
52+
return None
53+
54+
55+
def main():
56+
if len(sys.argv) > 1:
57+
# Read from file
58+
with open(sys.argv[1], "r") as f:
59+
lines = f.readlines()
60+
else:
61+
# Read from stdin
62+
lines = sys.stdin.readlines()
63+
64+
print("=" * 80)
65+
print("SERVER LOG ANALYSIS")
66+
print("=" * 80)
67+
68+
prefill_logs = []
69+
decode_logs = []
70+
71+
for line in lines:
72+
if "Prefill batch" in line:
73+
data = parse_prefill_line(line)
74+
if data:
75+
prefill_logs.append(data)
76+
elif "Decode batch" in line:
77+
data = parse_decode_line(line)
78+
if data:
79+
decode_logs.append(data)
80+
81+
# Analyze prefill logs
82+
if prefill_logs:
83+
print(f"\nPREFILL BATCHES: {len(prefill_logs)} found")
84+
print("-" * 80)
85+
print(
86+
f"{'#':>4} {'NewSeq':>8} {'NewTok':>10} {'CacheTok':>10} {'Usage':>8} {'RunReq':>8} {'QueueReq':>10}"
87+
)
88+
print("-" * 80)
89+
90+
total_new_tokens = 0
91+
for i, log in enumerate(prefill_logs[:20]): # Show first 20
92+
print(
93+
f"{i+1:4d} {log['new_seq']:8d} {log['new_token']:10d} {log['cached_token']:10d} "
94+
f"{log['token_usage']:8.2f} {log['running_req']:8d} {log['queue_req']:10d}"
95+
)
96+
total_new_tokens += log["new_token"]
97+
98+
if len(prefill_logs) > 20:
99+
print(f"... and {len(prefill_logs) - 20} more")
100+
101+
print(f"\nPrefill Statistics:")
102+
print(f" Total new tokens across all prefills: {total_new_tokens:,}")
103+
print(f" Average tokens per prefill: {total_new_tokens / len(prefill_logs):,.0f}")
104+
if prefill_logs:
105+
print(f" First prefill new tokens: {prefill_logs[0]['new_token']:,}")
106+
107+
# Analyze decode logs
108+
if decode_logs:
109+
print(f"\n{'=' * 80}")
110+
print(f"DECODE BATCHES: {len(decode_logs)} found")
111+
print("-" * 80)
112+
print(
113+
f"{'#':>4} {'RunReq':>8} {'#Token':>12} {'Usage':>8} {'Throughput':>12} {'QueueReq':>10}"
114+
)
115+
print("-" * 80)
116+
117+
for i, log in enumerate(decode_logs[:20]): # Show first 20
118+
print(
119+
f"{i+1:4d} {log['running_req']:8d} {log['token']:12,d} {log['token_usage']:8.2f} "
120+
f"{log['gen_throughput']:12.2f} {log['queue_req']:10d}"
121+
)
122+
123+
if len(decode_logs) > 20:
124+
print(f"... and {len(decode_logs) - 20} more")
125+
126+
print(f"\nDecode Statistics:")
127+
tokens_by_running_req = defaultdict(list)
128+
for log in decode_logs:
129+
tokens_by_running_req[log["running_req"]].append(log["token"])
130+
131+
for running_req in sorted(tokens_by_running_req.keys()):
132+
tokens = tokens_by_running_req[running_req]
133+
avg_token = sum(tokens) / len(tokens)
134+
min_token = min(tokens)
135+
max_token = max(tokens)
136+
print(
137+
f" Running {running_req:2d} requests: avg={avg_token:10.0f}, min={min_token:10,d}, max={max_token:10,d} tokens ({len(tokens)} samples)"
138+
)
139+
if running_req > 0:
140+
avg_per_req = avg_token / running_req
141+
print(f" → Avg per request: {avg_per_req:,.0f} tokens")
142+
143+
# Key findings
144+
print(f"\n{'=' * 80}")
145+
print("KEY FINDINGS")
146+
print("=" * 80)
147+
148+
if prefill_logs and decode_logs:
149+
first_prefill_tokens = prefill_logs[0]["new_token"]
150+
first_decode_tokens = decode_logs[0]["token"] if decode_logs else 0
151+
152+
print(f"\n1. First prefill batch:")
153+
print(f" New tokens: {first_prefill_tokens:,}")
154+
print(f" New sequences: {prefill_logs[0]['new_seq']}")
155+
print(f" Average per sequence: {first_prefill_tokens / prefill_logs[0]['new_seq']:,.0f}")
156+
157+
print(f"\n2. Decode batch token count:")
158+
print(f" Token count: {first_decode_tokens:,}")
159+
print(f" Running requests: {decode_logs[0]['running_req']}")
160+
if decode_logs[0]["running_req"] > 0:
161+
print(
162+
f" Average per request: {first_decode_tokens / decode_logs[0]['running_req']:,.0f}"
163+
)
164+
165+
print(
166+
f"\n3. Ratio (decode/prefill): {first_decode_tokens / first_prefill_tokens if first_prefill_tokens > 0 else 0:.2f}x"
167+
)
168+
169+
print("\n" + "=" * 80)
170+
171+
172+
if __name__ == "__main__":
173+
main()

check_actual_tokens.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Check actual token count in benchmark requests."""
2+
3+
import sys
4+
5+
sys.path.insert(0, "/Users/ramezes/job/sgl-project/sgl-jax/python")
6+
7+
import numpy as np
8+
9+
from sgl_jax.srt.hf_transformers_utils import get_tokenizer
10+
11+
# Initialize tokenizer
12+
tokenizer_path = "/models/xai-grok-2/tokenizer.tok.json"
13+
tokenizer = get_tokenizer(tokenizer_path)
14+
15+
print("=" * 60)
16+
print("Analyzing Actual Token Counts")
17+
print("=" * 60)
18+
19+
# Benchmark parameters
20+
num_prompts = 16
21+
random_input_len = 8192
22+
random_output_len = 1024
23+
random_range_ratio = 1.0
24+
25+
# Generate random input lengths (same logic as bench_serving.py)
26+
input_lens = np.random.randint(
27+
max(int(random_input_len * random_range_ratio), 1),
28+
random_input_len + 1,
29+
size=num_prompts,
30+
)
31+
output_lens = np.random.randint(
32+
int(random_output_len * random_range_ratio),
33+
random_output_len + 1,
34+
size=num_prompts,
35+
)
36+
37+
print(f"\nGenerated input lengths (first 8):")
38+
for i in range(min(8, len(input_lens))):
39+
print(
40+
f" Request {i}: input={input_lens[i]}, output={output_lens[i]}, total={input_lens[i] + output_lens[i]}"
41+
)
42+
43+
total_input_tokens = sum(input_lens)
44+
total_output_tokens = sum(output_lens)
45+
total_tokens = total_input_tokens + total_output_tokens
46+
47+
print(f"\nTotal across all {num_prompts} requests:")
48+
print(f" Total input tokens: {total_input_tokens:,}")
49+
print(f" Total output tokens: {total_output_tokens:,}")
50+
print(f" Grand total: {total_tokens:,}")
51+
52+
# With page_size=128 alignment
53+
page_size = 128
54+
print(f"\n{'=' * 60}")
55+
print(f"With PAGE_SIZE={page_size} alignment:")
56+
print(f"{'=' * 60}")
57+
58+
aligned_tokens_per_req = []
59+
for i in range(min(8, len(input_lens))):
60+
req_total = input_lens[i] + output_lens[i]
61+
pages_needed = (req_total + page_size - 1) // page_size
62+
aligned_tokens = pages_needed * page_size
63+
aligned_tokens_per_req.append(aligned_tokens)
64+
print(f" Request {i}:")
65+
print(f" Actual tokens: {req_total}")
66+
print(f" Pages needed: {pages_needed}")
67+
print(f" Aligned tokens: {aligned_tokens} (+{aligned_tokens - req_total} overhead)")
68+
69+
# For 8 concurrent requests
70+
print(f"\n{'=' * 60}")
71+
print(f"Concurrent execution (8 requests):")
72+
print(f"{'=' * 60}")
73+
concurrent_aligned = sum(aligned_tokens_per_req)
74+
print(f" Total aligned tokens for 8 concurrent requests: {concurrent_aligned:,}")
75+
76+
# Server reported value
77+
server_reported = 286848
78+
print(f"\nServer reported: {server_reported:,} tokens")
79+
print(f"Ratio: {server_reported / concurrent_aligned:.2f}x")
80+
print(f"Difference: {server_reported - concurrent_aligned:,} tokens")
81+
82+
# Check if this might be context_len related
83+
avg_per_req = server_reported / 8
84+
print(f"\nAverage tokens per request (server): {avg_per_req:,.0f}")
85+
print(f"This suggests each request might be using ~{avg_per_req:,.0f} tokens")
86+
87+
# Possible page-aligned context_len
88+
possible_context_lens = [32768, 36864, 40960, 49152]
89+
print(f"\nChecking common page-aligned context lengths:")
90+
for ctx_len in possible_context_lens:
91+
pages = ctx_len // page_size
92+
actual = pages * page_size
93+
total_8_req = actual * 8
94+
print(
95+
f" context_len={ctx_len:,} -> {pages} pages -> {actual:,} tokens/req -> {total_8_req:,} total"
96+
)
97+
if abs(total_8_req - server_reported) < 10000:
98+
print(f" ^^^ MATCH! This might be it!")
99+
100+
print("\n" + "=" * 60)
101+
print("Recommendation:")
102+
print("=" * 60)
103+
print("1. Check if requests are pre-allocating a fixed context length")
104+
print("2. Verify the actual input token count in server logs")
105+
print("3. Look for prefill logs showing '#new-token' count")

debug_token_count.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Debug script to understand token counting in sgl-jax server."""
2+
3+
# Based on the benchmark parameters:
4+
# --num-prompts 16
5+
# --random-input 8192
6+
# --random-output 1024
7+
# --max-concurrency 8
8+
9+
# Expected token count calculation:
10+
num_prompts = 16
11+
input_len = 8192
12+
output_len = 1024
13+
max_concurrency = 8
14+
15+
print("=" * 60)
16+
print("Expected Token Count Analysis")
17+
print("=" * 60)
18+
19+
# Calculation 1: If counting only currently running requests
20+
running_reqs = 8 # From log: #running-req: 8
21+
tokens_per_req_input = input_len
22+
tokens_per_req_total = input_len + output_len # Assuming all generated
23+
24+
expected_tokens_input_only = running_reqs * tokens_per_req_input
25+
expected_tokens_full = running_reqs * tokens_per_req_total
26+
27+
print(f"\nScenario 1: Only counting input tokens")
28+
print(f" Running requests: {running_reqs}")
29+
print(f" Input tokens per request: {tokens_per_req_input}")
30+
print(f" Expected total: {expected_tokens_input_only:,} tokens")
31+
32+
print(f"\nScenario 2: Counting input + output tokens (if all generated)")
33+
print(f" Running requests: {running_reqs}")
34+
print(f" Total tokens per request: {tokens_per_req_total}")
35+
print(f" Expected total: {expected_tokens_full:,} tokens")
36+
37+
# Actual server report
38+
actual_tokens = 286848
39+
40+
print(f"\nActual server report: {actual_tokens:,} tokens")
41+
print(f"\nRatio analysis:")
42+
print(f" Actual / Expected (input only): {actual_tokens / expected_tokens_input_only:.2f}x")
43+
print(f" Actual / Expected (input + output): {actual_tokens / expected_tokens_full:.2f}x")
44+
45+
# Average tokens per request based on actual count
46+
avg_tokens_per_req = actual_tokens / running_reqs
47+
print(f"\nAverage tokens per running request: {avg_tokens_per_req:,.0f}")
48+
49+
# Check if this matches context_len or some other value
50+
print(f"\nPossible explanations:")
51+
print(f"1. If each request pre-allocated max_context_len space:")
52+
print(f" - Implied context_len: ~{avg_tokens_per_req:,.0f} tokens per request")
53+
54+
print(f"\n2. If using paged allocation with large page_size:")
55+
print(f" - Each request might be allocated in page-sized chunks")
56+
57+
print(f"\n3. If there's a multiplier in the token counting:")
58+
print(f" - Check if token count includes multiple layers or other factors")
59+
60+
# Grok-2 model info (from HuggingFace)
61+
grok_num_layers = 64
62+
grok_num_kv_heads = 8
63+
grok_head_dim = 128
64+
65+
print(f"\nGrok-2 Model Configuration:")
66+
print(f" num_hidden_layers: {grok_num_layers}")
67+
print(f" num_key_value_heads: {grok_num_kv_heads}")
68+
print(f" head_dim: {grok_head_dim}")
69+
70+
print("\n" + "=" * 60)
71+
print("Recommendation:")
72+
print("=" * 60)
73+
print("Check the following in your server logs/code:")
74+
print("1. What is the max_context_len for each request?")
75+
print("2. What is the page_size setting?")
76+
print("3. Are requests pre-allocating their maximum length?")
77+
print("4. Check scheduler.py:_get_token_info() for the actual calculation")

0 commit comments

Comments
 (0)